Skip to content

Commit

Permalink
Add hicSparse backend to sparse matrix multiply.
Browse files Browse the repository at this point in the history
  • Loading branch information
l90lpa committed Nov 21, 2024
1 parent 427fc01 commit 62303e6
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,8 @@ linalg/sparse/SparseMatrixMultiply.h
linalg/sparse/SparseMatrixMultiply.tcc
linalg/sparse/SparseMatrixMultiply_EckitLinalg.h
linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc
linalg/sparse/SparseMatrixMultiply_HicSparse.h
linalg/sparse/SparseMatrixMultiply_HicSparse.cc
linalg/sparse/SparseMatrixMultiply_OpenMP.h
linalg/sparse/SparseMatrixMultiply_OpenMP.cc
linalg/dense.h
Expand Down
5 changes: 5 additions & 0 deletions src/atlas/linalg/sparse/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ struct eckit_linalg : Backend {
static std::string type() { return "eckit_linalg"; }
eckit_linalg(): Backend(type()) {}
};

struct hicsparse : Backend {
static std::string type() { return "hicsparse"; }
hicsparse(): Backend(type()) {}
};
} // namespace backend


Expand Down
6 changes: 6 additions & 0 deletions src/atlas/linalg/sparse/SparseMatrixMultiply.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ void sparse_matrix_multiply( const Matrix& matrix, const SourceView& src, Target
#endif
sparse::dispatch_sparse_matrix_multiply<sparse::backend::eckit_linalg>( matrix, src, tgt, indexing, util::Config("backend",type) );
}
else if ( type == sparse::backend::hicsparse::type() ) {
sparse::dispatch_sparse_matrix_multiply<sparse::backend::hicsparse>( matrix, src, tgt, indexing, config );
}
else {
throw_NotImplemented( "sparse_matrix_multiply cannot be performed with unsupported backend [" + type + "]",
Here() );
Expand Down Expand Up @@ -160,6 +163,9 @@ void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, Ta
#endif
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::eckit_linalg>( matrix, src, tgt, indexing, util::Config("backend",type) );
}
else if ( type == sparse::backend::hicsparse::type() ) {
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::hicsparse>( matrix, src, tgt, indexing, config );
}
else {
throw_NotImplemented( "sparse_matrix_multiply_add cannot be performed with unsupported backend [" + type + "]",
Here() );
Expand Down
324 changes: 324 additions & 0 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
/*
* (C) Copyright 2024 ECMWF.
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
* In applying this licence, ECMWF does not waive the privileges and immunities
* granted to it by virtue of its status as an intergovernmental organisation
* nor does it submit to any jurisdiction.
*/

#include "atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.h"

#include "atlas/parallel/omp/omp.h"
#include "atlas/runtime/Exception.h"

#include "hic/hic.h"
#include "hic/hic_library_types.h"
#include "hic/hicsparse.h"

namespace {

class HicSparseHandleRAIIWrapper {
public:
HicSparseHandleRAIIWrapper() { hicsparseCreate(&handle_); };
~HicSparseHandleRAIIWrapper() { hicsparseDestroy(handle_); }
hicsparseHandle_t value() { return handle_; }
private:
hicsparseHandle_t handle_;
};

hicsparseHandle_t getDefaultHicSparseHandle() {
static auto handle = HicSparseHandleRAIIWrapper();
return handle.value();
}

template<typename T>
constexpr hicsparseIndexType_t getHicsparseIndexType() {
using base_type = std::remove_const_t<T>;
if constexpr (std::is_same_v<base_type, int>) {
return HICSPARSE_INDEX_32I;
} else {
static_assert(std::is_same_v<base_type, long>, "Unsupported index type");
return HICSPARSE_INDEX_64I;
}
}

template<typename T>
constexpr auto getHicsparseValueType() {
using base_type = std::remove_const_t<T>;
if constexpr (std::is_same_v<base_type, float>) {
return HIC_R_32F;
} else {
static_assert(std::is_same_v<base_type, double>, "Unsupported value type");\
return HIC_R_64F;
}
}

template<atlas::linalg::Indexing IndexLayout, typename T>
hicsparseOrder_t getHicsparseOrder(const atlas::linalg::View<T, 2>& v) {
constexpr int row_idx = (IndexLayout == atlas::linalg::Indexing::layout_left) ? 0 : 1;
constexpr int col_idx = (IndexLayout == atlas::linalg::Indexing::layout_left) ? 1 : 0;

if (v.stride(row_idx) == 1) {
return HICSPARSE_ORDER_COL;
} else if (v.stride(col_idx) == 1) {
return HICSPARSE_ORDER_ROW;
} else {
atlas::throw_Exception("Unsupported dense matrix memory order", Here());
return HICSPARSE_ORDER_COL;
}
}

template<typename T>
int64_t getLeadingDimension(const atlas::linalg::View<T, 2>& v) {
if (v.stride(0) == 1) {
return v.stride(1);
} else if (v.stride(1) == 1) {
return v.stride(0);
} else {
atlas::throw_Exception("Unsupported dense matrix memory order", Here());
return 0;
}
}

}

namespace atlas {
namespace linalg {
namespace sparse {

template <typename SourceValue, typename TargetValue>
void hsSpMV(const SparseMatrix& W, const View<SourceValue, 1>& src, TargetValue beta, View<TargetValue, 1>& tgt) {
// Assume that src and tgt are device views

ATLAS_ASSERT(src.shape(0) >= W.cols());
ATLAS_ASSERT(tgt.shape(0) >= W.rows());

// Check if W is on the device and if not, copy it to the device
if (W.deviceNeedsUpdate()) {
W.updateDevice();
}

auto handle = getDefaultHicSparseHandle();

// Create a sparse matrix descriptor
hicsparseConstSpMatDescr_t matA;
HICSPARSE_CALL(hicsparseCreateConstCsr(
&matA,
W.rows(), W.cols(), W.nonZeros(),
W.device_outer(), // row_offsets
W.device_inner(), // column_indices
W.device_data(), // values
getHicsparseIndexType<SparseMatrix::Index>(),
getHicsparseIndexType<SparseMatrix::Index>(),
HICSPARSE_INDEX_BASE_ZERO,
getHicsparseValueType<SparseMatrix::Scalar>()));

// Create dense matrix descriptors
hicsparseConstDnVecDescr_t vecX;
HICSPARSE_CALL(hicsparseCreateConstDnVec(
&vecX,
static_cast<int64_t>(W.cols()),
src.data(),
getHicsparseValueType<typename View<SourceValue, 1>::value_type>()));

hicsparseDnVecDescr_t vecY;
HICSPARSE_CALL(hicsparseCreateDnVec(
&vecY,
W.rows(),
tgt.data(),
getHicsparseValueType<typename View<TargetValue, 1>::value_type>()));

using ComputeType = typename View<TargetValue, 1>::value_type;
constexpr auto compute_type = getHicsparseValueType<ComputeType>();

ComputeType alpha = 1;

// Determine buffer size
size_t bufferSize = 0;
HICSPARSE_CALL(hicsparseSpMV_bufferSize(
handle,
HICSPARSE_OPERATION_NON_TRANSPOSE,
&alpha,
matA,
vecX,
&beta,
vecY,
compute_type,
HICSPARSE_SPMV_ALG_DEFAULT,
&bufferSize));

// Allocate buffer
char* buffer;
HIC_CALL(hicMalloc(&buffer, bufferSize));

// Perform SpMV
HICSPARSE_CALL(hicsparseSpMV(
handle,
HICSPARSE_OPERATION_NON_TRANSPOSE,
&alpha,
matA,
vecX,
&beta,
vecY,
compute_type,
HICSPARSE_SPMV_ALG_DEFAULT,
buffer));

HIC_CALL(hicFree(buffer));
HICSPARSE_CALL(hicsparseDestroyDnVec(vecX));
HICSPARSE_CALL(hicsparseDestroyDnVec(vecY));
HICSPARSE_CALL(hicsparseDestroySpMat(matA));

HIC_CALL(hicDeviceSynchronize());
}


template <Indexing IndexLayout, typename SourceValue, typename TargetValue>
void hsSpMM(const SparseMatrix& W, const View<SourceValue, 2>& src, TargetValue beta, View<TargetValue, 2>& tgt) {
// Assume that src and tgt are device views

constexpr int row_idx = (IndexLayout == Indexing::layout_left) ? 0 : 1;
constexpr int col_idx = (IndexLayout == Indexing::layout_left) ? 1 : 0;

ATLAS_ASSERT(src.shape(row_idx) >= W.cols());
ATLAS_ASSERT(tgt.shape(row_idx) >= W.rows());
ATLAS_ASSERT(src.shape(col_idx) == tgt.shape(col_idx));

// Check if W is on the device and if not, copy it to the device
if (W.deviceNeedsUpdate()) {
W.updateDevice();
}

auto handle = getDefaultHicSparseHandle();

// Create a sparse matrix descriptor
hicsparseConstSpMatDescr_t matA;
HICSPARSE_CALL(hicsparseCreateConstCsr(
&matA,
W.rows(), W.cols(), W.nonZeros(),
W.device_outer(), // row_offsets
W.device_inner(), // column_indices
W.device_data(), // values
getHicsparseIndexType<SparseMatrix::Index>(),
getHicsparseIndexType<SparseMatrix::Index>(),
HICSPARSE_INDEX_BASE_ZERO,
getHicsparseValueType<SparseMatrix::Scalar>()));

// Create dense matrix descriptors
hicsparseConstDnMatDescr_t matB;
HICSPARSE_CALL(hicsparseCreateConstDnMat(
&matB,
W.cols(), src.shape(col_idx),
getLeadingDimension(src),
src.data(),
getHicsparseValueType<typename View<SourceValue, 2>::value_type>(),
getHicsparseOrder<IndexLayout>(src)));

hicsparseDnMatDescr_t matC;
HICSPARSE_CALL(hicsparseCreateDnMat(
&matC,
W.rows(), tgt.shape(col_idx),
getLeadingDimension(tgt),
tgt.data(),
getHicsparseValueType<typename View<TargetValue, 2>::value_type>(),
getHicsparseOrder<IndexLayout>(tgt)));

using ComputeType = typename View<TargetValue, 2>::value_type;
constexpr auto compute_type = getHicsparseValueType<ComputeType>();

ComputeType alpha = 1;

// Determine buffer size
size_t bufferSize = 0;
HICSPARSE_CALL(hicsparseSpMM_bufferSize(
handle,
HICSPARSE_OPERATION_NON_TRANSPOSE,
HICSPARSE_OPERATION_NON_TRANSPOSE,
&alpha,
matA,
matB,
&beta,
matC,
compute_type,
HICSPARSE_SPMM_ALG_DEFAULT,
&bufferSize));

// Allocate buffer
char* buffer;
HIC_CALL(hicMalloc(&buffer, bufferSize));

// Perform SpMM
HICSPARSE_CALL(hicsparseSpMM(
handle,
HICSPARSE_OPERATION_NON_TRANSPOSE,
HICSPARSE_OPERATION_NON_TRANSPOSE,
&alpha,
matA,
matB,
&beta,
matC,
compute_type,
HICSPARSE_SPMM_ALG_DEFAULT,
buffer));

HIC_CALL(hicFree(buffer));
HICSPARSE_CALL(hicsparseDestroyDnMat(matC));
HICSPARSE_CALL(hicsparseDestroyDnMat(matB));
HICSPARSE_CALL(hicsparseDestroySpMat(matA));

HIC_CALL(hicDeviceSynchronize());
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 1, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 0;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 1, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 1;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 2, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 0;
hsSpMM<Indexing::layout_left>(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_left, 2, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 1;
hsSpMM<Indexing::layout_left>(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 1, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 0;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 1, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 1>& src, View<double, 1>& tgt, const Configuration&) {
double beta = 1;
hsSpMV(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 2, double const, double>::multiply(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 0;
hsSpMM<Indexing::layout_right>(W, src, beta, tgt);
}

void SparseMatrixMultiply<backend::hicsparse, Indexing::layout_right, 2, double const, double>::multiply_add(
const SparseMatrix& W, const View<double const, 2>& src, View<double, 2>& tgt, const Configuration&) {
double beta = 1;
hsSpMM<Indexing::layout_right>(W, src, beta, tgt);
}

} // namespace sparse
} // namespace linalg
} // namespace atlas
Loading

0 comments on commit 62303e6

Please sign in to comment.