Skip to content

Commit

Permalink
First pass at compatibility work on socket subsystem.
Browse files Browse the repository at this point in the history
This is based on specs created by @yorickpeterse for the
rubysl/rubysl-socket move to a pure-ruby socket lib.

Among the changes thusfar:

* Many improvements to Addrinfo, including leveraging the JDK APIs
  better and having less state. More address types work correctly
  now.
* Socket now handles more types of sockets, including servers.
* Improvements to addressing behavior across all socket types.
* In-progress refactoring of socket classes to be reusable from
  Socket grab-bag.

At the moment there are around 565 specs tagged.
headius committed Dec 31, 2015
1 parent bd77246 commit 64458de
Showing 16 changed files with 814 additions and 430 deletions.
2 changes: 1 addition & 1 deletion core/pom.rb
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
jar 'com.github.jnr:jnr-netdb:1.1.5', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-enxio:0.10', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-x86asm:1.0.2', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-unixsocket:0.10', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-unixsocket:0.11-SNAPSHOT', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-posix:3.0.23', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-constants:0.9.0', :exclusions => ['com.github.jnr:jnr-ffi']
jar 'com.github.jnr:jnr-ffi:2.0.7'
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
@@ -125,7 +125,7 @@ DO NOT MODIFIY - GENERATED CODE
<dependency>
<groupId>com.github.jnr</groupId>
<artifactId>jnr-unixsocket</artifactId>
<version>0.10</version>
<version>0.11-SNAPSHOT</version>
<exclusions>
<exclusion>
<artifactId>jnr-ffi</artifactId>
436 changes: 303 additions & 133 deletions core/src/main/java/org/jruby/ext/socket/Addrinfo.java

Large diffs are not rendered by default.

157 changes: 83 additions & 74 deletions core/src/main/java/org/jruby/ext/socket/RubyBasicSocket.java
Original file line number Diff line number Diff line change
@@ -35,7 +35,6 @@
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
import java.nio.channels.DatagramChannel;
import java.nio.channels.NotYetConnectedException;
import java.nio.channels.SelectableChannel;

import jnr.constants.platform.Fcntl;
@@ -44,6 +43,7 @@
import jnr.constants.platform.SocketLevel;
import jnr.constants.platform.SocketOption;

