Skip to content

Commit

Permalink
align BH#hash with eql? (+ equals/hashCode on Java) & renamed method
Browse files Browse the repository at this point in the history
  • Loading branch information
kares committed Nov 29, 2016
1 parent dec94fc commit 7308da1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 25 deletions.
74 changes: 49 additions & 25 deletions src/main/java/org/jruby/ext/openssl/BN.java
Expand Up @@ -173,7 +173,7 @@ public synchronized IRubyObject initialize_copy(final IRubyObject that) {
@JRubyMethod(name = "copy")
public IRubyObject copy(IRubyObject other) {
if (this != other) {
this.value = getBigInteger(other);
this.value = asBigInteger(other);
}
return this;
}
Expand Down Expand Up @@ -271,8 +271,23 @@ public IRubyObject inspect() {
return ObjectSupport.inspect(this, Collections.EMPTY_LIST);
}

@Override
public boolean equals(Object other) {
return (other instanceof BN) ? this.value.equals(((BN) other).value) : false;
}

@Override
public int hashCode() {
return 997 * value.hashCode();
}

@JRubyMethod(name = "hash")
public RubyInteger hash(final ThreadContext context) {
return context.runtime.newFixnum(hashCode());
}

@JRubyMethod(name = "to_i")
public IRubyObject to_i() {
public RubyInteger to_i() {
if ( value.compareTo( MAX_LONG ) > 0 || value.compareTo( MIN_LONG ) < 0 ) {
return RubyBignum.newBignum(getRuntime(), value);
}
Expand Down Expand Up @@ -321,17 +336,17 @@ public RubyBoolean odd_p(final ThreadContext context) {

@JRubyMethod(name={"cmp", "<=>"})
public IRubyObject cmp(final ThreadContext context, IRubyObject other) {
return context.runtime.newFixnum( value.compareTo( getBigInteger(other) ) );
return context.runtime.newFixnum( value.compareTo( asBigInteger(other) ) );
}

@JRubyMethod(name="ucmp")
public IRubyObject ucmp(final ThreadContext context, IRubyObject other) {
return context.runtime.newFixnum( value.abs().compareTo( getBigInteger(other).abs() ) );
return context.runtime.newFixnum( value.abs().compareTo( asBigInteger(other).abs() ) );
}

@JRubyMethod(name={"eql?", "==", "==="})
public RubyBoolean eql_p(final ThreadContext context, IRubyObject other) {
return context.runtime.newBoolean( value.equals( getBigInteger(other) ) );
return context.runtime.newBoolean( value.equals( asBigInteger(other) ) );
}

@JRubyMethod(name="sqr")
Expand All @@ -347,23 +362,23 @@ public BN not(final ThreadContext context) {

@JRubyMethod(name="+")
public BN add(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.add(getBigInteger(other)));
return newBN(context.runtime, value.add(asBigInteger(other)));
}

@JRubyMethod(name="-")
public BN sub(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.subtract(getBigInteger(other)));
return newBN(context.runtime, value.subtract(asBigInteger(other)));
}

@JRubyMethod(name="*")
public BN mul(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.multiply(getBigInteger(other)));
return newBN(context.runtime, value.multiply(asBigInteger(other)));
}

@JRubyMethod(name="%")
public BN mod(final ThreadContext context, IRubyObject other) {
try {
return newBN(context.runtime, value.mod(getBigInteger(other)));
return newBN(context.runtime, value.mod(asBigInteger(other)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand All @@ -374,7 +389,7 @@ public BN mod(final ThreadContext context, IRubyObject other) {
public IRubyObject div(final ThreadContext context, IRubyObject other) {
final Ruby runtime = context.runtime;
try {
BigInteger[] result = value.divideAndRemainder(getBigInteger(other));
BigInteger[] result = value.divideAndRemainder(asBigInteger(other));
return runtime.newArray(newBN(runtime, result[0]), newBN(runtime, result[1]));
}
catch (ArithmeticException e) {
Expand All @@ -384,17 +399,17 @@ public IRubyObject div(final ThreadContext context, IRubyObject other) {

@JRubyMethod(name="&")
public BN and(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.and(getBigInteger(other)));
return newBN(context.runtime, value.and(asBigInteger(other)));
}

@JRubyMethod(name="|")
public BN or(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.or(getBigInteger(other)));
return newBN(context.runtime, value.or(asBigInteger(other)));
}

@JRubyMethod(name="^")
public BN xor(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.xor(getBigInteger(other)));
return newBN(context.runtime, value.xor(asBigInteger(other)));
}

@JRubyMethod(name="**")
Expand Down Expand Up @@ -439,13 +454,13 @@ else if ( other instanceof RubyBignum ) { // inherently too big

@JRubyMethod(name="gcd")
public BN gcd(final ThreadContext context, IRubyObject other) {
return newBN(context.runtime, value.gcd(getBigInteger(other)));
return newBN(context.runtime, value.gcd(asBigInteger(other)));
}

@JRubyMethod(name="mod_sqr")
public BN mod_sqr(final ThreadContext context, IRubyObject other) {
try {
return newBN(context.runtime, value.modPow(TWO, getBigInteger(other)));
return newBN(context.runtime, value.modPow(TWO, asBigInteger(other)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand All @@ -455,7 +470,7 @@ public BN mod_sqr(final ThreadContext context, IRubyObject other) {
@JRubyMethod(name="mod_inverse")
public BN mod_inverse(final ThreadContext context, IRubyObject other) {
try {
return newBN(context.runtime, value.modInverse(getBigInteger(other)));
return newBN(context.runtime, value.modInverse(asBigInteger(other)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand All @@ -465,7 +480,7 @@ public BN mod_inverse(final ThreadContext context, IRubyObject other) {
@JRubyMethod(name="mod_add")
public BN mod_add(final ThreadContext context, IRubyObject other, IRubyObject mod) {
try {
return newBN(context.runtime, value.add(getBigInteger(other)).mod(getBigInteger(mod)));
return newBN(context.runtime, value.add(asBigInteger(other)).mod(asBigInteger(mod)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand All @@ -475,7 +490,7 @@ public BN mod_add(final ThreadContext context, IRubyObject other, IRubyObject mo
@JRubyMethod(name="mod_sub")
public BN mod_sub(final ThreadContext context, IRubyObject other, IRubyObject mod) {
try {
return newBN(context.runtime, value.subtract(getBigInteger(other)).mod(getBigInteger(mod)));
return newBN(context.runtime, value.subtract(asBigInteger(other)).mod(asBigInteger(mod)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand All @@ -485,7 +500,7 @@ public BN mod_sub(final ThreadContext context, IRubyObject other, IRubyObject mo
@JRubyMethod(name="mod_mul")
public BN mod_mul(final ThreadContext context, IRubyObject other, IRubyObject mod) {
try {
return newBN(context.runtime, value.multiply(getBigInteger(other)).mod(getBigInteger(mod)));
return newBN(context.runtime, value.multiply(asBigInteger(other)).mod(asBigInteger(mod)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand All @@ -495,7 +510,7 @@ public BN mod_mul(final ThreadContext context, IRubyObject other, IRubyObject mo
@JRubyMethod(name="mod_exp")
public BN mod_exp(final ThreadContext context, IRubyObject other, IRubyObject mod) {
try {
return newBN(context.runtime, value.modPow(getBigInteger(other), getBigInteger(mod)));
return newBN(context.runtime, value.modPow(asBigInteger(other), asBigInteger(mod)));
}
catch (ArithmeticException e) {
throw context.runtime.newZeroDivisionError();
Expand Down Expand Up @@ -657,8 +672,8 @@ public static IRubyObject generate_prime(IRubyObject recv, IRubyObject[] args) {
int argc = Arity.checkArgumentCount(runtime, args, 1, 4);
int bits = RubyNumeric.num2int(args[0]);
boolean safe = argc > 1 ? args[1] != runtime.getFalse() : true;
BigInteger add = argc > 2 ? getBigInteger(args[2]) : null;
BigInteger rem = argc > 3 ? getBigInteger(args[3]) : null;
BigInteger add = argc > 2 ? asBigInteger(args[2]) : null;
BigInteger rem = argc > 3 ? asBigInteger(args[3]) : null;
if (bits < 3) {
if (safe) throw runtime.newArgumentError("bits < 3");
if (bits < 2) throw runtime.newArgumentError("bits < 2");
Expand Down Expand Up @@ -794,12 +809,12 @@ public static BigInteger getRandomBI(int bits, int top, boolean bottom, Random r

@JRubyMethod(name = "rand_range", meta = true)
public static IRubyObject rand_range(IRubyObject recv, IRubyObject arg) {
return randomValueInRange(recv.getRuntime(), getBigInteger(arg), getSecureRandom());
return randomValueInRange(recv.getRuntime(), asBigInteger(arg), getSecureRandom());
}

@JRubyMethod(name = "pseudo_rand_range", meta = true)
public static IRubyObject pseudo_rand_range(IRubyObject recv, IRubyObject arg) {
return randomValueInRange(recv.getRuntime(), getBigInteger(arg), getRandom());
return randomValueInRange(recv.getRuntime(), asBigInteger(arg), getRandom());
}

private static BN randomValueInRange(Ruby runtime, BigInteger limit, Random random) {
Expand Down Expand Up @@ -850,7 +865,7 @@ public static RaiseException newBNError(Ruby runtime, String message) {
return new RaiseException(runtime, runtime.getModule("OpenSSL").getClass("BNError"), message, true);
}

public static BigInteger getBigInteger(final IRubyObject arg) {
public static BigInteger asBigInteger(final IRubyObject arg) {
if ( arg.isNil() ) return null;

if ( arg instanceof RubyInteger ) {
Expand All @@ -862,6 +877,15 @@ public static BigInteger getBigInteger(final IRubyObject arg) {
throw arg.getRuntime().newTypeError("Cannot convert into OpenSSL::BN");
}

public static BigInteger asBigInteger(final BN arg) {
return arg.isNil() ? null : arg.value;
}

@Deprecated
public static BigInteger getBigInteger(final IRubyObject arg) {
return asBigInteger(arg);
}

@Override
public Object toJava(Class target) {
if ( target.isAssignableFrom(BigInteger.class) || target == Number.class ) return value;
Expand Down
14 changes: 14 additions & 0 deletions src/test/ruby/test_bn.rb
Expand Up @@ -26,6 +26,20 @@ def test_comparable
assert OpenSSL::BN.include? Comparable
end

def test_cmp
bn1 = OpenSSL::BN.new('1')
bn2 = OpenSSL::BN.new('1')
bn3 = OpenSSL::BN.new('2')
assert_equal(false, bn1 == nil)
assert_equal(true, bn1 != nil)
assert_equal(true, bn1 == bn2)
assert_equal(false, bn1 == bn3)
assert_equal(true, bn1.eql?(bn2))
assert_equal(false, bn1.eql?(bn3))
assert_equal(bn1.hash, bn2.hash)
assert_not_equal(bn3.hash, bn1.hash)
end if RUBY_VERSION >= '2.3'

def test_to_bn
bn = OpenSSL::BN.new('4224')
assert_equal bn, 4224.to_bn
Expand Down

0 comments on commit 7308da1

Please sign in to comment.