diff --git a/doc/source/nmod_vec.rst b/doc/source/nmod_vec.rst index 1aa24aa4d3..47ff530204 100644 --- a/doc/source/nmod_vec.rst +++ b/doc/source/nmod_vec.rst @@ -23,7 +23,7 @@ Random functions .. function:: void _nmod_vec_randtest(nn_ptr vec, flint_rand_t state, slong len, nmod_t mod) - Sets ``vec`` to a random vector of the given length with entries + Sets ``vec`` to a random vector of the given length with entries reduced modulo ``mod.n``. @@ -46,7 +46,7 @@ Basic manipulation and comparison .. function:: void _nmod_vec_reduce(nn_ptr res, nn_srcptr vec, slong len, nmod_t mod) - Reduces the entries of ``(vec, len)`` modulo ``mod.n`` and set + Reduces the entries of ``(vec, len)`` modulo ``mod.n`` and set ``res`` to the result. .. function:: flint_bitcnt_t _nmod_vec_max_bits(nn_srcptr vec, slong len) @@ -55,8 +55,8 @@ Basic manipulation and comparison .. function:: int _nmod_vec_equal(nn_srcptr vec, nn_srcptr vec2, slong len) - Returns~`1` if ``(vec, len)`` is equal to ``(vec2, len)``, - otherwise returns~`0`. + Returns `1` if ``(vec, len)`` is equal to ``(vec2, len)``, + otherwise returns `0`. Printing @@ -92,12 +92,12 @@ Arithmetic operations .. function:: void _nmod_vec_add(nn_ptr res, nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) - Sets ``(res, len)`` to the sum of ``(vec1, len)`` + Sets ``(res, len)`` to the sum of ``(vec1, len)`` and ``(vec2, len)``. .. function:: void _nmod_vec_sub(nn_ptr res, nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) - Sets ``(res, len)`` to the difference of ``(vec1, len)`` + Sets ``(res, len)`` to the difference of ``(vec1, len)`` and ``(vec2, len)``. .. function:: void _nmod_vec_neg(nn_ptr res, nn_srcptr vec, slong len, nmod_t mod) @@ -107,34 +107,93 @@ Arithmetic operations .. function:: void _nmod_vec_scalar_mul_nmod(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod) Sets ``(res, len)`` to ``(vec, len)`` multiplied by `c`. The element - `c` and all elements of `vec` are assumed to be less than `mod.n`. + `c` and all elements of ``vec`` are assumed to be less than ``mod.n``. .. function:: void _nmod_vec_scalar_mul_nmod_shoup(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod) Sets ``(res, len)`` to ``(vec, len)`` multiplied by `c` using - :func:`n_mulmod_shoup`. `mod.n` should be less than `2^{\mathtt{FLINT\_BITS} - 1}`. `c` - and all elements of `vec` should be less than `mod.n`. + :func:`n_mulmod_shoup`. `mod.n` should be less than `2^{\mathtt{FLINT\_BITS} - 1}`. `c` + and all elements of ``vec`` should be less than ``mod.n``. .. function:: void _nmod_vec_scalar_addmul_nmod(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod) Adds ``(vec, len)`` times `c` to the vector ``(res, len)``. The element - `c` and all elements of `vec` are assumed to be less than `mod.n`. + `c` and all elements of ``vec`` are assumed to be less than ``mod.n``. Dot products -------------------------------------------------------------------------------- +Dot products functions and macros rely on several implementations, depending on +the length of this dot product and on the underlying modulus. What +implementations will be called is determined via ``_nmod_vec_dot_params``, +which returns a ``dot_params_t`` element which can then be used as input to the +dot product routines. -.. function:: int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod) +The efficiency of the different approaches range roughly as follows, from +faster to slower, on 64 bit machines. In all cases, modular reduction is only +performed at the very end of the computation. - Returns the number of limbs (0, 1, 2 or 3) needed to represent the - unreduced dot product of two vectors of length ``len`` having entries - modulo ``mod.n``, assuming that ``len`` is nonnegative and that - ``mod.n`` is nonzero. The computed bound is tight. In other words, - this function returns the precise limb size of ``len`` times - ``(mod.n - 1) ^ 2``. +- moduli up to `1515531528` (about `2^{30.5}`): implemented via single limb + integer multiplication, using explicit vectorization if supported (current + support is for AVX2); + +- moduli that are a power of `2` up to `2^{32}`: same efficiency as the above + case; + +- moduli that are a power of `2` between `2^{33}` and `2^{63}`: efficiency + between that of the above case and that of the below one (depending on the + machine and on automatic vectorization); + +- other moduli up to `2^{32}`: implemented via single limb integer + multiplication combined with accumulation in two limbs; + +- moduli more than `2^{32}`, unreduced dot product fits in two limbs: + implemented via two limbs integer multiplication, with a final modular + reduction; + +- unreduced dot product fits in three limbs, moduli up to about `2^{62.5}`: + implemented via two limbs integer multiplication, with intermediate + accumulation of sub-products in two limbs, and overall accumulation in three + limbs; + +- unreduced dot product fits in three limbs, other moduli: implemented via two + limbs integer multiplication, with accumulation in three limbs. + + +.. type:: dot_params_t + +.. function:: dot_params_t _nmod_vec_dot_params(slong len, nmod_t mod) + + Returns a ``dot_params_t`` element. This element can be used as input for + the dot product macros and functions that require it, for any dot product + of vector with entries reduced modulo ``mod.n`` and whose length is less + than or equal to ``len``. + + Internals, subject to change: its field ``method`` indicates the method that + will be used to compute a dot product of this length ``len`` when working + with the given ``mod``. Its field ``pow2_precomp`` is set to ``2**DOT_SPLIT_BITS + % mod.n`` if ``method == _DOT2_SPLIT``, and to `0` otherwise. -.. macro:: NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, nlimbs) +.. function:: ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params) + + Returns the dot product of (``vec1``, ``len``) and (``vec2``, ``len``). The + input ``params`` has type ``dot_params_t`` and must have been computed via + ``_nmod_vec_dot_params`` with the specified ``mod`` and with a length + greater than or equal to ``len``. + +.. function:: ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params) + + The same as ``_nmod_vec_dot``, but reverses ``vec2``. + +.. function:: ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, dot_params_t params) + + Returns the dot product of (``vec1``, ``len``) and the values at + ``vec2[i][offset]``. The input ``params`` has type ``dot_params_t`` and + must have been computed via ``_nmod_vec_dot_params`` with the specified + ``mod`` and with a length greater than or equal to ``len``. + +.. macro:: NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, params) Effectively performs the computation:: @@ -142,27 +201,26 @@ Dot products for (i = 0; i < len; i++) res += (expr1) * (expr2); - but with the arithmetic performed modulo ``mod``. The ``nlimbs`` parameter - should be 0, 1, 2 or 3, specifying the number of limbs needed to represent - the unreduced result. + but with the arithmetic performed modulo ``mod``. The input ``params`` has + type ``dot_params_t`` and must have been computed via + ``_nmod_vec_dot_params`` with the specified ``mod`` and with a length + greater than or equal to ``len``. ``nmod.h`` has to be included in order for this macro to work (order of inclusions does not matter). -.. function:: ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs) - - Returns the dot product of (``vec1``, ``len``) and - (``vec2``, ``len``). The ``nlimbs`` parameter should be - 0, 1, 2 or 3, specifying the number of limbs needed to represent the - unreduced result. +.. function:: int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod) -.. function:: ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs) + Returns the number of limbs (0, 1, 2 or 3) needed to represent the + unreduced dot product of two vectors of length ``len`` having entries + modulo ``mod.n``, assuming that ``len`` is nonnegative and that + ``mod.n`` is nonzero. The computed bound is tight. In other words, + this function returns the precise limb size of ``len`` times + ``(mod.n - 1)**2``. - The same as ``_nmod_vec_dot``, but reverses ``vec2``. +.. function:: int _nmod_vec_dot_bound_limbs_from_params(slong len, nmod_t mod, dot_params_t params) -.. function:: ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, int nlimbs) + Same specification as ``_nmod_vec_dot_bound_limbs``, but uses the additional + input ``params`` to reduce the amount of computations; for correctness + ``params`` must have been computed for the specified ``len`` and ``mod``. - Returns the dot product of (``vec1``, ``len``) and the values at - ``vec2[i][offset]``. The ``nlimbs`` parameter should be - 0, 1, 2 or 3, specifying the number of limbs needed to represent the - unreduced result. diff --git a/src/arith/stirling2.c b/src/arith/stirling2.c index 1315f50b4e..71570cb5dc 100644 --- a/src/arith/stirling2.c +++ b/src/arith/stirling2.c @@ -537,7 +537,7 @@ stirling_2_nmod(const unsigned int * divtab, ulong n, ulong k, nmod_t mod) nn_ptr t, u; slong i, bin_len, pow_len; ulong s1, s2, bden, bd; - int bound_limbs; + dot_params_t params; TMP_INIT; TMP_START; @@ -575,13 +575,13 @@ stirling_2_nmod(const unsigned int * divtab, ulong n, ulong k, nmod_t mod) for (i = 1; i < bin_len; i += 2) t[i] = nmod_neg(t[i], mod); - bound_limbs = _nmod_vec_dot_bound_limbs(bin_len, mod); - s1 = _nmod_vec_dot(t, u, bin_len, mod, bound_limbs); + params = _nmod_vec_dot_params(bin_len, mod); + s1 = _nmod_vec_dot(t, u, bin_len, mod, params); if (pow_len > bin_len) { - bound_limbs = _nmod_vec_dot_bound_limbs(pow_len - bin_len, mod); - s2 = _nmod_vec_dot_rev(u + bin_len, t + k - pow_len + 1, pow_len - bin_len, mod, bound_limbs); + params = _nmod_vec_dot_params(pow_len - bin_len, mod); + s2 = _nmod_vec_dot_rev(u + bin_len, t + k - pow_len + 1, pow_len - bin_len, mod, params); if (k % 2) s1 = nmod_sub(s1, s2, mod); else diff --git a/src/fq_nmod_mpoly/fq_nmod_embed.c b/src/fq_nmod_mpoly/fq_nmod_embed.c index d53af70b1d..cb538a2a42 100644 --- a/src/fq_nmod_mpoly/fq_nmod_embed.c +++ b/src/fq_nmod_mpoly/fq_nmod_embed.c @@ -352,12 +352,12 @@ void bad_n_fq_embed_lg_to_sm( slong smd = fq_nmod_ctx_degree(emb->smctx); slong lgd = fq_nmod_ctx_degree(emb->lgctx); slong i; - int nlimbs = _nmod_vec_dot_bound_limbs(lgd, emb->lgctx->mod); + const dot_params_t params = _nmod_vec_dot_params(lgd, emb->lgctx->mod); n_poly_fit_length(out, lgd); for (i = 0; i < lgd; i++) out->coeffs[i] = _nmod_vec_dot(emb->lg_to_sm_mat->rows[i], in, lgd, - emb->lgctx->mod, nlimbs); + emb->lgctx->mod, params); FLINT_ASSERT(lgd/smd == emb->h->length - 1); out->length = emb->h->length - 1; _n_fq_poly_normalise(out, smd); @@ -438,7 +438,7 @@ void bad_n_fq_embed_sm_to_lg( slong smd = fq_nmod_ctx_degree(emb->smctx); slong lgd = fq_nmod_ctx_degree(emb->lgctx); slong i; - int nlimbs = _nmod_vec_dot_bound_limbs(lgd, emb->lgctx->mod); + const dot_params_t params = _nmod_vec_dot_params(lgd, emb->lgctx->mod); n_poly_stack_t St; /* TODO: pass the stack in */ n_fq_poly_struct * q, * in_red; @@ -454,7 +454,7 @@ void bad_n_fq_embed_sm_to_lg( for (i = 0; i < lgd; i++) out[i] = _nmod_vec_dot(emb->sm_to_lg_mat->rows[i], in_red->coeffs, - smd*in_red->length, emb->lgctx->mod, nlimbs); + smd*in_red->length, emb->lgctx->mod, params); n_poly_stack_give_back(St, 2); @@ -544,11 +544,11 @@ void bad_n_fq_embed_sm_elem_to_lg( slong smd = fq_nmod_ctx_degree(emb->smctx); slong lgd = fq_nmod_ctx_degree(emb->lgctx); slong i; - int nlimbs = _nmod_vec_dot_bound_limbs(smd, emb->lgctx->mod); + const dot_params_t params = _nmod_vec_dot_params(smd, emb->lgctx->mod); for (i = 0; i < lgd; i++) out[i] = _nmod_vec_dot(emb->sm_to_lg_mat->rows[i], in, smd, - emb->lgctx->mod, nlimbs); + emb->lgctx->mod, params); } void bad_fq_nmod_embed_sm_elem_to_lg( @@ -559,7 +559,7 @@ void bad_fq_nmod_embed_sm_elem_to_lg( slong smd = fq_nmod_ctx_degree(emb->smctx); slong lgd = fq_nmod_ctx_degree(emb->lgctx); slong i; - int nlimbs = _nmod_vec_dot_bound_limbs(smd, emb->lgctx->mod); + const dot_params_t params = _nmod_vec_dot_params(smd, emb->lgctx->mod); FLINT_ASSERT(in->length <= smd); @@ -568,7 +568,7 @@ void bad_fq_nmod_embed_sm_elem_to_lg( for (i = 0; i < lgd; i++) { out->coeffs[i] = _nmod_vec_dot(emb->sm_to_lg_mat->rows[i], - in->coeffs, in->length, emb->lgctx->mod, nlimbs); + in->coeffs, in->length, emb->lgctx->mod, params); } out->length = lgd; diff --git a/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_lgprime.c b/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_lgprime.c index 733c6e34f6..2a25148c12 100644 --- a/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_lgprime.c +++ b/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_lgprime.c @@ -190,11 +190,10 @@ static void _lattice( n_bpoly_t Q, R, dg; n_bpoly_struct * ld; nmod_mat_t M, T1, T2; - int nlimbs; ulong * trow; slong lift_order = lift_alpha_pow->length - 1; - nlimbs = _nmod_vec_dot_bound_limbs(r, ctx->mod); + const dot_params_t params = _nmod_vec_dot_params(r, ctx->mod); trow = (ulong *) flint_malloc(r*sizeof(ulong)); n_bpoly_init(Q); n_bpoly_init(R); @@ -243,7 +242,7 @@ static void _lattice( for (i = 0; i < d; i++) nmod_mat_entry(M, (j - starts[k])*deg + l, i) = - _nmod_vec_dot(trow, N->rows[i], r, ctx->mod, nlimbs); + _nmod_vec_dot(trow, N->rows[i], r, ctx->mod, params); } nmod_mat_init_nullspace_tr(T1, M); diff --git a/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_smprime.c b/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_smprime.c index 55da43afd4..ac4d498a88 100644 --- a/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_smprime.c +++ b/src/fq_nmod_mpoly_factor/n_bpoly_fq_factor_smprime.c @@ -957,10 +957,9 @@ static void _lattice( n_fq_bpoly_t Q, R, dg; n_fq_bpoly_struct * ld; nmod_mat_t M, T1, T2; - int nlimbs; ulong * trow; - nlimbs = _nmod_vec_dot_bound_limbs(r, ctx->mod); + const dot_params_t params = _nmod_vec_dot_params(r, ctx->mod); trow = (ulong *) flint_malloc(r*sizeof(ulong)); n_fq_bpoly_init(Q); n_fq_bpoly_init(R); @@ -985,20 +984,20 @@ static void _lattice( nmod_mat_init(M, d*(lift_order - CLD[k]), nrows, ctx->modulus->mod.n); for (j = CLD[k]; j < lift_order; j++) - for (l = 0; l < d; l++) - { - for (i = 0; i < r; i++) + for (l = 0; l < d; l++) { - if (k >= ld[i].length || j >= ld[i].coeffs[k].length) - trow[i] = 0; - else - trow[i] = ld[i].coeffs[k].coeffs[d*j + l]; - } + for (i = 0; i < r; i++) + { + if (k >= ld[i].length || j >= ld[i].coeffs[k].length) + trow[i] = 0; + else + trow[i] = ld[i].coeffs[k].coeffs[d*j + l]; + } - for (i = 0; i < nrows; i++) - nmod_mat_entry(M, (j - CLD[k])*d + l, i) = - _nmod_vec_dot(trow, N->rows[i], r, ctx->mod, nlimbs); - } + for (i = 0; i < nrows; i++) + nmod_mat_entry(M, (j - CLD[k])*d + l, i) = + _nmod_vec_dot(trow, N->rows[i], r, ctx->mod, params); + } nmod_mat_init_nullspace_tr(T1, M); diff --git a/src/fq_zech_mpoly_factor/bpoly_factor_smprime.c b/src/fq_zech_mpoly_factor/bpoly_factor_smprime.c index 84851961cd..6979ed9719 100644 --- a/src/fq_zech_mpoly_factor/bpoly_factor_smprime.c +++ b/src/fq_zech_mpoly_factor/bpoly_factor_smprime.c @@ -496,10 +496,9 @@ static void _lattice( fq_zech_bpoly_t Q, R, dg; fq_zech_bpoly_struct * ld; nmod_mat_t M, T1, T2; - int nlimbs; ulong * trow; - nlimbs = _nmod_vec_dot_bound_limbs(r, fq_zech_ctx_mod(ctx)); + const dot_params_t params = _nmod_vec_dot_params(r, fq_zech_ctx_mod(ctx)); trow = (ulong *) flint_malloc(r*sizeof(ulong)); fq_zech_bpoly_init(Q, ctx); fq_zech_bpoly_init(R, ctx); @@ -549,7 +548,7 @@ static void _lattice( for (i = 0; i < d; i++) nmod_mat_entry(M, (j - starts[k])*deg + l, i) = - _nmod_vec_dot(trow, N->rows[i], r, fq_zech_ctx_mod(ctx), nlimbs); + _nmod_vec_dot(trow, N->rows[i], r, fq_zech_ctx_mod(ctx), params); } nmod_mat_init_nullspace_tr(T1, M); diff --git a/src/gr/nmod.c b/src/gr/nmod.c index 54b7e2f11a..6e85aa80c1 100644 --- a/src/gr/nmod.c +++ b/src/gr/nmod.c @@ -856,53 +856,22 @@ _gr_nmod_vec_product(ulong * res, const ulong * vec, slong len, gr_ctx_t ctx) int __gr_nmod_vec_dot(ulong * res, const ulong * initial, int subtract, const ulong * vec1, const ulong * vec2, slong len, gr_ctx_t ctx) { - slong i; ulong s; - int nlimbs; + dot_params_t params; nmod_t mod; - if (len <= 1) + if (len == 0) { - if (len == 2) /* todo: fmma */ - { - mod = NMOD_CTX(ctx); - s = nmod_mul(vec1[0], vec2[0], mod); - s = nmod_addmul(s, vec1[1], vec2[1], mod); - } - else if (len == 1) - { - mod = NMOD_CTX(ctx); - s = nmod_mul(vec1[0], vec2[0], mod); - } + if (initial == NULL) + _gr_nmod_zero(res, ctx); else - { - if (initial == NULL) - _gr_nmod_zero(res, ctx); - else - _gr_nmod_set(res, initial, ctx); - return GR_SUCCESS; - } + _gr_nmod_set(res, initial, ctx); + return GR_SUCCESS; } - else - { - mod = NMOD_CTX(ctx); - if (len <= 16) - { - if (mod.n <= UWORD(1) << (FLINT_BITS / 2 - 2)) - nlimbs = 1; - if (mod.n <= UWORD(1) << (FLINT_BITS - 2)) - nlimbs = 2; - else - nlimbs = 3; - } - else - { - nlimbs = _nmod_vec_dot_bound_limbs(len, mod); - } - - NMOD_VEC_DOT(s, i, len, vec1[i], vec2[i], mod, nlimbs); - } + mod = NMOD_CTX(ctx); + params = _nmod_vec_dot_params(len, mod); + s = _nmod_vec_dot(vec1, vec2, len, mod, params); if (initial == NULL) { @@ -925,53 +894,22 @@ __gr_nmod_vec_dot(ulong * res, const ulong * initial, int subtract, const ulong int __gr_nmod_vec_dot_rev(ulong * res, const ulong * initial, int subtract, const ulong * vec1, const ulong * vec2, slong len, gr_ctx_t ctx) { - slong i; ulong s; - int nlimbs; + dot_params_t params; nmod_t mod; - if (len <= 1) + if (len == 0) { - if (len == 2) /* todo: fmma */ - { - mod = NMOD_CTX(ctx); - s = nmod_mul(vec1[0], vec2[1], mod); - s = nmod_addmul(s, vec1[1], vec2[0], mod); - } - else if (len == 1) - { - mod = NMOD_CTX(ctx); - s = nmod_mul(vec1[0], vec2[0], mod); - } + if (initial == NULL) + _gr_nmod_zero(res, ctx); else - { - if (initial == NULL) - _gr_nmod_zero(res, ctx); - else - _gr_nmod_set(res, initial, ctx); - return GR_SUCCESS; - } + _gr_nmod_set(res, initial, ctx); + return GR_SUCCESS; } - else - { - mod = NMOD_CTX(ctx); - if (len <= 16) - { - if (mod.n <= UWORD(1) << (FLINT_BITS / 2 - 2)) - nlimbs = 1; - if (mod.n <= UWORD(1) << (FLINT_BITS - 2)) - nlimbs = 2; - else - nlimbs = 3; - } - else - { - nlimbs = _nmod_vec_dot_bound_limbs(len, mod); - } - - NMOD_VEC_DOT(s, i, len, vec1[i], vec2[len - 1 - i], mod, nlimbs); - } + mod = NMOD_CTX(ctx); + params = _nmod_vec_dot_params(len, mod); + s = _nmod_vec_dot_rev(vec1, vec2, len, mod, params); if (initial == NULL) { diff --git a/src/gr/nmod32.c b/src/gr/nmod32.c index 83a8924980..a5c88b7c04 100644 --- a/src/gr/nmod32.c +++ b/src/gr/nmod32.c @@ -10,7 +10,7 @@ */ #include "fmpz.h" -#include "fmpq.h" +//#include "fmpq.h" #include "nmod.h" #include "nmod_vec.h" #include "gr.h" @@ -430,10 +430,8 @@ _nmod32_vec_dot(nmod32_t res, const nmod32_t initial, int subtract, const nmod32 { ulong ss; - int nlimbs; - - nlimbs = _nmod_vec_dot_bound_limbs(len, NMOD32_CTX(ctx)); - NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[i], NMOD32_CTX(ctx), nlimbs); + const dot_params_t params = _nmod_vec_dot_params(len, NMOD32_CTX(ctx)); + NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[i], NMOD32_CTX(ctx), params); s = n_addmod(s, ss, n); } @@ -477,10 +475,9 @@ _nmod32_vec_dot_rev(nmod32_t res, const nmod32_t initial, int subtract, const nm { ulong ss; - int nlimbs; - nlimbs = _nmod_vec_dot_bound_limbs(len, NMOD32_CTX(ctx)); - NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[len - 1 - i], NMOD32_CTX(ctx), nlimbs); + const dot_params_t params = _nmod_vec_dot_params(len, NMOD32_CTX(ctx)); + NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[len - 1 - i], NMOD32_CTX(ctx), params); s = n_addmod(s, ss, n); } diff --git a/src/gr/nmod8.c b/src/gr/nmod8.c index 5bb1ac71d3..1c1622bbf1 100644 --- a/src/gr/nmod8.c +++ b/src/gr/nmod8.c @@ -10,7 +10,7 @@ */ #include "fmpz.h" -#include "fmpq.h" +//#include "fmpq.h" #include "nmod.h" #include "nmod_vec.h" #include "gr.h" @@ -440,10 +440,9 @@ _nmod8_vec_dot(nmod8_t res, const nmod8_t initial, int subtract, const nmod8_str else { ulong ss; - int nlimbs; - nlimbs = _nmod_vec_dot_bound_limbs(len, NMOD8_CTX(ctx)); - NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[i], NMOD8_CTX(ctx), nlimbs); + const dot_params_t params = _nmod_vec_dot_params(len, NMOD8_CTX(ctx)); + NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[i], NMOD8_CTX(ctx), params); s = n_addmod(s, ss, n); } @@ -504,10 +503,9 @@ _nmod8_vec_dot_rev(nmod8_t res, const nmod8_t initial, int subtract, const nmod8 else { ulong ss; - int nlimbs; - nlimbs = _nmod_vec_dot_bound_limbs(len, NMOD8_CTX(ctx)); - NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[len - 1 - i], NMOD8_CTX(ctx), nlimbs); + const dot_params_t params = _nmod_vec_dot_params(len, NMOD8_CTX(ctx)); + NMOD_VEC_DOT(ss, i, len, (ulong) vec1[i], (ulong) vec2[len - 1 - i], NMOD8_CTX(ctx), params); s = n_addmod(s, ss, n); } diff --git a/src/machine_vectors.h b/src/machine_vectors.h index c2d7c91bc4..99ffbf1ff3 100644 --- a/src/machine_vectors.h +++ b/src/machine_vectors.h @@ -550,6 +550,15 @@ FLINT_FORCE_INLINE vec4d vec4n_convert_limited_vec4d(vec4n a) { return _mm256_sub_pd(_mm256_or_pd(_mm256_castsi256_pd(a), t), t); } +// horizontal sum +FLINT_FORCE_INLINE ulong vec4n_horizontal_sum(vec4n a) { + vec4n a_hi = _mm256_shuffle_epi32(a, 14); // 14 == 0b00001110 + vec4n sum_lo = _mm256_add_epi64(a, a_hi); + vec2n sum_hi = _mm256_extracti128_si256(sum_lo, 1); + vec2n sum = _mm_add_epi64(_mm256_castsi256_si128(sum_lo), sum_hi); + return (ulong) _mm_cvtsi128_si64(sum); +} + /* vec8d -- AVX2 ***********************************************************/ @@ -1141,10 +1150,10 @@ FLINT_FORCE_INLINE vec2d vec2d_reduce_0n_to_pmhn(vec2d a, vec2d n) { FLINT_FORCE_INLINE vec2d vec2d_reduce_pm1n_to_pmhn(vec2d a, vec2d n) { vec2d halfn = vec2d_half(n); vec2d t = vec2d_add(a, n); - + vec2n condition_a = vcgtq_f64(a, halfn); vec2n condition_t = vcltq_f64(t, halfn); - + return vbslq_f64(condition_a, vec2d_sub(a, n), vbslq_f64(condition_t, t, a)); } @@ -1500,18 +1509,18 @@ FLINT_FORCE_INLINE vec2n vec2n_sub(vec2n a, vec2n b) { FLINT_FORCE_INLINE vec2n vec2n_addmod(vec2n a, vec2n b, vec2n n) { vec2n nmb = vec2n_sub(n, b); vec2n sum = vec2n_sub(a, nmb); - + vec2n mask = vcgtq_u64(nmb, a); - + return vec2n_add(sum, vandq_u64(n, mask)); } // (a + b) % n for n < 2^63 FLINT_FORCE_INLINE vec2n vec2n_addmod_limited(vec2n a, vec2n b, vec2n n) { vec2n s = vec2n_add(a, b); - + vec2n mask = vcgeq_u64(s, n); - + return vec2n_sub(s, vandq_u64(n, mask)); } diff --git a/src/n_poly.h b/src/n_poly.h index 2a06829130..66ace2db13 100644 --- a/src/n_poly.h +++ b/src/n_poly.h @@ -357,7 +357,7 @@ void n_poly_mod_addmul_linear(n_poly_t A, const n_poly_t B, void n_poly_mod_scalar_addmul_nmod(n_poly_t A, const n_poly_t B, const n_poly_t C, ulong d0, nmod_t ctx); -ulong _n_poly_eval_pow(n_poly_t P, n_poly_t alphapow, int nlimbs, +ulong _n_poly_eval_pow(n_poly_t P, n_poly_t alphapow, dot_params_t params, nmod_t ctx); ulong n_poly_mod_eval_pow(n_poly_t P, n_poly_t alphapow, diff --git a/src/n_poly/n_poly_mod.c b/src/n_poly/n_poly_mod.c index 29da4c2db8..17842f403a 100644 --- a/src/n_poly/n_poly_mod.c +++ b/src/n_poly/n_poly_mod.c @@ -1030,7 +1030,7 @@ ulong n_poly_mod_remove(n_poly_t f, const n_poly_t p, nmod_t ctx) return i; } -ulong _n_poly_eval_pow(n_poly_t P, n_poly_t alphapow, int nlimbs, nmod_t ctx) +ulong _n_poly_eval_pow(n_poly_t P, n_poly_t alphapow, dot_params_t params, nmod_t ctx) { ulong * Pcoeffs = P->coeffs; slong Plen = P->length; @@ -1049,15 +1049,15 @@ ulong _n_poly_eval_pow(n_poly_t P, n_poly_t alphapow, int nlimbs, nmod_t ctx) alpha_powers[k] = nmod_mul(alpha_powers[k - 1], alpha_powers[1], ctx); } - NMOD_VEC_DOT(res, k, Plen, Pcoeffs[k], alpha_powers[k], ctx, nlimbs); + NMOD_VEC_DOT(res, k, Plen, Pcoeffs[k], alpha_powers[k], ctx, params); return res; } ulong n_poly_mod_eval_pow(n_poly_t P, n_poly_t alphapow, nmod_t ctx) { - int nlimbs = _nmod_vec_dot_bound_limbs(P->length, ctx); - return _n_poly_eval_pow(P, alphapow, nlimbs, ctx); + const dot_params_t params = _nmod_vec_dot_params(P->length, ctx); + return _n_poly_eval_pow(P, alphapow, params, ctx); } void n_poly_mod_eval2_pow( diff --git a/src/nmod.h b/src/nmod.h index ecaafcf6b2..2353adf5b3 100644 --- a/src/nmod.h +++ b/src/nmod.h @@ -185,6 +185,15 @@ ulong nmod_addmul(ulong a, ulong b, ulong c, nmod_t mod) (r) = nmod_addmul((r), (a), (b), (mod)); \ } while (0) +// TODO doc a*b + c*d +NMOD_INLINE +ulong nmod_fmma(ulong a, ulong b, ulong c, ulong d, nmod_t mod) +{ + a = nmod_mul(a, b, mod); + NMOD_ADDMUL(a, c, d, mod); + return a; +} + NMOD_INLINE ulong nmod_inv(ulong a, nmod_t mod) { diff --git a/src/nmod_mat/charpoly.c b/src/nmod_mat/charpoly.c index c5f870d4f9..31cca4b9e1 100644 --- a/src/nmod_mat/charpoly.c +++ b/src/nmod_mat/charpoly.c @@ -45,14 +45,13 @@ _nmod_mat_charpoly_berkowitz(nn_ptr cp, const nmod_mat_t mat, nmod_t mod) { slong i, k, t; nn_ptr a, A, s; - int nlimbs; TMP_INIT; TMP_START; a = TMP_ALLOC(sizeof(ulong) * (n * n)); A = a + (n - 1) * n; - nlimbs = _nmod_vec_dot_bound_limbs(n, mod); + const dot_params_t params = _nmod_vec_dot_params(n, mod); _nmod_vec_zero(cp, n + 1); cp[0] = nmod_neg(nmod_mat_entry(mat, 0, 0), mod); @@ -71,17 +70,17 @@ _nmod_mat_charpoly_berkowitz(nn_ptr cp, const nmod_mat_t mat, nmod_t mod) for (i = 0; i <= t; i++) { s = a + k * n + i; - s[0] = _nmod_vec_dot(mat->rows[i], a + (k - 1) * n, t + 1, mod, nlimbs); + s[0] = _nmod_vec_dot(mat->rows[i], a + (k - 1) * n, t + 1, mod, params); } A[k] = a[k * n + t]; } - A[t] = _nmod_vec_dot(mat->rows[t], a + (t - 1) * n, t + 1, mod, nlimbs); + A[t] = _nmod_vec_dot(mat->rows[t], a + (t - 1) * n, t + 1, mod, params); for (k = 0; k <= t; k++) { - cp[k] = nmod_sub(cp[k], _nmod_vec_dot_rev(A, cp, k, mod, nlimbs), mod); + cp[k] = nmod_sub(cp[k], _nmod_vec_dot_rev(A, cp, k, mod, params), mod); cp[k] = nmod_sub(cp[k], A[k], mod); } } @@ -117,7 +116,6 @@ void nmod_mat_charpoly_danilevsky(nmod_poly_t p, const nmod_mat_t M) ulong h; nmod_poly_t b; nmod_mat_t M2; - int num_limbs; TMP_INIT; if (M->r != M->c) @@ -142,7 +140,7 @@ void nmod_mat_charpoly_danilevsky(nmod_poly_t p, const nmod_mat_t M) TMP_START; i = 1; - num_limbs = _nmod_vec_dot_bound_limbs(n, p->mod); + const dot_params_t params = _nmod_vec_dot_params(n, p->mod); nmod_poly_one(p); nmod_poly_init(b, p->mod.n); nmod_mat_init_set(M2, M); @@ -226,7 +224,7 @@ void nmod_mat_charpoly_danilevsky(nmod_poly_t p, const nmod_mat_t M) for (k = 1; k <= n - i; k++) T[k - 1] = A[k - 1][j - 1]; - A[n - i - 1][j - 1] = _nmod_vec_dot(T, W, n - i, p->mod, num_limbs); + A[n - i - 1][j - 1] = _nmod_vec_dot(T, W, n - i, p->mod, params); } for (j = n - i; j <= n - 1; j++) @@ -234,13 +232,13 @@ void nmod_mat_charpoly_danilevsky(nmod_poly_t p, const nmod_mat_t M) for (k = 1; k <= n - i; k++) T[k - 1] = A[k - 1][j - 1]; - A[n - i - 1][j - 1] = n_addmod(_nmod_vec_dot(T, W, n - i, p->mod, num_limbs), W[j], p->mod.n); + A[n - i - 1][j - 1] = n_addmod(_nmod_vec_dot(T, W, n - i, p->mod, params), W[j], p->mod.n); } for (k = 1; k <= n - i; k++) T[k - 1] = A[k - 1][j - 1]; - A[n - i - 1][n - 1] = _nmod_vec_dot(T, W, n - i, p->mod, num_limbs); + A[n - i - 1][n - 1] = _nmod_vec_dot(T, W, n - i, p->mod, params); i++; } diff --git a/src/nmod_mat/lu.c b/src/nmod_mat/lu.c index 2aadb90e21..2d45dabddc 100644 --- a/src/nmod_mat/lu.c +++ b/src/nmod_mat/lu.c @@ -17,7 +17,7 @@ slong nmod_mat_lu(slong * P, nmod_mat_t A, int rank_check) { slong nrows, ncols, n, cutoff; - int nlimbs, bits; + int bits; nrows = A->r; ncols = A->c; @@ -46,9 +46,12 @@ nmod_mat_lu(slong * P, nmod_mat_t A, int rank_check) return nmod_mat_lu_recursive(P, A, rank_check); } - nlimbs = _nmod_vec_dot_bound_limbs(n, A->mod); + const dot_params_t params = _nmod_vec_dot_params(n, A->mod); - if (nlimbs <= 1 || (nlimbs == 2 && n >= 12) || (nlimbs == 3 && n >= 20)) + // TODO thresholds to re-examine after dot product changes + if (params.method <= _DOT1 // <= 0,1 limb + || (params.method <= _DOT2 && n >= 12) // <= 2 limbs (n >= 12 if exactly 2) + || (params.method > _DOT2 && n >= 20)) // == 3 limbs && n >= 20 return nmod_mat_lu_classical_delayed(P, A, rank_check); else return nmod_mat_lu_classical(P, A, rank_check); diff --git a/src/nmod_mat/lu_classical_delayed.c b/src/nmod_mat/lu_classical_delayed.c index 9f9bd52bc6..356a8305a4 100644 --- a/src/nmod_mat/lu_classical_delayed.c +++ b/src/nmod_mat/lu_classical_delayed.c @@ -430,15 +430,15 @@ slong nmod_mat_lu_classical_delayed(slong * P, nmod_mat_t A, int rank_check) { slong nrows, ncols; - int nlimbs; nrows = A->r; ncols = A->c; - nlimbs = _nmod_vec_dot_bound_limbs(FLINT_MIN(nrows, ncols), A->mod); + const dot_params_t params = _nmod_vec_dot_params(FLINT_MIN(nrows, ncols), A->mod); - if (nlimbs <= 1) + // TODO cases to re-examine after dot product changes? + if (params.method <= _DOT1) return nmod_mat_lu_classical_delayed_1(P, A, rank_check); - else if (nlimbs <= 2) + else if (params.method <= _DOT2) return nmod_mat_lu_classical_delayed_2(P, A, rank_check); else return nmod_mat_lu_classical_delayed_3(P, A, rank_check); diff --git a/src/nmod_mat/mul_classical.c b/src/nmod_mat/mul_classical.c index ff05ffb864..419b4637db 100644 --- a/src/nmod_mat/mul_classical.c +++ b/src/nmod_mat/mul_classical.c @@ -24,7 +24,7 @@ with op = -1, computes D = C - A*B static inline void _nmod_mat_addmul_basic_op(nn_ptr * D, nn_ptr * const C, nn_ptr * const A, - nn_ptr * const B, slong m, slong k, slong n, int op, nmod_t mod, int nlimbs) + nn_ptr * const B, slong m, slong k, slong n, int op, nmod_t mod, dot_params_t params) { slong i, j; ulong c; @@ -33,7 +33,7 @@ _nmod_mat_addmul_basic_op(nn_ptr * D, nn_ptr * const C, nn_ptr * const A, { for (j = 0; j < n; j++) { - c = _nmod_vec_dot_ptr(A[i], B, j, k, mod, nlimbs); + c = _nmod_vec_dot_ptr(A[i], B, j, k, mod, params); if (op == 1) c = nmod_add(C[i][j], c, mod); @@ -47,7 +47,7 @@ _nmod_mat_addmul_basic_op(nn_ptr * D, nn_ptr * const C, nn_ptr * const A, static inline void _nmod_mat_addmul_transpose_op(nn_ptr * D, const nn_ptr * C, const nn_ptr * A, - const nn_ptr * B, slong m, slong k, slong n, int op, nmod_t mod, int nlimbs) + const nn_ptr * B, slong m, slong k, slong n, int op, nmod_t mod, dot_params_t params) { nn_ptr tmp; ulong c; @@ -63,7 +63,7 @@ _nmod_mat_addmul_transpose_op(nn_ptr * D, const nn_ptr * C, const nn_ptr * A, { for (j = 0; j < n; j++) { - c = _nmod_vec_dot(A[i], tmp + j*k, k, mod, nlimbs); + c = _nmod_vec_dot(A[i], tmp + j*k, k, mod, params); if (op == 1) c = nmod_add(C[i][j], c, mod); @@ -164,7 +164,6 @@ _nmod_mat_mul_classical_op(nmod_mat_t D, const nmod_mat_t C, const nmod_mat_t A, const nmod_mat_t B, int op) { slong m, k, n; - int nlimbs; nmod_t mod; mod = A->mod; @@ -172,7 +171,7 @@ _nmod_mat_mul_classical_op(nmod_mat_t D, const nmod_mat_t C, k = A->c; n = B->c; - if (k == 0) + if (k == 0 || mod.n == 1) // covers params.method == _DOT0 { if (op == 0) nmod_mat_zero(D); @@ -181,9 +180,10 @@ _nmod_mat_mul_classical_op(nmod_mat_t D, const nmod_mat_t C, return; } - nlimbs = _nmod_vec_dot_bound_limbs(k, mod); + const dot_params_t params = _nmod_vec_dot_params(k, mod); - if (nlimbs == 1 && m > 10 && k > 10 && n > 10) + // TODO vec_dot changes --> thresholds to re-examine + if (params.method == _DOT1 && m > 10 && k > 10 && n > 10) { _nmod_mat_addmul_packed_op(D->rows, (op == 0) ? NULL : C->rows, A->rows, B->rows, m, k, n, op, D->mod); @@ -192,19 +192,13 @@ _nmod_mat_mul_classical_op(nmod_mat_t D, const nmod_mat_t C, || n < NMOD_MAT_MUL_TRANSPOSE_CUTOFF || k < NMOD_MAT_MUL_TRANSPOSE_CUTOFF) { - if ((mod.n & (mod.n - 1)) == 0) - nlimbs = 1; - _nmod_mat_addmul_basic_op(D->rows, (op == 0) ? NULL : C->rows, - A->rows, B->rows, m, k, n, op, D->mod, nlimbs); + A->rows, B->rows, m, k, n, op, D->mod, params); } else { - if ((mod.n & (mod.n - 1)) == 0) - nlimbs = 1; - _nmod_mat_addmul_transpose_op(D->rows, (op == 0) ? NULL : C->rows, - A->rows, B->rows, m, k, n, op, D->mod, nlimbs); + A->rows, B->rows, m, k, n, op, D->mod, params); } } diff --git a/src/nmod_mat/mul_classical_threaded.c b/src/nmod_mat/mul_classical_threaded.c index 24b6d3ec5d..4bd76b7adc 100644 --- a/src/nmod_mat/mul_classical_threaded.c +++ b/src/nmod_mat/mul_classical_threaded.c @@ -26,7 +26,7 @@ with op = -1, computes D = C - A*B static inline void _nmod_mat_addmul_basic_op(nn_ptr * D, nn_ptr * const C, nn_ptr * const A, - nn_ptr * const B, slong m, slong k, slong n, int op, nmod_t mod, int nlimbs) + nn_ptr * const B, slong m, slong k, slong n, int op, nmod_t mod, dot_params_t params) { slong i, j; ulong c; @@ -35,7 +35,7 @@ _nmod_mat_addmul_basic_op(nn_ptr * D, nn_ptr * const C, nn_ptr * const A, { for (j = 0; j < n; j++) { - c = _nmod_vec_dot_ptr(A[i], B, j, k, mod, nlimbs); + c = _nmod_vec_dot_ptr(A[i], B, j, k, mod, params); if (op == 1) c = nmod_add(C[i][j], c, mod); @@ -55,7 +55,7 @@ typedef struct slong k; slong m; slong n; - slong nlimbs; + dot_params_t params; const nn_ptr * A; const nn_ptr * C; nn_ptr * D; @@ -76,7 +76,7 @@ _nmod_mat_addmul_transpose_worker(void * arg_ptr) slong k = arg.k; slong m = arg.m; slong n = arg.n; - slong nlimbs = arg.nlimbs; + dot_params_t params = arg.params; const nn_ptr * A = arg.A; const nn_ptr * C = arg.C; nn_ptr * D = arg.D; @@ -114,7 +114,7 @@ _nmod_mat_addmul_transpose_worker(void * arg_ptr) { for (j = jstart ; j < jend; j++) { - c = _nmod_vec_dot(A[i], tmp + j*k, k, mod, nlimbs); + c = _nmod_vec_dot(A[i], tmp + j*k, k, mod, params); if (op == 1) c = nmod_add(C[i][j], c, mod); @@ -130,7 +130,7 @@ _nmod_mat_addmul_transpose_worker(void * arg_ptr) static inline void _nmod_mat_addmul_transpose_threaded_pool_op(nn_ptr * D, const nn_ptr * C, const nn_ptr * A, const nn_ptr * B, slong m, - slong k, slong n, int op, nmod_t mod, int nlimbs, + slong k, slong n, int op, nmod_t mod, dot_params_t params, thread_pool_handle * threads, slong num_threads) { nn_ptr tmp; @@ -164,7 +164,7 @@ _nmod_mat_addmul_transpose_threaded_pool_op(nn_ptr * D, const nn_ptr * C, args[i].k = k; args[i].m = m; args[i].n = n; - args[i].nlimbs = nlimbs; + args[i].params = params; args[i].A = A; args[i].C = C; args[i].D = D; @@ -312,7 +312,7 @@ _nmod_mat_addmul_packed_worker(void * arg_ptr) } } -/* Assumes nlimbs = 1 */ +/* Assumes nlimbs = 1 <-> params.method <= _DOT1 */ static void _nmod_mat_addmul_packed_threaded_pool_op(nn_ptr * D, const nn_ptr * C, const nn_ptr * A, const nn_ptr * B, @@ -420,7 +420,6 @@ _nmod_mat_mul_classical_threaded_pool_op(nmod_mat_t D, const nmod_mat_t C, thread_pool_handle * threads, slong num_threads) { slong m, k, n; - int nlimbs; nmod_t mod; mod = A->mod; @@ -428,20 +427,19 @@ _nmod_mat_mul_classical_threaded_pool_op(nmod_mat_t D, const nmod_mat_t C, k = A->c; n = B->c; - nlimbs = _nmod_vec_dot_bound_limbs(k, mod); + dot_params_t params = _nmod_vec_dot_params(k, mod); - if (nlimbs == 1 && m > 10 && k > 10 && n > 10) + if (params.method == _DOT0) + return; + if (params.method == _DOT1 && m > 10 && k > 10 && n > 10) { _nmod_mat_addmul_packed_threaded_pool_op(D->rows, (op == 0) ? NULL : C->rows, A->rows, B->rows, m, k, n, op, D->mod, threads, num_threads); } else { - if ((mod.n & (mod.n - 1)) == 0) - nlimbs = 1; - _nmod_mat_addmul_transpose_threaded_pool_op(D->rows, (op == 0) ? NULL : C->rows, - A->rows, B->rows, m, k, n, op, D->mod, nlimbs, threads, num_threads); + A->rows, B->rows, m, k, n, op, D->mod, params, threads, num_threads); } } @@ -466,10 +464,10 @@ _nmod_mat_mul_classical_threaded_op(nmod_mat_t D, const nmod_mat_t C, || A->c < NMOD_MAT_MUL_TRANSPOSE_CUTOFF || B->c < NMOD_MAT_MUL_TRANSPOSE_CUTOFF) { - slong nlimbs = _nmod_vec_dot_bound_limbs(A->c, D->mod); + dot_params_t params = _nmod_vec_dot_params(A->c, D->mod); _nmod_mat_addmul_basic_op(D->rows, (op == 0) ? NULL : C->rows, - A->rows, B->rows, A->r, A->c, B->c, op, D->mod, nlimbs); + A->rows, B->rows, A->r, A->c, B->c, op, D->mod, params); return; } diff --git a/src/nmod_mat/mul_nmod_vec.c b/src/nmod_mat/mul_nmod_vec.c index ab888196bd..a065b59496 100644 --- a/src/nmod_mat/mul_nmod_vec.c +++ b/src/nmod_mat/mul_nmod_vec.c @@ -9,7 +9,6 @@ (at your option) any later version. See . */ -#include "nmod.h" #include "nmod_vec.h" #include "nmod_mat.h" @@ -18,16 +17,11 @@ void nmod_mat_mul_nmod_vec( const nmod_mat_t A, const ulong * b, slong blen) { - nmod_t mod = A->mod; - slong i, j; - slong len = FLINT_MIN(A->c, blen); - int nlimbs = _nmod_vec_dot_bound_limbs(len, mod); + const slong len = FLINT_MIN(A->c, blen); + const dot_params_t params = _nmod_vec_dot_params(len, A->mod); - for (i = A->r - 1; i >= 0; i--) - { - const ulong * Ai = A->rows[i]; - NMOD_VEC_DOT(c[i], j, len, Ai[j], b[j], mod, nlimbs); - } + for (slong i = 0; i < A->r; i++) + c[i] = _nmod_vec_dot(A->rows[i], b, len, A->mod, params); } void nmod_mat_mul_nmod_vec_ptr( diff --git a/src/nmod_mat/solve_tril.c b/src/nmod_mat/solve_tril.c index d378c09d17..6c76f1a2e0 100644 --- a/src/nmod_mat/solve_tril.c +++ b/src/nmod_mat/solve_tril.c @@ -16,7 +16,6 @@ void nmod_mat_solve_tril_classical(nmod_mat_t X, const nmod_mat_t L, const nmod_mat_t B, int unit) { - int nlimbs; slong i, j, n, m; nmod_t mod; nn_ptr inv, tmp; @@ -34,7 +33,7 @@ nmod_mat_solve_tril_classical(nmod_mat_t X, const nmod_mat_t L, const nmod_mat_t else inv = NULL; - nlimbs = _nmod_vec_dot_bound_limbs(n, mod); + const dot_params_t params = _nmod_vec_dot_params(n, mod); tmp = _nmod_vec_init(n); for (i = 0; i < m; i++) @@ -45,7 +44,7 @@ nmod_mat_solve_tril_classical(nmod_mat_t X, const nmod_mat_t L, const nmod_mat_t for (j = 0; j < n; j++) { ulong s; - s = _nmod_vec_dot(L->rows[j], tmp, j, mod, nlimbs); + s = _nmod_vec_dot(L->rows[j], tmp, j, mod, params); s = nmod_sub(nmod_mat_entry(B, j, i), s, mod); if (!unit) s = n_mulmod2_preinv(s, inv[j], mod.n, mod.ninv); diff --git a/src/nmod_mat/solve_triu.c b/src/nmod_mat/solve_triu.c index 9a3e2cf05f..114afe165b 100644 --- a/src/nmod_mat/solve_triu.c +++ b/src/nmod_mat/solve_triu.c @@ -16,7 +16,6 @@ void nmod_mat_solve_triu_classical(nmod_mat_t X, const nmod_mat_t U, const nmod_mat_t B, int unit) { - int nlimbs; slong i, j, n, m; nmod_t mod; nn_ptr inv, tmp; @@ -34,7 +33,7 @@ nmod_mat_solve_triu_classical(nmod_mat_t X, const nmod_mat_t U, const nmod_mat_t else inv = NULL; - nlimbs = _nmod_vec_dot_bound_limbs(n, mod); + const dot_params_t params = _nmod_vec_dot_params(n, mod); tmp = _nmod_vec_init(n); for (i = 0; i < m; i++) @@ -46,7 +45,7 @@ nmod_mat_solve_triu_classical(nmod_mat_t X, const nmod_mat_t U, const nmod_mat_t { ulong s; s = _nmod_vec_dot(U->rows[j] + j + 1, - tmp + j + 1, n - j - 1, mod, nlimbs); + tmp + j + 1, n - j - 1, mod, params); s = nmod_sub(nmod_mat_entry(B, j, i), s, mod); if (!unit) s = n_mulmod2_preinv(s, inv[j], mod.n, mod.ninv); diff --git a/src/nmod_mpoly/interp.c b/src/nmod_mpoly/interp.c index 5545d25dc8..90807a5d51 100644 --- a/src/nmod_mpoly/interp.c +++ b/src/nmod_mpoly/interp.c @@ -393,7 +393,7 @@ int nmod_mpolyn_interp_crt_sm_bpoly( const nmod_mpoly_ctx_t ctx) { int changed = 0; - int nlimbs = _nmod_vec_dot_bound_limbs(modulus->length, ctx->mod); + const dot_params_t params = _nmod_vec_dot_params(modulus->length, ctx->mod); slong N = mpoly_words_per_exp(F->bits, ctx->minfo); slong off0, shift0, off1, shift1; n_poly_struct * Acoeffs = A->coeffs; @@ -440,7 +440,7 @@ int nmod_mpolyn_interp_crt_sm_bpoly( /* F term ok, A term ok */ mpoly_monomial_set(Texps + N*Ti, Fexps + N*Fi, N); - v = _n_poly_eval_pow(Fcoeffs + Fi, alphapow, nlimbs, ctx->mod); + v = _n_poly_eval_pow(Fcoeffs + Fi, alphapow, params, ctx->mod); v = nmod_sub(Acoeffs[Ai].coeffs[ai], v, ctx->mod); if (v != 0) { @@ -495,7 +495,7 @@ int nmod_mpolyn_interp_crt_sm_bpoly( /* F term ok, Aterm missing */ mpoly_monomial_set(Texps + N*Ti, Fexps + N*Fi, N); - v = _n_poly_eval_pow(Fcoeffs + Fi, alphapow, nlimbs, ctx->mod); + v = _n_poly_eval_pow(Fcoeffs + Fi, alphapow, params, ctx->mod); if (v != 0) { changed = 1; diff --git a/src/nmod_mpoly_factor/n_bpoly_mod_factor_lgprime.c b/src/nmod_mpoly_factor/n_bpoly_mod_factor_lgprime.c index 9fdf3dc349..457166fff2 100644 --- a/src/nmod_mpoly_factor/n_bpoly_mod_factor_lgprime.c +++ b/src/nmod_mpoly_factor/n_bpoly_mod_factor_lgprime.c @@ -641,11 +641,10 @@ static void _lattice( n_bpoly_t Q, R, dg; n_bpoly_struct * ld; nmod_mat_t M, T1, T2; - int nlimbs; ulong * trow; slong lift_order = lift_alpha_pow->length - 1; - nlimbs = _nmod_vec_dot_bound_limbs(r, ctx); + const dot_params_t params = _nmod_vec_dot_params(r, ctx); trow = (ulong *) flint_malloc(r*sizeof(ulong)); n_bpoly_init(Q); n_bpoly_init(R); @@ -683,7 +682,7 @@ static void _lattice( for (i = 0; i < d; i++) nmod_mat_entry(M, j - starts[k], i) = - _nmod_vec_dot(trow, N->rows[i], r, ctx, nlimbs); + _nmod_vec_dot(trow, N->rows[i], r, ctx, params); } nmod_mat_init_nullspace_tr(T1, M); diff --git a/src/nmod_mpoly_factor/n_bpoly_mod_factor_smprime.c b/src/nmod_mpoly_factor/n_bpoly_mod_factor_smprime.c index 01b887fea6..f0e219394e 100644 --- a/src/nmod_mpoly_factor/n_bpoly_mod_factor_smprime.c +++ b/src/nmod_mpoly_factor/n_bpoly_mod_factor_smprime.c @@ -1090,10 +1090,9 @@ static void _lattice( n_bpoly_t Q, R, dg; n_bpoly_struct * ld; nmod_mat_t M, T1, T2; - int nlimbs; ulong * trow; - nlimbs = _nmod_vec_dot_bound_limbs(r, ctx); + const dot_params_t params = _nmod_vec_dot_params(r, ctx); trow = FLINT_ARRAY_ALLOC(r, ulong); n_bpoly_init(Q); n_bpoly_init(R); @@ -1134,7 +1133,7 @@ static void _lattice( for (i = 0; i < nrows; i++) nmod_mat_entry(M, j - lower, i) = - _nmod_vec_dot(trow, N->rows[i], r, ctx, nlimbs); + _nmod_vec_dot(trow, N->rows[i], r, ctx, params); } nmod_mat_init_nullspace_tr(T1, M); diff --git a/src/nmod_poly/div_series.c b/src/nmod_poly/div_series.c index fdaec0500c..8d3e844e28 100644 --- a/src/nmod_poly/div_series.c +++ b/src/nmod_poly/div_series.c @@ -20,7 +20,6 @@ _nmod_poly_div_series_basecase_preinv1(nn_ptr Qinv, nn_srcptr P, slong Plen, nn_srcptr Q, slong Qlen, slong n, ulong q, nmod_t mod) { slong i, j, l; - int nlimbs; ulong s; Plen = FLINT_MIN(Plen, n); @@ -35,13 +34,13 @@ _nmod_poly_div_series_basecase_preinv1(nn_ptr Qinv, nn_srcptr P, slong Plen, { Qinv[0] = nmod_mul(q, P[0], mod); - nlimbs = _nmod_vec_dot_bound_limbs(FLINT_MIN(n, Qlen), mod); + const dot_params_t params = _nmod_vec_dot_params(FLINT_MIN(n, Qlen), mod); for (i = 1; i < n; i++) { l = FLINT_MIN(i, Qlen - 1); - NMOD_VEC_DOT(s, j, l, Q[j + 1], Qinv[i - 1 - j], mod, nlimbs); + NMOD_VEC_DOT(s, j, l, Q[j + 1], Qinv[i - 1 - j], mod, params); if (i < Plen) s = nmod_sub(P[i], s, mod); diff --git a/src/nmod_poly/divides.c b/src/nmod_poly/divides.c index d7771d6570..92b59e8fa3 100644 --- a/src/nmod_poly/divides.c +++ b/src/nmod_poly/divides.c @@ -129,7 +129,7 @@ static int _nmod_poly_mullow_classical_check(nn_srcptr p, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong n, nmod_t mod) { - slong i, j, bits, log_len, nlimbs, n1; + slong i, j, bits, log_len, n1; ulong c; len1 = FLINT_MIN(len1, n); @@ -140,6 +140,7 @@ _nmod_poly_mullow_classical_check(nn_srcptr p, nn_srcptr poly1, slong len1, if (n == 1) return p[0] == nmod_mul(poly1[0], poly2[0], mod); + // TODO could what is below make more direct use of nmod_vec_dot? log_len = FLINT_BIT_COUNT(n); bits = FLINT_BITS - (slong) mod.norm; bits = 2 * bits + log_len; @@ -160,10 +161,11 @@ _nmod_poly_mullow_classical_check(nn_srcptr p, nn_srcptr poly1, slong len1, } } else { + dot_params_t params = {_DOT2, 0}; if (bits <= 2 * FLINT_BITS) - nlimbs = 2; + params.method = _DOT2; else - nlimbs = 3; + params.method = _DOT3; for (i = 0; i < n; i++) { @@ -171,7 +173,7 @@ _nmod_poly_mullow_classical_check(nn_srcptr p, nn_srcptr poly1, slong len1, c = _nmod_vec_dot_rev(poly1, poly2 + i - n1, - n1 + 1, mod, nlimbs); + n1 + 1, mod, params); if (p[i] != c) return 0; diff --git a/src/nmod_poly/inv_series.c b/src/nmod_poly/inv_series.c index edd460a2be..365835dc10 100644 --- a/src/nmod_poly/inv_series.c +++ b/src/nmod_poly/inv_series.c @@ -29,16 +29,14 @@ _nmod_poly_inv_series_basecase_preinv1(nn_ptr Qinv, nn_srcptr Q, slong Qlen, slo } else { - slong i, j, l; - int nlimbs; + slong i, l; ulong s; - - nlimbs = _nmod_vec_dot_bound_limbs(FLINT_MIN(n, Qlen), mod); + const dot_params_t params = _nmod_vec_dot_params(FLINT_MIN(n, Qlen) - 1, mod); for (i = 1; i < n; i++) { l = FLINT_MIN(i, Qlen - 1); - NMOD_VEC_DOT(s, j, l, Q[j + 1], Qinv[i - 1 - j], mod, nlimbs); + s = _nmod_vec_dot_rev(Q+1, Qinv + i-l, l, mod, params); if (q == 1) Qinv[i] = nmod_neg(s, mod); diff --git a/src/nmod_poly/mul_classical.c b/src/nmod_poly/mul_classical.c index eddb86f38a..13ba36430e 100644 --- a/src/nmod_poly/mul_classical.c +++ b/src/nmod_poly/mul_classical.c @@ -20,7 +20,7 @@ void _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, nmod_t mod) { - slong i, j, bits, log_len, nlimbs, n1, n2; + slong i, j, n1, n2; int squaring; ulong c; @@ -38,11 +38,9 @@ _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1, squaring = (poly1 == poly2 && len1 == len2); - log_len = FLINT_BIT_COUNT(len2); - bits = FLINT_BITS - (slong) mod.norm; - bits = 2 * bits + log_len; + const dot_params_t params = _nmod_vec_dot_params(FLINT_MIN(len1, len2), mod); - if (bits <= FLINT_BITS) + if (params.method <= _DOT1) { flint_mpn_zero(res, len1 + len2 - 1); @@ -82,11 +80,6 @@ _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1, return; } - if (bits <= 2 * FLINT_BITS) - nlimbs = 2; - else - nlimbs = 3; - if (squaring) { for (i = 0; i < 2 * len1 - 1; i++) @@ -94,7 +87,7 @@ _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1, n1 = FLINT_MAX(0, i - len1 + 1); n2 = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1); - c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, nlimbs); + c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, params); c = nmod_add(c, c, mod); if (i % 2 == 0 && i / 2 < len1) @@ -112,7 +105,7 @@ _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1, res[i] = _nmod_vec_dot_rev(poly1 + i - n2, poly2 + i - n1, - n1 + n2 - i + 1, mod, nlimbs); + n1 + n2 - i + 1, mod, params); } } } diff --git a/src/nmod_poly/mullow_classical.c b/src/nmod_poly/mullow_classical.c index fe4b5fee2c..fa24979ac9 100644 --- a/src/nmod_poly/mullow_classical.c +++ b/src/nmod_poly/mullow_classical.c @@ -20,7 +20,7 @@ void _nmod_poly_mullow_classical(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong n, nmod_t mod) { - slong i, j, bits, log_len, nlimbs, n1, n2; + slong i, j, bits, log_len, n1, n2; int squaring; ulong c; @@ -41,6 +41,7 @@ _nmod_poly_mullow_classical(nn_ptr res, nn_srcptr poly1, slong len1, squaring = (poly1 == poly2 && len1 == len2); + // TODO could what is below make more direct use of nmod_vec_dot? log_len = FLINT_BIT_COUNT(len2); bits = FLINT_BITS - (slong) mod.norm; bits = 2 * bits + log_len; @@ -88,10 +89,11 @@ _nmod_poly_mullow_classical(nn_ptr res, nn_srcptr poly1, slong len1, return; } + dot_params_t params = {_DOT2, 0}; if (bits <= 2 * FLINT_BITS) - nlimbs = 2; + params.method = _DOT2; else - nlimbs = 3; + params.method = _DOT3; if (squaring) { @@ -100,7 +102,7 @@ _nmod_poly_mullow_classical(nn_ptr res, nn_srcptr poly1, slong len1, n1 = FLINT_MAX(0, i - len1 + 1); n2 = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1); - c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, nlimbs); + c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, params); c = nmod_add(c, c, mod); if (i % 2 == 0 && i / 2 < len1) @@ -118,7 +120,7 @@ _nmod_poly_mullow_classical(nn_ptr res, nn_srcptr poly1, slong len1, res[i] = _nmod_vec_dot_rev(poly1 + i - n2, poly2 + i - n1, - n1 + n2 - i + 1, mod, nlimbs); + n1 + n2 - i + 1, mod, params); } } } diff --git a/src/nmod_vec.h b/src/nmod_vec.h index fbc8e65ace..692fc760cb 100644 --- a/src/nmod_vec.h +++ b/src/nmod_vec.h @@ -1,6 +1,7 @@ /* Copyright (C) 2010 William Hart Copyright (C) 2021 Fredrik Johansson + Copyright (C) 2024 Vincent Neiger This file is part of FLINT. @@ -20,6 +21,7 @@ #endif #include "flint.h" +#include "nmod.h" // nmod_mul, nmod_fmma #ifdef __cplusplus extern "C" { @@ -97,6 +99,16 @@ int _nmod_vec_is_zero(nn_srcptr vec, slong len) return 1; } +/* some IO functions */ +#ifdef FLINT_HAVE_FILE +int _nmod_vec_fprint_pretty(FILE * file, nn_srcptr vec, slong len, nmod_t mod); +int _nmod_vec_fprint(FILE * f, nn_srcptr vec, slong len, nmod_t mod); +#endif + +void _nmod_vec_print_pretty(nn_srcptr vec, slong len, nmod_t mod); +int _nmod_vec_print(nn_srcptr vec, slong len, nmod_t mod); + +/* reduce, add, scalar mul */ void _nmod_vec_reduce(nn_ptr res, nn_srcptr vec, slong len, nmod_t mod); void _nmod_vec_add(nn_ptr res, nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); @@ -108,83 +120,632 @@ void _nmod_vec_scalar_mul_nmod_shoup(nn_ptr res, nn_srcptr vec, slong len, ulong void _nmod_vec_scalar_addmul_nmod(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod); -int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod); -#define NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, nlimbs) \ - do \ - { \ - ulong s0, s1, s2, t0, t1; \ - s0 = s1 = s2 = UWORD(0); \ - switch (nlimbs) \ - { \ - case 1: \ - for (i = 0; i < (len); i++) \ - s0 += (expr1) * (expr2); \ - NMOD_RED(s0, s0, mod); \ - break; \ - case 2: \ - if (mod.n <= (UWORD(1) << (FLINT_BITS / 2))) \ - { \ - for (i = 0; i < (len); i++) \ - { \ - t0 = (expr1) * (expr2); \ - add_ssaaaa(s1, s0, s1, s0, 0, t0); \ - } \ - } \ - else if ((len) < 8) \ - { \ - for (i = 0; i < len; i++) \ - { \ - umul_ppmm(t1, t0, (expr1), (expr2)); \ - add_ssaaaa(s1, s0, s1, s0, t1, t0); \ - } \ - } \ - else \ - { \ - ulong v0, v1, u0, u1; \ - i = 0; \ - if ((len) & 1) \ - umul_ppmm(v1, v0, (expr1), (expr2)); \ - else \ - v0 = v1 = 0; \ - for (i = (len) & 1; i < (len); i++) \ - { \ - umul_ppmm(t1, t0, (expr1), (expr2)); \ - add_ssaaaa(s1, s0, s1, s0, t1, t0); \ - i++; \ - umul_ppmm(u1, u0, (expr1), (expr2)); \ - add_ssaaaa(v1, v0, v1, v0, u1, u0); \ - } \ - add_ssaaaa(s1, s0, s1, s0, v1, v0); \ - } \ - NMOD2_RED2(s0, s1, s0, mod); \ - break; \ - default: \ - for (i = 0; i < (len); i++) \ - { \ - umul_ppmm(t1, t0, (expr1), (expr2)); \ - add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, t1, t0); \ - } \ - NMOD_RED(s2, s2, mod); \ - NMOD_RED3(s0, s2, s1, s0, mod); \ - break; \ - } \ - res = s0; \ - } while (0); - -ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs); -ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs); - -ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, int nlimbs); +/* ---- compute dot parameters ---- */ -/* some IO functions */ -#ifdef FLINT_HAVE_FILE -int _nmod_vec_fprint_pretty(FILE * file, nn_srcptr vec, slong len, nmod_t mod); -int _nmod_vec_fprint(FILE * f, nn_srcptr vec, slong len, nmod_t mod); +typedef enum +{ + _DOT0 = 0, /* len == 0 || mod.n == 1 */ + _DOT1 = 1, /* 1 limb */ +#if (FLINT_BITS == 64) + _DOT2_SPLIT = 2, /* 2 limbs, modulus < ~2**30.5 (FLINT_BITS == 64 only) */ +#endif // FLINT_BITS == 64 + _DOT2_HALF = 3, /* 2 limbs, modulus < 2**(FLINT_BITS/2) */ + _DOT2 = 4, /* 2 limbs */ + _DOT3_ACC = 5, /* 3 limbs, modulus allowing some accumulation in 2 limbs */ + _DOT3 = 6, /* 3 limbs */ + _DOT_POW2 = 7, /* mod.n is a power of 2 */ +} dot_method_t; +// if mod.n is a power of 2, we use _DOT_POW2 in all cases +// otherwise, number of limbs of unreduced dot product can be deduced: +// 1 limb <=> method <= _DOT1 +// 2 limbs <=> _DOT1 < method <= _DOT2 +// 3 limbs <=> _DOT2 < method + +typedef struct +{ + dot_method_t method; + ulong pow2_precomp; /* for splitting: (1L << 56) % mod.n */ +} dot_params_t; + +// for _DOT2_SPLIT +#if (FLINT_BITS == 64) +# define DOT_SPLIT_BITS 56 +# define DOT_SPLIT_MASK UWORD(72057594037927935) // (1L << DOT_SPLIT_BITS) - 1 +#endif // FLINT_BITS == 64 + +#define _FIXED_LEN_MOD_BOUNDS(fixedlen, onelimb_bnd, twolimb_bnd) \ + if (len == fixedlen) \ + { \ + if (mod.n <= UWORD(onelimb_bnd)) \ + return (dot_params_t) {_DOT1, UWORD(0)}; \ + if (mod.n <= UWORD(twolimb_bnd)) \ + return (dot_params_t) {_DOT2, UWORD(0)}; \ + return (dot_params_t) {_DOT3, UWORD(0)}; \ + } + +FLINT_FORCE_INLINE dot_params_t _nmod_vec_dot_params(ulong len, nmod_t mod) +{ + if (len == 0 || mod.n == 1) + return (dot_params_t) {_DOT0, UWORD(0)}; + if ((mod.n & (mod.n - 1)) == 0) + return (dot_params_t) {_DOT_POW2, UWORD(0)}; + // from here on len >= 1, n > 1 not power of 2 + + // short dot products: we use only _DOT1, _DOT2, _DOT3 in that case + if (len <= 11) + { +#if FLINT_BITS == 64 + // 64 bits: k limbs <=> n <= ceil(2**(32*k) / sqrt(len)) + _FIXED_LEN_MOD_BOUNDS(11, 1294981365, 5561902608746059656); + _FIXED_LEN_MOD_BOUNDS(10, 1358187914, 5833372668713515885); + _FIXED_LEN_MOD_BOUNDS( 9, 1431655766, 6148914691236517206); + _FIXED_LEN_MOD_BOUNDS( 8, 1518500250, 6521908912666391107); + _FIXED_LEN_MOD_BOUNDS( 7, 1623345051, 6972213902555716131); + _FIXED_LEN_MOD_BOUNDS( 6, 1753413057, 7530851732716320753); + _FIXED_LEN_MOD_BOUNDS( 5, 1920767767, 8249634742471189718); + _FIXED_LEN_MOD_BOUNDS( 4, 2147483648, 9223372036854775808); + _FIXED_LEN_MOD_BOUNDS( 3, 2479700525, 10650232656628343402); + _FIXED_LEN_MOD_BOUNDS( 2, 3037000500, 13043817825332782213); +#else // FLINT_BITS == 64 + // 32 bits: k limbs <=> n <= ceil(2**(16*k) / sqrt(len)) + _FIXED_LEN_MOD_BOUNDS(11, 19760, 1294981365); + _FIXED_LEN_MOD_BOUNDS(10, 20725, 1358187914); + _FIXED_LEN_MOD_BOUNDS( 9, 21846, 1431655766); + _FIXED_LEN_MOD_BOUNDS( 8, 23171, 1518500250); + _FIXED_LEN_MOD_BOUNDS( 7, 24771, 1623345051); + _FIXED_LEN_MOD_BOUNDS( 6, 26755, 1753413057); + _FIXED_LEN_MOD_BOUNDS( 5, 29309, 1920767767); + _FIXED_LEN_MOD_BOUNDS( 4, 32768, 2147483648); + _FIXED_LEN_MOD_BOUNDS( 3, 37838, 2479700525); + _FIXED_LEN_MOD_BOUNDS( 2, 46341, 3037000500); +#endif // FLINT_BITS == 64 + // remains len == 1 + if (mod.n <= (UWORD(1) << FLINT_BITS / 2)) + return (dot_params_t) {_DOT1, UWORD(0)}; + return (dot_params_t) {_DOT2, UWORD(0)}; + } + + if (mod.n <= UWORD(1) << (FLINT_BITS / 2)) // implies <= 2 limbs + { + const ulong t0 = (mod.n - 1) * (mod.n - 1); + ulong u1, u0; + umul_ppmm(u1, u0, t0, len); + if (u1 == 0) // 1 limb + return (dot_params_t) {_DOT1, UWORD(0)}; + + // u1 != 0 <=> 2 limbs +#if (FLINT_BITS == 64) // _SPLIT: see end of file for these constraints + if (mod.n <= UWORD(1515531528) && len <= WORD(380368697)) + { + ulong pow2_precomp; + NMOD_RED(pow2_precomp, (UWORD(1) << DOT_SPLIT_BITS), mod); + return (dot_params_t) {_DOT2_SPLIT, pow2_precomp}; + } #endif + return (dot_params_t) {_DOT2_HALF, UWORD(0)}; + } + // from here on, mod.n > 2**(FLINT_BITS / 2) + // --> unreduced dot cannot fit in 1 limb + + ulong t2, t1, t0, u1, u0; + umul_ppmm(t1, t0, mod.n - 1, mod.n - 1); + umul_ppmm(t2, t1, t1, len); + umul_ppmm(u1, u0, t0, len); + add_ssaaaa(t2, t1, t2, t1, UWORD(0), u1); + + if (t2 == 0) // 2 limbs + return (dot_params_t) {_DOT2, UWORD(0)}; + + // 3 limbs: +#if (FLINT_BITS == 64) + if (mod.n <= UWORD(6521908912666391107)) // room for accumulating 8 terms +#else + if (mod.n <= UWORD(1518500250)) // room for accumulating 8 terms +#endif + return (dot_params_t) {_DOT3_ACC, UWORD(0)}; + + return (dot_params_t) {_DOT3, UWORD(0)}; +} + +#undef _FIXED_LEN_MOD_BOUNDS + +int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod); +int _nmod_vec_dot_bound_limbs_from_params(slong len, nmod_t mod, dot_params_t params); + + +/* ------ dot product, specific algorithms ------ */ + +/* vec1[i] * vec2[i] */ +ulong _nmod_vec_dot_pow2(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot1(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot2_half(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot2(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot3_acc(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot3(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +#if FLINT_BITS == 64 +ulong _nmod_vec_dot2_split(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, ulong pow2_precomp); +#endif // FLINT_BITS == 64 + +/* vec1[i] * vec2[len-1-i] */ +ulong _nmod_vec_dot_pow2_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot1_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot2_half_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot2_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot3_acc_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +ulong _nmod_vec_dot3_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod); +#if FLINT_BITS == 64 +ulong _nmod_vec_dot2_split_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, ulong pow2_precomp); +#endif // FLINT_BITS == 64 + +/* vec1[i] * vec2[i][offset] */ +ulong _nmod_vec_dot_pow2_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod); +ulong _nmod_vec_dot1_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod); +ulong _nmod_vec_dot2_half_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod); +ulong _nmod_vec_dot2_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod); +ulong _nmod_vec_dot3_acc_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod); +ulong _nmod_vec_dot3_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod); +#if FLINT_BITS == 64 +ulong _nmod_vec_dot2_split_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, ulong pow2_precomp); +#endif // FLINT_BITS == 64 + + + + + +/* ------------- dot product, general --------------- */ + +// auxiliary for short dot products +// (fixedlen small constant: compiler unrolls the loops completely) +#define _NMOD_VEC_DOT_SHORT1(fixedlen,expr1,expr2) \ + { \ + ulong res = (expr1) * (expr2); i++; \ + for (slong j = 0; j < fixedlen-1; j++, i++) \ + res += (expr1) * (expr2); \ + NMOD_RED(res, res, mod); \ + return res; \ + } \ + +#define _NMOD_VEC_DOT_SHORT2(fixedlen,expr1,expr2) \ + { \ + ulong s0, s1, u0, u1; \ + umul_ppmm(u1, u0, (expr1), (expr2)); i++; \ + for (slong j = 0; j < fixedlen-1; j++, i++) \ + { \ + umul_ppmm(s1, s0, (expr1), (expr2)); \ + add_ssaaaa(u1, u0, u1, u0, s1, s0); \ + } \ + NMOD2_RED2(s0, u1, u0, mod); \ + return s0; \ + } \ + +#define _NMOD_VEC_DOT_SHORT3(fixedlen,expr1,expr2) \ + { \ + ulong t2 = UWORD(0); \ + ulong t1, t0; \ + umul_ppmm(t1, t0, (expr1), (expr2)); i++; \ + for (slong j = 0; j < fixedlen - 1; j++, i++) \ + { \ + ulong s0, s1; \ + umul_ppmm(s1, s0, (expr1), (expr2)); \ + add_sssaaaaaa(t2, t1, t0, \ + t2, t1, t0, \ + UWORD(0), s1, s0); \ + } \ + \ + NMOD_RED(t2, t2, mod); \ + ulong res; \ + NMOD_RED3(res, t2, t1, t0, mod); \ + return res; \ + } \ + +// * supports 1 <= len <= 11, requires method==DOT1|DOT2|DOT3|DOT_POW2 +// * i must be already initialized at the first wanted value +#define _NMOD_VEC_DOT_SHORT(i, expr1, expr2, len, mod, method) \ +{ \ + if (method == _DOT1 || method == _DOT_POW2) \ + { \ + if (len == 1) _NMOD_VEC_DOT_SHORT1( 1, expr1, expr2) \ + if (len == 2) _NMOD_VEC_DOT_SHORT1( 2, expr1, expr2) \ + if (len == 3) _NMOD_VEC_DOT_SHORT1( 3, expr1, expr2) \ + if (len == 4) _NMOD_VEC_DOT_SHORT1( 4, expr1, expr2) \ + if (len == 5) _NMOD_VEC_DOT_SHORT1( 5, expr1, expr2) \ + if (len == 6) _NMOD_VEC_DOT_SHORT1( 6, expr1, expr2) \ + if (len == 7) _NMOD_VEC_DOT_SHORT1( 7, expr1, expr2) \ + if (len == 8) _NMOD_VEC_DOT_SHORT1( 8, expr1, expr2) \ + if (len == 9) _NMOD_VEC_DOT_SHORT1( 9, expr1, expr2) \ + if (len == 10) _NMOD_VEC_DOT_SHORT1(10, expr1, expr2) \ + _NMOD_VEC_DOT_SHORT1(11, expr1, expr2) \ + } \ + \ + else if (method == _DOT2) \ + { \ + if (len == 1) return nmod_mul((expr1), (expr2), mod); \ + if (len == 2) _NMOD_VEC_DOT_SHORT2( 2, expr1, expr2) \ + if (len == 3) _NMOD_VEC_DOT_SHORT2( 3, expr1, expr2) \ + if (len == 4) _NMOD_VEC_DOT_SHORT2( 4, expr1, expr2) \ + if (len == 5) _NMOD_VEC_DOT_SHORT2( 5, expr1, expr2) \ + if (len == 6) _NMOD_VEC_DOT_SHORT2( 6, expr1, expr2) \ + if (len == 7) _NMOD_VEC_DOT_SHORT2( 7, expr1, expr2) \ + if (len == 8) _NMOD_VEC_DOT_SHORT2( 8, expr1, expr2) \ + if (len == 9) _NMOD_VEC_DOT_SHORT2( 9, expr1, expr2) \ + if (len == 10) _NMOD_VEC_DOT_SHORT2(10, expr1, expr2) \ + _NMOD_VEC_DOT_SHORT2(11, expr1, expr2) \ + } \ + \ + else if (method == _DOT3) \ + { \ + if (len == 1) return nmod_mul((expr1), (expr2), mod); \ + if (len == 2) _NMOD_VEC_DOT_SHORT3( 2, expr1, expr2) \ + if (len == 3) _NMOD_VEC_DOT_SHORT3( 3, expr1, expr2) \ + if (len == 4) _NMOD_VEC_DOT_SHORT3( 4, expr1, expr2) \ + if (len == 5) _NMOD_VEC_DOT_SHORT3( 5, expr1, expr2) \ + if (len == 6) _NMOD_VEC_DOT_SHORT3( 6, expr1, expr2) \ + if (len == 7) _NMOD_VEC_DOT_SHORT3( 7, expr1, expr2) \ + if (len == 8) _NMOD_VEC_DOT_SHORT3( 8, expr1, expr2) \ + if (len == 9) _NMOD_VEC_DOT_SHORT3( 9, expr1, expr2) \ + if (len == 10) _NMOD_VEC_DOT_SHORT3(10, expr1, expr2) \ + _NMOD_VEC_DOT_SHORT3(11, expr1, expr2) \ + } \ +} while(0); \ + +FLINT_FORCE_INLINE ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params) +{ + if (len <= 11) + { + if (len == 0) return UWORD(0); + slong i = 0; + _NMOD_VEC_DOT_SHORT(i, vec1[i], vec2[i], len, mod, params.method); + } + + if (params.method == _DOT1) + return _nmod_vec_dot1(vec1, vec2, len, mod); + +#if FLINT_BITS == 64 + if (params.method == _DOT2_SPLIT) + return _nmod_vec_dot2_split(vec1, vec2, len, mod, params.pow2_precomp); +#endif // FLINT_BITS == 64 + + if (params.method == _DOT2) + return _nmod_vec_dot2(vec1, vec2, len, mod); + + if (params.method == _DOT3_ACC) + return _nmod_vec_dot3_acc(vec1, vec2, len, mod); + + if (params.method == _DOT3) + return _nmod_vec_dot3(vec1, vec2, len, mod); + + if (params.method == _DOT2_HALF) + return _nmod_vec_dot2_half(vec1, vec2, len, mod); + + if (params.method == _DOT_POW2) + { + if (mod.n <= UWORD(1) << (FLINT_BITS / 2)) + return _nmod_vec_dot1(vec1, vec2, len, mod); + else + return _nmod_vec_dot_pow2(vec1, vec2, len, mod); + } + + // covers _DOT0 for len > 11 (i.e. mod.n == 1...) + return UWORD(0); +} + +FLINT_FORCE_INLINE ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params) +{ + if (len <= 11) + { + if (len == 0) return UWORD(0); + slong i = 0; + _NMOD_VEC_DOT_SHORT(i, vec1[i], vec2[len-1-i], len, mod, params.method); + } + + if (params.method == _DOT1) + return _nmod_vec_dot1_rev(vec1, vec2, len, mod); + +#if FLINT_BITS == 64 + if (params.method == _DOT2_SPLIT) + return _nmod_vec_dot2_split_rev(vec1, vec2, len, mod, params.pow2_precomp); +#endif // FLINT_BITS == 64 + + if (params.method == _DOT2) + return _nmod_vec_dot2_rev(vec1, vec2, len, mod); + + if (params.method == _DOT3_ACC) + return _nmod_vec_dot3_acc_rev(vec1, vec2, len, mod); + + if (params.method == _DOT3) + return _nmod_vec_dot3_rev(vec1, vec2, len, mod); + + if (params.method == _DOT2_HALF) + return _nmod_vec_dot2_half_rev(vec1, vec2, len, mod); + + if (params.method == _DOT_POW2) + { + if (mod.n <= UWORD(1) << (FLINT_BITS / 2)) + return _nmod_vec_dot1_rev(vec1, vec2, len, mod); + else + return _nmod_vec_dot_pow2_rev(vec1, vec2, len, mod); + } + + // covers _DOT0 for len > 11 (i.e. mod.n == 1...) + return UWORD(0); +} + +FLINT_FORCE_INLINE ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, dot_params_t params) +{ + if (len <= 11) + { + if (len == 0) return UWORD(0); + slong i = 0; + _NMOD_VEC_DOT_SHORT(i, vec1[i], vec2[i][offset], len, mod, params.method); + } + + if (params.method == _DOT1) + return _nmod_vec_dot1_ptr(vec1, vec2, offset, len, mod); + +#if FLINT_BITS == 64 + if (params.method == _DOT2_SPLIT) + return _nmod_vec_dot2_split_ptr(vec1, vec2, offset, len, mod, params.pow2_precomp); +#endif // FLINT_BITS == 64 + + if (params.method == _DOT2) + return _nmod_vec_dot2_ptr(vec1, vec2, offset, len, mod); + + if (params.method == _DOT3_ACC) + return _nmod_vec_dot3_acc_ptr(vec1, vec2, offset, len, mod); + + if (params.method == _DOT3) + return _nmod_vec_dot3_ptr(vec1, vec2, offset, len, mod); + + if (params.method == _DOT2_HALF) + return _nmod_vec_dot2_half_ptr(vec1, vec2, offset, len, mod); + + if (params.method == _DOT_POW2) + { + if (mod.n <= UWORD(1) << (FLINT_BITS / 2)) + return _nmod_vec_dot1_ptr(vec1, vec2, offset, len, mod); + else + return _nmod_vec_dot_pow2_ptr(vec1, vec2, offset, len, mod); + } + + // covers _DOT0 for len > 11 (i.e. mod.n == 1...) + return UWORD(0); +} + +#undef _NMOD_VEC_DOT_SHORT1 +#undef _NMOD_VEC_DOT_SHORT2 +#undef _NMOD_VEC_DOT_SHORT3 + + +/* ---- macros for dot product with expressions, specific algorithms ---- */ + +// _DOT1 (1 limb) +#define _NMOD_VEC_DOT1(res, i, len, expr1, expr2, mod) \ +do \ +{ \ + res = UWORD(0); \ + for (i = 0; i < (len); i++) \ + res += (expr1) * (expr2); \ + NMOD_RED(res, res, mod); \ +} while(0); + +// _DOT2_SPLIT (2 limbs, splitting at DOT_SPLIT_BITS bits, 8-unrolling) +#if (FLINT_BITS == 64) +#define _NMOD_VEC_DOT2_SPLIT(res, i, len, expr1, expr2, mod, pow2_precomp) \ +do \ +{ \ + ulong dp_lo = 0; \ + ulong dp_hi = 0; \ + \ + for (i = 0; i+7 < (len); ) \ + { \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + dp_lo += (expr1) * (expr2); i++; \ + \ + dp_hi += dp_lo >> DOT_SPLIT_BITS; \ + dp_lo &= DOT_SPLIT_MASK; \ + } \ + \ + for ( ; i < (len); i++) \ + dp_lo += (expr1) * (expr2); \ + \ + res = pow2_precomp * dp_hi + dp_lo; \ + NMOD_RED(res, res, mod); \ +} while(0); +#endif // FLINT_BITS == 64 + +// _DOT2_HALF (two limbs, modulus < 2**32) +// mod.n is too close to 2**32 to accumulate in some ulong +// still interesting: a bit faster than _NMOD_VEC_DOT2 +#define _NMOD_VEC_DOT2_HALF(res, i, len, expr1, expr2, mod) \ +do \ +{ \ + ulong s0zz = UWORD(0); \ + ulong s1zz = UWORD(0); \ + for (i = 0; i < (len); i++) \ + { \ + const ulong prodzz = (expr1) * (expr2); \ + add_ssaaaa(s1zz, s0zz, s1zz, s0zz, 0, prodzz); \ + } \ + NMOD2_RED2(res, s1zz, s0zz, mod); \ +} while(0); + +// _DOT2 (two limbs, general) +#define _NMOD_VEC_DOT2(res, i, len, expr1, expr2, mod) \ +do \ +{ \ + ulong u0zz = UWORD(0); \ + ulong u1zz = UWORD(0); \ + \ + for (i = 0; i+7 < (len); ) \ + { \ + ulong s0zz, s1zz; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + } \ + for ( ; i < (len); i++) \ + { \ + ulong s0zz, s1zz; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + } \ + \ + NMOD2_RED2(res, u1zz, u0zz, mod); \ +} while(0); + +// _DOT3_ACC (three limbs, delayed accumulations) +// 8-unroll: requires 8 * (mod.n - 1)**2 < 2**128 +#define _NMOD_VEC_DOT3_ACC(res, i, len, expr1, expr2, mod) \ +do \ +{ \ + ulong t2zz = UWORD(0); \ + ulong t1zz = UWORD(0); \ + ulong t0zz = UWORD(0); \ + \ + for (i = 0; i+7 < (len); ) \ + { \ + ulong s0zz, s1zz; \ + ulong u0zz = UWORD(0); \ + ulong u1zz = UWORD(0); \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + i++; \ + add_sssaaaaaa(t2zz, t1zz, t0zz, \ + t2zz, t1zz, t0zz, \ + UWORD(0), u1zz, u0zz); \ + } \ + \ + ulong s0zz, s1zz; \ + ulong u0zz = UWORD(0); \ + ulong u1zz = UWORD(0); \ + for ( ; i < (len); i++) \ + { \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_ssaaaa(u1zz, u0zz, u1zz, u0zz, s1zz, s0zz); \ + } \ + \ + add_sssaaaaaa(t2zz, t1zz, t0zz, \ + t2zz, t1zz, t0zz, \ + UWORD(0), u1zz, u0zz); \ + \ + NMOD_RED(t2zz, t2zz, mod); \ + NMOD_RED3(res, t2zz, t1zz, t0zz, mod); \ +} while(0); + +// _DOT3 (three limbs, general) +// mod.n is too close to 2**64 to accumulate in two words +#define _NMOD_VEC_DOT3(res, i, len, expr1, expr2, mod) \ +do \ +{ \ + ulong t2zz = UWORD(0); \ + ulong t1zz = UWORD(0); \ + ulong t0zz = UWORD(0); \ + for (i = 0; i < (len); i++) \ + { \ + ulong s0zz, s1zz; \ + umul_ppmm(s1zz, s0zz, (expr1), (expr2)); \ + add_sssaaaaaa(t2zz, t1zz, t0zz, \ + t2zz, t1zz, t0zz, \ + UWORD(0), s1zz, s0zz); \ + } \ + \ + NMOD_RED(t2zz, t2zz, mod); \ + NMOD_RED3(res, t2zz, t1zz, t0zz, mod); \ +} while(0); + +/* ---- macros for dot product with expressions, general ---- */ +// currently no vectorization here + +#if (FLINT_BITS == 64) + +#define NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, params) \ +do \ +{ \ + res = UWORD(0); /* covers _DOT0 */ \ + if (params.method == _DOT1 || params.method == _DOT_POW2) \ + _NMOD_VEC_DOT1(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT2_SPLIT) \ + _NMOD_VEC_DOT2_SPLIT(res, i, len, expr1, expr2, mod, \ + params.pow2_precomp) \ + else if (params.method == _DOT2_HALF) \ + _NMOD_VEC_DOT2_HALF(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT2) \ + _NMOD_VEC_DOT2(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT3_ACC) \ + _NMOD_VEC_DOT3_ACC(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT3) \ + _NMOD_VEC_DOT3(res, i, len, expr1, expr2, mod) \ +} while(0); + +#else // FLINT_BITS == 64 + +#define NMOD_VEC_DOT(res, i, len, expr1, expr2, mod, params) \ +do \ +{ \ + res = UWORD(0); /* covers _DOT0 */ \ + if (params.method == _DOT1 || params.method == _DOT_POW2) \ + _NMOD_VEC_DOT1(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT2_HALF) \ + _NMOD_VEC_DOT2_HALF(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT2) \ + _NMOD_VEC_DOT2(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT3_ACC) \ + _NMOD_VEC_DOT3_ACC(res, i, len, expr1, expr2, mod) \ + else if (params.method == _DOT3) \ + _NMOD_VEC_DOT3(res, i, len, expr1, expr2, mod) \ +} while(0); + +#endif // FLINT_BITS == 64 -void _nmod_vec_print_pretty(nn_srcptr vec, slong len, nmod_t mod); -int _nmod_vec_print(nn_srcptr vec, slong len, nmod_t mod); #ifdef __cplusplus } diff --git a/src/nmod_vec/dot.c b/src/nmod_vec/dot.c index 9317b172a8..5c185c22ec 100644 --- a/src/nmod_vec/dot.c +++ b/src/nmod_vec/dot.c @@ -1,5 +1,6 @@ /* Copyright (C) 2011, 2021 Fredrik Johansson + Copyright (C) 2024 Vincent Neiger This file is part of FLINT. @@ -12,63 +13,555 @@ #include "nmod.h" #include "nmod_vec.h" -ulong -_nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs) -{ - ulong res; - slong i; - NMOD_VEC_DOT(res, i, len, vec1[i], vec2[i], mod, nlimbs); - return res; -} +// currently only vectorized for AVX2 +#if (defined(__AVX2__) && FLINT_BITS == 64) +# include "machine_vectors.h" +#endif // if defined(__AVX2__) -int -_nmod_vec_dot_bound_limbs(slong len, nmod_t mod) +int _nmod_vec_dot_bound_limbs(slong len, nmod_t mod) { - ulong t2, t1, t0, u1, u0; + if (mod.n <= UWORD(1) << (FLINT_BITS / 2)) // implies <= 2 limbs + { + const ulong t0 = (mod.n - 1) * (mod.n - 1); + ulong u1, u0; + umul_ppmm(u1, u0, t0, len); + if (u1 != 0) + return 2; + return (u0 != 0); + } + ulong t2, t1, t0, u1, u0; umul_ppmm(t1, t0, mod.n - 1, mod.n - 1); umul_ppmm(t2, t1, t1, len); umul_ppmm(u1, u0, t0, len); - add_ssaaaa(t2, t1, t2, t1, UWORD(0), u1); + add_sssaaaaaa(t2, t1, t0, t2, t1, UWORD(0), UWORD(0), u1, u0); + + if (t2 != 0) + return 3; + if (t1 != 0) + return 2; + return (t0 != 0); +} + +int _nmod_vec_dot_bound_limbs_from_params(slong len, nmod_t mod, dot_params_t params) +{ + if (params.method == _DOT_POW2) + return _nmod_vec_dot_bound_limbs(len, mod); + if (params.method == _DOT0) + return 0; + if (params.method <= _DOT1) + return 1; + if (params.method <= _DOT2) + return 2; + return 3; +} + +/*-------------------------------------------*/ +/* dot product: vec1[i] * vec2[i] */ +/*-------------------------------------------*/ + +ulong _nmod_vec_dot_pow2(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT1(res, i, len, vec1[i], vec2[i], mod) + return res; +} + +ulong _nmod_vec_dot1(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +#if defined(__AVX2__) && FLINT_BITS == 64 +{ + vec4n dp = vec4n_zero(); - if (t2 != 0) return 3; - if (t1 != 0) return 2; - return (u0 != 0); + slong i = 0; + for ( ; i+31 < len; i += 32) + { + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 0), vec4n_load_unaligned(vec2+i+ 0))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 4), vec4n_load_unaligned(vec2+i+ 4))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 8), vec4n_load_unaligned(vec2+i+ 8))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+12), vec4n_load_unaligned(vec2+i+12))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+16), vec4n_load_unaligned(vec2+i+16))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+20), vec4n_load_unaligned(vec2+i+20))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+24), vec4n_load_unaligned(vec2+i+24))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+28), vec4n_load_unaligned(vec2+i+28))); + } + + for ( ; i + 3 < len; i += 4) + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i), vec4n_load_unaligned(vec2+i))); + + ulong res = vec4n_horizontal_sum(dp); + + for (; i < len; i++) + res += vec1[i] * vec2[i]; + + NMOD_RED(res, res, mod); + return res; } +#else // if defined(__AVX2__) && FLINT_BITS == 64 +{ + ulong res; slong i; + _NMOD_VEC_DOT1(res, i, len, vec1[i], vec2[i], mod) + return res; +} +#endif // if defined(__AVX2__) && FLINT_BITS == 64 -ulong -_nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, - slong len, nmod_t mod, int nlimbs) +#if FLINT_BITS == 64 +ulong _nmod_vec_dot2_split(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, ulong pow2_precomp) +#if defined(__AVX2__) { + const vec4n low_bits = vec4n_set_n(DOT_SPLIT_MASK); + vec4n dp_lo = vec4n_zero(); + vec4n dp_hi = vec4n_zero(); + + slong i = 0; + for ( ; i+31 < len; i += 32) + { + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 0), vec4n_load_unaligned(vec2+i+ 0))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 4), vec4n_load_unaligned(vec2+i+ 4))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 8), vec4n_load_unaligned(vec2+i+ 8))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+12), vec4n_load_unaligned(vec2+i+12))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+16), vec4n_load_unaligned(vec2+i+16))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+20), vec4n_load_unaligned(vec2+i+20))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+24), vec4n_load_unaligned(vec2+i+24))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+28), vec4n_load_unaligned(vec2+i+28))); + + dp_hi = vec4n_add(dp_hi, vec4n_bit_shift_right(dp_lo, DOT_SPLIT_BITS)); + dp_lo = vec4n_bit_and(dp_lo, low_bits); + } + + for ( ; i + 3 < len; i += 4) + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i), vec4n_load_unaligned(vec2+i))); + + dp_hi = vec4n_add(dp_hi, vec4n_bit_shift_right(dp_lo, DOT_SPLIT_BITS)); + dp_lo = vec4n_bit_and(dp_lo, low_bits); + + ulong hsum_lo = vec4n_horizontal_sum(dp_lo); + const ulong hsum_hi = vec4n_horizontal_sum(dp_hi) + (hsum_lo >> DOT_SPLIT_BITS); + hsum_lo &= DOT_SPLIT_MASK; + + for (; i < len; i++) + hsum_lo += vec1[i] * vec2[i]; + ulong res; - slong i; - NMOD_VEC_DOT(res, i, len, vec1[i], vec2[i][offset], mod, nlimbs); + NMOD_RED(res, pow2_precomp * hsum_hi + hsum_lo, mod); + return res; +} +#else // defined(__AVX2__) +{ + ulong res; slong i; + _NMOD_VEC_DOT2_SPLIT(res, i, len, vec1[i], vec2[i], mod, pow2_precomp) + return res; +} +#endif // defined(__AVX2__) +#endif // FLINT_BITS == 64 + +ulong _nmod_vec_dot2_half(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT2_HALF(res, i, len, vec1[i], vec2[i], mod) return res; } -static ulong -nmod_fmma(ulong a, ulong b, ulong c, ulong d, nmod_t mod) +ulong _nmod_vec_dot2(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) { - a = nmod_mul(a, b, mod); - NMOD_ADDMUL(a, c, d, mod); - return a; + ulong res; slong i; + _NMOD_VEC_DOT2(res, i, len, vec1[i], vec2[i], mod) + return res; } -ulong -_nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, int nlimbs) +ulong _nmod_vec_dot3(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) { + ulong res; slong i; + _NMOD_VEC_DOT3(res, i, len, vec1[i], vec2[i], mod) + return res; +} + +ulong _nmod_vec_dot3_acc(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT3_ACC(res, i, len, vec1[i], vec2[i], mod) + return res; +} + + +/*-----------------------------------------------*/ +/* dot product rev: vec1[i] * vec2[len-1-i] */ +/*-----------------------------------------------*/ + +ulong _nmod_vec_dot_pow2_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT1(res, i, len, vec1[i], vec2[len-1-i], mod) + return res; +} + +ulong _nmod_vec_dot1_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +#if defined(__AVX2__) && FLINT_BITS == 64 +{ + vec4n dp = vec4n_zero(); + + slong i = 0; + for ( ; i+31 < len; i += 32) + { + const ulong ii = len - 32 - i; // >= 0 + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 0), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+28)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 4), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+24)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 8), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+20)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+12), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+16)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+16), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+12)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+20), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+ 8)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+24), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+ 4)))); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+28), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+ 0)))); + } + + for ( ; i + 3 < len; i += 4) + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+len-4-i)))); + + ulong res = vec4n_horizontal_sum(dp); + + for (; i < len; i++) + res += vec1[i] * vec2[len-1-i]; + + NMOD_RED(res, res, mod); + return res; +} +#else // if defined(__AVX2__) && FLINT_BITS == 64 +{ + ulong res; slong i; + _NMOD_VEC_DOT1(res, i, len, vec1[i], vec2[len-1-i], mod) + return res; +} +#endif // if defined(__AVX2__) && FLINT_BITS == 64 + +#if FLINT_BITS == 64 +ulong _nmod_vec_dot2_split_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, ulong pow2_precomp) +#if defined(__AVX2__) +{ + const vec4n low_bits = vec4n_set_n(DOT_SPLIT_MASK); + vec4n dp_lo = vec4n_zero(); + vec4n dp_hi = vec4n_zero(); + + slong i = 0; + for ( ; i+31 < len; i += 32) + { + const ulong ii = len - 32 - i; // >= 0 + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 0), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+28)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 4), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+24)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 8), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+20)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+12), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+16)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+16), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+12)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+20), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+ 8)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+24), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+ 4)))); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+28), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+ii+ 0)))); + + dp_hi = vec4n_add(dp_hi, vec4n_bit_shift_right(dp_lo, DOT_SPLIT_BITS)); + dp_lo = vec4n_bit_and(dp_lo, low_bits); + } + + for ( ; i + 3 < len; i += 4) + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i), vec4n_permute_3_2_1_0(vec4n_load_unaligned(vec2+len-4-i)))); + + dp_hi = vec4n_add(dp_hi, vec4n_bit_shift_right(dp_lo, DOT_SPLIT_BITS)); + dp_lo = vec4n_bit_and(dp_lo, low_bits); + + ulong hsum_lo = vec4n_horizontal_sum(dp_lo); + const ulong hsum_hi = vec4n_horizontal_sum(dp_hi) + (hsum_lo >> DOT_SPLIT_BITS); + hsum_lo &= DOT_SPLIT_MASK; + + for (; i < len; i++) + hsum_lo += vec1[i] * vec2[len-1-i]; + ulong res; - slong i; + NMOD_RED(res, pow2_precomp * hsum_hi + hsum_lo, mod); + return res; +} +#else // defined(__AVX2__) +{ + ulong res; slong i; + _NMOD_VEC_DOT2_SPLIT(res, i, len, vec1[i], vec2[len-1-i], mod, pow2_precomp) + return res; +} +#endif // defined(__AVX2__) +#endif // FLINT_BITS == 64 + +ulong _nmod_vec_dot2_half_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT2_HALF(res, i, len, vec1[i], vec2[len-1-i], mod) + return res; +} + +ulong _nmod_vec_dot2_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT2(res, i, len, vec1[i], vec2[len-1-i], mod) + return res; +} + +ulong _nmod_vec_dot3_acc_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT3_ACC(res, i, len, vec1[i], vec2[len-1-i], mod) + return res; +} + +ulong _nmod_vec_dot3_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT3(res, i, len, vec1[i], vec2[len-1-i], mod) + return res; +} + +/*-----------------------------------------------*/ +/* dot product ptr: vec1[i] * vec2[i][offset] */ +/*-----------------------------------------------*/ - if (len <= 2 && nlimbs >= 2) +ulong _nmod_vec_dot_pow2_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT1(res, i, len, vec1[i], vec2[i][offset], mod) + return res; +} + +ulong _nmod_vec_dot1_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod) +#if defined(__AVX2__) && FLINT_BITS == 64 +{ + vec4n dp = vec4n_zero(); + + slong i = 0; + for ( ; i+31 < len; i += 32) { - if (len == 2) - return nmod_fmma(vec1[0], vec2[1], vec1[1], vec2[0], mod); - if (len == 1) - return nmod_mul(vec1[0], vec2[0], mod); - return 0; + vec4n vec2_4n; + vec2_4n = vec4n_set_n4(vec2[i+ 0][offset], vec2[i+ 1][offset], vec2[i+ 2][offset], vec2[i+ 3][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 0), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+ 4][offset], vec2[i+ 5][offset], vec2[i+ 6][offset], vec2[i+ 7][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 4), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+ 8][offset], vec2[i+ 9][offset], vec2[i+10][offset], vec2[i+11][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+ 8), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+12][offset], vec2[i+13][offset], vec2[i+14][offset], vec2[i+15][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+12), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+16][offset], vec2[i+17][offset], vec2[i+18][offset], vec2[i+19][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+16), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+20][offset], vec2[i+21][offset], vec2[i+22][offset], vec2[i+23][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+20), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+24][offset], vec2[i+25][offset], vec2[i+26][offset], vec2[i+27][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+24), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+28][offset], vec2[i+29][offset], vec2[i+30][offset], vec2[i+31][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i+28), vec2_4n)); } - NMOD_VEC_DOT(res, i, len, vec1[i], vec2[len - 1 - i], mod, nlimbs); + for ( ; i + 3 < len; i += 4) + { + vec4n vec2_4n = vec4n_set_n4(vec2[i+0][offset], vec2[i+1][offset], vec2[i+2][offset], vec2[i+3][offset]); + dp = vec4n_add(dp, vec4n_mul(vec4n_load_unaligned(vec1+i), vec2_4n)); + } + + ulong res = vec4n_horizontal_sum(dp); + + for (; i < len; i++) + res += vec1[i] * vec2[i][offset]; + + NMOD_RED(res, res, mod); + return res; +} +#else // if defined(__AVX2__) && FLINT_BITS == 64 +{ + ulong res; slong i; + _NMOD_VEC_DOT1(res, i, len, vec1[i], vec2[i][offset], mod) return res; } +#endif // if defined(__AVX2__) && FLINT_BITS == 64 + +#if FLINT_BITS == 64 +ulong _nmod_vec_dot2_split_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, ulong pow2_precomp) +#if defined(__AVX2__) +{ + const vec4n low_bits = vec4n_set_n(DOT_SPLIT_MASK); + vec4n dp_lo = vec4n_zero(); + vec4n dp_hi = vec4n_zero(); + + slong i = 0; + for ( ; i+31 < len; i += 32) + { + vec4n vec2_4n; + vec2_4n = vec4n_set_n4(vec2[i+ 0][offset], vec2[i+ 1][offset], vec2[i+ 2][offset], vec2[i+ 3][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 0), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+ 4][offset], vec2[i+ 5][offset], vec2[i+ 6][offset], vec2[i+ 7][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 4), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+ 8][offset], vec2[i+ 9][offset], vec2[i+10][offset], vec2[i+11][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+ 8), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+12][offset], vec2[i+13][offset], vec2[i+14][offset], vec2[i+15][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+12), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+16][offset], vec2[i+17][offset], vec2[i+18][offset], vec2[i+19][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+16), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+20][offset], vec2[i+21][offset], vec2[i+22][offset], vec2[i+23][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+20), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+24][offset], vec2[i+25][offset], vec2[i+26][offset], vec2[i+27][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+24), vec2_4n)); + vec2_4n = vec4n_set_n4(vec2[i+28][offset], vec2[i+29][offset], vec2[i+30][offset], vec2[i+31][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i+28), vec2_4n)); + + dp_hi = vec4n_add(dp_hi, vec4n_bit_shift_right(dp_lo, DOT_SPLIT_BITS)); + dp_lo = vec4n_bit_and(dp_lo, low_bits); + } + + for ( ; i + 3 < len; i += 4) + { + vec4n vec2_4n = vec4n_set_n4(vec2[i+0][offset], vec2[i+1][offset], vec2[i+2][offset], vec2[i+3][offset]); + dp_lo = vec4n_add(dp_lo, vec4n_mul(vec4n_load_unaligned(vec1+i), vec2_4n)); + } + + dp_hi = vec4n_add(dp_hi, vec4n_bit_shift_right(dp_lo, DOT_SPLIT_BITS)); + dp_lo = vec4n_bit_and(dp_lo, low_bits); + + ulong hsum_lo = vec4n_horizontal_sum(dp_lo); + const ulong hsum_hi = vec4n_horizontal_sum(dp_hi) + (hsum_lo >> DOT_SPLIT_BITS); + hsum_lo &= DOT_SPLIT_MASK; + + for (; i < len; i++) + hsum_lo += vec1[i] * vec2[i][offset]; + + ulong res; + NMOD_RED(res, pow2_precomp * hsum_hi + hsum_lo, mod); + return res; +} +#else // defined(__AVX2__) +{ + ulong res; slong i; + _NMOD_VEC_DOT2_SPLIT(res, i, len, vec1[i], vec2[i][offset], mod, pow2_precomp) + return res; +} +#endif // defined(__AVX2__) +#endif // FLINT_BITS == 64 + +ulong _nmod_vec_dot2_half_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT2_HALF(res, i, len, vec1[i], vec2[i][offset], mod) + return res; +} + +ulong _nmod_vec_dot2_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT2(res, i, len, vec1[i], vec2[i][offset], mod) + return res; +} + +ulong _nmod_vec_dot3_acc_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT3_ACC(res, i, len, vec1[i], vec2[i][offset], mod) + return res; +} + +ulong _nmod_vec_dot3_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod) +{ + ulong res; slong i; + _NMOD_VEC_DOT3(res, i, len, vec1[i], vec2[i][offset], mod) + return res; +} + +/*----------------------------------------*/ +/* notes concerning the different methods */ +/*----------------------------------------*/ + +// Why no vectorization in the general NMOD_VEC_DOT macro? +// attempts at vectorized versions (2024-06-16, for methods _DOT1, +// _DOT2_SPLIT) did not show an advantage except in "regular" cases where +// memory accesses are fast (typically, expr = v[i] or expr = v[len - 1 -i]). +// For these, there is dedicated code anyway. + +// 2024-06-16 _DOT2_HALF is slightly faster than _DOT2 +// 2024-06-16 _DOT3_ACC is slightly faster than _DOT3 + +// 3 limbs, conditions mod.n <= UWORD(6521908912666391107): +// we can accumulate 8 terms if n == mod.n is such that +// 8 * (n-1)**2 < 2**(2*FLINT_BITS), this is equivalent to +// n <= ceil(sqrt(2**(2*FLINT_BITS-3))) + +/*---------------------------------------------*/ +/* dot product for small modulus via splitting */ +/*---------------------------------------------*/ + +// in short: with current DOT_SPLIT_BITS value 56, +// -> modulus n up to about 2**30.5 +// (more precisely, n <= 1515531528) +// -> length of dot product up to at least 380368697 +// (more precisely, len * (n-1)**3 < 2**120 + 2**56 - 2**112) + +// APPROACH: +// +// Let n = mod.n, s = DOT_SPLIT_BITS +// As input, take pow2_precomp == 2**s % n +// +// -> avoiding modular reductions altogether, compute dp_lo and dp_hi such that +// the dot product without modular reduction is dp = dp_lo + 2**s * dp_hi +// -> finally, compute (dp_lo + pow2_precomp * dp_hi) % n +// -> done through repeating this: accumulate a few terms, +// move higher bits to dp_hi and keep lower ones in dp_lo + +// PARAMETER CONSTRAINTS: +// +// 2024-06-16: currently, the code accumulates 8 terms as this showed slightly better performance +// +// -> constraint (C0-8): +// if we accumulate 8 terms (each a product of two integers reduced modulo n) +// on top of an s-bit integer, we require +// 2**s - 1 + 8 * (n-1)**2 < 2**64 +// so one can take any modulus with +// n <= 1 + floor(sqrt(2**61 - 2**(s-3))) +// in particular, n-1 < 2**30.5, (n-1)**2 < 2**61, (n-1)**3 < 2**91.5 +// +// -> constraint (C0-4): +// similarly, if we accumulate 4 terms on top of an s-bit integer, we require +// 2**s - 1 + 4 * (n-1)**2 < 2**64 +// so one can take any modulus with +// n <= 1 + floor(sqrt(2**62 - 2**(s-2))) +// in particular, n-1 < 2**30.5, (n-1)**2 < 2**61, (n-1)**3 < 2**91.5 +// +// -> constraint (C1): +// in the above representation of dp we will use a ulong for dp_hi, +// so we require len * (n-1)**2 <= 2**s * (2**64 - 1) +// which is less restrictive than the below (C2) +// +// -> constraint (C2): +// for dp_lo + pow2_precomp * dp_hi to fit in a single word, we require +// 2**s - 1 + (n-1) dp_hi < 2**64. +// Since dp_hi <= len * (n-1)**2 / 2**s, it suffices to ensure +// len * (n-1)**3 < 2**s * (2**64 + 1 - 2**s) +// +// sage: for s in range(40,64): +// ....: nmax8 = 1 + floor(sqrt(2**61 - 2**(s-3))) # (C0-8) +// ....: nmax4 = 1 + floor(sqrt(2**62 - 2**(s-2))) # (C0-4) +// ....: lenmax4 = floor(2**s * (2**64 - 1) / (nmax4-1)**2) # (C1) +// ....: lenmax4_bis = ceil(2**s * (2**64 + 1 - 2**s) / (nmax4-1)**3) - 1 # (C2) +// ....: lenmax8 = floor(2**s * (2**64 - 1) / (nmax8-1)**2) # (C1) +// ....: lenmax8_bis = ceil(2**s * (2**64 + 1 - 2**s) / (nmax8-1)**3) - 1 # (C2) +// ....: print(f"{s}\t{nmax.nbits()}\t{nmax8}\t{lenmax8_bis}\t{nmax4}\t{lenmax4_bis}") +// ....: +// s nbits nmax8 (C2) for nmax8 nmax4 (C2) for nmax4 +// 40 31 1518500205 5792 2147483584 2048 +// 41 31 1518500160 11585 2147483520 4096 +// 42 31 1518500069 23170 2147483392 8192 +// 43 31 1518499888 46340 2147483136 16384 +// 44 31 1518499526 92681 2147482624 32768 +// 45 31 1518498802 185363 2147481600 65536 +// 46 31 1518497354 370728 2147479552 131072 +// 47 31 1518494458 741458 2147475456 262145 +// 48 31 1518488665 1482921 2147467264 524292 +// 49 31 1518477080 2965866 2147450880 1048592 +// 50 31 1518453909 5931822 2147418111 2097216 +// 51 31 1518407566 11864007 2147352572 4194560 +// 52 31 1518314875 23729463 2147221488 8389632 +// 53 31 1518129478 47464722 2146959296 16781313 +// 54 31 1517758614 94952640 2146434816 33570828 +// 55 31 1517016615 189998167 2145385471 67174496 +// 56 31 1515531528 380368697 2143285240 134480642 +// 57 31 1512556978 762233438 2139078592 269490216 +// 58 31 1506590261 1530504392 2130640379 541115017 +// 59 31 1494585366 3085595597 2113662895 1090922784 +// 60 31 1470281545 6273201268 2079292102 2217911575 +// 61 31 1420426920 12986760413 2008787014 4591513178 +// 62 31 1315059793 28054608908 1859775394 9918802104 +// 63 31 1073741825 68719476736 1518500250 24296004047 + diff --git a/src/nmod_vec/profile/p-dot.c b/src/nmod_vec/profile/p-dot.c index 7b782a1f06..6d226710be 100644 --- a/src/nmod_vec/profile/p-dot.c +++ b/src/nmod_vec/profile/p-dot.c @@ -38,21 +38,26 @@ void nmod_mat_rand(nmod_mat_t mat, flint_rand_t state) /* direct: dot / dot_rev / dot expr */ /*------------------------------------*/ +// timings excluding dot_params void time_dot(ulong len, ulong n, flint_rand_t state) { nmod_t mod; nmod_init(&mod, n); - const int n_limbs = _nmod_vec_dot_bound_limbs(len, mod); + const dot_params_t params = _nmod_vec_dot_params(len, mod); nn_ptr v1 = _nmod_vec_init(len); _nmod_vec_rand(v1, state, len, mod); nn_ptr v2 = _nmod_vec_init(len); _nmod_vec_rand(v2, state, len, mod); + // store results in volatile variable to avoid that they + // are "optimized away" (especially for inlined part) + volatile ulong FLINT_SET_BUT_UNUSED(res); + double FLINT_SET_BUT_UNUSED(tcpu), twall; TIMEIT_START - _nmod_vec_dot(v1, v2, len, mod, n_limbs); + res = _nmod_vec_dot(v1, v2, len, mod, params); TIMEIT_STOP_VALUES(tcpu, twall) printf("%.2e", twall); @@ -65,17 +70,21 @@ void time_dot_rev(ulong len, ulong n, flint_rand_t state) { nmod_t mod; nmod_init(&mod, n); - const int n_limbs = _nmod_vec_dot_bound_limbs(len, mod); + const dot_params_t params = _nmod_vec_dot_params(len, mod); nn_ptr v1 = _nmod_vec_init(len); _nmod_vec_rand(v1, state, len, mod); nn_ptr v2 = _nmod_vec_init(len); _nmod_vec_rand(v2, state, len, mod); + // store results in volatile variable to avoid that they + // are "optimized away" (especially for inlined part) + volatile ulong FLINT_SET_BUT_UNUSED(res); + double FLINT_SET_BUT_UNUSED(tcpu), twall; TIMEIT_START - _nmod_vec_dot_rev(v1, v2, len, mod, n_limbs); + res = _nmod_vec_dot_rev(v1, v2, len, mod, params); TIMEIT_STOP_VALUES(tcpu, twall) printf("%.2e", twall); @@ -84,25 +93,59 @@ void time_dot_rev(ulong len, ulong n, flint_rand_t state) _nmod_vec_clear(v2); } -void time_dot_expr(ulong len, ulong n, flint_rand_t state) +void time_dot_ptr(ulong len, ulong n, flint_rand_t state) { nmod_t mod; nmod_init(&mod, n); - const int n_limbs = _nmod_vec_dot_bound_limbs(len, mod); + const dot_params_t params = _nmod_vec_dot_params(len, mod); + + const ulong offset = UWORD(7); + + nn_ptr v1 = _nmod_vec_init(len); + _nmod_vec_rand(v1, state, len, mod); + nn_ptr v2tmp = _nmod_vec_init(len); + _nmod_vec_rand(v2tmp, state, len, mod); + nn_ptr * v2 = flint_malloc(sizeof(nn_ptr) * len); + for (ulong i = 0; i < len; i++) + v2[i] = &v2tmp[i] + offset; - nn_ptr v1 = _nmod_vec_init(9*len); - _nmod_vec_rand(v1, state, 9*len, mod); - nn_ptr v2 = _nmod_vec_init(9*len); - _nmod_vec_rand(v2, state, 9*len, mod); + // store results in volatile variable to avoid that they + // are "optimized away" (especially for inlined part) + volatile ulong FLINT_SET_BUT_UNUSED(res); double FLINT_SET_BUT_UNUSED(tcpu), twall; - nn_srcptr v1i = v1; - nn_srcptr v2i = v2; - ulong i, FLINT_SET_BUT_UNUSED(res); + TIMEIT_START + res = _nmod_vec_dot_ptr(v1, v2, offset, len, mod, params); + TIMEIT_STOP_VALUES(tcpu, twall) + + printf("%.2e", twall); + + _nmod_vec_clear(v1); + _nmod_vec_clear(v2tmp); + flint_free(v2); +} + +// timings including dot_params +void time_dot_incparams(ulong len, ulong n, flint_rand_t state) +{ + nmod_t mod; + nmod_init(&mod, n); + + nn_ptr v1 = _nmod_vec_init(len); + _nmod_vec_rand(v1, state, len, mod); + nn_ptr v2 = _nmod_vec_init(len); + _nmod_vec_rand(v2, state, len, mod); + + // store results in volatile variable to avoid that they + // are "optimized away" (especially for inlined part) + volatile ulong FLINT_SET_BUT_UNUSED(res); + + double FLINT_SET_BUT_UNUSED(tcpu), twall; TIMEIT_START - NMOD_VEC_DOT(res, i, len, v1i[9*len - 1 - 9*i], v2i[9*len - 1 - 9*i], mod, n_limbs); + const dot_params_t params = _nmod_vec_dot_params(len, mod); + res = _nmod_vec_dot(v1, v2, len, mod, params); TIMEIT_STOP_VALUES(tcpu, twall) printf("%.2e", twall); @@ -111,13 +154,70 @@ void time_dot_expr(ulong len, ulong n, flint_rand_t state) _nmod_vec_clear(v2); } +void time_dot_rev_incparams(ulong len, ulong n, flint_rand_t state) +{ + nmod_t mod; + nmod_init(&mod, n); + + nn_ptr v1 = _nmod_vec_init(len); + _nmod_vec_rand(v1, state, len, mod); + nn_ptr v2 = _nmod_vec_init(len); + _nmod_vec_rand(v2, state, len, mod); + + // store results in volatile variable to avoid that they + // are "optimized away" (especially for inlined part) + volatile ulong FLINT_SET_BUT_UNUSED(res); + + double FLINT_SET_BUT_UNUSED(tcpu), twall; + + TIMEIT_START + const dot_params_t params = _nmod_vec_dot_params(len, mod); + res = _nmod_vec_dot_rev(v1, v2, len, mod, params); + TIMEIT_STOP_VALUES(tcpu, twall) + + printf("%.2e", twall); + + _nmod_vec_clear(v1); + _nmod_vec_clear(v2); +} + +void time_dot_ptr_incparams(ulong len, ulong n, flint_rand_t state) +{ + nmod_t mod; + nmod_init(&mod, n); + + const ulong offset = UWORD(7); + + nn_ptr v1 = _nmod_vec_init(len); + _nmod_vec_rand(v1, state, len, mod); + nn_ptr v2tmp = _nmod_vec_init(len); + _nmod_vec_rand(v2tmp, state, len, mod); + nn_ptr * v2 = flint_malloc(sizeof(nn_ptr) * len); + for (ulong i = 0; i < len; i++) + v2[i] = &v2tmp[i] + offset; + + // store results in volatile variable to avoid that they + // are "optimized away" (especially for inlined part) + volatile ulong FLINT_SET_BUT_UNUSED(res); + + double FLINT_SET_BUT_UNUSED(tcpu), twall; + + TIMEIT_START + const dot_params_t params = _nmod_vec_dot_params(len, mod); + res = _nmod_vec_dot_ptr(v1, v2, offset, len, mod, params); + TIMEIT_STOP_VALUES(tcpu, twall) + + printf("%.2e", twall); + + _nmod_vec_clear(v1); + _nmod_vec_clear(v2tmp); + flint_free(v2); +} + /*-------------------------*/ /* indirect: poly */ /*-------------------------*/ -// void _nmod_poly_inv_series_basecase_preinv1(nn_ptr Qinv, nn_srcptr Q, slong Qlen, slong n, ulong q, nmod_t mod) -// void _nmod_poly_exp_series(nn_ptr f, nn_srcptr h, slong hlen, slong n, nmod_t mod) - void time_dot_poly_mul(ulong len, ulong n, flint_rand_t state) { if (len > 10000) @@ -150,7 +250,7 @@ void time_dot_poly_mul(ulong len, ulong n, flint_rand_t state) void time_dot_poly_inv_series(ulong len, ulong n, flint_rand_t state) { - if (len > 10000 || n == (UWORD(1) << 63)) + if (len > 10000 || n % 2 == 0) { printf(" "); return; @@ -175,10 +275,9 @@ void time_dot_poly_inv_series(ulong len, ulong n, flint_rand_t state) _nmod_vec_clear(res); } - void time_dot_poly_exp_series(ulong len, ulong n, flint_rand_t state) { - if (len > 10000 || n == (UWORD(1) << 63)) + if (len > 10000 || n % 2 == 0) { printf(" "); return; @@ -186,7 +285,7 @@ void time_dot_poly_exp_series(ulong len, ulong n, flint_rand_t state) gr_ctx_t ctx; gr_ctx_init_nmod(ctx, n); - int status; + int FLINT_SET_BUT_UNUSED(status); gr_poly_t p; gr_poly_init(p, ctx); @@ -199,9 +298,6 @@ void time_dot_poly_exp_series(ulong len, ulong n, flint_rand_t state) TIMEIT_START status |= gr_poly_exp_series_basecase(res, p, len, ctx); - -// int gr_poly_exp_series_basecase(gr_poly_t f, const gr_poly_t h, slong n, gr_ctx_t ctx) - TIMEIT_STOP_VALUES(tcpu, twall) printf("%.2e", twall); @@ -210,8 +306,6 @@ void time_dot_poly_exp_series(ulong len, ulong n, flint_rand_t state) gr_poly_clear(res, ctx); } - - /*-------------------------*/ /* indirect: mat */ /*-------------------------*/ @@ -271,7 +365,7 @@ void time_dot_mat_mul_vec(ulong len, ulong n, flint_rand_t state) void time_dot_mat_solve_tril(ulong len, ulong n, flint_rand_t state) { - if (len > 4000 || n == (UWORD(1) << 63)) + if (len > 4000 || n % 2 == 0) { printf(" "); return; @@ -296,7 +390,7 @@ void time_dot_mat_solve_tril(ulong len, ulong n, flint_rand_t state) void time_dot_mat_solve_tril_vec(ulong len, ulong n, flint_rand_t state) { - if (len > 10000 || n == (UWORD(1) << 63)) + if (len > 10000 || n % 2 == 0) { printf(" "); return; @@ -321,7 +415,7 @@ void time_dot_mat_solve_tril_vec(ulong len, ulong n, flint_rand_t state) void time_dot_mat_solve_triu(ulong len, ulong n, flint_rand_t state) { - if (len > 4000 || n == (UWORD(1) << 63)) + if (len > 4000 || n % 2 == 0) { printf(" "); return; @@ -346,7 +440,7 @@ void time_dot_mat_solve_triu(ulong len, ulong n, flint_rand_t state) void time_dot_mat_solve_triu_vec(ulong len, ulong n, flint_rand_t state) { - if (len > 10000 || n == (UWORD(1)<<63)) + if (len > 10000 || n % 2 == 0) { printf(" "); return; @@ -381,46 +475,67 @@ int main(int argc, char ** argv) flint_rand_set_seed(state, time(NULL), time(NULL)+129384125L); // modulus bitsize - const slong nbits = 12; - const ulong bits[] = {0, 12, 28, 30, 31, 32, 40, 50, 60, 61, 62, 63, 64}; + const slong nbits = 14; + const ulong bits[] = {12, 28, 30, 31, 32, 40, 50, 60, 61, 62, 63, 64, 232, 263}; // vector lengths - const slong nlens = 14; - const ulong lens[] = {1, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 100000, 1000000}; + const slong nlens = 20; + const ulong lens[] = {1, 2, 3, 4, 5, 7, 10, 15, 25, 35, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 100000, 1000000}; // bench functions - const slong nfuns = 12; + const slong nfuns = 15; typedef void (*timefun) (ulong, ulong, flint_rand_t); const timefun funs[] = { time_dot, // 0 - time_dot_rev, // 1 - time_dot_expr, // 2 - time_dot_poly_mul, // 3 - time_dot_poly_inv_series, // 4 - time_dot_poly_exp_series, // 5 - time_dot_mat_mul, // 6 - time_dot_mat_solve_tril, // 7 - time_dot_mat_solve_triu, // 8 - time_dot_mat_mul_vec, // 9 - time_dot_mat_solve_tril_vec, // 10 - time_dot_mat_solve_triu_vec, // 11 + time_dot_incparams, // 1 + time_dot_rev, // 2 + time_dot_rev_incparams, // 3 + time_dot_ptr, // 4 + time_dot_ptr_incparams, // 5 + time_dot_poly_mul, // 6 + time_dot_poly_inv_series, // 7 + time_dot_poly_exp_series, // 8 + time_dot_mat_mul, // 9 + time_dot_mat_solve_tril, // 10 + time_dot_mat_solve_triu, // 11 + time_dot_mat_mul_vec, // 12 + time_dot_mat_solve_tril_vec, // 13 + time_dot_mat_solve_triu_vec, // 14 }; const char * description[] = { - "#0 --> vec dot ", - "#1 --> vec dot rev ", - "#2 --> vec dot expr ", - "#3 --> poly_mul ", - "#4 --> poly_inv_series ", - "#5 --> poly_exp_series ", - "#6 --> mat_mul ", - "#7 --> mat_solve_tril ", - "#8 --> mat_solve_triu ", - "#9 --> mat_mul_vec ", - "#10 --> mat_solve_tril_vec", - "#11 --> mat_solve_triu_vec" + "#0 --> vec dot ", + "#1 --> vec dot inc params ", + "#2 --> vec dot rev ", + "#3 --> vec dot rev inc params", + "#4 --> vec dot ptr ", + "#5 --> vec dot ptr inc params", + "#6 --> poly_mul ", + "#7 --> poly_inv_series ", + "#8 --> poly_exp_series ", + "#9 --> mat_mul ", + "#10 --> mat_solve_tril ", + "#11 --> mat_solve_triu ", + "#12 --> mat_mul_vec ", + "#13 --> mat_solve_tril_vec ", + "#14 --> mat_solve_triu_vec " }; + if (argc == 1) // show usage + { + printf("Usage: `%s [fun] [nbits] [len]`\n", argv[0]); + printf(" Each argument is optional; no argument shows this help.\n"); + printf(" - fun: id number of the timed function (see below),\n"); + printf(" exception: fun == -1 times all available functions successively\n"); + printf(" - nbits: number of bits for the modulus, chosen as nextprime(2**(nbits-1))\n"); + printf(" exception: nbits == 232 and 263 (moduli 2**32, 2**63)\n"); + printf(" - len: length for the vector, row and column dimension for the matrices\n"); + printf("\nAvailable functions:\n"); + for (slong j = 0; j < nfuns; j++) + printf(" %s\n", description[j]); + + return 0; + } printf("#warmup... "); for (slong i = 0; i < 10; i++) @@ -430,7 +545,7 @@ int main(int argc, char ** argv) } printf("\n"); - if (argc == 1) // launching full suite + if (argc == 2 && atoi(argv[1]) == -1) // launching full suite { for (slong ifun = 0; ifun < nfuns; ifun++) { @@ -447,7 +562,13 @@ int main(int argc, char ** argv) const slong b = bits[j]; printf("%-10ld", b); - const ulong n = (b==0) ? (UWORD(1) << 63) : n_nextprime(UWORD(1) << (b-1), 0); + ulong n; + if (b == 232) + n = UWORD(1) << 32; + else if (b == 263) + n = UWORD(1) << 63; + else + n = n_nextprime(UWORD(1) << (b-1), 0); for (slong i = 0; i < nlens; i++) { tfun(lens[i], n, state); @@ -472,7 +593,13 @@ int main(int argc, char ** argv) const slong b = bits[j]; printf("%-10ld", b); - const ulong n = (b==0) ? (UWORD(1) << 63) : n_nextprime(UWORD(1) << (b-1), 0); + ulong n; + if (b == 232) + n = UWORD(1) << 32; + else if (b == 263) + n = UWORD(1) << 63; + else + n = n_nextprime(UWORD(1) << (b-1), 0); for (slong i = 0; i < nlens; i++) { tfun(lens[i], n, state); @@ -493,7 +620,13 @@ int main(int argc, char ** argv) printf("\n"); printf("%-10ld", b); - const ulong n = (b==0) ? (UWORD(1) << 63) : n_nextprime(UWORD(1) << (b-1), 0); + ulong n; + if (b == 232) + n = UWORD(1) << 32; + else if (b == 263) + n = UWORD(1) << 63; + else + n = n_nextprime(UWORD(1) << (b-1), 0); for (slong i = 0; i < nlens; i++) { tfun(lens[i], n, state); @@ -514,7 +647,13 @@ int main(int argc, char ** argv) printf("\n"); printf("%-10ld", b); - const ulong n = (b==0) ? (UWORD(1) << 63) : n_nextprime(UWORD(1) << (b-1), 0); + ulong n; + if (b == 232) + n = UWORD(1) << 32; + else if (b == 263) + n = UWORD(1) << 63; + else + n = n_nextprime(UWORD(1) << (b-1), 0); tfun(len, n, state); printf("\n"); diff --git a/src/nmod_vec/test/main.c b/src/nmod_vec/test/main.c index 71ad164969..246cfbd272 100644 --- a/src/nmod_vec/test/main.c +++ b/src/nmod_vec/test/main.c @@ -13,7 +13,7 @@ #include "t-add_sub_neg.c" #include "t-discrete_log_pohlig_hellman.c" -#include "t-dot_bound_limbs.c" +#include "t-dot_nlimbs.c" #include "t-dot.c" #include "t-dot_ptr.c" #include "t-nmod.c" @@ -29,7 +29,7 @@ test_struct tests[] = { TEST_FUNCTION(nmod_vec_add_sub_neg), TEST_FUNCTION(nmod_vec_discrete_log_pohlig_hellman), - TEST_FUNCTION(nmod_vec_dot_bound_limbs), + TEST_FUNCTION(_nmod_vec_dot_params), TEST_FUNCTION(nmod_vec_dot), TEST_FUNCTION(nmod_vec_dot_ptr), TEST_FUNCTION(nmod_vec_nmod), diff --git a/src/nmod_vec/test/t-dot.c b/src/nmod_vec/test/t-dot.c index e1215a57c3..b85ee44acd 100644 --- a/src/nmod_vec/test/t-dot.c +++ b/src/nmod_vec/test/t-dot.c @@ -24,7 +24,6 @@ TEST_FUNCTION_START(nmod_vec_dot, state) nmod_t mod; ulong m, res; nn_ptr x, y; - int limbs1; mpz_t s, t; slong j; @@ -39,9 +38,9 @@ TEST_FUNCTION_START(nmod_vec_dot, state) _nmod_vec_randtest(x, state, len, mod); _nmod_vec_randtest(y, state, len, mod); - limbs1 = _nmod_vec_dot_bound_limbs(len, mod); + const dot_params_t params = _nmod_vec_dot_params(len, mod); - res = _nmod_vec_dot(x, y, len, mod, limbs1); + res = _nmod_vec_dot(x, y, len, mod, params); mpz_init(s); mpz_init(t); @@ -59,7 +58,7 @@ TEST_FUNCTION_START(nmod_vec_dot, state) "m = %wu\n" "len = %wd\n" "limbs1 = %d\n", - m, len, limbs1); + m, len, params); mpz_clear(s); mpz_clear(t); diff --git a/src/nmod_vec/test/t-dot_bound_limbs.c b/src/nmod_vec/test/t-dot_nlimbs.c similarity index 55% rename from src/nmod_vec/test/t-dot_bound_limbs.c rename to src/nmod_vec/test/t-dot_nlimbs.c index ac050d18e3..3e48a9738d 100644 --- a/src/nmod_vec/test/t-dot_bound_limbs.c +++ b/src/nmod_vec/test/t-dot_nlimbs.c @@ -1,5 +1,6 @@ /* Copyright (C) 2011 Fredrik Johansson + Copyright (C) 2024 Vincent Neiger This file is part of FLINT. @@ -14,7 +15,7 @@ #include "nmod.h" #include "nmod_vec.h" -TEST_FUNCTION_START(nmod_vec_dot_bound_limbs, state) +TEST_FUNCTION_START(_nmod_vec_dot_params, state) { int i; @@ -23,7 +24,8 @@ TEST_FUNCTION_START(nmod_vec_dot_bound_limbs, state) slong len; nmod_t mod; ulong m; - int limbs1, limbs2; + int nlimbs1, nlimbs2, nlimbs3; + dot_params_t params; mpz_t t; len = n_randint(state, 10000) + 1; @@ -31,22 +33,33 @@ TEST_FUNCTION_START(nmod_vec_dot_bound_limbs, state) nmod_init(&mod, m); - limbs1 = _nmod_vec_dot_bound_limbs(len, mod); + params = _nmod_vec_dot_params(len, mod); + nlimbs1 = _nmod_vec_dot_bound_limbs_from_params(len, mod, params); + nlimbs2 = _nmod_vec_dot_bound_limbs(len, mod); mpz_init2(t, 4*FLINT_BITS); flint_mpz_set_ui(t, m-1); mpz_mul(t, t, t); flint_mpz_mul_ui(t, t, len); - limbs2 = mpz_size(t); + nlimbs3 = mpz_size(t); - if (limbs1 != limbs2) + if (nlimbs1 != nlimbs3) TEST_FUNCTION_FAIL( "m = %wu\n" "len = %wd\n" - "limbs1 = %d\n" - "limbs2 = %d\n" + "nlimbs1(from params) = %d\n" + "nlimbs3(mpz) = %d\n" "bound: %{mpz}\n", - m, len, limbs1, limbs2, t); + m, len, nlimbs1, nlimbs3, t); + + if (nlimbs2 != nlimbs3) + TEST_FUNCTION_FAIL( + "m = %wu\n" + "len = %wd\n" + "nlimbs2(from len+mod) = %d\n" + "nlimbs3(mpz) = %d\n" + "bound: %{mpz}\n", + m, len, nlimbs2, nlimbs3, t); mpz_clear(t); } diff --git a/src/nmod_vec/test/t-dot_ptr.c b/src/nmod_vec/test/t-dot_ptr.c index cc8485e71a..191638973e 100644 --- a/src/nmod_vec/test/t-dot_ptr.c +++ b/src/nmod_vec/test/t-dot_ptr.c @@ -24,7 +24,6 @@ TEST_FUNCTION_START(nmod_vec_dot_ptr, state) ulong m, res, res2; nn_ptr x, y; nn_ptr * z; - int limbs1; slong j, offset; len = n_randint(state, 1000) + 1; @@ -43,17 +42,17 @@ TEST_FUNCTION_START(nmod_vec_dot_ptr, state) for (j = 0; j < len; j++) z[j] = &y[j] + offset; - limbs1 = _nmod_vec_dot_bound_limbs(len, mod); + const dot_params_t params = _nmod_vec_dot_params(len, mod); - res = _nmod_vec_dot_ptr(x, z, -offset, len, mod, limbs1); - res2 = _nmod_vec_dot(x, y, len, mod, limbs1); + res = _nmod_vec_dot_ptr(x, z, -offset, len, mod, params); + res2 = _nmod_vec_dot(x, y, len, mod, params); if (res != res2) TEST_FUNCTION_FAIL( "m = %wu\n" "len = %wd\n" - "limbs1 = %d\n", - m, len, limbs1); + "method = %d\n", + m, len, params.method); _nmod_vec_clear(x); _nmod_vec_clear(y);