Skip to content

Commit 61820c4

Browse files
committed
Implement faster multiplication using Number Theoretic Transform
Performs ntt with three primes (29<<27|1, 26<<27|1, 24<<27|1)
1 parent 0aa97bb commit 61820c4

File tree

4 files changed

+244
-0
lines changed

4 files changed

+244
-0
lines changed

bigdecimal.gemspec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Gem::Specification.new do |s|
4646
ext/bigdecimal/feature.h
4747
ext/bigdecimal/missing.c
4848
ext/bigdecimal/missing.h
49+
ext/bigdecimal/ntt.h
4950
ext/bigdecimal/missing/dtoa.c
5051
ext/bigdecimal/static_assert.h
5152
]

ext/bigdecimal/bigdecimal.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333

3434
#define BIGDECIMAL_VERSION "3.3.0"
3535

36+
#if SIZEOF_DECDIG == 4
37+
#define USE_NTT_MULTIPLICATION 1
38+
#include "ntt.h"
39+
#define NTT_MULTIPLICATION_THRESHOLD 100
40+
#endif
41+
3642
/* #define ENABLE_NUMERIC_STRING */
3743

3844
#define SIGNED_VALUE_MAX INTPTR_MAX
@@ -3280,6 +3286,25 @@ BigDecimal_vpmult(VALUE self, VALUE v) {
32803286
RB_GC_GUARD(b.bigdecimal);
32813287
return c.bigdecimal;
32823288
}
3289+
3290+
#if SIZEOF_DECDIG == 4
3291+
VALUE
3292+
BigDecimal_nttmult(VALUE self, VALUE v) {
3293+
BDVALUE a,b,c;
3294+
a = GetBDValueMust(self);
3295+
b = GetBDValueMust(v);
3296+
c = NewZeroWrap(1, VPMULT_RESULT_PREC(a.real, b.real) * BASE_FIG);
3297+
ntt_multiply(a.real->Prec, b.real->Prec, a.real->frac, b.real->frac, c.real->frac);
3298+
VpSetSign(c.real, a.real->sign * b.real->sign);
3299+
c.real->exponent = a.real->exponent + b.real->exponent;
3300+
c.real->Prec = a.real->Prec + b.real->Prec;
3301+
VpNmlz(c.real);
3302+
RB_GC_GUARD(a.bigdecimal);
3303+
RB_GC_GUARD(b.bigdecimal);
3304+
return c.bigdecimal;
3305+
}
3306+
#endif
3307+
32833308
#endif /* BIGDECIMAL_USE_VP_TEST_METHODS */
32843309

32853310
/* Document-class: BigDecimal
@@ -3652,6 +3677,9 @@ Init_bigdecimal(void)
36523677
#ifdef BIGDECIMAL_USE_VP_TEST_METHODS
36533678
rb_define_method(rb_cBigDecimal, "vpdivd", BigDecimal_vpdivd, 2);
36543679
rb_define_method(rb_cBigDecimal, "vpmult", BigDecimal_vpmult, 1);
3680+
#ifdef USE_NTT_MULTIPLICATION
3681+
rb_define_method(rb_cBigDecimal, "nttmult", BigDecimal_nttmult, 1);
3682+
#endif
36553683
#endif /* BIGDECIMAL_USE_VP_TEST_METHODS */
36563684

36573685
#define ROUNDING_MODE(i, name, value) \
@@ -4934,6 +4962,15 @@ VpMult(Real *c, Real *a, Real *b)
49344962
c->exponent = a->exponent; /* set exponent */
49354963
VpSetSign(c, VpGetSign(a) * VpGetSign(b)); /* set sign */
49364964
if (!AddExponent(c, b->exponent)) return 0;
4965+
4966+
#ifdef USE_NTT_MULTIPLICATION
4967+
if (b->Prec >= NTT_MULTIPLICATION_THRESHOLD) {
4968+
ntt_multiply((uint32_t)a->Prec, (uint32_t)b->Prec, a->frac, b->frac, c->frac);
4969+
c->Prec = a->Prec + b->Prec;
4970+
goto Cleanup;
4971+
}
4972+
#endif
4973+
49374974
carry = 0;
49384975
nc = ind_c = MxIndAB;
49394976
memset(c->frac, 0, (nc + 1) * sizeof(DECDIG)); /* Initialize c */
@@ -4980,6 +5017,8 @@ VpMult(Real *c, Real *a, Real *b)
49805017
}
49815018
}
49825019
}
5020+
5021+
Cleanup:
49835022
VpNmlz(c);
49845023

