diff --git a/doc/source/gr_poly.rst b/doc/source/gr_poly.rst index aa820dab16..25f3b0cccd 100644 --- a/doc/source/gr_poly.rst +++ b/doc/source/gr_poly.rst @@ -170,6 +170,19 @@ Arithmetic algorithm with `O(n^{1.6})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch to :func:`_gr_poly_mul_karatsuba` above some cutoff. +.. function:: int _gr_poly_mul_toom33(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx); + int gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx); + + Balanced Toom-3 multiplication with interpolation in five points, + using the Bodrato evaluation scheme. Assumes commutativity and that the ring + supports exact division by 2 and 3. + Not optimized for squaring. + The underscore method requires positive lengths and does not support aliasing. + This function calls :func:`_gr_poly_mul` recursively rather than itself, so to get a recursive + algorithm with `O(n^{1.5})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch + to :func:`_gr_poly_mul_toom33` above some cutoff. + + Powering -------------------------------------------------------------------------------- diff --git a/src/gr_poly.h b/src/gr_poly.h index 6ba945a21c..9ce700b3fd 100644 --- a/src/gr_poly.h +++ b/src/gr_poly.h @@ -124,7 +124,8 @@ WARN_UNUSED_RESULT int gr_poly_mul_scalar(gr_poly_t res, const gr_poly_t poly, g WARN_UNUSED_RESULT int _gr_poly_mul_karatsuba(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx); WARN_UNUSED_RESULT int gr_poly_mul_karatsuba(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx); - +WARN_UNUSED_RESULT int _gr_poly_mul_toom33(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx); +WARN_UNUSED_RESULT int gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx); /* powering */ diff --git a/src/gr_poly/mul_toom33.c b/src/gr_poly/mul_toom33.c new file mode 100644 index 0000000000..3be7c20ffb --- /dev/null +++ b/src/gr_poly/mul_toom33.c @@ -0,0 +1,207 @@ +/* + Copyright (C) 2007 Marco Bodrato + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "gr_vec.h" +#include "gr_poly.h" + +/* + Toom33 (interpolation in 5 points) using Bodrato scheme + http://marco.bodrato.it/papers/Bodrato2007-OptimalToomCookMultiplicationForBinaryFieldAndIntegers.pdf + + Assumes commutativity, division by 3. + Todo: squaring version. + Todo: skip unnecessary zero-extensions of vectors and tighten + allocations. +*/ +int +_gr_poly_mul_toom33(gr_ptr res, gr_srcptr f, slong flen, gr_srcptr g, slong glen, gr_ctx_t ctx) +{ + gr_srcptr U0, U1, U2, V0, V1, V2; + gr_ptr tmp, W0, W1, W2, W3, W4; + slong m, U2len, V2len, U1len, V1len, U0len, V0len, rlen, len; + slong W4len; + slong sz = ctx->sizeof_elem; + slong alloc; + int status = GR_SUCCESS; + + /* TODO: should explicitly call basecase mul. */ + if (flen <= 1 || glen <= 1) + return _gr_poly_mullow_generic(res, f, flen, g, glen, flen + glen - 1, ctx); + + /* U = U2*x^(2m) + U1*x^m + U0 */ + /* V = V2*x^(2m) + V1*x^m + V0 */ + /* Each block has length m */ + m = FLINT_MAX(flen, glen); + m = (m + 3 - 1) / 3; + U0 = f; + U1 = GR_ENTRY(f, m, sz); + U2 = GR_ENTRY(f, 2 * m, sz); + V0 = g; + V1 = GR_ENTRY(g, m, sz); + V2 = GR_ENTRY(g, 2 * m, sz); + + U2len = FLINT_MAX(flen - 2 * m, 0); + V2len = FLINT_MAX(glen - 2 * m, 0); + U1len = FLINT_MIN(FLINT_MAX(flen - m, 0), m); + V1len = FLINT_MIN(FLINT_MAX(glen - m, 0), m); + U0len = FLINT_MIN(flen, m); + V0len = FLINT_MIN(glen, m); + + alloc = 10 * m; + GR_TMP_INIT_VEC(tmp, alloc, ctx); + W0 = tmp; + W1 = GR_ENTRY(W0, 2 * m, sz); + W2 = GR_ENTRY(W1, 2 * m, sz); + W3 = GR_ENTRY(W2, 2 * m, sz); + W4 = GR_ENTRY(W3, 2 * m, sz); + + /* Evaluation: 5*2 add, 2 shift; 5mul */ + /* W0 = U2 + U0 */ + /* if max(U2len,U0len) < m, assumes top coefficients are already zeroed from the initialization */ + status |= _gr_poly_add(W0, U2, U2len, U0, U0len, ctx); + /* W4 = V2 + V0 */ + /* if max(V2len,V0len) < m, assumes top coefficients are already zeroed from the initialization */ + status |= _gr_poly_add(W4, V2, V2len, V0, V0len, ctx); + /* W2 = W0 - U1 */ + status |= _gr_poly_sub(W2, W0, m, U1, U1len, ctx); + /* W1 = W4 - V1 */ + status |= _gr_poly_sub(W1, W4, m, V1, V1len, ctx); + /* W0 = W0 + U1 */ + status |= _gr_poly_add(W0, W0, m, U1, U1len, ctx); + /* W4 = W4 + V1 */ + status |= _gr_poly_add(W4, W4, m, V1, V1len, ctx); + /* W3 = W2 * W1 */ + status |= _gr_poly_mul(W3, W2, m, W1, m, ctx); + /* W1 = W0 * W4 */ + status |= _gr_poly_mul(W1, W0, m, W4, m, ctx); + /* W0 = ((W0 + U2) << 1) - U0 */ + status |= _gr_poly_add(W0, W0, m, U2, U2len, ctx); + status |= _gr_vec_mul_scalar_2exp_si(W0, W0, m, 1, ctx); + status |= _gr_poly_sub(W0, W0, m, U0, U0len, ctx); + /* W4 = ((W4 + V2) << 1) - V0 */ + status |= _gr_poly_add(W4, W4, m, V2, V2len, ctx); + status |= _gr_vec_mul_scalar_2exp_si(W4, W4, m, 1, ctx); + status |= _gr_poly_sub(W4, W4, m, V0, V0len, ctx); + /* W2 = W0 * W4 */ + status |= _gr_poly_mul(W2, W0, m, W4, m, ctx); + /* W0 = U0 * V0 */ + if (U0len > 0 && V0len > 0) + { + status |= _gr_poly_mul(W0, U0, U0len, V0, V0len, ctx); + status |= _gr_vec_zero(GR_ENTRY(W0, U0len + V0len - 1, sz), 2 * m - (U0len + V0len - 1), ctx); + } + else + status |= _gr_vec_zero(W0, 2 * m, ctx); + /* W4 = U2 * V2 */ + /* We compute this length accurately instead of zero-extending. */ + if (U2len > 0 && V2len > 0) + { + W4len = U2len + V2len - 1; + status |= _gr_poly_mul(W4, U2, U2len, V2, V2len, ctx); + } + else + { + W4len = 0; + } + + /* toom42 variant */ + /* U = U3*x^(3m) + U2*x^(2m) + U1*x^m + U0 */ + /* V = V1*x^m + V0 */ + /* Evaluation: 7+3 add, 3 shift; 5mul */ + /* + W0 = U1 + U3; + W4 = U0 + U2; + W3 = W4 + W0; + W4 = W4 - W0; + W0 = V0 + V1; + W2 = V0 - V1; + W1 = W3 * W0; + W3 = W4 * W2; + W4 = (((((U3<<1) + U2) << 1) + U1) << 1) + U0; + W0 = W0 + V1; + W2 = W4 * W0; + W0 = U0 * V0; + W4 = U3 * V1; + */ + + /* Interpolation: 8 add, 3 shift, 1 Sdiv */ + len = 2 * m - 1; + /* W2 = (W2 - W3) / 3 */ + status |= _gr_vec_sub(W2, W2, W3, len, ctx); + status |= _gr_vec_divexact_scalar_ui(W2, W2, len, 3, ctx); + /* W3 = (W1 - W3) >> 1 */ + status |= _gr_vec_sub(W3, W1, W3, len, ctx); + status |= _gr_vec_mul_scalar_2exp_si(W3, W3, len, -1, ctx); + /* W1 = W1 - W0 */ + status |= _gr_vec_sub(W1, W1, W0, len, ctx); + /* W2 = ((W2 - W1) >> 1) - (W4 << 1) */ + status |= _gr_vec_sub(W2, W2, W1, len, ctx); + status |= _gr_vec_mul_scalar_2exp_si(W2, W2, len, -1, ctx); + status |= _gr_vec_mul_scalar_2exp_si(res, W4, W4len, 1, ctx); + status |= _gr_vec_sub(W2, W2, res, W4len, ctx); + /* W1 = W1 - W3 - W4 */ + status |= _gr_vec_sub(W1, W1, W3, len, ctx); + status |= _gr_poly_sub(W1, W1, len, W4, W4len, ctx); + /* W3 = W3 - W2 */ + status |= _gr_vec_sub(W3, W3, W2, len, ctx); + + /* Recomposition: */ + /* W = W4 * x^(4m) + W2*x^(3m) + W1*x^(2m) + W*x^m + W0 */ + + rlen = flen + glen - 1; + len = FLINT_MIN(rlen, m); + status |= _gr_vec_set(res, W0, FLINT_MIN(rlen, m), ctx); + len = FLINT_MIN(rlen - m, m); + status |= _gr_vec_add(GR_ENTRY(res, m, sz), W3, GR_ENTRY(W0, m, sz), len, ctx); + len = FLINT_MIN(rlen - 2 * m, m); + status |= _gr_vec_add(GR_ENTRY(res, 2 * m, sz), W1, GR_ENTRY(W3, m, sz), len, ctx); + len = FLINT_MIN(rlen - 3 * m, m); + status |= _gr_vec_add(GR_ENTRY(res, 3 * m, sz), W2, GR_ENTRY(W1, m, sz), len, ctx); + len = FLINT_MIN(rlen - 4 * m, m); + status |= _gr_poly_add(GR_ENTRY(res, 4 * m, sz), W4, FLINT_MIN(W4len, len), GR_ENTRY(W2, m, sz), len, ctx); + len = FLINT_MIN(rlen - 5 * m, m); + status |= _gr_vec_set(GR_ENTRY(res, 5 * m, sz), GR_ENTRY(W4, m, sz), len, ctx); + + GR_TMP_CLEAR_VEC(tmp, alloc, ctx); + + return status; +} + +int +gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx) +{ + slong len_out; + int status; + + if (poly1->length == 0 || poly2->length == 0) + return gr_poly_zero(res, ctx); + + len_out = poly1->length + poly2->length - 1; + + if (res == poly1 || res == poly2) + { + gr_poly_t t; + gr_poly_init2(t, len_out, ctx); + status = _gr_poly_mul_toom33(t->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx); + gr_poly_swap(res, t, ctx); + gr_poly_clear(t, ctx); + } + else + { + gr_poly_fit_length(res, len_out, ctx); + status = _gr_poly_mul_toom33(res->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx); + } + + _gr_poly_set_length(res, len_out, ctx); + _gr_poly_normalise(res, ctx); + return status; +} diff --git a/src/gr_poly/test/main.c b/src/gr_poly/test/main.c index 62016ec560..04af841384 100644 --- a/src/gr_poly/test/main.c +++ b/src/gr_poly/test/main.c @@ -44,6 +44,7 @@ #include "t-log_series.c" #include "t-make_monic.c" #include "t-mul_karatsuba.c" +#include "t-mul_toom33.c" #include "t-nth_derivative.c" #include "t-pow_series_fmpq.c" #include "t-pow_series_ui.c" @@ -106,6 +107,7 @@ test_struct tests[] = TEST_FUNCTION(gr_poly_log_series), TEST_FUNCTION(gr_poly_make_monic), TEST_FUNCTION(gr_poly_mul_karatsuba), + TEST_FUNCTION(gr_poly_mul_toom33), TEST_FUNCTION(gr_poly_nth_derivative), TEST_FUNCTION(gr_poly_pow_series_fmpq), TEST_FUNCTION(gr_poly_pow_series_ui), diff --git a/src/gr_poly/test/t-mul_toom33.c b/src/gr_poly/test/t-mul_toom33.c new file mode 100644 index 0000000000..16439ac6b7 --- /dev/null +++ b/src/gr_poly/test/t-mul_toom33.c @@ -0,0 +1,106 @@ +/* + Copyright (C) 2023 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "ulong_extras.h" +#include "gr_poly.h" + +FLINT_DLL extern gr_static_method_table _ca_methods; + +int +test_mul1(flint_rand_t state, int which) +{ + gr_ctx_t ctx; + slong n; + gr_poly_t A, B, C, D; + int status = GR_SUCCESS; + + gr_ctx_init_random(ctx, state); + + gr_poly_init(A, ctx); + gr_poly_init(B, ctx); + gr_poly_init(C, ctx); + gr_poly_init(D, ctx); + + if (ctx->methods == _ca_methods) + n = 2; + else if (gr_ctx_is_finite(ctx) == T_TRUE) + n = 30; + else + n = 10; + + GR_MUST_SUCCEED(gr_poly_randtest(A, state, 1 + n_randint(state, n), ctx)); + GR_MUST_SUCCEED(gr_poly_randtest(B, state, 1 + n_randint(state, n), ctx)); + GR_MUST_SUCCEED(gr_poly_randtest(C, state, 1 + n_randint(state, n), ctx)); + + switch (which) + { + case 0: + status |= gr_poly_mul_toom33(C, A, B, ctx); + break; + case 1: + status |= gr_poly_set(C, A, ctx); + status |= gr_poly_mul_toom33(C, C, B, ctx); + break; + case 2: + status |= gr_poly_set(C, B, ctx); + status |= gr_poly_mul_toom33(C, A, C, ctx); + break; + case 3: + status |= gr_poly_set(B, A, ctx); + status |= gr_poly_mul_toom33(C, A, A, ctx); + break; + case 4: + status |= gr_poly_set(B, A, ctx); + status |= gr_poly_set(C, A, ctx); + status |= gr_poly_mul_toom33(C, C, C, ctx); + break; + + default: + flint_abort(); + } + + /* todo: should explicitly call basecase mul */ + status |= gr_poly_mullow(D, A, B, FLINT_MAX(0, A->length + B->length - 1), ctx); + + if (status == GR_SUCCESS && gr_poly_equal(C, D, ctx) == T_FALSE) + { + flint_printf("FAIL\n\n"); + flint_printf("which = %d, n = %wd\n\n", which, n); + gr_ctx_println(ctx); + flint_printf("A = "); gr_poly_print(A, ctx); flint_printf("\n\n"); + flint_printf("B = "); gr_poly_print(B, ctx); flint_printf("\n\n"); + flint_printf("C = "); gr_poly_print(C, ctx); flint_printf("\n\n"); + flint_printf("D = "); gr_poly_print(D, ctx); flint_printf("\n\n"); + flint_abort(); + } + + gr_poly_clear(A, ctx); + gr_poly_clear(B, ctx); + gr_poly_clear(C, ctx); + gr_poly_clear(D, ctx); + + gr_ctx_clear(ctx); + + return status; +} + +TEST_FUNCTION_START(gr_poly_mul_toom33, state) +{ + slong iter; + + for (iter = 0; iter < 1000; iter++) + { + test_mul1(state, n_randint(state, 5)); + } + + TEST_FUNCTION_END(state); +}