Skip to content

Commit

Permalink
Cache hicSparse handle to avoid library re-initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
l90lpa committed Nov 15, 2024
1 parent 862836a commit f23b744
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_HicSparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@

namespace {

class HicSparseHandleRAIIWrapper {
public:
HicSparseHandleRAIIWrapper() { hicsparseCreate(&handle_); };
~HicSparseHandleRAIIWrapper() { hicsparseDestroy(handle_); }
hicsparseHandle_t value() { return handle_; }
private:
hicsparseHandle_t handle_;
};

hicsparseHandle_t getDefaultHicSparseHandle() {
static auto handle = HicSparseHandleRAIIWrapper();
return handle.value();
}

template<typename T>
constexpr hicsparseIndexType_t getHicsparseIndexType() {
using base_type = std::remove_const_t<T>;
Expand Down Expand Up @@ -86,10 +100,7 @@ void hsSpMV(const SparseMatrix& W, const View<SourceValue, 1>& src, TargetValue
W.updateDevice();
}

// Create sparse library handle
// todo: use singleton class for storing hicSparse library handle.
hicsparseHandle_t handle;
HICSPARSE_CALL(hicsparseCreate(&handle));
auto handle = getDefaultHicSparseHandle();

// Create a sparse matrix descriptor
hicsparseConstSpMatDescr_t matA;
Expand Down Expand Up @@ -159,7 +170,6 @@ void hsSpMV(const SparseMatrix& W, const View<SourceValue, 1>& src, TargetValue
HICSPARSE_CALL(hicsparseDestroyDnVec(vecX));
HICSPARSE_CALL(hicsparseDestroyDnVec(vecY));
HICSPARSE_CALL(hicsparseDestroySpMat(matA));
HICSPARSE_CALL(hicsparseDestroy(handle));

HIC_CALL(hicDeviceSynchronize());
}
Expand All @@ -181,10 +191,7 @@ void hsSpMM(const SparseMatrix& W, const View<SourceValue, 2>& src, TargetValue
W.updateDevice();
}

// Create sparse library handle
// todo: use singleton class for storing hicSparse library handle.
hicsparseHandle_t handle;
HICSPARSE_CALL(hicsparseCreate(&handle));
auto handle = getDefaultHicSparseHandle();

// Create a sparse matrix descriptor
hicsparseConstSpMatDescr_t matA;
Expand Down Expand Up @@ -260,7 +267,6 @@ void hsSpMM(const SparseMatrix& W, const View<SourceValue, 2>& src, TargetValue
HICSPARSE_CALL(hicsparseDestroyDnMat(matC));
HICSPARSE_CALL(hicsparseDestroyDnMat(matB));
HICSPARSE_CALL(hicsparseDestroySpMat(matA));
HICSPARSE_CALL(hicsparseDestroy(handle));

HIC_CALL(hicDeviceSynchronize());
}
Expand Down

0 comments on commit f23b744

Please sign in to comment.