49855024
Exit:

ext/bigdecimal/ntt.h

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// NTT (Number Theoretic Transform) implementation for BigDecimal multiplication
2+
3+
#define NTT_PRIMITIVE_ROOT 17
4+
#define NTT_PRIME_BASE1 24
5+
#define NTT_PRIME_BASE2 26
6+
#define NTT_PRIME_BASE3 29
7+
#define NTT_PRIME_SHIFT 27
8+
#define NTT_PRIME1 (((uint32_t)NTT_PRIME_BASE1 << NTT_PRIME_SHIFT) | 1)
9+
#define NTT_PRIME2 (((uint32_t)NTT_PRIME_BASE2 << NTT_PRIME_SHIFT) | 1)
10+
#define NTT_PRIME3 (((uint32_t)NTT_PRIME_BASE3 << NTT_PRIME_SHIFT) | 1)
11+
#define MAX_NTT32_BITS 27
12+
#define NTT_DECDIG_BASE 1000000000
13+
14+
// Calculates base**ex % mod
15+
static uint32_t
16+
mod_pow(uint32_t base, uint32_t ex, uint32_t mod) {
17+
uint32_t res = 1;
18+
uint32_t bit = 1;
19+
while (true) {
20+
if (ex & bit) {
21+
ex ^= bit;
22+
res = ((uint64_t)res * base) % mod;
23+
}
24+
if (!ex) break;
25+
base = ((uint64_t)base * base) % mod;
26+
bit <<= 1;
27+
}
28+
return res;
29+
}
30+
31+
// Recursively performs butterfly operations of NTT
32+
static void
33+
ntt_recursive(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int depth, uint32_t r, uint32_t prime) {
34+
if (depth > 0) {
35+
ntt_recursive(size_bits, input, tmp, output, depth - 1, ((uint64_t)r * r) % prime, prime);
36+
} else {
37+
tmp = input;
38+
}
39+
uint32_t size_half = (uint32_t)1 << (size_bits - 1);
40+
uint32_t stride = (uint32_t)1 << (size_bits - depth - 1);
41+
uint32_t n = size_half / stride;
42+
uint32_t rn = 1, rm = prime - 1;
43+
uint32_t idx = 0;
44+
for (uint32_t i = 0; i < n; i++) {
45+
uint32_t j = i * 2 * stride;
46+
for (uint32_t k = 0; k < stride; k++, j++, idx++) {
47+
uint32_t a = tmp[j], b = tmp[j + stride];
48+
output[idx] = (a + (uint64_t)rn * b) % prime;
49+
output[idx + size_half] = (a + (uint64_t)rm * b) % prime;
50+
}
51+
rn = ((uint64_t)rn * r) % prime;
52+
rm = ((uint64_t)rm * r) % prime;
53+
}
54+
}
55+
56+
/* Perform NTT on input array.
57+
* base, shift: Represent the prime number as (base << shift | 1)
58+
* r_base: Primitive root of unity modulo prime
59+
* size_bits: log2 of the size of the input array. Should be less or equal to shift
60+
* input: input array of size (1 << size_bits)
61+
*/
62+
static void
63+
ntt(int size_bits, uint32_t *input, uint32_t *output, uint32_t *tmp, int r_base, int base, int shift, int dir) {
64+
uint32_t size = (uint32_t)1 << size_bits;
65+
uint32_t prime = ((uint32_t)base << shift) | 1;
66+
67+
// rmax**(1 << shift) % prime == 1
68+
// r**size % prime == 1
69+
uint32_t rmax = mod_pow(r_base, base, prime);
70+
uint32_t r = mod_pow(rmax, (uint32_t)1 << (shift - size_bits), prime);
71+
72+
if (dir < 0) r = mod_pow(r, prime - 2, prime);
73+
ntt_recursive(size_bits, input, output, tmp, size_bits - 1, r, prime);
74+
if (dir < 0) {
75+
uint32_t n_inv = mod_pow((uint32_t)size, prime - 2, prime);
76+
for (uint32_t i = 0; i < size; i++) {
77+
output[i] = ((uint64_t)output[i] * n_inv) % prime;
78+
}
79+
}
80+
}
81+
82+
/* Calculate c that satisfies: c % PRIME1 == mod1 && c % PRIME2 == mod2 && c % PRIME3 == mod3
83+
* c = (mod1 * 35002755423056150739595925972 + mod2 * 14584479687667766215746868453 + mod3 * 37919651490985126265126719818) % (PRIME1 * PRIME2 * PRIME3)
84+
* Assume c <= 999999999**2*(1<<27)
85+
*/
86+
static inline void
87+
mod_restore_prime_24_26_29_shift_27(uint32_t mod1, uint32_t mod2, uint32_t mod3, uint32_t *digits) {
88+
// Use mixed radix notation to eliminate modulo by PRIME1 * PRIME2 * PRIME3
89+
// [DIG0, DIG1, DIG2] = DIG0 + DIG1 * PRIME1 + DIG2 * PRIME1 * PRIME2
90+
// DIG0: 0...PRIME1, DIG1: 0...PRIME2, DIG2: 0...PRIME3
91+
// 35002755423056150739595925972 = [1, 3489660916, 3113851359]
92+
// 14584479687667766215746868453 = [0, 13, 1297437912]
93+
// 37919651490985126265126719818 = [0, 0, 3373338954]
94+
uint64_t c0 = mod1;
95+
uint64_t c1 = (uint64_t)mod2 * 13 + (uint64_t)mod1 * 3489660916;
96+
uint64_t c2 = (uint64_t)mod3 * 3373338954 % NTT_PRIME3 + (uint64_t)mod2 * 1297437912 % NTT_PRIME3 + (uint64_t)mod1 * 3113851359 % NTT_PRIME3;
97+
c2 += c1 / NTT_PRIME2;
98+
c1 %= NTT_PRIME2;
99+
c2 %= NTT_PRIME3;
100+
// Base conversion. c fits in 3 digits.
101+
c1 += c2 % NTT_DECDIG_BASE * NTT_PRIME2;
102+
c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1;
103+
c1 /= NTT_DECDIG_BASE;
104+
digits[0] = c0 % NTT_DECDIG_BASE;
105+
c0 /= NTT_DECDIG_BASE;
106+
c1 += c2 / NTT_DECDIG_BASE % NTT_DECDIG_BASE * NTT_PRIME2;
107+
c0 += c1 % NTT_DECDIG_BASE * NTT_PRIME1;
108+
c1 /= NTT_DECDIG_BASE;
109+
digits[1] = c0 % NTT_DECDIG_BASE;
110+
digits[2] = (uint32_t)(c0 / NTT_DECDIG_BASE + c1 % NTT_DECDIG_BASE * NTT_PRIME1);
111+
}
112+
113+
/*
114+
* NTT multiplication
115+
* Uses three NTTs with mod (24 << 27 | 1), (26 << 27 | 1), and (29 << 27 | 1)
116+
*/
117+
static void
118+
ntt_multiply(size_t a_size, size_t b_size, uint32_t *a, uint32_t *b, uint32_t *c) {
119+
if (a_size < b_size) {
120+
ntt_multiply(b_size, a_size, b, a, c);
121+
return;
122+
}
123+
124+
int b_bits = 0;
125+
while (((uint32_t)1 << b_bits) < (uint32_t)b_size) b_bits++;
126+
int ntt_size_bits = b_bits + 1;
127+
if (ntt_size_bits > MAX_NTT32_BITS) {
128+
rb_raise(rb_eArgError, "Multiply size too large");
129+
}
130+
131+
// To calculate large_a * small_b faster, split into several batches.
132+
uint32_t ntt_size = (uint32_t)1 << ntt_size_bits;
133+
uint32_t batch_size = ntt_size - (uint32_t)b_size;
134+
uint32_t batch_count = (uint32_t)((a_size + batch_size - 1) / batch_size);
135+
136+
uint32_t *mem = ruby_xcalloc(sizeof(uint32_t), ntt_size * 9);
137+
uint32_t *ntt1 = mem;
138+
uint32_t *ntt2 = mem + ntt_size;
139+
uint32_t *ntt3 = mem + ntt_size * 2;
140+
uint32_t *tmp1 = mem + ntt_size * 3;
141+
uint32_t *tmp2 = mem + ntt_size * 4;
142+
uint32_t *tmp3 = mem + ntt_size * 5;
143+
uint32_t *conv1 = mem + ntt_size * 6;
144+
uint32_t *conv2 = mem + ntt_size * 7;
145+
uint32_t *conv3 = mem + ntt_size * 8;
146+
147+
// Calculate NTT for b in three primes. Result is reused for each batch of a.
148+
memcpy(tmp1, b, b_size * sizeof(uint32_t));
149+
memset(tmp1 + b_size, 0, (ntt_size - b_size) * sizeof(uint32_t));
150+
ntt(ntt_size_bits, tmp1, ntt1, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1);
151+
ntt(ntt_size_bits, tmp1, ntt2, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1);
152+
ntt(ntt_size_bits, tmp1, ntt3, tmp2, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1);
153+
154+
memset(c, 0, (a_size + b_size) * sizeof(uint32_t));
155+
for (uint32_t idx = 0; idx < batch_count; idx++) {
156+
uint32_t len = idx == batch_count - 1 ? (uint32_t)a_size - idx * batch_size : batch_size;
157+
memcpy(tmp1, a + idx * batch_size, len * sizeof(uint32_t));
158+
memset(tmp1 + len, 0, (ntt_size - len) * sizeof(uint32_t));
159+
// Calculate convolution for this batch in three primes
160+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, +1);
161+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt1[i]) % NTT_PRIME1;
162+
ntt(ntt_size_bits, tmp2, conv1, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE1, NTT_PRIME_SHIFT, -1);
163+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, +1);
164+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt2[i]) % NTT_PRIME2;
165+
ntt(ntt_size_bits, tmp2, conv2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE2, NTT_PRIME_SHIFT, -1);
166+
ntt(ntt_size_bits, tmp1, tmp2, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, +1);
167+
for (uint32_t i = 0; i < ntt_size; i++) tmp2[i] = ((uint64_t)tmp2[i] * ntt3[i]) % NTT_PRIME3;
168+
ntt(ntt_size_bits, tmp2, conv3, tmp3, NTT_PRIMITIVE_ROOT, NTT_PRIME_BASE3, NTT_PRIME_SHIFT, -1);
169+
170+
// Restore the original convolution value from three convolutions calculated in three primes.
171+
// Each convolution value is maximum 999999999**2*(1<<27)/2
172+
for (uint32_t i = 0; i < ntt_size; i++) {
173+
uint32_t dig[3];
174+
mod_restore_prime_24_26_29_shift_27(conv1[i], conv2[i], conv3[i], dig);
175+
// Maximum values of dig[0], dig[1], and dig[2] are 999999999, 999999999 and 67108863 respectively
176+
// Maximum overlapped sum (considering overlaps between 2 batches) is less than 4134217722
177+
// so this sum doesn't overflow uint32_t.
178+
for (int j = 0; j < 3; j++) {
179+
// Index check: if dig[j] is non-zero, assign index is within valid range.
180+
if (dig[j]) c[idx * batch_size + i + 1 - j] += dig[j];
181+
}
182+
}
183+
}
184+
uint32_t carry = 0;
185+
for (int32_t i = (uint32_t)(a_size + b_size - 1); i >= 0; i--) {
186+
uint32_t v = c[i] + carry;
187+
c[i] = v % NTT_DECDIG_BASE;
188+
carry = v / NTT_DECDIG_BASE;
189+
}
190+
ruby_xfree(mem);
191+
}

