Skip to content

Commit

Permalink
Merge pull request #2062 from fredrik-johansson/matmul
Browse files Browse the repository at this point in the history
Fixed-point matrix multiplication improvements
  • Loading branch information
fredrik-johansson authored Sep 6, 2024
2 parents 06c1720 + 2737830 commit 6c38679
Show file tree
Hide file tree
Showing 27 changed files with 4,073 additions and 1,003 deletions.
59 changes: 55 additions & 4 deletions doc/source/nfloat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ code for reduced overhead.
Matrix functions
-------------------------------------------------------------------------------

.. function:: int nfloat_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
.. function:: int nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_prec, gr_ctx_t ctx)
int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx)
int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)

Expand Down Expand Up @@ -416,11 +415,63 @@ real pairs.
int _nfloat_complex_vec_set(nfloat_complex_ptr res, nfloat_complex_srcptr x, slong len, gr_ctx_t ctx)
int _nfloat_complex_vec_add(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx)
int _nfloat_complex_vec_sub(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx)
int nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_prec, gr_ctx_t ctx)
int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx)
int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx)
int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx)

Packed fixed-point arithmetic
-------------------------------------------------------------------------------

A fixed-point number in the range `(-1,1)` with `n`-limb precision
is represented as `n+1` contiguous limbs as follows:

+---------------+
| sign limb |
+---------------+
| mantissa[0] |
+---------------+
| ... |
+---------------+
| mantissa[n-1] |
+---------------+

In the following method signatures, ``nlimbs`` always refers to the
precision ``n`` while the storage is ``nlimbs + 1``.

There is no overflow handling: all methods assume that inputs have
been scaled to a range `[-\varepsilon,\varepsilon]` so that all
intermediate results (including rounding errors) lie in `(-1,1)`.

.. function:: void _nfixed_print(nn_srcptr x, slong nlimbs, slong exp)

Print the fixed-point number

.. function:: void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs)
void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs)

Vectorized addition or subtraction of *len* fixed-point numbers.

.. function:: void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)
void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)
void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)
void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)
void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)
void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)
void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len)

Dot product with a fixed number of limbs. The ``xstride`` and ``ystride`` parameters
indicate the offset in number of limbs between consecutive entries
and may be negative.

.. function:: void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs)
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)

Matrix multiplication using various algorithms.
The *strassen* variant takes a *cutoff* parameter specifying where
to switch from basecase multiplication to Strassen multiplication.
Loading

0 comments on commit 6c38679

Please sign in to comment.