Skip to content

Commit

Permalink
add Csr lookup implementation without storage
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Apr 15, 2024
1 parent 577caef commit 6b02154
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 24 deletions.
7 changes: 7 additions & 0 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,13 @@ __global__ __launch_bounds__(default_block_size) void build_csr_lookup(
}
return;
}
// if hash lookup is not allowed, we are done here
if (!csr_lookup_allowed(allowed, sparsity_type::hash)) {
if (lane == 0) {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
return;
}
// sparse hashmap storage
// we need at least one unfilled entry to avoid infinite loops on search
GKO_ASSERT(row_len < available_storage);
Expand Down
4 changes: 3 additions & 1 deletion common/unified/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,10 @@ void build_lookup_offsets(std::shared_ptr<const DefaultExecutor> exec,
if (csr_lookup_allowed(allowed, sparsity_type::bitmap) &&
bitmap_storage <= hashmap_storage) {
storage_offsets[row] = bitmap_storage;
} else {
} else if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
storage_offsets[row] = hashmap_storage;
} else {
storage_offsets[row] = 0;
}
}
},
Expand Down
47 changes: 44 additions & 3 deletions core/matrix/csr_lookup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ namespace csr {
* a single bit set.
*/
enum class sparsity_type {
/**
* The row has no precomputed lookup-information associated with it, so the
* nonzero location needs to be located from the column indices explicitly.
*/
none = 0,
/**
* The row is dense, i.e. it contains all entries in
* `[min_col, min_col + storage_size)`.
Expand Down Expand Up @@ -148,9 +153,9 @@ struct device_sparsity_lookup {
return lookup_bitmap(col);
case sparsity_type::hash:
return lookup_hash(col);
default:
return lookup_search(col);
}
GKO_ASSERT(false);
return invalid_index<IndexType>();
}

/**
Expand All @@ -176,7 +181,8 @@ struct device_sparsity_lookup {
result = lookup_hash_unsafe(col);
break;
default:
GKO_ASSERT(false);
result = lookup_search_unsafe(col);
break;
}
GKO_ASSERT(local_cols[result] == col);
return result;
Expand Down Expand Up @@ -290,6 +296,41 @@ struct device_sparsity_lookup {
// out_idx is either correct or invalid_index, the hashmap sentinel
return out_idx;
}

GKO_ATTRIBUTES GKO_INLINE IndexType
lookup_search_unsafe(IndexType col) const
{
// binary search through the column indices
auto length = row_nnz;
IndexType offset{};
while (length > 0) {
auto half_length = length / 2;
auto mid = offset + half_length;
// this finds the first index with column index >= col
auto pred = local_cols[mid] >= col;
length = pred ? half_length : length - (half_length + 1);
offset = pred ? offset : mid + 1;
}
return offset;
}

GKO_ATTRIBUTES GKO_INLINE IndexType lookup_search(IndexType col) const
{
// binary search through the column indices
auto length = row_nnz;
IndexType offset{};
while (length > 0) {
auto half_length = length / 2;
auto mid = offset + half_length;
// this finds the first index with column index >= col
auto pred = local_cols[mid] >= col;
length = pred ? half_length : length - (half_length + 1);
offset = pred ? offset : mid + 1;
}
return offset < row_nnz && local_cols[offset] == col
? offset
: invalid_index<IndexType>();
}
};


Expand Down
10 changes: 8 additions & 2 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,7 @@ void build_lookup(std::shared_ptr<const DpcppExecutor> exec,
const IndexType* storage_offsets, int64* row_desc,
int32* storage)
{
using matrix::csr::sparsity_type;
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) {
const auto row = static_cast<size_type>(idx[0]);
Expand All @@ -2820,8 +2821,13 @@ void build_lookup(std::shared_ptr<const DpcppExecutor> exec,
row_desc[row], local_storage, local_cols);
}
if (!done) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
csr_lookup_build_hash(row_len, available_storage,
row_desc[row], local_storage,
local_cols);
} else {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
}
});
});
Expand Down
9 changes: 7 additions & 2 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,7 @@ void build_lookup(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* storage_offsets, int64* row_desc,
int32* storage)
{
using matrix::csr::sparsity_type;
#pragma omp parallel for
for (size_type row = 0; row < num_rows; row++) {
const auto row_begin = row_ptrs[row];
Expand All @@ -1386,8 +1387,12 @@ void build_lookup(std::shared_ptr<const DefaultExecutor> exec,
row_desc[row], local_storage, local_cols);
}
if (!done) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
} else {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
}
}
}
Expand Down
14 changes: 11 additions & 3 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "core/components/prefix_sum_kernels.hpp"
#include "core/matrix/csr_accessor_helper.hpp"
#include "core/matrix/csr_builder.hpp"
#include "core/matrix/csr_lookup.hpp"
#include "reference/components/csr_spgeam.hpp"


Expand Down Expand Up @@ -1297,8 +1298,10 @@ void build_lookup_offsets(std::shared_ptr<const ReferenceExecutor> exec,
if (csr_lookup_allowed(allowed, sparsity_type::bitmap) &&
bitmap_storage <= hashmap_storage) {
storage_offsets[row] = bitmap_storage;
} else {
} else if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
storage_offsets[row] = hashmap_storage;
} else {
storage_offsets[row] = 0;
}
}
}
Expand Down Expand Up @@ -1397,6 +1400,7 @@ void build_lookup(std::shared_ptr<const ReferenceExecutor> exec,
const IndexType* storage_offsets, int64* row_desc,
int32* storage)
{
using matrix::csr::sparsity_type;
for (size_type row = 0; row < num_rows; row++) {
const auto row_begin = row_ptrs[row];
const auto row_len = row_ptrs[row + 1] - row_begin;
Expand All @@ -1415,8 +1419,12 @@ void build_lookup(std::shared_ptr<const ReferenceExecutor> exec,
row_desc[row], local_storage, local_cols);
}
if (!done) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
} else {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
}
}
}
Expand Down
30 changes: 18 additions & 12 deletions reference/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2598,6 +2598,7 @@ TYPED_TEST_SUITE(CsrLookup, gko::test::ValueIndexTypes,
TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets)
{
using IndexType = typename TestFixture::index_type;
using gko::matrix::csr::csr_lookup_allowed;
using gko::matrix::csr::sparsity_type;
const auto num_rows = this->mtx->get_size()[0];
gko::array<IndexType> storage_offset_array(this->exec, num_rows + 1);
Expand All @@ -2608,19 +2609,19 @@ TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets)
for (auto allowed :
{sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) {
sparsity_type::full | sparsity_type::hash, sparsity_type::hash,
sparsity_type::none}) {
gko::kernels::reference::csr::build_lookup_offsets(
this->exec, row_ptrs, col_idxs, num_rows, allowed, storage_offsets);
bool allow_full =
gko::matrix::csr::csr_lookup_allowed(allowed, sparsity_type::full);
bool allow_bitmap = gko::matrix::csr::csr_lookup_allowed(
allowed, sparsity_type::bitmap);
bool allow_full = csr_lookup_allowed(allowed, sparsity_type::full);
bool allow_bitmap = csr_lookup_allowed(allowed, sparsity_type::bitmap);
bool allow_hash = csr_lookup_allowed(allowed, sparsity_type::hash);

for (gko::size_type row = 0; row < num_rows; row++) {
const auto expected_size =
std::min(allow_full ? this->full_sizes[row] : 1000,
std::min(allow_bitmap ? this->bitmap_sizes[row] : 1000,
this->hash_sizes[row]));
const auto expected_size = std::min(
allow_full ? this->full_sizes[row] : 1000,
std::min(allow_bitmap ? this->bitmap_sizes[row] : 1000,
allow_hash ? this->hash_sizes[row] : IndexType{}));
const auto size = storage_offsets[row + 1] - storage_offsets[row];

ASSERT_EQ(size, expected_size);
Expand All @@ -2644,16 +2645,21 @@ TYPED_TEST(CsrLookup, GeneratesLookupData)
for (auto allowed :
{sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) {
sparsity_type::full | sparsity_type::hash, sparsity_type::hash,
sparsity_type::none}) {
gko::kernels::reference::csr::build_lookup_offsets(
this->exec, row_ptrs, col_idxs, num_rows, allowed, storage_offsets);
gko::array<gko::int32> storage_array(this->exec,
storage_offsets[num_rows]);
const auto storage = storage_array.get_data();
const auto hash_equivalent =
csr_lookup_allowed(allowed, sparsity_type::hash)
? sparsity_type::hash
: sparsity_type::none;
const auto bitmap_equivalent =
csr_lookup_allowed(allowed, sparsity_type::bitmap)
? sparsity_type::bitmap
: sparsity_type::hash;
: hash_equivalent;
const auto full_equivalent =
csr_lookup_allowed(allowed, sparsity_type::full)
? sparsity_type::full
Expand Down Expand Up @@ -2687,7 +2693,7 @@ TYPED_TEST(CsrLookup, GeneratesLookupData)
ASSERT_EQ(row_descs[0] & 0xF, static_cast<int>(full_equivalent));
ASSERT_EQ(row_descs[1] & 0xF, static_cast<int>(full_equivalent));
ASSERT_EQ(row_descs[2] & 0xF, static_cast<int>(bitmap_equivalent));
ASSERT_EQ(row_descs[3] & 0xF, static_cast<int>(sparsity_type::hash));
ASSERT_EQ(row_descs[3] & 0xF, static_cast<int>(hash_equivalent));
ASSERT_EQ(row_descs[4] & 0xF, static_cast<int>(full_equivalent));
ASSERT_EQ(row_descs[5] & 0xF, static_cast<int>(full_equivalent));
}
Expand Down
3 changes: 2 additions & 1 deletion test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ TYPED_TEST(CsrLookup, BuildLookupWorks)
for (auto allowed :
{sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) {
sparsity_type::full | sparsity_type::hash, sparsity_type::hash,
sparsity_type::none}) {
// check that storage offsets are calculated correctly
// otherwise things might crash
gko::kernels::reference::csr::build_lookup_offsets(
Expand Down

0 comments on commit 6b02154

Please sign in to comment.