test/bigdecimal/test_vp_operation.rb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def setup
1313
end
1414
end
1515

16+
def ntt_mult_available?
17+
BASE_FIG == 9
18+
end
19+
1620
def test_vpmult
1721
assert_equal(BigDecimal('121932631112635269'), BigDecimal('123456789').vpmult(BigDecimal('987654321')))
1822
assert_equal(BigDecimal('12193263.1112635269'), BigDecimal('123.456789').vpmult(BigDecimal('98765.4321')))
@@ -21,6 +25,15 @@ def test_vpmult
2125
assert_equal(BigDecimal("#{x * y}e-300"), BigDecimal("#{x}e-100").vpmult(BigDecimal("#{y}e-200")))
2226
end
2327

28+
def test_nttmult
29+
omit 'NTT multiplication is only available for 32-bit DECDIG' unless ntt_mult_available?
30+
[*1..32].repeated_permutation(2) do |a, b|
31+
x = BigDecimal(10 ** (BASE_FIG * a) / 7)
32+
y = BigDecimal(10 ** (BASE_FIG * b) / 13)
33+
assert_equal(x.to_i * y.to_i, x.nttmult(y))
34+
end
35+
end
36+
2437
def test_vpdivd
2538
# a[0] > b[0]
2639
# XXXX_YYYY_ZZZZ / 1111 #=> 000X_000Y_000Z

0 commit comments

Comments
 (0)