Skip to content

Commit

Permalink
Merge pull request #2038 from fredrik-johansson/matmul
Browse files Browse the repository at this point in the history
Linear algebra tuning for nfloat + cmpabs
  • Loading branch information
fredrik-johansson committed Jul 20, 2024
2 parents 91f0ece + 99bd846 commit 05340d2
Show file tree
Hide file tree
Showing 8 changed files with 539 additions and 19 deletions.
7 changes: 7 additions & 0 deletions doc/source/nfloat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ Matrix functions

Different implementations of matrix multiplication.

.. function:: int nfloat_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
int nfloat_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
int nfloat_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx)

Internal functions
-------------------------------------------------------------------------------

Expand Down Expand Up @@ -417,3 +421,6 @@ real pairs.
int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx)
int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx)
103 changes: 95 additions & 8 deletions src/gr/acb.c
Original file line number Diff line number Diff line change
Expand Up @@ -987,21 +987,108 @@ _gr_acb_cmp(int * res, const acb_t x, const acb_t y, const gr_ctx_t ctx)
}
}

int
_gr_arb_cmpabs(int * res, const arb_t x, const arb_t y, const gr_ctx_t ctx);

int
_gr_acb_cmpabs(int * res, const acb_t x, const acb_t y, const gr_ctx_t ctx)
{
acb_t t, u;
if (arb_is_zero(acb_imagref(x)) && arb_is_zero(acb_imagref(y)))
{
arb_srcptr a = acb_realref(x);
arb_srcptr c = acb_realref(y);

*t = *x;
*u = *y;
/* OK; ignores the context object */
return _gr_arb_cmpabs(res, a, c, ctx);
}
else
{
slong prec = ACB_CTX_PREC(ctx);
int status = GR_SUCCESS;

if (arf_sgn(arb_midref(acb_realref(t))) < 0)
ARF_NEG(arb_midref(acb_realref(t)));
arb_srcptr a = acb_realref(x);
arb_srcptr b = acb_imagref(x);
arb_srcptr c = acb_realref(y);
arb_srcptr d = acb_imagref(y);

mag_t xlo, xhi, ylo, yhi, t;

mag_init(xlo);
mag_init(xhi);
mag_init(ylo);
mag_init(yhi);
mag_init(t);

arb_get_mag_lower(xlo, a);
arb_get_mag_lower(t, b);
mag_mul_lower(xlo, xlo, xlo);
mag_mul_lower(t, t, t);
mag_add_lower(xlo, xlo, t);

arb_get_mag_lower(ylo, c);
arb_get_mag_lower(t, d);
mag_mul_lower(ylo, ylo, ylo);
mag_mul_lower(t, t, t);
mag_add_lower(ylo, ylo, t);

arb_get_mag(xhi, a);
arb_get_mag(t, b);
mag_mul(xhi, xhi, xhi);
mag_mul(t, t, t);
mag_add(xhi, xhi, t);

arb_get_mag(yhi, c);
arb_get_mag(t, d);
mag_mul(yhi, yhi, yhi);
mag_mul(t, t, t);
mag_add(yhi, yhi, t);

if (mag_cmp(xhi, ylo) < 0)
{
*res = -1;
status = GR_SUCCESS;
}
else if (mag_cmp(xlo, yhi) > 0)
{
*res = 1;
status = GR_SUCCESS;
}
else
{
arb_t t, u;

if (arf_sgn(arb_midref(acb_realref(u))) < 0)
ARF_NEG(arb_midref(acb_realref(u)));
arb_init(t);
arb_init(u);

return _gr_acb_cmp(res, t, u, ctx);
arb_mul(t, a, a, prec);
arb_addmul(t, b, b, prec);
arb_mul(u, c, c, prec);
arb_addmul(u, d, d, prec);

if ((arb_is_exact(t) && arb_is_exact(u)) || !arb_overlaps(t, u))
{
*res = arf_cmp(arb_midref(t), arb_midref(u));
status = GR_SUCCESS;
}
else
{
/* todo: worth it to do an exact computation? */
*res = 0;
status = GR_UNABLE;
}

arb_clear(t);
arb_clear(u);
}

mag_clear(xlo);
mag_clear(xhi);
mag_clear(ylo);
mag_clear(yhi);
mag_clear(t);

return status;
}
}

int
Expand Down
137 changes: 134 additions & 3 deletions src/gr/acf.c
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,144 @@ _gr_acf_cmp(int * res, const acf_t x, const acf_t y, const gr_ctx_t ctx)
return GR_SUCCESS;
}

/* ignores ctx, so we can pass in the acf context */
int
_gr_arf_cmpabs(int * res, const arf_t x, const arf_t y, const gr_ctx_t ctx);

#include "double_extras.h"

