Skip to content

Commit

Permalink
[Truffle] Merge PredicateDispatchHeadNode into CallDispatchHeadNode.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisseaton committed Jan 8, 2015
1 parent 1a36da4 commit 02ebbe7
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 108 deletions.
Expand Up @@ -13,21 +13,21 @@
import com.oracle.truffle.api.nodes.UnexpectedResultException;
import com.oracle.truffle.api.source.SourceSection;
import org.jruby.truffle.nodes.RubyNode;
import org.jruby.truffle.nodes.dispatch.PredicateDispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.*;
import org.jruby.truffle.runtime.RubyContext;
import org.jruby.truffle.runtime.core.RubyArray;

public class WhenSplatNode extends RubyNode {

@Child private RubyNode readCaseExpression;
@Child private RubyNode splat;
@Child private PredicateDispatchHeadNode dispatchCaseEqual;
@Child private CallDispatchHeadNode dispatchCaseEqual;

public WhenSplatNode(RubyContext context, SourceSection sourceSection, RubyNode readCaseExpression, RubyNode splat) {
super(context, sourceSection);
this.readCaseExpression = readCaseExpression;
this.splat = splat;
dispatchCaseEqual = new PredicateDispatchHeadNode(context);
dispatchCaseEqual = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

@Override
Expand All @@ -45,7 +45,7 @@ public boolean executeBoolean(VirtualFrame frame) {
}

for (Object value : array.slowToArray()) {
if (dispatchCaseEqual.call(frame, caseExpression, "===", null, value)) {
if (dispatchCaseEqual.callBoolean(frame, caseExpression, "===", null, value)) {
return true;
}
}
Expand Down
11 changes: 4 additions & 7 deletions core/src/main/java/org/jruby/truffle/nodes/core/ArrayNodes.java
Expand Up @@ -31,10 +31,7 @@
import org.jruby.truffle.nodes.methods.arguments.ReadPreArgumentNode;
import org.jruby.truffle.nodes.methods.locals.ReadLevelVariableNodeFactory;
import org.jruby.truffle.nodes.yield.YieldDispatchHeadNode;
import org.jruby.truffle.runtime.DebugOperations;
import org.jruby.truffle.runtime.RubyArguments;
import org.jruby.truffle.runtime.RubyContext;
import org.jruby.truffle.runtime.UndefinedPlaceholder;
import org.jruby.truffle.runtime.*;
import org.jruby.truffle.runtime.control.BreakException;
import org.jruby.truffle.runtime.control.NextException;
import org.jruby.truffle.runtime.control.RaiseException;
Expand Down Expand Up @@ -438,11 +435,11 @@ public RubyArray orObject(RubyArray a, RubyArray b) {
@CoreMethod(names = {"==", "eql?"}, required = 1)
public abstract static class EqualNode extends ArrayCoreMethodNode {

@Child private PredicateDispatchHeadNode equals;
@Child private CallDispatchHeadNode equals;

public EqualNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
equals = new PredicateDispatchHeadNode(context);
equals = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public EqualNode(EqualNode prev) {
Expand Down Expand Up @@ -511,7 +508,7 @@ public boolean equal(VirtualFrame frame, RubyArray a, RubyArray b) {
final Object[] bs = b.slowToArray();

for (int n = 0; n < a.getSize(); n++) {
if (!equals.call(frame, as[n], "==", null, bs[n])) {
if (!equals.callBoolean(frame, as[n], "==", null, bs[n])) {
return false;
}
}
Expand Down
Expand Up @@ -58,11 +58,11 @@ public boolean not(boolean value) {
@CoreMethod(names = "!=", required = 1)
public abstract static class NotEqualNode extends CoreMethodNode {

@Child private PredicateDispatchHeadNode equalNode;
@Child private CallDispatchHeadNode equalNode;

public NotEqualNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
equalNode = new PredicateDispatchHeadNode(context);
equalNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public NotEqualNode(NotEqualNode prev) {
Expand All @@ -72,7 +72,7 @@ public NotEqualNode(NotEqualNode prev) {

@Specialization
public boolean equal(VirtualFrame frame, Object a, Object b) {
return !equalNode.call(frame, a, "==", null, b);
return !equalNode.callBoolean(frame, a, "==", null, b);
}

}
Expand Down
41 changes: 19 additions & 22 deletions core/src/main/java/org/jruby/truffle/nodes/core/HashNodes.java
Expand Up @@ -18,10 +18,7 @@
import com.oracle.truffle.api.utilities.BranchProfile;
import org.jruby.runtime.Visibility;
import org.jruby.truffle.nodes.RubyRootNode;
import org.jruby.truffle.nodes.dispatch.CallDispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.DispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.DispatchHeadNodeFactory;
import org.jruby.truffle.nodes.dispatch.PredicateDispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.*;
import org.jruby.truffle.nodes.hash.FindEntryNode;
import org.jruby.truffle.nodes.yield.YieldDispatchHeadNode;
import org.jruby.truffle.runtime.DebugOperations;
Expand All @@ -43,11 +40,11 @@ public abstract class HashNodes {
@CoreMethod(names = "==", required = 1)
public abstract static class EqualNode extends HashCoreMethodNode {

@Child private PredicateDispatchHeadNode equalNode;
@Child private CallDispatchHeadNode equalNode;

public EqualNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
equalNode = new PredicateDispatchHeadNode(context);
equalNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public EqualNode(EqualNode prev) {
Expand Down Expand Up @@ -229,7 +226,7 @@ public RubyHash construct(Object[] args) {
@CoreMethod(names = "[]", required = 1)
public abstract static class GetIndexNode extends HashCoreMethodNode {

@Child private PredicateDispatchHeadNode eqlNode;
@Child private CallDispatchHeadNode eqlNode;
@Child private YieldDispatchHeadNode yield;
@Child private FindEntryNode findEntryNode;

Expand All @@ -238,7 +235,7 @@ public abstract static class GetIndexNode extends HashCoreMethodNode {

public GetIndexNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
eqlNode = new PredicateDispatchHeadNode(context);
eqlNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
yield = new YieldDispatchHeadNode(context);
findEntryNode = new FindEntryNode(context, sourceSection);
}
Expand Down Expand Up @@ -270,7 +267,7 @@ public Object getPackedArray(VirtualFrame frame, RubyHash hash, Object key) {
final int size = hash.getSize();

for (int n = 0; n < HashOperations.SMALL_HASH_SIZE; n++) {
if (n < size && eqlNode.call(frame, store[n * 2], "eql?", null, key)) {
if (n < size && eqlNode.callBoolean(frame, store[n * 2], "eql?", null, key)) {
return store[n * 2 + 1];
}
}
Expand Down Expand Up @@ -319,14 +316,14 @@ public Object getBuckets(VirtualFrame frame, RubyHash hash, Object key) {
@CoreMethod(names = "[]=", required = 2)
public abstract static class SetIndexNode extends HashCoreMethodNode {

@Child private PredicateDispatchHeadNode eqlNode;
@Child private CallDispatchHeadNode eqlNode;

private final BranchProfile considerExtendProfile = BranchProfile.create();
private final BranchProfile extendProfile = BranchProfile.create();

public SetIndexNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
eqlNode = new PredicateDispatchHeadNode(context);
eqlNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public SetIndexNode(SetIndexNode prev) {
Expand All @@ -353,7 +350,7 @@ public Object setPackedArray(VirtualFrame frame, RubyHash hash, Object key, Obje
final int size = hash.getSize();

for (int n = 0; n < HashOperations.SMALL_HASH_SIZE; n++) {
if (n < size && eqlNode.call(frame, store[n * 2], "eql?", null, key)) {
if (n < size && eqlNode.callBoolean(frame, store[n * 2], "eql?", null, key)) {
store[n * 2 + 1] = value;
return value;
}
Expand Down Expand Up @@ -428,12 +425,12 @@ public RubyHash empty(RubyHash hash) {
@CoreMethod(names = "delete", required = 1)
public abstract static class DeleteNode extends HashCoreMethodNode {

@Child private PredicateDispatchHeadNode eqlNode;
@Child private CallDispatchHeadNode eqlNode;
@Child private FindEntryNode findEntryNode;

public DeleteNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
eqlNode = new PredicateDispatchHeadNode(context);
eqlNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
findEntryNode = new FindEntryNode(context, sourceSection);
}

Expand All @@ -457,7 +454,7 @@ public Object deletePackedArray(VirtualFrame frame, RubyHash hash, Object key) {
final int size = hash.getSize();

for (int n = 0; n < HashOperations.SMALL_HASH_SIZE * 2; n += 2) {
if (n < size && eqlNode.call(frame, store[n], "eql?", null, key)) {
if (n < size && eqlNode.callBoolean(frame, store[n], "eql?", null, key)) {
final Object value = store[n + 1];

// Move the later values down
Expand Down Expand Up @@ -742,11 +739,11 @@ public RubyString inspectPackedArray(VirtualFrame frame, RubyHash hash) {
@CoreMethod(names = { "has_key?", "key?" }, required = 1)
public abstract static class KeyNode extends HashCoreMethodNode {

@Child private PredicateDispatchHeadNode eqlNode;
@Child private CallDispatchHeadNode eqlNode;

public KeyNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
eqlNode = new PredicateDispatchHeadNode(context);
eqlNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public KeyNode(KeyNode prev) {
Expand All @@ -767,7 +764,7 @@ public boolean keyPackedArray(VirtualFrame frame, RubyHash hash, Object key) {
final Object[] store = (Object[]) hash.getStore();

for (int n = 0; n < store.length; n += 2) {
if (n < size && eqlNode.call(frame, store[n], "eql?", null, key)) {
if (n < size && eqlNode.callBoolean(frame, store[n], "eql?", null, key)) {
return true;
}
}
Expand All @@ -780,7 +777,7 @@ public boolean keyBuckets(VirtualFrame frame, RubyHash hash, Object key) {
notDesignedForCompilation();

for (KeyValue keyValue : HashOperations.verySlowToKeyValues(hash)) {
if (eqlNode.call(frame, keyValue.getKey(), "eql?", null, key)) {
if (eqlNode.callBoolean(frame, keyValue.getKey(), "eql?", null, key)) {
return true;
}
}
Expand Down Expand Up @@ -903,7 +900,7 @@ public RubyArray mapBuckets(VirtualFrame frame, RubyHash hash, RubyProc block) {
@CoreMethod(names = "merge", required = 1)
public abstract static class MergeNode extends HashCoreMethodNode {

@Child private PredicateDispatchHeadNode eqlNode;
@Child private CallDispatchHeadNode eqlNode;

private final BranchProfile nothingFromFirstProfile = BranchProfile.create();
private final BranchProfile considerNothingFromSecondProfile = BranchProfile.create();
Expand All @@ -915,7 +912,7 @@ public abstract static class MergeNode extends HashCoreMethodNode {

public MergeNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
eqlNode = new PredicateDispatchHeadNode(context);
eqlNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public MergeNode(MergeNode prev) {
Expand Down Expand Up @@ -951,7 +948,7 @@ public RubyHash mergePackedArrayPackedArray(VirtualFrame frame, RubyHash hash, R

for (int b = 0; b < HashOperations.SMALL_HASH_SIZE; b++) {
if (b < storeBSize) {
if (eqlNode.call(frame, storeA[a * 2], "eql?", null, storeB[b * 2])) {
if (eqlNode.callBoolean(frame, storeA[a * 2], "eql?", null, storeB[b * 2])) {
merge = false;
break;
}
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/java/org/jruby/truffle/nodes/core/KernelNodes.java
Expand Up @@ -118,7 +118,7 @@ public RubyString backtick(RubyString command) {
public abstract static class SameOrEqualNode extends CoreMethodNode {

@Child private BasicObjectNodes.ReferenceEqualNode referenceEqualNode;
@Child private PredicateDispatchHeadNode equalNode;
@Child private CallDispatchHeadNode equalNode;

public SameOrEqualNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
Expand All @@ -139,9 +139,9 @@ protected boolean areSame(VirtualFrame frame, Object left, Object right) {
protected boolean areEqual(VirtualFrame frame, Object left, Object right) {
if (equalNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
equalNode = insert(new PredicateDispatchHeadNode(getContext()));
equalNode = insert(DispatchHeadNodeFactory.createMethodCall(getContext(), false, false, null));
}
return equalNode.call(frame, left, "==", null, right);
return equalNode.callBoolean(frame, left, "==", null, right);
}

public abstract boolean executeSameOrEqual(VirtualFrame frame, Object a, Object b);
Expand Down Expand Up @@ -176,11 +176,11 @@ public RubyNilClass equal(Object other) {
@CoreMethod(names = "!~", required = 1)
public abstract static class NotMatchNode extends CoreMethodNode {

@Child private PredicateDispatchHeadNode matchNode;
@Child private CallDispatchHeadNode matchNode;

public NotMatchNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
matchNode = new PredicateDispatchHeadNode(context);
matchNode = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public NotMatchNode(NotMatchNode prev) {
Expand All @@ -190,7 +190,7 @@ public NotMatchNode(NotMatchNode prev) {

@Specialization
public boolean notMatch(VirtualFrame frame, Object self, Object other) {
return !matchNode.call(frame, self, "=~", null, other);
return !matchNode.callBoolean(frame, self, "=~", null, other);
}

}
Expand Down
23 changes: 10 additions & 13 deletions core/src/main/java/org/jruby/truffle/nodes/core/RangeNodes.java
Expand Up @@ -16,10 +16,7 @@
import com.oracle.truffle.api.utilities.BranchProfile;
import org.jruby.truffle.nodes.RubyNode;
import org.jruby.truffle.nodes.RubyRootNode;
import org.jruby.truffle.nodes.dispatch.CallDispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.DispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.DispatchHeadNodeFactory;
import org.jruby.truffle.nodes.dispatch.PredicateDispatchHeadNode;
import org.jruby.truffle.nodes.dispatch.*;
import org.jruby.truffle.runtime.RubyContext;
import org.jruby.truffle.runtime.control.BreakException;
import org.jruby.truffle.runtime.control.NextException;
Expand Down Expand Up @@ -234,15 +231,15 @@ public Object each(RubyRange.ObjectRange range) {
@CoreMethod(names = {"include?", "==="}, required = 1)
public abstract static class IncludeNode extends CoreMethodNode {

@Child private PredicateDispatchHeadNode callLess;
@Child private PredicateDispatchHeadNode callGreater;
@Child private PredicateDispatchHeadNode callGreaterEqual;
@Child private CallDispatchHeadNode callLess;
@Child private CallDispatchHeadNode callGreater;
@Child private CallDispatchHeadNode callGreaterEqual;

public IncludeNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
callLess = new PredicateDispatchHeadNode(context);
callGreater = new PredicateDispatchHeadNode(context);
callGreaterEqual = new PredicateDispatchHeadNode(context);
callLess = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
callGreater = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
callGreaterEqual = DispatchHeadNodeFactory.createMethodCall(context, false, false, null);
}

public IncludeNode(IncludeNode prev) {
Expand All @@ -261,16 +258,16 @@ public boolean include(RubyRange.IntegerFixnumRange range, int value) {
public boolean include(VirtualFrame frame, RubyRange.ObjectRange range, Object value) {
notDesignedForCompilation();

if (callLess.call(frame, value, "<", null, range.getBegin())) {
if (callLess.callBoolean(frame, value, "<", null, range.getBegin())) {
return false;
}

if (range.doesExcludeEnd()) {
if (callGreaterEqual.call(frame, value, ">=", null, range.getEnd())) {
if (callGreaterEqual.callBoolean(frame, value, ">=", null, range.getEnd())) {
return false;
}
} else {
if (callGreater.call(frame, value, ">", null, range.getEnd())) {
if (callGreater.callBoolean(frame, value, ">", null, range.getEnd())) {
return false;
}
}
Expand Down
Expand Up @@ -11,13 +11,17 @@

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.frame.VirtualFrame;
import org.jruby.truffle.nodes.cast.BooleanCastNode;
import org.jruby.truffle.nodes.cast.BooleanCastNodeFactory;
import org.jruby.truffle.runtime.LexicalScope;
import org.jruby.truffle.runtime.RubyContext;
import org.jruby.truffle.runtime.control.RaiseException;
import org.jruby.truffle.runtime.core.RubyProc;

public class CallDispatchHeadNode extends DispatchHeadNode {

@Child private BooleanCastNode booleanCastNode;

public CallDispatchHeadNode(RubyContext context, boolean ignoreVisibility, boolean indirect, MissingBehavior missingBehavior, LexicalScope lexicalScope) {
super(context, ignoreVisibility, indirect, missingBehavior, lexicalScope, DispatchAction.CALL_METHOD);
}
Expand All @@ -36,6 +40,21 @@ public Object call(
argumentsObjects);
}

public boolean callBoolean(
VirtualFrame frame,
Object receiverObject,
Object methodName,
RubyProc blockObject,
Object... argumentsObjects) {
if (booleanCastNode == null) {
CompilerDirectives.transferToInterpreter();
booleanCastNode = insert(BooleanCastNodeFactory.create(context, getSourceSection(), null));
}

return booleanCastNode.executeBoolean(frame,
dispatch(frame, receiverObject, methodName, blockObject, argumentsObjects));
}

public double callFloat(
VirtualFrame frame,
Object receiverObject,
Expand Down
Expand Up @@ -9,6 +9,7 @@
*/
package org.jruby.truffle.nodes.dispatch;

import org.jruby.truffle.runtime.LexicalScope;
import org.jruby.truffle.runtime.RubyContext;

public class DispatchHeadNodeFactory {
Expand Down

0 comments on commit 02ebbe7

Please sign in to comment.