Skip to content

Commit

Permalink
Add test + expose more private functions
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Aug 22, 2023
1 parent 1921748 commit e2e26b3
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 20 deletions.
32 changes: 31 additions & 1 deletion include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ class open_addressing_ref_impl {
}
}

private:
/**
* @brief Compares the content of the address `address` (old value) with the `expected` value and,
* only if they are the same, sets the content of `address` to `desired`.
Expand Down Expand Up @@ -653,6 +652,37 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Gets the sentinel used to represent an empty slot.
*
* @return The sentinel value used to represent an empty slot
*/
[[nodiscard]] __device__ constexpr value_type empty_slot_sentinel() const noexcept
{
return empty_slot_sentinel_;
}

/**
* @brief Gets the probing scheme.
*
* @return The probing scheme used for the container
*/
[[nodiscard]] __device__ constexpr probing_scheme_type const& probing_scheme() const noexcept
{
return probing_scheme_;
}

/**
* @brief Gets the non-owning storage ref.
*
* @return The non-owning storage ref of the container
*/
[[nodiscard]] __device__ constexpr storage_ref_type storage_ref() const noexcept
{
return storage_ref_;
}

private:
/**
* @brief Inserts the specified element with one single CAS operation.
*
Expand Down
30 changes: 15 additions & 15 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,16 @@ class operator_impl<
*/
__device__ void insert_or_assign(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const key = value.first;

static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = ref_.impl_.probing_scheme_(key, ref_.impl_.storage_ref_.window_extent());

ref_type& ref_ = static_cast<ref_type&>(*this);
auto const key = value.first;
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.window_extent());

while (true) {
auto const window_slots = ref_._impl_.storage_ref_[*probing_iter];
auto const window_slots = storage_ref[*probing_iter];

for (auto& slot_content : window_slots) {
auto const eq_res = ref_.predicate_(slot_content, key);
Expand All @@ -274,9 +276,7 @@ class operator_impl<
if (eq_res == detail::equal_result::EMPTY) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
if (attempt_insert_or_assign(
(ref_.impl_.storage_ref_.data() + *probing_iter)->data() + intra_window_index,
value,
ref_.predicate_)) {
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) {
return;
}
}
Expand All @@ -298,13 +298,14 @@ class operator_impl<
value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const key = value.first;

auto probing_iter =
ref_.impl_.probing_scheme_(group, key, ref_.impl_.storage_ref_.window_extent());
auto const key = value.first;
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.window_extent());

while (true) {
auto const window_slots = ref_.impl_.storage_ref_[*probing_iter];
auto const window_slots = storage_ref[*probing_iter];

auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
Expand All @@ -326,8 +327,7 @@ class operator_impl<
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert_or_assign(
(ref_.impl_.storage_ref_.data() + *probing_iter)->data() + intra_window_index,
value)
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)
: false;

// Exit if inserted or assigned
Expand Down Expand Up @@ -355,7 +355,7 @@ class operator_impl<
value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const expected_key = ref_.impl_.empty_slot_sentinel_.first;
auto const expected_key = ref_.impl_.empty_slot_sentinel().first;

auto old_key = ref_.impl_.compare_and_swap(&slot->first, expected_key, value.first);
auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
Expand Down
6 changes: 3 additions & 3 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,9 @@ class static_map {
* @return Pair of iterators indicating the last elements in the output
*/
template <typename KeyOut, typename ValueOut>
[[nodiscard]] std::pair<KeyOut, ValueOut> retrieve_all(KeyOut keys_out,
ValueOut values_out,
cuda_stream_ref stream = {}) const;
std::pair<KeyOut, ValueOut> retrieve_all(KeyOut keys_out,
ValueOut values_out,
cuda_stream_ref stream = {}) const;

/**
* @brief Gets the number of elements in the container.
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/static_set.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class static_set {
* @return Iterator indicating the end of the output
*/
template <typename OutputIt>
[[nodiscard]] OutputIt retrieve_all(OutputIt output_begin, cuda_stream_ref stream = {}) const;
OutputIt retrieve_all(OutputIt output_begin, cuda_stream_ref stream = {}) const;

/**
* @brief Gets the number of elements in the container.
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ ConfigureTest(STATIC_MAP_TEST
static_map/erase_test.cu
static_map/heterogeneous_lookup_test.cu
static_map/insert_and_find_test.cu
static_map/insert_or_assign_test.cu
static_map/key_sentinel_test.cu
static_map/shared_memory_test.cu
static_map/stream_test.cu
Expand Down

0 comments on commit e2e26b3

Please sign in to comment.