From 6b02154fdfd4d5da612a35def2f24445ad82fd37 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Tue, 2 Apr 2024 15:59:08 +0200 Subject: [PATCH 1/2] add Csr lookup implementation without storage --- common/cuda_hip/matrix/csr_kernels.hpp.inc | 7 ++++ common/unified/matrix/csr_kernels.cpp | 4 +- core/matrix/csr_lookup.hpp | 47 ++++++++++++++++++++-- dpcpp/matrix/csr_kernels.dp.cpp | 10 ++++- omp/matrix/csr_kernels.cpp | 9 ++++- reference/matrix/csr_kernels.cpp | 14 +++++-- reference/test/matrix/csr_kernels.cpp | 30 ++++++++------ test/matrix/csr_kernels.cpp | 3 +- 8 files changed, 100 insertions(+), 24 deletions(-) diff --git a/common/cuda_hip/matrix/csr_kernels.hpp.inc b/common/cuda_hip/matrix/csr_kernels.hpp.inc index 2f6682b9743..324248d4265 100644 --- a/common/cuda_hip/matrix/csr_kernels.hpp.inc +++ b/common/cuda_hip/matrix/csr_kernels.hpp.inc @@ -1140,6 +1140,13 @@ __global__ __launch_bounds__(default_block_size) void build_csr_lookup( } return; } + // if hash lookup is not allowed, we are done here + if (!csr_lookup_allowed(allowed, sparsity_type::hash)) { + if (lane == 0) { + row_desc[row] = static_cast(sparsity_type::none); + } + return; + } // sparse hashmap storage // we need at least one unfilled entry to avoid infinite loops on search GKO_ASSERT(row_len < available_storage); diff --git a/common/unified/matrix/csr_kernels.cpp b/common/unified/matrix/csr_kernels.cpp index be42ec9fa50..761aefebb82 100644 --- a/common/unified/matrix/csr_kernels.cpp +++ b/common/unified/matrix/csr_kernels.cpp @@ -264,8 +264,10 @@ void build_lookup_offsets(std::shared_ptr exec, if (csr_lookup_allowed(allowed, sparsity_type::bitmap) && bitmap_storage <= hashmap_storage) { storage_offsets[row] = bitmap_storage; - } else { + } else if (csr_lookup_allowed(allowed, sparsity_type::hash)) { storage_offsets[row] = hashmap_storage; + } else { + storage_offsets[row] = 0; } } }, diff --git a/core/matrix/csr_lookup.hpp b/core/matrix/csr_lookup.hpp index faa4aef7a61..52045bc113b 100644 --- a/core/matrix/csr_lookup.hpp +++ b/core/matrix/csr_lookup.hpp @@ -25,6 +25,11 @@ namespace csr { * a single bit set. */ enum class sparsity_type { + /** + * The row has no precomputed lookup-information associated with it, so the + * nonzero location needs to be located from the column indices explicitly. + */ + none = 0, /** * The row is dense, i.e. it contains all entries in * `[min_col, min_col + storage_size)`. @@ -148,9 +153,9 @@ struct device_sparsity_lookup { return lookup_bitmap(col); case sparsity_type::hash: return lookup_hash(col); + default: + return lookup_search(col); } - GKO_ASSERT(false); - return invalid_index(); } /** @@ -176,7 +181,8 @@ struct device_sparsity_lookup { result = lookup_hash_unsafe(col); break; default: - GKO_ASSERT(false); + result = lookup_search_unsafe(col); + break; } GKO_ASSERT(local_cols[result] == col); return result; @@ -290,6 +296,41 @@ struct device_sparsity_lookup { // out_idx is either correct or invalid_index, the hashmap sentinel return out_idx; } + + GKO_ATTRIBUTES GKO_INLINE IndexType + lookup_search_unsafe(IndexType col) const + { + // binary search through the column indices + auto length = row_nnz; + IndexType offset{}; + while (length > 0) { + auto half_length = length / 2; + auto mid = offset + half_length; + // this finds the first index with column index >= col + auto pred = local_cols[mid] >= col; + length = pred ? half_length : length - (half_length + 1); + offset = pred ? offset : mid + 1; + } + return offset; + } + + GKO_ATTRIBUTES GKO_INLINE IndexType lookup_search(IndexType col) const + { + // binary search through the column indices + auto length = row_nnz; + IndexType offset{}; + while (length > 0) { + auto half_length = length / 2; + auto mid = offset + half_length; + // this finds the first index with column index >= col + auto pred = local_cols[mid] >= col; + length = pred ? half_length : length - (half_length + 1); + offset = pred ? offset : mid + 1; + } + return offset < row_nnz && local_cols[offset] == col + ? offset + : invalid_index(); + } }; diff --git a/dpcpp/matrix/csr_kernels.dp.cpp b/dpcpp/matrix/csr_kernels.dp.cpp index 773fa67b8a4..c8c8e898563 100644 --- a/dpcpp/matrix/csr_kernels.dp.cpp +++ b/dpcpp/matrix/csr_kernels.dp.cpp @@ -2799,6 +2799,7 @@ void build_lookup(std::shared_ptr exec, const IndexType* storage_offsets, int64* row_desc, int32* storage) { + using matrix::csr::sparsity_type; exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) { const auto row = static_cast(idx[0]); @@ -2820,8 +2821,13 @@ void build_lookup(std::shared_ptr exec, row_desc[row], local_storage, local_cols); } if (!done) { - csr_lookup_build_hash(row_len, available_storage, row_desc[row], - local_storage, local_cols); + if (csr_lookup_allowed(allowed, sparsity_type::hash)) { + csr_lookup_build_hash(row_len, available_storage, + row_desc[row], local_storage, + local_cols); + } else { + row_desc[row] = static_cast(sparsity_type::none); + } } }); }); diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 6ff364de7eb..70df9f07944 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -1367,6 +1367,7 @@ void build_lookup(std::shared_ptr exec, const IndexType* storage_offsets, int64* row_desc, int32* storage) { + using matrix::csr::sparsity_type; #pragma omp parallel for for (size_type row = 0; row < num_rows; row++) { const auto row_begin = row_ptrs[row]; @@ -1386,8 +1387,12 @@ void build_lookup(std::shared_ptr exec, row_desc[row], local_storage, local_cols); } if (!done) { - csr_lookup_build_hash(row_len, available_storage, row_desc[row], - local_storage, local_cols); + if (csr_lookup_allowed(allowed, sparsity_type::hash)) { + csr_lookup_build_hash(row_len, available_storage, row_desc[row], + local_storage, local_cols); + } else { + row_desc[row] = static_cast(sparsity_type::none); + } } } } diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index 4cb7124d3c4..711efdc9175 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -31,6 +31,7 @@ #include "core/components/prefix_sum_kernels.hpp" #include "core/matrix/csr_accessor_helper.hpp" #include "core/matrix/csr_builder.hpp" +#include "core/matrix/csr_lookup.hpp" #include "reference/components/csr_spgeam.hpp" @@ -1297,8 +1298,10 @@ void build_lookup_offsets(std::shared_ptr exec, if (csr_lookup_allowed(allowed, sparsity_type::bitmap) && bitmap_storage <= hashmap_storage) { storage_offsets[row] = bitmap_storage; - } else { + } else if (csr_lookup_allowed(allowed, sparsity_type::hash)) { storage_offsets[row] = hashmap_storage; + } else { + storage_offsets[row] = 0; } } } @@ -1397,6 +1400,7 @@ void build_lookup(std::shared_ptr exec, const IndexType* storage_offsets, int64* row_desc, int32* storage) { + using matrix::csr::sparsity_type; for (size_type row = 0; row < num_rows; row++) { const auto row_begin = row_ptrs[row]; const auto row_len = row_ptrs[row + 1] - row_begin; @@ -1415,8 +1419,12 @@ void build_lookup(std::shared_ptr exec, row_desc[row], local_storage, local_cols); } if (!done) { - csr_lookup_build_hash(row_len, available_storage, row_desc[row], - local_storage, local_cols); + if (csr_lookup_allowed(allowed, sparsity_type::hash)) { + csr_lookup_build_hash(row_len, available_storage, row_desc[row], + local_storage, local_cols); + } else { + row_desc[row] = static_cast(sparsity_type::none); + } } } } diff --git a/reference/test/matrix/csr_kernels.cpp b/reference/test/matrix/csr_kernels.cpp index 11c3b521d76..a206c8c40c2 100644 --- a/reference/test/matrix/csr_kernels.cpp +++ b/reference/test/matrix/csr_kernels.cpp @@ -2598,6 +2598,7 @@ TYPED_TEST_SUITE(CsrLookup, gko::test::ValueIndexTypes, TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets) { using IndexType = typename TestFixture::index_type; + using gko::matrix::csr::csr_lookup_allowed; using gko::matrix::csr::sparsity_type; const auto num_rows = this->mtx->get_size()[0]; gko::array storage_offset_array(this->exec, num_rows + 1); @@ -2608,19 +2609,19 @@ TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets) for (auto allowed : {sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash, sparsity_type::bitmap | sparsity_type::hash, - sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) { + sparsity_type::full | sparsity_type::hash, sparsity_type::hash, + sparsity_type::none}) { gko::kernels::reference::csr::build_lookup_offsets( this->exec, row_ptrs, col_idxs, num_rows, allowed, storage_offsets); - bool allow_full = - gko::matrix::csr::csr_lookup_allowed(allowed, sparsity_type::full); - bool allow_bitmap = gko::matrix::csr::csr_lookup_allowed( - allowed, sparsity_type::bitmap); + bool allow_full = csr_lookup_allowed(allowed, sparsity_type::full); + bool allow_bitmap = csr_lookup_allowed(allowed, sparsity_type::bitmap); + bool allow_hash = csr_lookup_allowed(allowed, sparsity_type::hash); for (gko::size_type row = 0; row < num_rows; row++) { - const auto expected_size = - std::min(allow_full ? this->full_sizes[row] : 1000, - std::min(allow_bitmap ? this->bitmap_sizes[row] : 1000, - this->hash_sizes[row])); + const auto expected_size = std::min( + allow_full ? this->full_sizes[row] : 1000, + std::min(allow_bitmap ? this->bitmap_sizes[row] : 1000, + allow_hash ? this->hash_sizes[row] : IndexType{})); const auto size = storage_offsets[row + 1] - storage_offsets[row]; ASSERT_EQ(size, expected_size); @@ -2644,16 +2645,21 @@ TYPED_TEST(CsrLookup, GeneratesLookupData) for (auto allowed : {sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash, sparsity_type::bitmap | sparsity_type::hash, - sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) { + sparsity_type::full | sparsity_type::hash, sparsity_type::hash, + sparsity_type::none}) { gko::kernels::reference::csr::build_lookup_offsets( this->exec, row_ptrs, col_idxs, num_rows, allowed, storage_offsets); gko::array storage_array(this->exec, storage_offsets[num_rows]); const auto storage = storage_array.get_data(); + const auto hash_equivalent = + csr_lookup_allowed(allowed, sparsity_type::hash) + ? sparsity_type::hash + : sparsity_type::none; const auto bitmap_equivalent = csr_lookup_allowed(allowed, sparsity_type::bitmap) ? sparsity_type::bitmap - : sparsity_type::hash; + : hash_equivalent; const auto full_equivalent = csr_lookup_allowed(allowed, sparsity_type::full) ? sparsity_type::full @@ -2687,7 +2693,7 @@ TYPED_TEST(CsrLookup, GeneratesLookupData) ASSERT_EQ(row_descs[0] & 0xF, static_cast(full_equivalent)); ASSERT_EQ(row_descs[1] & 0xF, static_cast(full_equivalent)); ASSERT_EQ(row_descs[2] & 0xF, static_cast(bitmap_equivalent)); - ASSERT_EQ(row_descs[3] & 0xF, static_cast(sparsity_type::hash)); + ASSERT_EQ(row_descs[3] & 0xF, static_cast(hash_equivalent)); ASSERT_EQ(row_descs[4] & 0xF, static_cast(full_equivalent)); ASSERT_EQ(row_descs[5] & 0xF, static_cast(full_equivalent)); } diff --git a/test/matrix/csr_kernels.cpp b/test/matrix/csr_kernels.cpp index a597234382c..347425175bb 100644 --- a/test/matrix/csr_kernels.cpp +++ b/test/matrix/csr_kernels.cpp @@ -209,7 +209,8 @@ TYPED_TEST(CsrLookup, BuildLookupWorks) for (auto allowed : {sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash, sparsity_type::bitmap | sparsity_type::hash, - sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) { + sparsity_type::full | sparsity_type::hash, sparsity_type::hash, + sparsity_type::none}) { // check that storage offsets are calculated correctly // otherwise things might crash gko::kernels::reference::csr::build_lookup_offsets( From 0a1cdb4e6e6ea0427661844c28efa1240a510a73 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Mon, 15 Apr 2024 07:49:01 +0200 Subject: [PATCH 2/2] simplify lookup --- core/matrix/csr_lookup.hpp | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/core/matrix/csr_lookup.hpp b/core/matrix/csr_lookup.hpp index 52045bc113b..6de3265ff21 100644 --- a/core/matrix/csr_lookup.hpp +++ b/core/matrix/csr_lookup.hpp @@ -316,17 +316,7 @@ struct device_sparsity_lookup { GKO_ATTRIBUTES GKO_INLINE IndexType lookup_search(IndexType col) const { - // binary search through the column indices - auto length = row_nnz; - IndexType offset{}; - while (length > 0) { - auto half_length = length / 2; - auto mid = offset + half_length; - // this finds the first index with column index >= col - auto pred = local_cols[mid] >= col; - length = pred ? half_length : length - (half_length + 1); - offset = pred ? offset : mid + 1; - } + const auto offset = lookup_search_unsafe(col); return offset < row_nnz && local_cols[offset] == col ? offset : invalid_index();