Skip to content

Commit

Permalink
[refactor] call-site-ize SSLSocket as #initialize gets hit a lot
Browse files Browse the repository at this point in the history
  • Loading branch information
kares committed May 16, 2018
1 parent d61f1dc commit 0e927fa
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 42 deletions.
118 changes: 78 additions & 40 deletions src/main/java/org/jruby/ext/openssl/SSLSocket.java
Expand Up @@ -48,26 +48,15 @@
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;

import org.jruby.Ruby;
import org.jruby.RubyArray;
import org.jruby.RubyClass;
import org.jruby.RubyHash;
import org.jruby.RubyIO;
import org.jruby.RubyModule;
import org.jruby.RubyNumeric;
import org.jruby.RubyObject;
import org.jruby.RubyString;
import org.jruby.RubyThread;
import org.jruby.*;
import org.jruby.anno.JRubyMethod;
import org.jruby.exceptions.RaiseException;
import org.jruby.ext.openssl.x509store.X509Utils;
import org.jruby.runtime.Arity;
import org.jruby.runtime.Block;
import org.jruby.runtime.ObjectAllocator;
import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.*;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.runtime.callsite.FunctionalCachingCallSite;
import org.jruby.runtime.callsite.RespondToCallSite;
import org.jruby.util.ByteList;
import org.jruby.runtime.Visibility;

import static org.jruby.ext.openssl.SSL.newSSLErrorWaitReadable;
import static org.jruby.ext.openssl.SSL.newSSLErrorWaitWritable;
Expand All @@ -86,12 +75,45 @@ public IRubyObject allocate(Ruby runtime, RubyClass klass) {
}
};

private enum CallSiteIndex {

// self
hostname("hostname"),
//sync_close("sync_close"),
//sync_close_w("sync_close="),
// io
_respond_to_nonblock_w("nonblock="),
nonblock_w("nonblock="),
sync("sync"),
sync_w("sync="),
flush("flush"),
close("close"),
closed_p("closed?"),
// ssl_context
verify_mode("verify_mode");

final String method;

CallSiteIndex(String method) { this.method = method; }

}

public static void createSSLSocket(final Ruby runtime, final RubyModule SSL) { // OpenSSL::SSL
CallSiteIndex[] values = CallSiteIndex.values();
CallSite[] extraCallSites = new CallSite[values.length];
for (int i=0; i<values.length; i++) {
if (values[i].name().startsWith("_respond_to")) {
extraCallSites[i] = new RespondToCallSite();
}
else {
extraCallSites[i] = new FunctionalCachingCallSite(values[i].method);
}
}

RubyClass SSLSocket = runtime.defineClassUnder("SSLSocket", runtime.getObject(), ALLOCATOR, SSL, extraCallSites);

final ThreadContext context = runtime.getCurrentContext();
RubyClass SSLSocket = SSL.defineClassUnder("SSLSocket", runtime.getObject(), ALLOCATOR);
// SSLSocket.addReadAttribute(context, "io");
// SSLSocket.defineAlias("to_io", "io");
// SSLSocket.addReadAttribute(context, "context");

SSLSocket.addReadWriteAttribute(context, "sync_close");
SSLSocket.addReadWriteAttribute(context, "hostname");
SSLSocket.defineAnnotatedMethods(SSLSocket.class);
Expand Down Expand Up @@ -123,6 +145,10 @@ private static RaiseException newSSLErrorFromHandshake(Ruby runtime, SSLHandshak
return SSL.newSSLError(runtime, cause);
}

private CallSite callSite(final CallSiteIndex index) {
return getMetaClass().getExtraCallSites()[ index.ordinal() ];
}

private SSLContext sslContext;
private SSLEngine engine;
private RubyIO io;
Expand Down Expand Up @@ -152,29 +178,36 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a
if ( ! ( args[0] instanceof RubyIO ) ) {
throw runtime.newTypeError("IO expected but got " + args[0].getMetaClass().getName());
}
setInstanceVariable("@io", this.io = (RubyIO) args[0]); // compat (we do not read @io)
// Ruby 2.3 : @io.nonblock = true if @io.respond_to?(:nonblock=)
if (io.respondsTo("nonblock=")) {
io.callMethod(context, "nonblock=", runtime.getTrue());
}
setInstanceVariable("@context", this.sslContext); // only compat (we do not use @context)
setInstanceVariable("@io", this.io = (RubyIO) args[0]);
set_io_nonblock_checked(context, runtime.getTrue());
// This is a bit of a hack: SSLSocket should share code with
// RubyBasicSocket, which always sets sync to true.
// Instead we set it here for now.
this.set_sync(context, runtime.getTrue()); // io.sync = true
this.callMethod(context, "sync_close=", runtime.getFalse());
set_sync(context, runtime.getTrue()); // io.sync = true
setInstanceVariable("@sync_close", runtime.getFalse()); // self.sync_close = false
sslContext.setup(context);
return Utils.invokeSuper(context, this, args, Block.NULL_BLOCK); // super()
}

