Skip to content

Commit

Permalink
Add group erase and erase kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Oct 6, 2023
1 parent cb17023 commit 8aa07fa
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 17 deletions.
16 changes: 8 additions & 8 deletions include/cuco/detail/common_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,28 +154,28 @@ __global__ void insert_if_n(
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIterator Device accessible input iterator whose `value_type` is
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `value_type` of the data structure
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIterator, typename Ref>
__global__ void erase(InputIterator first, cuco::detail::index_type n, Ref ref)
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename Ref>
__global__ void erase(InputIt first, cuco::detail::index_type n, Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename Ref::value_type const erase_element{*(first + idx)};
typename std::iterator_traits<InputIt>::value_type const& erase_element{*(first + idx)};
if constexpr (CGSize == 1) {
// ref.insert(insert_pair);
ref.erase(erase_element);
} else {
// auto const tile =
// cooperative_groups::tiled_partition<CGSize>(cooperative_groups::this_thread_block());
// ref.insert(tile, insert_pair);
auto const tile =
cooperative_groups::tiled_partition<CGSize>(cooperative_groups::this_thread_block());
ref.erase(tile, erase_element);
}
idx += loop_stride;
}
Expand Down
93 changes: 84 additions & 9 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,15 @@ class open_addressing_ref_impl {
}
}

/**
* @brief Erases an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param value The element to erase
*
* @return True if the given element is successfully erased
*/
template <typename Value>
__device__ bool erase(Value const& value) noexcept
{
Expand All @@ -506,25 +515,77 @@ class open_addressing_ref_impl {
// Key exists, return true if successfully deleted
if (eq_res == detail::equal_result::EQUAL) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
auto const erased_slot = [&]() {
if constexpr (this->has_payload) {
return cuco::pair{this->erased_key_sentinel(), empty_slot_sentinel_.second};
} else {
return this->erased_key_sentinel();
}
}();
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
slot_content,
erased_slot)) {
case insert_result::CONTINUE: continue;
this->erased_slot_sentinel())) {
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: return false;
default: continue;
}
}
}
++probing_iter;
}
}

/**
* @brief Erases an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group erase
* @param value The element to erase
*
* @return True if the given element is successfully erased
*/
template <typename Value>
__device__ bool erase(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value) noexcept
{
auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

while (true) {
auto const window_slots = storage_ref_[*probing_iter];

auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::AVAILABLE:
return window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EQUAL:
return window_probing_results{detail::equal_result::EQUAL, i};
default: continue;
}
}
// returns dummy index `-1` for UNEQUAL
return window_probing_results{detail::equal_result::UNEQUAL, -1};
}();

auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL);
if (group_contains_equal) {
auto const src_lane = __ffs(group_contains_equal) - 1;
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
window_slots[src_lane],
this->erased_slot_sentinel())
: insert_result::CONTINUE;

switch (group.shfl(status, src_lane)) {
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: return false;
default: continue;
}
} else if (group.any(state == detail::equal_result::AVAILABLE)) {
// Key doesn't exist, return false
return false;
} else {
++probing_iter;
}
}
}

/**
* @brief Indicates whether the probe key `key` was inserted into the container.
*
Expand Down Expand Up @@ -811,6 +872,20 @@ class open_addressing_ref_impl {
return value.second;
}

/**
* @brief Gets the sentinel used to represent an erased slot.
*
* @return The sentinel value used to represent an erased slot
*/
[[nodiscard]] __device__ constexpr value_type const& erased_slot_sentinel() const noexcept
{
if constexpr (this->has_payload) {
return cuco::pair{this->erased_key_sentinel(), this->extract_payload()};
} else {
return this->erased_key_sentinel();
}
}

/**
* @brief Inserts the specified element with one single CAS operation.
*
Expand Down

0 comments on commit 8aa07fa

Please sign in to comment.