Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic Toom-3 multiplication for gr_poly #2071

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/source/gr_poly.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ Arithmetic
algorithm with `O(n^{1.6})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch
to :func:`_gr_poly_mul_karatsuba` above some cutoff.

.. function:: int _gr_poly_mul_toom33(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
int gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);

Balanced Toom-3 multiplication with interpolation in five points,
using the Bodrato evaluation scheme. Assumes commutativity and that the ring
supports exact division by 2 and 3.
Not optimized for squaring.
The underscore method requires positive lengths and does not support aliasing.
This function calls :func:`_gr_poly_mul` recursively rather than itself, so to get a recursive
algorithm with `O(n^{1.5})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch
to :func:`_gr_poly_mul_toom33` above some cutoff.


Powering
--------------------------------------------------------------------------------

Expand Down
3 changes: 2 additions & 1 deletion src/gr_poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ WARN_UNUSED_RESULT int gr_poly_mul_scalar(gr_poly_t res, const gr_poly_t poly, g

WARN_UNUSED_RESULT int _gr_poly_mul_karatsuba(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_poly_mul_karatsuba(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);

WARN_UNUSED_RESULT int _gr_poly_mul_toom33(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);

/* powering */

Expand Down
207 changes: 207 additions & 0 deletions src/gr_poly/mul_toom33.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
Copyright (C) 2007 Marco Bodrato
Copyright (C) 2024 Fredrik Johansson

This file is part of FLINT.

FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "gr_vec.h"
#include "gr_poly.h"

/*
Toom33 (interpolation in 5 points) using Bodrato scheme
http://marco.bodrato.it/papers/Bodrato2007-OptimalToomCookMultiplicationForBinaryFieldAndIntegers.pdf

Assumes commutativity, division by 3.
Todo: squaring version.
Todo: skip unnecessary zero-extensions of vectors and tighten
allocations.
*/
int
_gr_poly_mul_toom33(gr_ptr res, gr_srcptr f, slong flen, gr_srcptr g, slong glen, gr_ctx_t ctx)
{
gr_srcptr U0, U1, U2, V0, V1, V2;
gr_ptr tmp, W0, W1, W2, W3, W4;
slong m, U2len, V2len, U1len, V1len, U0len, V0len, rlen, len;
slong W4len;
slong sz = ctx->sizeof_elem;
slong alloc;
int status = GR_SUCCESS;

/* TODO: should explicitly call basecase mul. */
if (flen <= 1 || glen <= 1)
return _gr_poly_mullow_generic(res, f, flen, g, glen, flen + glen - 1, ctx);

/* U = U2*x^(2m) + U1*x^m + U0 */
/* V = V2*x^(2m) + V1*x^m + V0 */
/* Each block has length m */
m = FLINT_MAX(flen, glen);
m = (m + 3 - 1) / 3;
U0 = f;
U1 = GR_ENTRY(f, m, sz);
U2 = GR_ENTRY(f, 2 * m, sz);
V0 = g;
V1 = GR_ENTRY(g, m, sz);
V2 = GR_ENTRY(g, 2 * m, sz);

U2len = FLINT_MAX(flen - 2 * m, 0);
V2len = FLINT_MAX(glen - 2 * m, 0);
U1len = FLINT_MIN(FLINT_MAX(flen - m, 0), m);
V1len = FLINT_MIN(FLINT_MAX(glen - m, 0), m);
U0len = FLINT_MIN(flen, m);
V0len = FLINT_MIN(glen, m);

alloc = 10 * m;
GR_TMP_INIT_VEC(tmp, alloc, ctx);
W0 = tmp;
W1 = GR_ENTRY(W0, 2 * m, sz);
W2 = GR_ENTRY(W1, 2 * m, sz);
W3 = GR_ENTRY(W2, 2 * m, sz);
W4 = GR_ENTRY(W3, 2 * m, sz);

/* Evaluation: 5*2 add, 2 shift; 5mul */
/* W0 = U2 + U0 */
/* if max(U2len,U0len) < m, assumes top coefficients are already zeroed from the initialization */
status |= _gr_poly_add(W0, U2, U2len, U0, U0len, ctx);
/* W4 = V2 + V0 */
/* if max(V2len,V0len) < m, assumes top coefficients are already zeroed from the initialization */
status |= _gr_poly_add(W4, V2, V2len, V0, V0len, ctx);
/* W2 = W0 - U1 */
status |= _gr_poly_sub(W2, W0, m, U1, U1len, ctx);
/* W1 = W4 - V1 */
status |= _gr_poly_sub(W1, W4, m, V1, V1len, ctx);
/* W0 = W0 + U1 */
status |= _gr_poly_add(W0, W0, m, U1, U1len, ctx);
/* W4 = W4 + V1 */
status |= _gr_poly_add(W4, W4, m, V1, V1len, ctx);
/* W3 = W2 * W1 */
status |= _gr_poly_mul(W3, W2, m, W1, m, ctx);
/* W1 = W0 * W4 */
status |= _gr_poly_mul(W1, W0, m, W4, m, ctx);
/* W0 = ((W0 + U2) << 1) - U0 */
status |= _gr_poly_add(W0, W0, m, U2, U2len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W0, W0, m, 1, ctx);
status |= _gr_poly_sub(W0, W0, m, U0, U0len, ctx);
/* W4 = ((W4 + V2) << 1) - V0 */
status |= _gr_poly_add(W4, W4, m, V2, V2len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W4, W4, m, 1, ctx);
status |= _gr_poly_sub(W4, W4, m, V0, V0len, ctx);
/* W2 = W0 * W4 */
status |= _gr_poly_mul(W2, W0, m, W4, m, ctx);
/* W0 = U0 * V0 */
if (U0len > 0 && V0len > 0)
{
status |= _gr_poly_mul(W0, U0, U0len, V0, V0len, ctx);
status |= _gr_vec_zero(GR_ENTRY(W0, U0len + V0len - 1, sz), 2 * m - (U0len + V0len - 1), ctx);
}
else
status |= _gr_vec_zero(W0, 2 * m, ctx);
/* W4 = U2 * V2 */
/* We compute this length accurately instead of zero-extending. */
if (U2len > 0 && V2len > 0)
{
W4len = U2len + V2len - 1;
status |= _gr_poly_mul(W4, U2, U2len, V2, V2len, ctx);
}
else
{
W4len = 0;
}

/* toom42 variant */
/* U = U3*x^(3m) + U2*x^(2m) + U1*x^m + U0 */
/* V = V1*x^m + V0 */
/* Evaluation: 7+3 add, 3 shift; 5mul */
/*
W0 = U1 + U3;
W4 = U0 + U2;
W3 = W4 + W0;
W4 = W4 - W0;
W0 = V0 + V1;
W2 = V0 - V1;
W1 = W3 * W0;
W3 = W4 * W2;
W4 = (((((U3<<1) + U2) << 1) + U1) << 1) + U0;
W0 = W0 + V1;
W2 = W4 * W0;
W0 = U0 * V0;
W4 = U3 * V1;
*/

/* Interpolation: 8 add, 3 shift, 1 Sdiv */
len = 2 * m - 1;
/* W2 = (W2 - W3) / 3 */
status |= _gr_vec_sub(W2, W2, W3, len, ctx);
status |= _gr_vec_divexact_scalar_ui(W2, W2, len, 3, ctx);
/* W3 = (W1 - W3) >> 1 */
status |= _gr_vec_sub(W3, W1, W3, len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W3, W3, len, -1, ctx);
/* W1 = W1 - W0 */
status |= _gr_vec_sub(W1, W1, W0, len, ctx);
/* W2 = ((W2 - W1) >> 1) - (W4 << 1) */
status |= _gr_vec_sub(W2, W2, W1, len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W2, W2, len, -1, ctx);
status |= _gr_vec_mul_scalar_2exp_si(res, W4, W4len, 1, ctx);
status |= _gr_vec_sub(W2, W2, res, W4len, ctx);
/* W1 = W1 - W3 - W4 */
status |= _gr_vec_sub(W1, W1, W3, len, ctx);
status |= _gr_poly_sub(W1, W1, len, W4, W4len, ctx);
/* W3 = W3 - W2 */
status |= _gr_vec_sub(W3, W3, W2, len, ctx);

/* Recomposition: */
/* W = W4 * x^(4m) + W2*x^(3m) + W1*x^(2m) + W*x^m + W0 */

rlen = flen + glen - 1;
len = FLINT_MIN(rlen, m);
status |= _gr_vec_set(res, W0, FLINT_MIN(rlen, m), ctx);
len = FLINT_MIN(rlen - m, m);
status |= _gr_vec_add(GR_ENTRY(res, m, sz), W3, GR_ENTRY(W0, m, sz), len, ctx);
len = FLINT_MIN(rlen - 2 * m, m);
status |= _gr_vec_add(GR_ENTRY(res, 2 * m, sz), W1, GR_ENTRY(W3, m, sz), len, ctx);
len = FLINT_MIN(rlen - 3 * m, m);
status |= _gr_vec_add(GR_ENTRY(res, 3 * m, sz), W2, GR_ENTRY(W1, m, sz), len, ctx);
len = FLINT_MIN(rlen - 4 * m, m);
status |= _gr_poly_add(GR_ENTRY(res, 4 * m, sz), W4, FLINT_MIN(W4len, len), GR_ENTRY(W2, m, sz), len, ctx);
len = FLINT_MIN(rlen - 5 * m, m);
status |= _gr_vec_set(GR_ENTRY(res, 5 * m, sz), GR_ENTRY(W4, m, sz), len, ctx);

GR_TMP_CLEAR_VEC(tmp, alloc, ctx);

return status;
}

int
gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx)
{
slong len_out;
int status;

if (poly1->length == 0 || poly2->length == 0)
return gr_poly_zero(res, ctx);

len_out = poly1->length + poly2->length - 1;

if (res == poly1 || res == poly2)
{
gr_poly_t t;
gr_poly_init2(t, len_out, ctx);
status = _gr_poly_mul_toom33(t->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx);
gr_poly_swap(res, t, ctx);
gr_poly_clear(t, ctx);
}
else
{
gr_poly_fit_length(res, len_out, ctx);
status = _gr_poly_mul_toom33(res->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx);
}

_gr_poly_set_length(res, len_out, ctx);
_gr_poly_normalise(res, ctx);
return status;
}
2 changes: 2 additions & 0 deletions src/gr_poly/test/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "t-log_series.c"
#include "t-make_monic.c"
#include "t-mul_karatsuba.c"
#include "t-mul_toom33.c"
#include "t-nth_derivative.c"
#include "t-pow_series_fmpq.c"
#include "t-pow_series_ui.c"
Expand Down Expand Up @@ -106,6 +107,7 @@ test_struct tests[] =
TEST_FUNCTION(gr_poly_log_series),
TEST_FUNCTION(gr_poly_make_monic),
TEST_FUNCTION(gr_poly_mul_karatsuba),
TEST_FUNCTION(gr_poly_mul_toom33),
TEST_FUNCTION(gr_poly_nth_derivative),
TEST_FUNCTION(gr_poly_pow_series_fmpq),
TEST_FUNCTION(gr_poly_pow_series_ui),
Expand Down
106 changes: 106 additions & 0 deletions src/gr_poly/test/t-mul_toom33.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
Copyright (C) 2023 Fredrik Johansson

This file is part of FLINT.

FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "test_helpers.h"
#include "ulong_extras.h"
#include "gr_poly.h"

FLINT_DLL extern gr_static_method_table _ca_methods;

int
test_mul1(flint_rand_t state, int which)
{
gr_ctx_t ctx;
slong n;
gr_poly_t A, B, C, D;
int status = GR_SUCCESS;

gr_ctx_init_random(ctx, state);

gr_poly_init(A, ctx);
gr_poly_init(B, ctx);
gr_poly_init(C, ctx);
gr_poly_init(D, ctx);

if (ctx->methods == _ca_methods)
n = 2;
else if (gr_ctx_is_finite(ctx) == T_TRUE)
n = 30;
else
n = 10;

GR_MUST_SUCCEED(gr_poly_randtest(A, state, 1 + n_randint(state, n), ctx));
GR_MUST_SUCCEED(gr_poly_randtest(B, state, 1 + n_randint(state, n), ctx));
GR_MUST_SUCCEED(gr_poly_randtest(C, state, 1 + n_randint(state, n), ctx));

switch (which)
{
case 0:
status |= gr_poly_mul_toom33(C, A, B, ctx);
break;
case 1:
status |= gr_poly_set(C, A, ctx);
status |= gr_poly_mul_toom33(C, C, B, ctx);
break;
case 2:
status |= gr_poly_set(C, B, ctx);
status |= gr_poly_mul_toom33(C, A, C, ctx);
break;
case 3:
status |= gr_poly_set(B, A, ctx);
status |= gr_poly_mul_toom33(C, A, A, ctx);
break;
case 4:
status |= gr_poly_set(B, A, ctx);
status |= gr_poly_set(C, A, ctx);
status |= gr_poly_mul_toom33(C, C, C, ctx);
break;

default:
flint_abort();
}

/* todo: should explicitly call basecase mul */
status |= gr_poly_mullow(D, A, B, FLINT_MAX(0, A->length + B->length - 1), ctx);

if (status == GR_SUCCESS && gr_poly_equal(C, D, ctx) == T_FALSE)
{
flint_printf("FAIL\n\n");
flint_printf("which = %d, n = %wd\n\n", which, n);
gr_ctx_println(ctx);
flint_printf("A = "); gr_poly_print(A, ctx); flint_printf("\n\n");
flint_printf("B = "); gr_poly_print(B, ctx); flint_printf("\n\n");
flint_printf("C = "); gr_poly_print(C, ctx); flint_printf("\n\n");
flint_printf("D = "); gr_poly_print(D, ctx); flint_printf("\n\n");
flint_abort();
}

gr_poly_clear(A, ctx);
gr_poly_clear(B, ctx);
gr_poly_clear(C, ctx);
gr_poly_clear(D, ctx);

gr_ctx_clear(ctx);

return status;
}

TEST_FUNCTION_START(gr_poly_mul_toom33, state)
{
slong iter;

for (iter = 0; iter < 1000; iter++)
{
test_mul1(state, n_randint(state, 5));
}

TEST_FUNCTION_END(state);
}
Loading