private IRubyObject set_io_nonblock_checked(final ThreadContext context, RubyBoolean value) {
// @io.nonblock = true if @io.respond_to?(:nonblock=)
IRubyObject respond = callSite(CallSiteIndex._respond_to_nonblock_w).call(context, io, io, context.runtime.newSymbol("nonblock="));
if (respond.isTrue()) {
return callSite(CallSiteIndex.nonblock_w).call(context, io, io, value);
}
return context.nil;
}

private SSLEngine ossl_ssl_setup(final ThreadContext context)
throws NoSuchAlgorithmException, KeyManagementException {
SSLEngine engine = this.engine;
if ( engine != null ) return engine;

// Server Name Indication (SNI) RFC 3546
// SNI support will not be attempted unless hostname is explicitly set by the caller
String peerHost = this.callMethod(context, "hostname").toString();
IRubyObject hostname = callSite(CallSiteIndex.hostname).call(context, this, this); // self.hostname
String peerHost = hostname.toString();
final int peerPort = socketChannelImpl().getRemotePort();
engine = sslContext.createSSLEngine(peerHost, peerPort);

Expand All @@ -199,12 +232,12 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context)

@JRubyMethod(name = "sync")
public IRubyObject sync(final ThreadContext context) {
return this.io.callMethod(context, "sync");
return callSite(CallSiteIndex.sync).call(context, io, io); // io.sync
}

@JRubyMethod(name = "sync=")
public IRubyObject set_sync(final ThreadContext context, final IRubyObject sync) {
return this.io.callMethod(context, "sync=", sync);
return callSite(CallSiteIndex.sync_w).call(context, io, io, sync); // io.sync = sync
}

@JRubyMethod
Expand Down Expand Up @@ -302,8 +335,8 @@ private IRubyObject acceptImpl(final ThreadContext context, final boolean blocki
if ( ! initialHandshake ) {
final SSLEngine engine = ossl_ssl_setup(context);
engine.setUseClientMode(false);
final IRubyObject verify_mode = sslContext.callMethod(context, "verify_mode");
if ( ! verify_mode.isNil() ) {
final IRubyObject verify_mode = callSite(CallSiteIndex.verify_mode).call(context, sslContext, sslContext);
if ( verify_mode != context.nil ) {
final int verify = RubyNumeric.fix2int(verify_mode);
if ( verify == 0 ) { // VERIFY_NONE
engine.setNeedClientAuth(false);
Expand Down Expand Up @@ -863,7 +896,7 @@ private IRubyObject syswriteImpl(final ThreadContext context,
written = write(buff, blocking);
}

this.io.callMethod(context, "flush");
callSite(CallSiteIndex.flush).call(context, io, io); // io.flush

return runtime.newFixnum(written);
}
Expand Down Expand Up @@ -931,17 +964,22 @@ private void close(boolean force) {

@JRubyMethod
public IRubyObject sysclose(final ThreadContext context) {
//if ( isClosed() ) return context.runtime.getNil();
if ( this.io.callMethod(context, "closed?").isTrue() ) {
return context.runtime.getNil();
} // Ruby 2.3
if ( io_closed_p(context).isTrue() ) return context.nil;

// no need to try shutdown when it's a server
close( sslContext.isProtocolForClient() );

if ( this.callMethod(context, "sync_close").isTrue() ) {
return this.io.callMethod(context, "close");
}
return context.runtime.getNil();
if ( getInstanceVariable("@sync_close").isTrue() ) return io_close(context);

return context.nil;
}

private IRubyObject io_closed_p(final ThreadContext context) { // io.closed?
return callSite(CallSiteIndex.closed_p).call(context, io, io);
}

private IRubyObject io_close(final ThreadContext context) { // io.close
return callSite(CallSiteIndex.close).call(context, io, io);
}

@JRubyMethod
Expand Down
4 changes: 2 additions & 2 deletions src/test/ruby/ssl/test_socket.rb
Expand Up @@ -18,12 +18,12 @@ def test_attr_methods

assert socket.io
assert_equal socket.io, socket.to_io
assert ! socket.respond_to?(:'io=')
assert ! socket.respond_to?('io=')
# due compatibility :
assert_equal socket.io, socket.instance_variable_get(:@io)

assert socket.context
assert ! socket.respond_to?(:'context=')
assert ! socket.respond_to?('context=')
# due compatibility :
assert_equal socket.context, socket.instance_variable_get(:@context)

Expand Down

0 comments on commit 0e927fa

Please sign in to comment.