int
_gr_acf_cmpabs(int * res, const acf_t x, const acf_t y, const gr_ctx_t ctx)
{
if (!arf_is_zero(acf_imagref(x)) || !arf_is_zero(acf_imagref(y)))
return GR_UNABLE;
arf_srcptr a = acf_realref(x);
arf_srcptr b = acf_imagref(x);
arf_srcptr c = acf_realref(y);
arf_srcptr d = acf_imagref(y);

if (arf_is_zero(b))
{
if (arf_is_zero(d))
return _gr_arf_cmpabs(res, a, c, ctx);
if (arf_is_zero(c))
return _gr_arf_cmpabs(res, a, d, ctx);
if (arf_is_zero(a))
{
*res = -1;
return GR_SUCCESS;
}
}

if (arf_is_zero(a))
{
if (arf_is_zero(d))
return _gr_arf_cmpabs(res, b, c, ctx);
if (arf_is_zero(c))
return _gr_arf_cmpabs(res, b, d, ctx);
}

if (arf_is_zero(c) && arf_is_zero(d))
{
*res = 1;
return GR_SUCCESS;
}

if (ARF_IS_LAGOM(a) && ARF_IS_LAGOM(b) && ARF_IS_LAGOM(c) && ARF_IS_LAGOM(d))
{
slong aexp, bexp, cexp, dexp, xexp, yexp, exp;

aexp = arf_is_zero(a) ? WORD_MIN : ARF_EXP(a);
bexp = arf_is_zero(b) ? WORD_MIN : ARF_EXP(b);
cexp = arf_is_zero(c) ? WORD_MIN : ARF_EXP(c);
dexp = arf_is_zero(d) ? WORD_MIN : ARF_EXP(d);

/* 0.5 * 2^xexp <= |x| < sqrt(2) * 2^xexp */
xexp = FLINT_MAX(aexp, bexp);
/* 0.5 * 2^yexp <= |y| < sqrt(2) * 2^yexp */
yexp = FLINT_MAX(cexp, dexp);

if (xexp + 2 < yexp)
{
*res = -1;
return GR_SUCCESS;
}

if (xexp > yexp + 2)
{
*res = 1;
return GR_SUCCESS;
}

exp = FLINT_MAX(xexp, yexp);

double tt, xx = 0.0, yy = 0.0;
nn_srcptr xp;
slong xn;

if (aexp >= exp - 53)
{
ARF_GET_MPN_READONLY(xp, xn, a);
tt = d_mul_2exp_inrange(xp[xn - 1], aexp - exp - FLINT_BITS);
xx += tt * tt;
}

if (bexp >= exp - 53)
{
ARF_GET_MPN_READONLY(xp, xn, b);
tt = d_mul_2exp_inrange(xp[xn - 1], bexp - exp - FLINT_BITS);
xx += tt * tt;
}

if (cexp >= exp - 53)
{
ARF_GET_MPN_READONLY(xp, xn, c);
tt = d_mul_2exp_inrange(xp[xn - 1], cexp - exp - FLINT_BITS);
yy += tt * tt;
}

if (dexp >= exp - 53)
{
ARF_GET_MPN_READONLY(xp, xn, d);
tt = d_mul_2exp_inrange(xp[xn - 1], dexp - exp - FLINT_BITS);
yy += tt * tt;
}

if (xx < yy * 0.999999)
{
*res = -1;
return GR_SUCCESS;
}

if (xx * 0.999999 > yy)
{
*res = 1;
return GR_SUCCESS;
}
}

arf_struct s[5];

arf_init(s + 0);
arf_init(s + 1);
arf_init(s + 2);
arf_init(s + 3);
arf_init(s + 4);

arf_mul(s + 0, a, a, ARF_PREC_EXACT, ARF_RND_DOWN);
arf_mul(s + 1, b, b, ARF_PREC_EXACT, ARF_RND_DOWN);
arf_mul(s + 2, c, c, ARF_PREC_EXACT, ARF_RND_DOWN);
arf_mul(s + 3, d, d, ARF_PREC_EXACT, ARF_RND_DOWN);
arf_neg(s + 2, s + 2);
arf_neg(s + 3, s + 3);
arf_sum(s + 4, s, 4, 30, ARF_RND_DOWN);

*res = arf_sgn(s + 4);

arf_clear(s + 0);
arf_clear(s + 1);
arf_clear(s + 2);
arf_clear(s + 3);
arf_clear(s + 4);

*res = arf_cmpabs(acf_realref(x), acf_realref(y));
return GR_SUCCESS;
}

Expand Down
7 changes: 3 additions & 4 deletions src/gr/test_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -2904,7 +2904,7 @@ gr_test_ordered_ring_cmpabs(gr_ctx_t R, flint_rand_t state, int test_flags)
status = GR_TEST_FAIL;
}

if (status & GR_DOMAIN && !(status & GR_UNABLE))
if (gr_ctx_is_ordered_ring(R) == T_TRUE && (status & GR_DOMAIN && !(status & GR_UNABLE)))
{
status = GR_TEST_FAIL;
}
Expand Down Expand Up @@ -4315,10 +4315,9 @@ gr_test_ring(gr_ctx_t R, slong iters, int test_flags)
gr_test_iter(R, state, "pow: ui/si/fmpz/fmpq", gr_test_pow_type_variants, iters, test_flags & (~GR_TEST_ALWAYS_ABLE));

if (gr_ctx_is_ordered_ring(R) == T_TRUE)
{
gr_test_iter(R, state, "ordered_ring_cmp", gr_test_ordered_ring_cmp, iters, test_flags);
gr_test_iter(R, state, "ordered_ring_cmpabs", gr_test_ordered_ring_cmpabs, iters, test_flags);
}

gr_test_iter(R, state, "ordered_ring_cmpabs", gr_test_ordered_ring_cmpabs, iters, test_flags);

gr_test_iter(R, state, "numerator_denominator", gr_test_numerator_denominator, iters, test_flags);
gr_test_iter(R, state, "complex_parts", gr_test_complex_parts, iters, test_flags);
Expand Down
9 changes: 9 additions & 0 deletions src/nfloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,11 @@ int nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ct
int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx);
int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);

int nfloat_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
int nfloat_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
int nfloat_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx);


/* Complex numbers */
/* Note: we use the same context data for real and complex rings
(only which_ring and sizeof_elem differ). This allows us to call
Expand Down Expand Up @@ -569,6 +574,10 @@ int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B,
int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);

int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx);
int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx);

#ifdef __cplusplus
}
#endif
Expand Down
Loading

0 comments on commit 05340d2

Please sign in to comment.