Skip to content

Commit

Permalink
Make build() a private member
Browse files Browse the repository at this point in the history
Query methods will check and perform build()
  • Loading branch information
amukkara committed Sep 6, 2023
1 parent b2c88de commit a629730
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 29 deletions.
21 changes: 11 additions & 10 deletions include/cuco/detail/trie/dynamic_bitset/dynamic_bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ class dynamic_bitset {
*/
constexpr void set_last(bool value) noexcept;

/**
* @brief Builds indexes for rank and select
*/
constexpr void build() noexcept;

/**
* @brief For any element `keys_begin[i]` in the range `[keys_begin, keys_end)`, stores the
* boolean value at position `keys_begin[i]` to `output_begin[i]`.
Expand All @@ -146,7 +141,7 @@ class dynamic_bitset {
constexpr void test(KeyIt keys_begin,
KeyIt keys_end,
OutputIt outputs_begin,
cuda_stream_ref stream = {}) const noexcept;
cuda_stream_ref stream = {}) noexcept;

/**
* @brief For any element `keys_begin[i]` in the range `[keys_begin, keys_end)`, stores total
Expand All @@ -166,7 +161,7 @@ class dynamic_bitset {
constexpr void rank(KeyIt keys_begin,
KeyIt keys_end,
OutputIt outputs_begin,
cuda_stream_ref stream = {}) const noexcept;
cuda_stream_ref stream = {}) noexcept;

/**
* @brief For any element `keys_begin[i]` in the range `[keys_begin, keys_end)`, stores the
Expand All @@ -186,7 +181,7 @@ class dynamic_bitset {
constexpr void select(KeyIt keys_begin,
KeyIt keys_end,
OutputIt outputs_begin,
cuda_stream_ref stream = {}) const noexcept;
cuda_stream_ref stream = {}) noexcept;

using rank_type = cuco::experimental::detail::rank; ///< Rank type

Expand Down Expand Up @@ -335,6 +330,7 @@ class dynamic_bitset {

allocator_type allocator_; ///< Words allocator
size_type n_bits_; ///< Number of bits dynamic_bitset currently holds
bool is_built_; ///< Flag indicating whether the rank and select indices are built or not

/// Words vector that represents all bits
thrust::device_vector<word_type, allocator_type> words_;
Expand All @@ -348,11 +344,16 @@ class dynamic_bitset {
thrust::device_vector<size_type, size_allocator_type> selects_false_;

/**
* @brief Populates rank and select indexes on device
* @brief Builds indexes for rank and select
*/
constexpr void build() noexcept;

/**
* @brief Populates rank and select indexes for true or false bits
*
* @param ranks Output array of ranks
* @param selects Output array of selects
* @param flip_bits If true, negate bits to construct indexes for `0` bits
* @param flip_bits If true, negate bits to construct indexes for false bits
*/
constexpr void build_ranks_and_selects(
thrust::device_vector<rank_type, rank_allocator_type>& ranks,
Expand Down
18 changes: 13 additions & 5 deletions include/cuco/detail/trie/dynamic_bitset/dynamic_bitset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ template <class Allocator>
constexpr dynamic_bitset<Allocator>::dynamic_bitset(Allocator const& allocator)
: allocator_{allocator},
n_bits_{0},
is_built_{false},
words_{allocator},
ranks_true_{allocator},
ranks_false_{allocator},
Expand All @@ -53,6 +54,7 @@ constexpr void dynamic_bitset<Allocator>::push_back(bool bit) noexcept
template <class Allocator>
constexpr void dynamic_bitset<Allocator>::set(size_type index, bool bit) noexcept
{
is_built_ = false;
size_type word_id = index / bits_per_word;
size_type bit_id = index % bits_per_word;
if (bit) {
Expand All @@ -73,9 +75,10 @@ template <typename KeyIt, typename OutputIt>
constexpr void dynamic_bitset<Allocator>::test(KeyIt keys_begin,
KeyIt keys_end,
OutputIt outputs_begin,
cuda_stream_ref stream) const noexcept
cuda_stream_ref stream) noexcept

{
build();
auto const num_keys = cuco::detail::distance(keys_begin, keys_end);
if (num_keys == 0) { return; }

Expand All @@ -90,9 +93,10 @@ template <typename KeyIt, typename OutputIt>
constexpr void dynamic_bitset<Allocator>::rank(KeyIt keys_begin,
KeyIt keys_end,
OutputIt outputs_begin,
cuda_stream_ref stream) const noexcept
cuda_stream_ref stream) noexcept

{
build();
auto const num_keys = cuco::detail::distance(keys_begin, keys_end);
if (num_keys == 0) { return; }

Expand All @@ -107,9 +111,10 @@ template <typename KeyIt, typename OutputIt>
constexpr void dynamic_bitset<Allocator>::select(KeyIt keys_begin,
KeyIt keys_end,
OutputIt outputs_begin,
cuda_stream_ref stream) const noexcept
cuda_stream_ref stream) noexcept

{
build();
auto const num_keys = cuco::detail::distance(keys_begin, keys_end);
if (num_keys == 0) { return; }

Expand Down Expand Up @@ -178,8 +183,11 @@ constexpr void dynamic_bitset<Allocator>::build_ranks_and_selects(
template <class Allocator>
constexpr void dynamic_bitset<Allocator>::build() noexcept
{
build_ranks_and_selects(ranks_true_, selects_true_, false); // 1 bits
build_ranks_and_selects(ranks_false_, selects_false_, true); // 0 bits
if (not is_built_) {
build_ranks_and_selects(ranks_true_, selects_true_, false); // 1 bits
build_ranks_and_selects(ranks_false_, selects_false_, true); // 0 bits
is_built_ = true;
}
}

template <class Allocator>
Expand Down
1 change: 0 additions & 1 deletion tests/dynamic_bitset/find_next_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ TEST_CASE("Find next set test", "")
for (size_type i = 0; i < num_elements; i++) {
bv.push_back(modulo_bitgen(i));
}
bv.build();

thrust::device_vector<size_type> device_result(num_elements);
auto ref = bv.ref();
Expand Down
19 changes: 10 additions & 9 deletions tests/dynamic_bitset/get_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,24 @@ TEST_CASE("Get test", "")
bv.push_back(modulo_bitgen(i));
num_set_ref += modulo_bitgen(i);
}
bv.build();

// Device-ref test
auto ref = bv.ref();
thrust::device_vector<size_type> test_result(num_elements);
test_kernel<<<1, 1024>>>(ref, num_elements, test_result.data());

size_type num_set = thrust::reduce(thrust::device, test_result.begin(), test_result.end(), 0);
REQUIRE(num_set == num_set_ref);

// Host-bulk test
thrust::device_vector<size_type> keys(num_elements);
thrust::sequence(keys.begin(), keys.end(), 0);

thrust::device_vector<size_type> test_result(num_elements);
thrust::fill(test_result.begin(), test_result.end(), 0);

bv.test(keys.begin(), keys.end(), test_result.begin());

size_type num_set = thrust::reduce(thrust::device, test_result.begin(), test_result.end(), 0);
REQUIRE(num_set == num_set_ref);

// Device-ref test
auto ref = bv.ref();
thrust::fill(test_result.begin(), test_result.end(), 0);
test_kernel<<<1, 1024>>>(ref, num_elements, test_result.data());

num_set = thrust::reduce(thrust::device, test_result.begin(), test_result.end(), 0);
REQUIRE(num_set == num_set_ref);
}
1 change: 0 additions & 1 deletion tests/dynamic_bitset/rank_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ TEST_CASE("Rank test", "")
for (size_type i = 0; i < num_elements; i++) {
bv.push_back(modulo_bitgen(i));
}
bv.build();

thrust::device_vector<size_type> keys(num_elements);
thrust::sequence(keys.begin(), keys.end(), 0);
Expand Down
3 changes: 1 addition & 2 deletions tests/dynamic_bitset/select_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ TEST_CASE("Select test", "")
bv.push_back(modulo_bitgen(i));
num_set += modulo_bitgen(i);
}
bv.build();
auto ref = bv.ref();

// Check select
{
Expand Down Expand Up @@ -79,6 +77,7 @@ TEST_CASE("Select test", "")
{
size_type num_not_set = num_elements - num_set;

auto ref = bv.ref();
thrust::device_vector<size_type> device_result(num_not_set);
select_false_kernel<<<1, 1024>>>(ref, num_not_set, device_result.data());
thrust::host_vector<size_type> host_result = device_result;
Expand Down
1 change: 0 additions & 1 deletion tests/dynamic_bitset/size_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ TEST_CASE("Size computation", "")
for (size_type i = 0; i < num_elements; i++) {
bv.push_back(i % 2 == 0); // Alternate 0s and 1s pattern
}
bv.build();

auto size = bv.size();
REQUIRE(size == num_elements);
Expand Down

0 comments on commit a629730

Please sign in to comment.