import jnr.unixsocket.UnixSocketAddress;
import org.jruby.Ruby;
import org.jruby.RubyBoolean;
import org.jruby.RubyClass;
@@ -57,7 +57,6 @@
import org.jruby.ast.util.ArgsUtil;
import org.jruby.ext.fcntl.FcntlLibrary;
import org.jruby.platform.Platform;
import org.jruby.runtime.Arity;
import org.jruby.runtime.ObjectAllocator;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
@@ -78,18 +77,14 @@
* Implementation of the BasicSocket class from Ruby.
*/
@JRubyClass(name="BasicSocket", parent="IO")
public class RubyBasicSocket extends RubyIO {
public abstract class RubyBasicSocket extends RubyIO {
static void createBasicSocket(Ruby runtime) {
RubyClass rb_cBasicSocket = runtime.defineClass("BasicSocket", runtime.getIO(), BASICSOCKET_ALLOCATOR);
RubyClass rb_cBasicSocket = runtime.defineClass("BasicSocket", runtime.getIO(), ObjectAllocator.NOT_ALLOCATABLE_ALLOCATOR);

rb_cBasicSocket.defineAnnotatedMethods(RubyBasicSocket.class);
}

private static ObjectAllocator BASICSOCKET_ALLOCATOR = new ObjectAllocator() {
public IRubyObject allocate(Ruby runtime, RubyClass klass) {
return new RubyBasicSocket(runtime, klass);
}
};
rb_cBasicSocket.undefineMethod("initialize");
}

public RubyBasicSocket(Ruby runtime, RubyClass type) {
super(runtime, type);
@@ -179,15 +174,15 @@ public IRubyObject recv(ThreadContext context, IRubyObject[] args) {

@Deprecated
public IRubyObject recv(ThreadContext context, IRubyObject length, IRubyObject flags) {
return recv(context, new IRubyObject[] { length, flags });
return recv(context, new IRubyObject[]{length, flags});
}

private IRubyObject recv(ThreadContext context, IRubyObject length,
RubyString str, IRubyObject flags) {
// TODO: implement flags
final ByteBuffer buffer = ByteBuffer.allocate(RubyNumeric.fix2int(length));

ByteList bytes = doReceive(context, buffer);
ByteList bytes = doRead(context, buffer);

if (bytes == null) return context.nil;

@@ -198,11 +193,6 @@ private IRubyObject recv(ThreadContext context, IRubyObject length,
return RubyString.newString(context.runtime, bytes);
}

@JRubyMethod
public IRubyObject recv_nonblock(ThreadContext context, IRubyObject length) {
return recv_nonblock(context, length, context.nil, context.nil, false);
}

@JRubyMethod(required = 1, optional = 3) // (length) required = 1 handled above
public IRubyObject recv_nonblock(ThreadContext context, IRubyObject[] args) {
Ruby runtime = context.runtime;
@@ -223,19 +213,12 @@ public IRubyObject recv_nonblock(ThreadContext context, IRubyObject[] args) {
length = args[1];
}

boolean exception = ArgsUtil.extractKeywordArg(context, "exception", opts) != runtime.getFalse();

return recv_nonblock(context, length, flags, str, exception);
}

protected IRubyObject recv_nonblock(ThreadContext context, IRubyObject length,
IRubyObject flags, IRubyObject str, boolean ex) {
Ruby runtime = context.runtime;
boolean ex = ArgsUtil.extractKeywordArg(context, "exception", opts) != runtime.getFalse();

// TODO: implement flags
final ByteBuffer buffer = ByteBuffer.allocate(RubyNumeric.fix2int(length));

ByteList bytes = doReceiveNonblock(context, buffer);
ByteList bytes = doReadNonBlock(context, buffer);

if (bytes == null) {
if (!ex) return runtime.newSymbol("wait_readable");
@@ -354,18 +337,13 @@ public IRubyObject getsockname(ThreadContext context) {
public IRubyObject getpeername(ThreadContext context) {
Ruby runtime = context.runtime;

try {
SocketAddress sock = getRemoteSocket();
InetSocketAddress sock = getInetRemoteSocket();

if (sock == null) {
throw runtime.newIOError("Not Supported");
}
if (sock != null) return runtime.newString(sock.getHostName());

return runtime.newString(sock.toString());
}
catch (BadDescriptorException e) {
throw runtime.newErrnoEBADFError();
}
UnixSocketAddress unix = getUnixRemoteSocket();

return runtime.newString(unix.path());
}

@JRubyMethod(name = "getpeereid", notImplemented = true)
@@ -375,30 +353,38 @@ public IRubyObject getpeereid(ThreadContext context) {

@JRubyMethod
public IRubyObject local_address(ThreadContext context) {
try {
InetSocketAddress address = getSocketAddress();
Ruby runtime = context.runtime;

if (address == null) return context.nil;
InetSocketAddress address = getInetSocketAddress();

return new Addrinfo(context.runtime, context.runtime.getClass("Addrinfo"), address.getAddress(), address.getPort(), SocketType.forChannel(getChannel()));
}
catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError("address unavailable");
if (address != null) {
SocketType socketType = SocketType.forChannel(getChannel());
return new Addrinfo(runtime, runtime.getClass("Addrinfo"), address, socketType.getSocketType(), socketType);
}

UnixSocketAddress unix = getUnixSocketAddress();

return Addrinfo.unix(context, runtime.getClass("Addrinfo"), runtime.newString(unix.path()));
}

@JRubyMethod
public IRubyObject remote_address(ThreadContext context) {
try {
InetSocketAddress address = getRemoteSocket();
Ruby runtime = context.runtime;

if (address == null) return context.nil;
InetSocketAddress address = getInetRemoteSocket();

return new Addrinfo(context.runtime, context.runtime.getClass("Addrinfo"), address.getAddress(), address.getPort(), SocketType.forChannel(getChannel()));
if (address != null) {
SocketType socketType = SocketType.forChannel(getChannel());
return new Addrinfo(runtime, runtime.getClass("Addrinfo"), address, socketType.getSocketType(), socketType);
}
catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError("address unavailable");

UnixSocketAddress unix = getUnixRemoteSocket();

if (unix != null) {
return Addrinfo.unix(context, runtime.getClass("Addrinfo"), runtime.newString(unix.path()));
}

throw runtime.newErrnoENOTCONNError();
}

@JRubyMethod(optional = 1)
@@ -407,10 +393,8 @@ public IRubyObject shutdown(ThreadContext context, IRubyObject[] args) {

if (args.length > 0) {
String howString = null;
if (args[0] instanceof RubyString) {
howString = ((RubyString) args[0]).asJavaString();
} else if (args[0] instanceof RubySymbol) {
howString = ((RubySymbol) args[0]).asJavaString();
if (args[0] instanceof RubyString || args[0] instanceof RubySymbol) {
howString = args[0].asJavaString();
}

if (howString != null) {
@@ -503,7 +487,7 @@ public IRubyObject readmsg_nonblock(ThreadContext context, IRubyObject[] args) {
throw context.runtime.newNotImplementedError("readmsg_nonblock is not implemented");
}

private ByteList doReceive(ThreadContext context, final ByteBuffer buffer) {
protected ByteList doRead(ThreadContext context, final ByteBuffer buffer) {
Ruby runtime = context.runtime;
OpenFile fptr;

@@ -514,7 +498,6 @@ private ByteList doReceive(ThreadContext context, final ByteBuffer buffer) {
context.getThread().beforeBlockingCall();

int read = openFile.readChannel().read(buffer);

if (read == 0) return null;

return new ByteList(buffer.array(), 0, buffer.position());
@@ -534,7 +517,7 @@ private ByteList doReceive(ThreadContext context, final ByteBuffer buffer) {
}
}

public ByteList doReceiveNonblock(ThreadContext context, final ByteBuffer buffer) {
public ByteList doReadNonBlock(ThreadContext context, final ByteBuffer buffer) {
Ruby runtime = context.runtime;
Channel channel = getChannel();

@@ -551,7 +534,7 @@ public ByteList doReceiveNonblock(ThreadContext context, final ByteBuffer buffer
selectable.configureBlocking(false);

try {
return doReceive(context, buffer);
return doRead(context, buffer);
}
finally {
selectable.configureBlocking(oldBlocking);
@@ -579,35 +562,50 @@ private void joinMulticastGroup(IRubyObject val) throws IOException, BadDescript
}
}

protected InetSocketAddress getSocketAddress() throws BadDescriptorException {
Channel channel = getOpenChannel();
protected InetSocketAddress getInetSocketAddress() {
SocketAddress socketAddress = getSocketAddress();
if (socketAddress instanceof InetSocketAddress) return (InetSocketAddress) socketAddress;
return null;
}

protected InetSocketAddress getInetRemoteSocket() {
SocketAddress socketAddress = getRemoteSocket();
if (socketAddress instanceof InetSocketAddress) return (InetSocketAddress) socketAddress;
return null;
}

return (InetSocketAddress)SocketType.forChannel(channel).getLocalSocketAddress(channel);
protected UnixSocketAddress getUnixSocketAddress() {
SocketAddress socketAddress = getSocketAddress();
if (socketAddress instanceof UnixSocketAddress) return (UnixSocketAddress) socketAddress;
return null;
}

protected InetSocketAddress getRemoteSocket() throws BadDescriptorException {
protected UnixSocketAddress getUnixRemoteSocket() {
SocketAddress socketAddress = getRemoteSocket();
if (socketAddress instanceof UnixSocketAddress) return (UnixSocketAddress) socketAddress;
return null;
}

protected SocketAddress getSocketAddress() {
Channel channel = getOpenChannel();

return (InetSocketAddress)SocketType.forChannel(channel).getRemoteSocketAddress(channel);
return SocketType.forChannel(channel).getLocalSocketAddress(channel);
}

protected Sock getDefaultSocketType() {
return Sock.SOCK_STREAM;
protected SocketAddress getRemoteSocket() {
Channel channel = getOpenChannel();

return SocketType.forChannel(channel).getRemoteSocketAddress(channel);
}

protected IRubyObject getSocknameCommon(ThreadContext context, String caller) {
try {
InetSocketAddress sock = getSocketAddress();
InetSocketAddress sock = getInetSocketAddress();

if (sock == null) {
return Sockaddr.pack_sockaddr_in(context, 0, "0.0.0.0");
}
if (sock != null) return Sockaddr.pack_sockaddr_in(context, sock);

return Sockaddr.pack_sockaddr_in(context, sock);
}
catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError();
}
UnixSocketAddress unixSocketAddress = getUnixSocketAddress();

return Sockaddr.pack_sockaddr_un(context, unixSocketAddress.path());
}

private IRubyObject shutdownInternal(ThreadContext context, int how) throws BadDescriptorException {
@@ -796,4 +794,15 @@ public static IRubyObject set_do_not_reverse_lookup(IRubyObject recv, IRubyObjec

// By default we always reverse lookup unless do_not_reverse_lookup set.
private boolean doNotReverseLookup = false;

protected static class ReceiveTuple {
ReceiveTuple() {}
ReceiveTuple(RubyString result, InetSocketAddress sender) {
this.result = result;
this.sender = sender;
}

RubyString result;
InetSocketAddress sender;
}
}// RubyBasicSocket
99 changes: 34 additions & 65 deletions core/src/main/java/org/jruby/ext/socket/RubyIPSocket.java
Original file line number Diff line number Diff line change
@@ -45,20 +45,14 @@
* @author <a href="mailto:ola.bini@ki.se">Ola Bini</a>
*/
@JRubyClass(name="IPSocket", parent="BasicSocket")
public class RubyIPSocket extends RubyBasicSocket {
public abstract class RubyIPSocket extends RubyBasicSocket {
static void createIPSocket(Ruby runtime) {
RubyClass rb_cIPSocket = runtime.defineClass("IPSocket", runtime.getClass("BasicSocket"), IPSOCKET_ALLOCATOR);
RubyClass rb_cIPSocket = runtime.defineClass("IPSocket", runtime.getClass("BasicSocket"), ObjectAllocator.NOT_ALLOCATABLE_ALLOCATOR);

rb_cIPSocket.defineAnnotatedMethods(RubyIPSocket.class);

runtime.getObject().setConstant("IPsocket",rb_cIPSocket);
}

private static ObjectAllocator IPSOCKET_ALLOCATOR = new ObjectAllocator() {
public IRubyObject allocate(Ruby runtime, RubyClass klass) {
return new RubyIPSocket(runtime, klass);
}
};

public RubyIPSocket(Ruby runtime, RubyClass type) {
super(runtime, type);
@@ -101,36 +95,31 @@ public static IRubyObject getaddress(ThreadContext context, IRubyObject self, IR
public IRubyObject recvfrom(ThreadContext context, IRubyObject _length) {
Ruby runtime = context.runtime;

try {
IRubyObject result = recv(context, _length);
InetSocketAddress sender = getRemoteSocket();

int port;
String hostName;
String hostAddress;

if (sender == null) {
port = 0;
hostName = hostAddress = "0.0.0.0";
} else {
port = sender.getPort();
hostName = sender.getHostName();
hostAddress = sender.getAddress().getHostAddress();
}
IRubyObject result = recv(context, _length);
InetSocketAddress sender = getInetRemoteSocket();

IRubyObject addressArray = context.runtime.newArray(
new IRubyObject[]{
runtime.newString("AF_INET"),
runtime.newFixnum(port),
runtime.newString(hostName),
runtime.newString(hostAddress)
});
int port;
String hostName;
String hostAddress;

return runtime.newArray(result, addressArray);

} catch (BadDescriptorException e) {
throw runtime.newErrnoEBADFError();
if (sender == null) {
port = 0;
hostName = hostAddress = "0.0.0.0";
} else {
port = sender.getPort();
hostName = sender.getHostName();
hostAddress = sender.getAddress().getHostAddress();
}

IRubyObject addressArray = context.runtime.newArray(
new IRubyObject[]{
runtime.newString("AF_INET"),
runtime.newFixnum(port),
runtime.newString(hostName),
runtime.newString(hostAddress)
});

return runtime.newArray(result, addressArray);
}

@JRubyMethod
@@ -146,56 +135,36 @@ public IRubyObject getpeereid(ThreadContext context) {

@Override
protected IRubyObject getSocknameCommon(ThreadContext context, String caller) {
try {
InetSocketAddress sock = getSocketAddress();

return Sockaddr.packSockaddrFromAddress(context, sock);
InetSocketAddress sock = getInetSocketAddress();

} catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError();
}
return Sockaddr.packSockaddrFromAddress(context, sock);
}

@Override
public IRubyObject getpeername(ThreadContext context) {
try {
InetSocketAddress sock = getRemoteSocket();
InetSocketAddress sock = getInetRemoteSocket();

return Sockaddr.packSockaddrFromAddress(context, sock);

} catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError();
}
return Sockaddr.packSockaddrFromAddress(context, sock);
}

private IRubyObject addrCommon(ThreadContext context, boolean reverse) {
try {
InetSocketAddress address = getSocketAddress();
InetSocketAddress address = getInetSocketAddress();

if (address == null) {
throw context.runtime.newErrnoENOTSOCKError("Not socket or not connected");
}

return addrFor(context, address, reverse);

} catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError();
}
}

private IRubyObject peeraddrCommon(ThreadContext context, boolean reverse) {
try {
InetSocketAddress address = getRemoteSocket();
InetSocketAddress address = getInetRemoteSocket();

if (address == null) {
throw context.runtime.newErrnoENOTSOCKError("Not socket or not connected");
}

return addrFor(context, address, reverse);

} catch (BadDescriptorException e) {
throw context.runtime.newErrnoEBADFError();
if (address == null) {
throw context.runtime.newErrnoENOTSOCKError("Not socket or not connected");
}

return addrFor(context, address, reverse);
}

@Deprecated
26 changes: 18 additions & 8 deletions core/src/main/java/org/jruby/ext/socket/RubyServerSocket.java
Original file line number Diff line number Diff line change
@@ -47,6 +47,7 @@
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.SocketException;
import java.net.StandardSocketOptions;
import java.net.UnknownHostException;
import java.nio.channels.Channel;
import java.nio.channels.IllegalBlockingModeException;
@@ -100,6 +101,9 @@ public IRubyObject bind(ThreadContext context, IRubyObject addr) {

if (addr instanceof Addrinfo) {
Addrinfo addrInfo = (Addrinfo) addr;
if (!addrInfo.ip_p(context).isTrue()) {
throw context.runtime.newTypeError("not an INET or INET6 address: " + addrInfo);
}
iaddr = new InetSocketAddress(addrInfo.getInetAddress().getHostAddress(), addrInfo.getPort());
} else {
iaddr = Sockaddr.addressFromSockaddr_in(context, addr);
@@ -115,6 +119,9 @@ public IRubyObject bind(ThreadContext context, IRubyObject addr, IRubyObject bac

if (addr instanceof Addrinfo) {
Addrinfo addrInfo = (Addrinfo) addr;
if (!addrInfo.ip_p(context).isTrue()) {
throw context.runtime.newTypeError("not an INET or INET6 address: " + addrInfo);
}
iaddr = new InetSocketAddress(addrInfo.getInetAddress().getHostAddress(), addrInfo.getPort());
} else {
iaddr = Sockaddr.addressFromSockaddr_in(context, addr);
@@ -126,7 +133,7 @@ public IRubyObject bind(ThreadContext context, IRubyObject addr, IRubyObject bac

@JRubyMethod()
public IRubyObject accept(ThreadContext context) {
return doAccept(context, getChannel(), true);
return doAccept(this, context, true);
}

@JRubyMethod()
@@ -136,7 +143,7 @@ public IRubyObject accept_nonblock(ThreadContext context) {

@JRubyMethod()
public IRubyObject accept_nonblock(ThreadContext context, IRubyObject opts) {
return doAcceptNonblock(context, getChannel(), ArgsUtil.extractKeywordArg(context, "exception", opts) != context.runtime.getFalse());
return doAcceptNonblock(this, context, ArgsUtil.extractKeywordArg(context, "exception", opts) != context.runtime.getFalse());
}

protected ChannelFD initChannelFD(Ruby runtime) {
@@ -157,8 +164,9 @@ protected ChannelFD initChannelFD(Ruby runtime) {
}
}

private IRubyObject doAcceptNonblock(ThreadContext context, Channel channel, boolean ex) {
public static IRubyObject doAcceptNonblock(RubySocket sock, ThreadContext context, boolean ex) {
try {
Channel channel = sock.getChannel();
if (channel instanceof SelectableChannel) {
SelectableChannel selectable = (SelectableChannel)channel;

@@ -168,7 +176,7 @@ private IRubyObject doAcceptNonblock(ThreadContext context, Channel channel, boo
try {
selectable.configureBlocking(false);

IRubyObject socket = doAccept(context, channel, ex);
IRubyObject socket = doAccept(sock, context, ex);
if (!(socket instanceof RubySocket)) return socket;
SocketChannel socketChannel = (SocketChannel)((RubySocket)socket).getChannel();
InetSocketAddress addr = (InetSocketAddress)socketChannel.socket().getRemoteSocketAddress();
@@ -190,12 +198,14 @@ private IRubyObject doAcceptNonblock(ThreadContext context, Channel channel, boo
}
}

private IRubyObject doAccept(ThreadContext context, Channel channel, boolean ex) {
public static IRubyObject doAccept(RubySocket sock, ThreadContext context, boolean ex) {
Ruby runtime = context.runtime;

Channel channel = sock.getChannel();

try {
if (channel instanceof ServerSocketChannel) {
ServerSocketChannel serverChannel = (ServerSocketChannel)getChannel();
ServerSocketChannel serverChannel = (ServerSocketChannel)sock.getChannel();

SocketChannel socket = serverChannel.accept();

@@ -208,9 +218,9 @@ private IRubyObject doAccept(ThreadContext context, Channel channel, boolean ex)
}

RubySocket rubySocket = new RubySocket(runtime, runtime.getClass("Socket"));
rubySocket.initFromServer(runtime, this, socket);
rubySocket.initFromServer(runtime, sock, socket);

return rubySocket;
return runtime.newArray(rubySocket, new Addrinfo(runtime, runtime.getClass("Addrinfo"), socket.getRemoteAddress()));
}
throw runtime.newErrnoENOPROTOOPTError();
}
214 changes: 169 additions & 45 deletions core/src/main/java/org/jruby/ext/socket/RubySocket.java

Large diffs are not rendered by default.

33 changes: 21 additions & 12 deletions core/src/main/java/org/jruby/ext/socket/RubyTCPServer.java
Original file line number Diff line number Diff line change
@@ -82,18 +82,27 @@ public RubyTCPServer(Ruby runtime, RubyClass type) {
@JRubyMethod(name = "initialize", required = 1, optional = 1, visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context, IRubyObject[] args) {
Ruby runtime = context.runtime;
IRubyObject _host = args[0];
IRubyObject _port = args.length > 1 ? args[1] : context.nil;

String host;
if(_host.isNil()|| ((_host instanceof RubyString) && ((RubyString) _host).isEmpty())) {
host = "0.0.0.0";
} else if (_host instanceof RubyFixnum) {
// numeric host, use it for port
_port = _host;
host = "0.0.0.0";
} else {
host = _host.convertToString().toString();
IRubyObject _host;
IRubyObject _port = null;
String host = "0.0.0.0";

switch (args.length) {
case 2:
_host = args[0];
_port = args[1];

if (!_host.isNil()) {
RubyString hostString = _host.convertToString();
if (hostString.size() > 0) host = hostString.toString();
} else if (_host instanceof RubyFixnum) {
throw runtime.newTypeError(_host, runtime.getString());
} else {
host = _host.convertToString().toString();
}

break;
case 1:
_port = args[0];
}

int port = SocketUtils.getPortFrom(context, _port);
5 changes: 5 additions & 0 deletions core/src/main/java/org/jruby/ext/socket/RubyTCPSocket.java
Original file line number Diff line number Diff line change
@@ -75,6 +75,11 @@ public RubyTCPSocket(Ruby runtime, RubyClass type) {
super(runtime, type);
}

@Override
public IRubyObject recv_nonblock(ThreadContext context, IRubyObject[] args) {
return null;
}

@JRubyMethod(required = 2, optional = 2, visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context, IRubyObject[] args) {
Ruby runtime = context.runtime;
122 changes: 67 additions & 55 deletions core/src/main/java/org/jruby/ext/socket/RubyUDPSocket.java
Original file line number Diff line number Diff line change
@@ -32,21 +32,24 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.PortUnreachableException;
import java.net.ProtocolFamily;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.MulticastSocket;
import java.net.StandardProtocolFamily;
import java.net.UnknownHostException;
import java.net.DatagramPacket;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
import java.nio.channels.DatagramChannel;
import java.nio.channels.IllegalBlockingModeException;
import java.nio.channels.NotYetConnectedException;

import jnr.constants.platform.AddressFamily;
import jnr.netdb.Service;
import org.jruby.Ruby;
import org.jruby.RubyClass;
import org.jruby.RubyFixnum;
import org.jruby.RubyInteger;
import org.jruby.RubyModule;
import org.jruby.RubyNumeric;
import org.jruby.RubyString;
@@ -90,31 +93,37 @@ public RubyUDPSocket(Ruby runtime, RubyClass type) {

@JRubyMethod(visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context) {
return initialize(context, StandardProtocolFamily.INET);
}

@JRubyMethod(visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context, IRubyObject _family) {
int family = _family.convertToInteger().getIntValue();
if (family == AddressFamily.AF_INET.intValue()) {
return initialize(context, StandardProtocolFamily.INET);
} else if (family == AddressFamily.AF_INET6.intValue()) {
return initialize(context, StandardProtocolFamily.INET6);
}
throw SocketUtils.sockerr(context.runtime, "invalid family for UDPSocket: " + _family);
}

public IRubyObject initialize(ThreadContext context, ProtocolFamily family) {
Ruby runtime = context.runtime;

try {
DatagramChannel channel = DatagramChannel.open();
DatagramChannel channel = DatagramChannel.open(family);
initSocket(newChannelFD(runtime, channel));
}
catch (ConnectException e) {
} catch (ConnectException e) {
throw runtime.newErrnoECONNREFUSEDError();
}
catch (UnknownHostException e) {
} catch (UnknownHostException e) {
throw SocketUtils.sockerr(runtime, "initialize: name or service not known");
}
catch (IOException e) {
} catch (IOException e) {
throw sockerr(runtime, "initialize: name or service not known", e);
}

return this;
}

@JRubyMethod(visibility = Visibility.PRIVATE)
public IRubyObject initialize(ThreadContext context, IRubyObject protocol) {
// we basically ignore protocol. let someone report it...
return initialize(context);
}

@JRubyMethod
public IRubyObject bind(ThreadContext context, IRubyObject host, IRubyObject _port) {
final Ruby runtime = context.runtime;
@@ -131,8 +140,15 @@ public IRubyObject bind(ThreadContext context, IRubyObject host, IRubyObject _po
}
else if (host instanceof RubyFixnum) {
// passing in something like INADDR_ANY
final int intAddr = RubyNumeric.fix2int(host);
final RubyModule Socket = runtime.getModule("Socket");
int intAddr = 0;
if (host instanceof RubyInteger) {
intAddr = RubyNumeric.fix2int(host);
} else if (host instanceof RubyString) {
intAddr = ((RubyString)host).to_i().convertToInteger().getIntValue();
} else {
throw runtime.newTypeError(host, runtime.getInteger());
}
RubyModule Socket = runtime.getModule("Socket");
if (intAddr == RubyNumeric.fix2int(Socket.getConstant("INADDR_ANY"))) {
addr = new InetSocketAddress(InetAddress.getByName("0.0.0.0"), port);
}
@@ -148,7 +164,7 @@ else if (host instanceof RubyFixnum) {
}

if (multicastStateManager == null) {
((DatagramChannel) channel).socket().bind(addr);
((DatagramChannel) channel).bind(addr);
} else {
multicastStateManager.rebindToPort(port);
}
@@ -204,13 +220,16 @@ public IRubyObject connect(ThreadContext context, IRubyObject host, IRubyObject
}
}

@JRubyMethod
public IRubyObject recvfrom_nonblock(ThreadContext context, IRubyObject length) {
return recv_nonblock(context, length, context.nil, context.nil, false);
private DatagramChannel getDatagramChannel() {
return (DatagramChannel) getChannel();
}

@JRubyMethod(required = 1, optional = 3) // (length) required = 1 handled above
@JRubyMethod(required = 1, optional = 3)
public IRubyObject recvfrom_nonblock(ThreadContext context, IRubyObject[] args) {
return recvfrom_nonblock(this, context, args);
}

public static IRubyObject recvfrom_nonblock(RubyBasicSocket socket, ThreadContext context, IRubyObject[] args) {
Ruby runtime = context.runtime;
int argc = args.length;
IRubyObject opts = ArgsUtil.getOptionsArg(context.runtime, args);
@@ -222,32 +241,32 @@ public IRubyObject recvfrom_nonblock(ThreadContext context, IRubyObject[] args)

switch (argc) {
case 3:
str = args[3];
str = args[2];
case 2:
flags = args[2];
flags = args[1];
case 1:
length = args[1];
length = args[0];
}

boolean exception = ArgsUtil.extractKeywordArg(context, "exception", opts) != runtime.getFalse();

return recvfrom_nonblock(context, length, flags, str, exception);
return recvfrom_nonblock(socket, context, length, flags, str, exception);
}

public IRubyObject recvfrom_nonblock(ThreadContext context, IRubyObject _length, IRubyObject _flags, IRubyObject str, boolean ex) {
public static IRubyObject recvfrom_nonblock(RubyBasicSocket socket, ThreadContext context, IRubyObject _length, IRubyObject _flags, IRubyObject str, boolean ex) {
final Ruby runtime = context.runtime;

try {
int length = RubyNumeric.fix2int(_length);

ReceiveTuple tuple = doReceiveNonblockTuple(runtime, length, ex);
ReceiveTuple tuple = doReceiveNonblockTuple(socket, runtime, length, ex);

// TODO: make this efficient
if (!str.isNil()) {
tuple.result = str.convertToString().replace19(tuple.result);
}

IRubyObject addressArray = addrFor(context, tuple.sender, false);
IRubyObject addressArray = socket.addrFor(context, tuple.sender, false);

return runtime.newArray(tuple.result, addressArray);
}
@@ -393,14 +412,18 @@ public static IRubyObject open(ThreadContext context, IRubyObject recv, IRubyObj
*/
@Override
public IRubyObject recvfrom(ThreadContext context, IRubyObject _length) {
return recvfrom(this, context, _length);
}

public static IRubyObject recvfrom(RubyBasicSocket socket, ThreadContext context, IRubyObject _length) {
final Ruby runtime = context.runtime;

try {
int length = RubyNumeric.fix2int(_length);

ReceiveTuple tuple = doReceiveTuple(runtime, length, true);
ReceiveTuple tuple = doReceiveTuple(socket, runtime, length, true);

IRubyObject addressArray = addrFor(context, tuple.sender, false);
IRubyObject addressArray = socket.addrFor(context, tuple.sender, false);

return runtime.newArray(tuple.result, addressArray);
}
@@ -436,7 +459,7 @@ public IRubyObject recv(ThreadContext context, IRubyObject _length) {
final Ruby runtime = context.runtime;

try {
return doReceive(runtime, RubyNumeric.fix2int(_length));
return doReceive(this, runtime, RubyNumeric.fix2int(_length));
}
catch (IOException e) { // SocketException
throw runtime.newIOErrorFromException(e);
@@ -456,52 +479,41 @@ public IRubyObject recv(ThreadContext context, IRubyObject _length, IRubyObject
return recv(context, _length);
}

private ReceiveTuple doReceiveTuple(Ruby runtime, int length, boolean ex) throws IOException {
private static ReceiveTuple doReceiveTuple(RubyBasicSocket socket, Ruby runtime, int length, boolean ex) throws IOException {
ReceiveTuple tuple = new ReceiveTuple();

if (this.multicastStateManager == null) {
doReceive(runtime, length, ex, tuple);
if (socket.multicastStateManager == null) {
doReceive(socket, runtime, length, ex, tuple);
} else {
doReceiveMulticast(runtime, length, ex, tuple);
doReceiveMulticast(socket, runtime, length, ex, tuple);
}

return tuple;
}

private ReceiveTuple doReceiveNonblockTuple(Ruby runtime, int length, boolean ex) throws IOException {
DatagramChannel channel = (DatagramChannel)getChannel();
private static ReceiveTuple doReceiveNonblockTuple(RubyBasicSocket socket, Ruby runtime, int length, boolean ex) throws IOException {
DatagramChannel channel = (DatagramChannel)socket.getChannel();

synchronized (channel.blockingLock()) {
boolean oldBlocking = channel.isBlocking();

channel.configureBlocking(false);

try {
return doReceiveTuple(runtime, length, ex);
return doReceiveTuple(socket, runtime, length, ex);
}
finally {
channel.configureBlocking(oldBlocking);
}
}
}

private static class ReceiveTuple {
ReceiveTuple() {}
ReceiveTuple(RubyString result, InetSocketAddress sender) {
this.result = result;
this.sender = sender;
}

RubyString result;
InetSocketAddress sender;
}

private IRubyObject doReceive(Ruby runtime, int length) throws IOException {
return doReceive(runtime, length, true, null);
private static IRubyObject doReceive(RubyBasicSocket socket, Ruby runtime, int length) throws IOException {
return doReceive(socket, runtime, length, true, null);
}

private IRubyObject doReceive(Ruby runtime, int length, boolean ex, ReceiveTuple tuple) throws IOException {
DatagramChannel channel = (DatagramChannel)getChannel();
protected static IRubyObject doReceive(RubyBasicSocket socket, Ruby runtime, int length, boolean ex, ReceiveTuple tuple) throws IOException {
DatagramChannel channel = (DatagramChannel)socket.getChannel();

ByteBuffer buf = ByteBuffer.allocate(length);

@@ -528,12 +540,12 @@ private IRubyObject doReceive(Ruby runtime, int length, boolean ex, ReceiveTuple
return result;
}

private IRubyObject doReceiveMulticast(Ruby runtime, int length, boolean ex, ReceiveTuple tuple) throws IOException {
private static IRubyObject doReceiveMulticast(RubyBasicSocket socket, Ruby runtime, int length, boolean ex, ReceiveTuple tuple) throws IOException {
byte[] buf2 = new byte[length];
ByteBuffer recv = ByteBuffer.wrap(buf2);
SocketAddress address;

DatagramChannel channel = this.multicastStateManager.getMulticastSocket().getChannel();
DatagramChannel channel = socket.multicastStateManager.getMulticastSocket().getChannel();

address = channel.receive(recv);

19 changes: 18 additions & 1 deletion core/src/main/java/org/jruby/ext/socket/RubyUNIXServer.java
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@


import jnr.unixsocket.UnixServerSocketChannel;
import jnr.unixsocket.UnixSocketAddress;
import jnr.unixsocket.UnixSocketChannel;
import org.jruby.Ruby;
import org.jruby.RubyClass;
@@ -42,6 +43,7 @@
import org.jruby.runtime.builtin.IRubyObject;

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;

@@ -156,7 +158,8 @@ public IRubyObject listen(ThreadContext context, IRubyObject log) {

@JRubyMethod
public IRubyObject sysaccept(ThreadContext context) {
return accept(context);
RubyUNIXSocket socket = (RubyUNIXSocket) accept(context);
return context.runtime.newFixnum(((UnixSocketChannel) socket.getChannel()).getFD());
}

@JRubyMethod
@@ -178,6 +181,20 @@ public IRubyObject peeraddr(ThreadContext context) {
throw context.runtime.newErrnoENOTCONNError();
}

@Override
protected UnixSocketAddress getUnixSocketAddress() {
SocketAddress socketAddress = ((UnixServerSocketChannel)getChannel()).getLocalSocketAddress();
if (socketAddress instanceof UnixSocketAddress) return (UnixSocketAddress) socketAddress;
return null;
}

@Override
protected UnixSocketAddress getUnixRemoteSocket() {
SocketAddress socketAddress = ((UnixServerSocketChannel)getChannel()).getLocalSocketAddress();
if (socketAddress instanceof UnixSocketAddress) return (UnixSocketAddress) socketAddress;
return null;
}

private UnixServerSocketChannel asUnixServer() {
return (UnixServerSocketChannel)getChannel();
}
2 changes: 2 additions & 0 deletions core/src/main/java/org/jruby/ext/socket/RubyUNIXSocket.java
Original file line number Diff line number Diff line change
@@ -53,12 +53,14 @@
import org.jruby.runtime.Visibility;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ByteList;
import org.jruby.util.io.BadDescriptorException;
import org.jruby.util.io.ModeFlags;
import org.jruby.util.io.OpenFile;
import org.jruby.util.io.FilenoUtil;

import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.Channel;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
10 changes: 10 additions & 0 deletions core/src/main/java/org/jruby/ext/socket/SocketType.java
Original file line number Diff line number Diff line change
@@ -29,8 +29,10 @@
import jnr.constants.platform.Sock;
import jnr.constants.platform.SocketOption;
import jnr.unixsocket.UnixServerSocketChannel;
import jnr.unixsocket.UnixSocketAddress;
import jnr.unixsocket.UnixSocketChannel;

import java.io.File;
import java.io.IOException;
import java.net.DatagramSocket;
import java.net.ServerSocket;
@@ -238,6 +240,14 @@ public void shutdownInput(Channel channel)throws IOException {
public void shutdownOutput(Channel channel) throws IOException {
toSocket(channel).shutdownOutput();
}

public SocketAddress getRemoteSocketAddress(Channel channel) {
return toSocket(channel).getRemoteSocketAddress();
}

public SocketAddress getLocalSocketAddress(Channel channel) {
return new UnixSocketAddress(new File("empty-path"));
}
},

UNKNOWN(Sock.SOCK_STREAM);
60 changes: 25 additions & 35 deletions core/src/main/java/org/jruby/ext/socket/SocketUtils.java
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@

import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.UnknownHostException;
@@ -264,15 +265,15 @@ public void addrinfo(InetAddress address, int port, Sock sock, Boolean reverse)

if (sock_dgram) {
l.add(new Addrinfo(runtime, runtime.getClass("Addrinfo"),
address,
port,
new InetSocketAddress(address, port),
Sock.SOCK_DGRAM,
SocketType.DATAGRAM));
}

if (sock_stream) {
l.add(new Addrinfo(runtime, runtime.getClass("Addrinfo"),
address,
port,
new InetSocketAddress(address, port),
Sock.SOCK_STREAM,
SocketType.SOCKET));
}
}
@@ -575,65 +576,54 @@ private static String getHostAddress(ThreadContext context, InetAddress addr, Bo
private static final String ANY = "<any>";
private static final byte[] INADDR_ANY = new byte[] {0,0,0,0}; // 0.0.0.0

// MRI: address family part of rsock_family_to_int
static AddressFamily addressFamilyFromArg(IRubyObject domain) {
AddressFamily addressFamily = null;

if(domain instanceof RubyString || domain instanceof RubySymbol) {
String domainString = domain.toString();
if (!domainString.startsWith("AF_")) domainString = "AF_" + domainString;
addressFamily = AddressFamily.valueOf(domainString);
} else {
int domainInt = RubyNumeric.fix2int(domain);
addressFamily = AddressFamily.valueOf(domainInt);
if (domainString.startsWith("AF_")) return AddressFamily.valueOf(domainString);
return AddressFamily.valueOf("AF_" + domainString);
}

return addressFamily;
int domainInt = RubyNumeric.fix2int(domain);
return AddressFamily.valueOf(domainInt);
}

static Sock sockFromArg(IRubyObject type) {
Sock sockType = null;

if(type instanceof RubyString || type instanceof RubySymbol) {
String typeString = type.toString();
sockType = Sock.valueOf("SOCK_" + typeString);
} else {
int typeInt = RubyNumeric.fix2int(type);
sockType = Sock.valueOf(typeInt);
if (typeString.startsWith("SOCK_")) return Sock.valueOf(typeString.toString());
return Sock.valueOf("SOCK_" + typeString);
}

return sockType;
int typeInt = RubyNumeric.fix2int(type);
return Sock.valueOf(typeInt);
}

// MRI: protocol family part of rsock_family_to_int
static ProtocolFamily protocolFamilyFromArg(IRubyObject protocol) {
ProtocolFamily protocolFamily = null;

if(protocol instanceof RubyString || protocol instanceof RubySymbol) {
if (protocol instanceof RubyString || protocol instanceof RubySymbol) {
String protocolString = protocol.toString();
protocolFamily = ProtocolFamily.valueOf("PF_" + protocolString);
} else {
int protocolInt = RubyNumeric.fix2int(protocol);
if (protocolInt == 0) return null;
protocolFamily = ProtocolFamily.valueOf(protocolInt);
if (protocolString.startsWith("PF_")) return ProtocolFamily.valueOf(protocolString);
if (protocolString.startsWith("AF_")) return ProtocolFamily.valueOf(AddressFamily.valueOf(protocolString).intValue());
return ProtocolFamily.valueOf("PF_" + protocolString);
}

return protocolFamily;
int protocolInt = RubyNumeric.fix2int(protocol);
return ProtocolFamily.valueOf(protocolInt);
}

static Protocol protocolFromArg(IRubyObject protocol) {
Protocol proto;

if(protocol instanceof RubyString || protocol instanceof RubySymbol) {
String protocolString = protocol.toString();
proto = Protocol.getProtocolByName(protocolString);
} else {
int protocolInt = RubyNumeric.fix2int(protocol);
proto = Protocol.getProtocolByNumber(protocolInt);
return Protocol.getProtocolByName(protocolString);
}

return proto;
int protocolInt = RubyNumeric.fix2int(protocol);
return Protocol.getProtocolByNumber(protocolInt);
}

public static int portToInt(IRubyObject port) {
return port.isNil() ? 0 : RubyNumeric.fix2int(port);
}

}
6 changes: 6 additions & 0 deletions core/src/main/java/org/jruby/runtime/Helpers.java
Original file line number Diff line number Diff line change
@@ -2727,4 +2727,10 @@ public static String encodeParameterList(List<String[]> args) {
return builder.toString();
}

public static byte[] subseq(byte[] ary, int start, int len) {
byte[] newAry = new byte[len];
System.arraycopy(ary, start, newAry, 0, len);
return newAry;
}

}
51 changes: 51 additions & 0 deletions core/src/main/java/org/jruby/util/io/Sockaddr.java
Original file line number Diff line number Diff line change
@@ -16,9 +16,16 @@
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.util.Arrays;

import org.jruby.RubyString;
import org.jruby.ext.socket.SocketUtils;
import org.jruby.runtime.Helpers;
@@ -57,6 +64,9 @@ public static InetSocketAddress addressFromArg(ThreadContext context, IRubyObjec
InetSocketAddress iaddr;
if (arg instanceof Addrinfo) {
Addrinfo addrinfo = (Addrinfo)arg;
if (!addrinfo.ip_p(context).isTrue()) {
throw context.runtime.newTypeError("not an INET or INET6 address: " + addrinfo);
}
iaddr = new InetSocketAddress(addrinfo.getInetAddress(), addrinfo.getPort());
} else {
iaddr = addressFromSockaddr_in(context, arg);
@@ -125,6 +135,18 @@ public static IRubyObject packSockaddrFromAddress(ThreadContext context, InetSoc
}
}

public static IRubyObject pack_sockaddr_un(ThreadContext context, String path) {
ByteArrayOutputStream bufS = new ByteArrayOutputStream();
try {
DataOutputStream ds = new DataOutputStream(bufS);
writeSockaddrHeader(AddressFamily.AF_UNIX, ds);
ds.writeUTF(path);
return context.runtime.newString(new ByteList(bufS.toByteArray(), false));
} catch (IOException e) {
throw sockerr(context.runtime, "pack_sockaddr_in: internal error");
}
}

public static IRubyObject pack_sockaddr_in(ThreadContext context, int port, String host) {
ByteArrayOutputStream bufS = new ByteArrayOutputStream();
try {
@@ -240,4 +262,33 @@ public static AddressFamily getAddressFamilyFromSockaddr(Ruby runtime, ByteList
private static RuntimeException sockerr(Ruby runtime, String msg) {
return new RaiseException(runtime, runtime.getClass("SocketError"), msg, true);
}

public static SocketAddress sockaddrFromBytes(Ruby runtime, byte[] val) throws IOException {
AddressFamily afamily = AddressFamily.valueOf(uint16(val[0], val[1]));

if (afamily == null || afamily == AddressFamily.__UNKNOWN_CONSTANT__) {
throw runtime.newArgumentError("can't resolve socket address of wrong type");
}

int port;
switch (afamily) {
case AF_INET:
port = uint16(val[2], val[3]);
Inet4Address inet4Address = (Inet4Address)InetAddress.getByAddress(Helpers.subseq(val, 4, 4));
return new InetSocketAddress(inet4Address, port);
case AF_INET6:
port = uint16(val[2], val[3]);
Inet6Address inet6Address = (Inet6Address)InetAddress.getByAddress(Helpers.subseq(val, 4, 16));
return new InetSocketAddress(inet6Address, port);
case AF_UNIX:
String path = new String(val, 2, val.length - 2);
return new UnixSocketAddress(new File(path));
default:
throw runtime.newArgumentError("can't resolve socket address of wrong type");
}
}

private static int uint16(byte high, byte low) {
return ((high & 0xFF) << 8) + (low & 0xFF);
}
}

0 comments on commit 64458de

Please sign in to comment.