Skip to content

Commit

Permalink
Add multiply_add function for sparse matrix linear algebra.
Browse files Browse the repository at this point in the history
  • Loading branch information
l90lpa committed Nov 15, 2024
1 parent fe692b6 commit 40f5b65
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 59 deletions.
43 changes: 40 additions & 3 deletions src/atlas/linalg/sparse/SparseMatrixMultiply.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing,
const Configuration& config);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, const Configuration& config);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing,
const Configuration& config);

class SparseMatrixMultiply {
public:
SparseMatrixMultiply() = default;
Expand All @@ -46,14 +59,34 @@ class SparseMatrixMultiply {

template <typename Matrix, typename SourceView, typename TargetView>
void operator()(const Matrix& matrix, const SourceView& src, TargetView& tgt) const {
sparse_matrix_multiply(matrix, src, tgt, backend());
multiply(matrix, src, tgt);
}

template <typename Matrix, typename SourceView, typename TargetView>
void operator()(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const {
multiply(matrix, src, tgt, indexing);
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt) const {
sparse_matrix_multiply(matrix, src, tgt, backend());
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const {
sparse_matrix_multiply(matrix, src, tgt, indexing, backend());
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt) const {
sparse_matrix_multiply_add(matrix, src, tgt, backend());
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const {
sparse_matrix_multiply_add(matrix, src, tgt, indexing, backend());
}

const sparse::Backend& backend() const { return backend_; }

private:
Expand All @@ -65,8 +98,12 @@ namespace sparse {
// Template class which needs (full or partial) specialization for concrete template parameters
template <typename Backend, Indexing, int Rank, typename SourceValue, typename TargetValue>
struct SparseMatrixMultiply {
static void apply(const SparseMatrix&, const View<SourceValue, Rank>&, View<TargetValue, Rank>&,
const Configuration&) {
static void multiply(const SparseMatrix&, const View<SourceValue, Rank>&, View<TargetValue, Rank>&,
const Configuration&) {
throw_NotImplemented("SparseMatrixMultiply needs a template specialization with the implementation", Here());
}
static void multiply_add(const SparseMatrix&, const View<SourceValue, Rank>&, View<TargetValue, Rank>&,
const Configuration&) {
throw_NotImplemented("SparseMatrixMultiply needs a template specialization with the implementation", Here());
}
};
Expand Down
86 changes: 81 additions & 5 deletions src/atlas/linalg/sparse/SparseMatrixMultiply.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,25 @@ namespace {
template <typename Backend, Indexing indexing>
struct SparseMatrixMultiplyHelper {
template <typename SourceView, typename TargetView>
static void apply( const SparseMatrix& W, const SourceView& src, TargetView& tgt,
static void multiply( const SparseMatrix& W, const SourceView& src, TargetView& tgt,
const eckit::Configuration& config ) {
using SourceValue = const typename std::remove_const<typename SourceView::value_type>::type;
using TargetValue = typename std::remove_const<typename TargetView::value_type>::type;
constexpr int src_rank = introspection::rank<SourceView>();
constexpr int tgt_rank = introspection::rank<TargetView>();
static_assert( src_rank == tgt_rank, "src and tgt need same rank" );
SparseMatrixMultiply<Backend, indexing, src_rank, SourceValue, TargetValue>::apply( W, src, tgt, config );
SparseMatrixMultiply<Backend, indexing, src_rank, SourceValue, TargetValue>::multiply( W, src, tgt, config );
}

template <typename SourceView, typename TargetView>
static void multiply_add( const SparseMatrix& W, const SourceView& src, TargetView& tgt,
const eckit::Configuration& config ) {
using SourceValue = const typename std::remove_const<typename SourceView::value_type>::type;
using TargetValue = typename std::remove_const<typename TargetView::value_type>::type;
constexpr int src_rank = introspection::rank<SourceView>();
constexpr int tgt_rank = introspection::rank<TargetView>();
static_assert( src_rank == tgt_rank, "src and tgt need same rank" );
SparseMatrixMultiply<Backend, indexing, src_rank, SourceValue, TargetValue>::multiply_add( W, src, tgt, config );
}
};

Expand All @@ -53,14 +64,38 @@ void dispatch_sparse_matrix_multiply( const Matrix& matrix, const SourceView& sr
if ( introspection::layout_right( src ) || introspection::layout_right( tgt ) ) {
ATLAS_ASSERT( introspection::layout_right( src ) && introspection::layout_right( tgt ) );
// Override layout with known layout given by introspection
SparseMatrixMultiplyHelper<Backend, linalg::Indexing::layout_right>::apply( matrix, src_v, tgt_v, config );
SparseMatrixMultiplyHelper<Backend, linalg::Indexing::layout_right>::multiply( matrix, src_v, tgt_v, config );
}
else {
if( indexing == Indexing::layout_left ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_left>::multiply( matrix, src_v, tgt_v, config );
}
else if( indexing == Indexing::layout_right ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_right>::multiply( matrix, src_v, tgt_v, config );
}
else {
throw_NotImplemented( "indexing not implemented", Here() );
}
}
}

template <typename Backend, typename Matrix, typename SourceView, typename TargetView>
void dispatch_sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing,
const eckit::Configuration& config ) {
auto src_v = make_view( src );
auto tgt_v = make_view( tgt );

if ( introspection::layout_right( src ) || introspection::layout_right( tgt ) ) {
ATLAS_ASSERT( introspection::layout_right( src ) && introspection::layout_right( tgt ) );
// Override layout with known layout given by introspection
SparseMatrixMultiplyHelper<Backend, linalg::Indexing::layout_right>::multiply_add( matrix, src_v, tgt_v, config );
}
else {
if( indexing == Indexing::layout_left ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_left>::apply( matrix, src_v, tgt_v, config );
SparseMatrixMultiplyHelper<Backend, Indexing::layout_left>::multiply_add( matrix, src_v, tgt_v, config );
}
else if( indexing == Indexing::layout_right ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_right>::apply( matrix, src_v, tgt_v, config );
SparseMatrixMultiplyHelper<Backend, Indexing::layout_right>::multiply_add( matrix, src_v, tgt_v, config );
}
else {
throw_NotImplemented( "indexing not implemented", Here() );
Expand Down Expand Up @@ -111,6 +146,47 @@ void sparse_matrix_multiply( const Matrix& matrix, const SourceView& src, Target
sparse_matrix_multiply( matrix, src, tgt, Indexing::layout_left );
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing,
const eckit::Configuration& config ) {
std::string type = config.getString( "type", sparse::current_backend() );
if ( type == sparse::backend::openmp::type() ) {
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::openmp>( matrix, src, tgt, indexing, config );
}
else if ( type == sparse::backend::eckit_linalg::type() ) {
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::eckit_linalg>( matrix, src, tgt, indexing, config );
}
#if ATLAS_ECKIT_HAVE_ECKIT_585
else if( eckit::linalg::LinearAlgebraSparse::hasBackend(type) ) {
#else
else if( eckit::linalg::LinearAlgebra::hasBackend(type) ) {
#endif
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::eckit_linalg>( matrix, src, tgt, indexing, util::Config("backend",type) );
}
else if ( type == sparse::backend::hicsparse::type() ) {
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::hicsparse>( matrix, src, tgt, indexing, config );
}
else {
throw_NotImplemented( "sparse_matrix_multiply_add cannot be performed with unsupported backend [" + type + "]",
Here() );
}
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, const eckit::Configuration& config ) {
sparse_matrix_multiply_add( matrix, src, tgt, Indexing::layout_left, config );
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing ) {
sparse_matrix_multiply_add( matrix, src, tgt, indexing, sparse::Backend() );
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt ) {
sparse_matrix_multiply_add( matrix, src, tgt, Indexing::layout_left );
}

} // namespace linalg
} // namespace atlas

Expand Down
47 changes: 43 additions & 4 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "SparseMatrixMultiply_EckitLinalg.h"

#include "atlas/array.h"
#include "atlas/library/config.h"

#if ATLAS_ECKIT_HAVE_ECKIT_585
Expand Down Expand Up @@ -64,7 +65,7 @@ const eckit::linalg::LinearAlgebra& eckit_linalg_backend(const Configuration& co

} // namespace

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::apply(
void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {
ATLAS_ASSERT(src.contiguous());
ATLAS_ASSERT(tgt.contiguous());
Expand All @@ -73,7 +74,7 @@ void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, cons
eckit_linalg_backend(config).spmv(W.host_matrix(), v_src, v_tgt);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::apply(
void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply(
const SparseMatrix& W, const View<const double, 2>& src, View<double, 2>& tgt, const Configuration& config) {
ATLAS_ASSERT(src.contiguous());
ATLAS_ASSERT(tgt.contiguous());
Expand All @@ -84,9 +85,47 @@ void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, cons
eckit_linalg_backend(config).spmm(W.host_matrix(), m_src, m_tgt);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::apply(
void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::multiply(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {
SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::apply(W, src, tgt,
SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply(W, src, tgt,
config);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply_add(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {

array::ArrayT<double> tmp(src.shape(0));
auto v_tmp_tmp = array::make_view<double, 1>(tmp);
v_tmp_tmp.assign(0.);
auto v_tmp = make_view(v_tmp_tmp);

SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply(W, src, v_tmp, config);

for (idx_t t = 0; t < tmp.shape(0); ++t) {
tgt(t) += v_tmp(t);
}
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply_add(
const SparseMatrix& W, const View<const double, 2>& src, View<double, 2>& tgt, const Configuration& config) {

array::ArrayT<double> tmp(src.shape(0), src.shape(1));
auto v_tmp_tmp = array::make_view<double, 2>(tmp);
v_tmp_tmp.assign(0.);
auto v_tmp = make_view(v_tmp_tmp);

SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply(W, src, v_tmp, config);

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t k = 0; k < tmp.shape(1); ++k) {
tgt(t, k) += v_tmp(t, k);
}
}
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::multiply_add(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {
SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply_add(W, src, tgt,
config);
}

Expand Down
12 changes: 9 additions & 3 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,26 @@ namespace sparse {

template <>
struct SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double> {
static void apply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
static void multiply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
static void multiply_add(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
};

template <>
struct SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double> {
static void apply(const SparseMatrix&, const View<const double, 2>& src, View<double, 2>& tgt,
static void multiply(const SparseMatrix&, const View<const double, 2>& src, View<double, 2>& tgt,
const Configuration&);
static void multiply_add(const SparseMatrix&, const View<const double, 2>& src, View<double, 2>& tgt,
const Configuration&);
};


template <>
struct SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double> {
static void apply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
static void multiply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
static void multiply_add(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
};

Expand Down
32 changes: 28 additions & 4 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,30 +265,54 @@ void hsSpMM(const SparseMatrix& W, const View<SourceValue, 2>& src, TargetValue
HIC_CALL(hicDeviceSynchronize());
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 1, double const, double>::apply(
void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 1, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 0;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 2, double const, double>::apply(
void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 1, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 1;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 2, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 0;
hsSpMM<Indexing::layout_left>(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 1, double const, double>::apply(
void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 2, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 1;
hsSpMM<Indexing::layout_left>(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 1, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 0;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 2, double const, double>::apply(
void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 1, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 1;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 2, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 0;
hsSpMM<Indexing::layout_right>(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 2, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 1;
hsSpMM<Indexing::layout_right>(W, src, beta, tgt);
}

} // namespace sparse
} // namespace linalg
} // namespace atlas
Loading

0 comments on commit 40f5b65

Please sign in to comment.