Skip to content

Commit

Permalink
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions core/src/main/java/org/jruby/truffle/nodes/core/ArrayNodes.java
Original file line number Diff line number Diff line change
@@ -1881,6 +1881,87 @@ public RubyArray initializeCopyObject(RubyArray self, RubyArray from) {

}

@CoreMethod(names = {"inject", "reduce"}, needsBlock = true, optional = 1)
@ImportGuards(ArrayGuards.class)
public abstract static class InjectNode extends YieldingCoreMethodNode {

@Child private CallDispatchHeadNode dispatch;

public InjectNode(RubyContext context, SourceSection sourceSection) {
super(context, sourceSection);
dispatch = DispatchHeadNodeFactory.createMethodCall(context, MissingBehavior.CALL_METHOD_MISSING);
}

public InjectNode(InjectNode prev) {
super(prev);
dispatch = prev.dispatch;
}

@Specialization(guards = "isObject")
public Object injectObject(VirtualFrame frame, RubyArray array, Object initial, RubyProc block) {
int count = 0;

final Object[] store = (Object[]) array.getStore();

Object accumulator = initial;

try {
for (int n = 0; n < array.getSize(); n++) {
if (CompilerDirectives.inInterpreter()) {
count++;
}

accumulator = yield(frame, block, accumulator, store[n]);
}
} finally {
if (CompilerDirectives.inInterpreter()) {
((RubyRootNode) getRootNode()).reportLoopCount(count);
}
}

return accumulator;
}

@Specialization
public Object inject(VirtualFrame frame, RubyArray array, Object initial, RubyProc block) {
notDesignedForCompilation();

final Object[] store = array.slowToArray();

if (store.length < 2) {
throw new UnsupportedOperationException();
}

Object accumulator = initial;

for (int n = 0; n < array.getSize(); n++) {
accumulator = yield(frame, block, accumulator, store[n]);
}

return accumulator;
}

@Specialization
public Object inject(VirtualFrame frame, RubyArray array, RubySymbol symbol, UndefinedPlaceholder unused) {
notDesignedForCompilation();

final Object[] store = array.slowToArray();

if (store.length < 2) {
throw new UnsupportedOperationException();
}

Object accumulator = dispatch.call(frame, store[0], symbol, null, store[1]);

for (int n = 2; n < array.getSize(); n++) {
accumulator = dispatch.call(frame, accumulator, symbol, null, store[n]);
}

return accumulator;
}

}

@CoreMethod(names = "insert", required = 2)
public abstract static class InsertNode extends ArrayCoreMethodNode {

0 comments on commit bc23c47

Please sign in to comment.