diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index 76b9a62f..a0c1b59f 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -145,6 +145,9 @@ bdvalue_nullable(BDVALUE v) #define BIGDECIMAL_POSITIVE_P(bd) ((bd)->sign > 0) #define BIGDECIMAL_NEGATIVE_P(bd) ((bd)->sign < 0) +#define MULT_BY_RB_INTEGER_THRESHOLD 150 +#define DIV_BY_RB_INTEGER_THRESHOLD 300 + /* * ================== Memory allocation ============================ */ @@ -5873,6 +5876,32 @@ VpSetPTR(Real *a, Real *b, Real *c, size_t *a_pos, size_t *b_pos, size_t *c_pos, return word_shift; } +static int +VpMultWithRubyInteger(Real *c, Real *a, Real *b) +{ + Real *ap, *bp; + VALUE a2, b2, ab; + BDVALUE c2; + ap = VpCopy(NULL, a); + bp = VpCopy(NULL, b); + ap->exponent = a->Prec; + bp->exponent = b->Prec; + a2 = TypedData_Wrap_Struct(rb_cBigDecimal, &BigDecimal_data_type, 0); + b2 = TypedData_Wrap_Struct(rb_cBigDecimal, &BigDecimal_data_type, 0); + BigDecimal_wrap_struct(a2, ap); + BigDecimal_wrap_struct(b2, bp); + ab = rb_funcall(BigDecimal_to_i(a2), '*', 1, BigDecimal_to_i(b2)); + + c2 = GetBDValueMust(ab); + VpAsgn(c, c2.real, 1); + if (!AddExponent(c, a->exponent)) return 0; + if (!AddExponent(c, b->exponent)) return 0; + if (!AddExponent(c, -a->Prec)) return 0; + if (!AddExponent(c, -b->Prec)) return 0; + RB_GC_GUARD(c2.bigdecimal); + return 1; +} + /* * Return number of significant digits * c = a * b , Where a = a0a1a2 ... an @@ -5926,6 +5955,11 @@ VpMult(Real *c, Real *a, Real *b) MxIndC = c->MaxPrec - 1; MxIndAB = a->Prec + b->Prec - 1; + if (a->Prec >= MULT_BY_RB_INTEGER_THRESHOLD && b->Prec >= MULT_BY_RB_INTEGER_THRESHOLD) { + if (!VpMultWithRubyInteger(c, a, b)) return 0; + goto Exit; + } + if (MxIndC < MxIndAB) { /* The Max. prec. of c < Prec(a)+Prec(b) */ w = c; c = NewZeroNolimit(1, (size_t)((MxIndAB + 1) * BASE_FIG)); @@ -6000,6 +6034,46 @@ VpMult(Real *c, Real *a, Real *b) return c->Prec*BASE_FIG; } +static void +VpDivdWithRubyInteger(Real *c, Real *r, Real *a, Real *b) +{ + Real *ap, *bp; + BDVALUE c2, r2; + VALUE a2, b2, divmod, div, mod; + size_t div_prec = c->MaxPrec - 1; + size_t base_prec = b->Prec; + + ap = VpCopy(NULL, a); + bp = VpCopy(NULL, b); + VpSetSign(ap, 1); + VpSetSign(bp, 1); + ap->exponent = base_prec + div_prec; + bp->exponent = base_prec; + a2 = TypedData_Wrap_Struct(rb_cBigDecimal, &BigDecimal_data_type, 0); + b2 = TypedData_Wrap_Struct(rb_cBigDecimal, &BigDecimal_data_type, 0); + BigDecimal_wrap_struct(a2, ap); + BigDecimal_wrap_struct(b2, bp); + divmod = rb_funcall(BigDecimal_to_i(a2), rb_intern("divmod"), 1, BigDecimal_to_i(b2)); + if (RB_TYPE_P(divmod, T_ARRAY) && RARRAY_LEN(divmod) == 2) { + div = RARRAY_AREF(divmod, 0); + mod = RARRAY_AREF(divmod, 1); + } else { + div = mod = Qnil; + } + + c2 = GetBDValueMust(div); + r2 = GetBDValueMust(mod); + VpAsgn(c, c2.real, VpGetSign(a) * VpGetSign(b)); + VpAsgn(r, r2.real, VpGetSign(a)); + RB_GC_GUARD(c2.bigdecimal); + RB_GC_GUARD(r2.bigdecimal); + AddExponent(c, a->exponent); + AddExponent(c, -b->exponent); + AddExponent(c, -div_prec); + AddExponent(r, a->exponent); + AddExponent(r, -base_prec - div_prec); +} + /* * c = a / b, remainder = r * XXXX_YYYY_ZZZZ / 0001 = XXXX_YYYY_ZZZZ @@ -6041,6 +6115,11 @@ VpDivd(Real *c, Real *r, Real *a, Real *b) if (word_a > word_r || word_b + word_c - 2 >= word_r) goto space_error; + if (word_c >= DIV_BY_RB_INTEGER_THRESHOLD && word_b >= DIV_BY_RB_INTEGER_THRESHOLD) { + VpDivdWithRubyInteger(c, r, a, b); + goto Exit; + } + for (i = 0; i < word_a; ++i) r->frac[i] = a->frac[i]; for (i = word_a; i < word_r; ++i) r->frac[i] = 0; for (i = 0; i < word_c; ++i) c->frac[i] = 0; diff --git a/test/bigdecimal/test_vp_operation.rb b/test/bigdecimal/test_vp_operation.rb index 075df0b6..6bb38b22 100644 --- a/test/bigdecimal/test_vp_operation.rb +++ b/test/bigdecimal/test_vp_operation.rb @@ -5,6 +5,9 @@ class TestVpOperation < Test::Unit::TestCase include TestBigDecimalBase + INTEGER_MULT_THRESHOLD = 150 + INTEGER_DIV_THRESHOLD = 300 + def setup super unless BigDecimal.instance_methods.include?(:vpdivd) @@ -16,9 +19,15 @@ def setup def test_vpmult assert_equal(BigDecimal('121932631112635269'), BigDecimal('123456789').vpmult(BigDecimal('987654321'))) assert_equal(BigDecimal('12193263.1112635269'), BigDecimal('123.456789').vpmult(BigDecimal('98765.4321'))) - x = 123**456 - y = 987**123 - assert_equal(BigDecimal("#{x * y}e-300"), BigDecimal("#{x}e-100").vpmult(BigDecimal("#{y}e-200"))) + + [10, 50, INTEGER_MULT_THRESHOLD, INTEGER_MULT_THRESHOLD + 50].each do |prec| + x = (BASE + BASE / 7) ** prec + y = (BASE + BASE / 3) ** prec + assert_equal(BigDecimal("#{x * y}e-168"), BigDecimal("#{x}e-123").vpmult(BigDecimal("#{y}e-45"))) + assert_equal(BigDecimal("-#{x * y}e-46"), BigDecimal("#{x}e-12").vpmult(BigDecimal("-#{y}e-34"))) + assert_equal(BigDecimal("-#{x * y}e-333"), BigDecimal("-#{x}e12").vpmult(BigDecimal("#{y}e-345"))) + assert_equal(BigDecimal("#{x * y}e46"), BigDecimal("-#{x}e12").vpmult(BigDecimal("-#{y}e34"))) + end end def test_vpdivd @@ -84,6 +93,35 @@ def test_vpdivd_precisions end end + def test_vpdivd_by_ruby_integer + # Exponent check + integer_div_prec = INTEGER_DIV_THRESHOLD + 50 + x = (BASE + BASE / 7) ** integer_div_prec + y = (BASE + BASE / 3) ** integer_div_prec + bx = BigDecimal("#{x}e-123") + by = BigDecimal("#{y}e-456") + div, mod = bx.vpdivd(by, integer_div_prec) + assert_include(0...by, mod) + assert_equal(bx, div * by + mod) + + # Precision should consistent around DIV_BY_RB_INTEGER_THRESHOLD threshold + [2, 3, 4, integer_div_prec, integer_div_prec + 1, integer_div_prec + 2].each do |prec| + a = BigDecimal('1' + '2' * BASE_FIG * prec) + b = BigDecimal('9' * BASE_FIG + '9' * BASE_FIG * prec) + assert_equal(BASE_FIG * (prec - 2) + 1, a.vpdivd(b, prec).first.n_significant_digits) + assert_equal(BASE_FIG * prec, b.vpdivd(a, prec).first.n_significant_digits) + + x = BigDecimal("1#{'0' * BASE_FIG * prec}1") + y = BigDecimal("1#{'0' * BASE_FIG * prec}2") + div = BigDecimal("0.#{'9' * BASE_FIG * (prec - 1)}") + mod = BigDecimal("#{'9' * (BASE_FIG + 1)}.#{'0' * (BASE_FIG * (prec - 1) - 1)}2") + assert_equal([div, mod], x.vpdivd(y, prec)) + assert_equal([-div, mod], x.vpdivd(-y, prec)) + assert_equal([-div, -mod], (-x).vpdivd(y, prec)) + assert_equal([div, -mod], (-x).vpdivd(-y, prec)) + end + end + def test_vpdivd_borrow y_small = BASE / 7 * BASE ** 4 y_large = (4 * BASE_FIG).times.map {|i| i % 9 + 1 }.join.to_i