diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/common_kernels.cuh index dedf1f910..cecd50735 100644 --- a/include/cuco/detail/common_kernels.cuh +++ b/include/cuco/detail/common_kernels.cuh @@ -154,7 +154,7 @@ __global__ void insert_if_n( * * @tparam CGSize Number of threads in each CG * @tparam BlockSize Number of threads in each block - * @tparam InputIterator Device accessible input iterator whose `value_type` is + * @tparam InputIt Device accessible input iterator whose `value_type` is * convertible to the `value_type` of the data structure * @tparam Ref Type of non-owning device ref allowing access to storage * @@ -162,20 +162,20 @@ __global__ void insert_if_n( * @param n Number of input elements * @param ref Non-owning container device ref used to access the slot storage */ -template -__global__ void erase(InputIterator first, cuco::detail::index_type n, Ref ref) +template +__global__ void erase(InputIt first, cuco::detail::index_type n, Ref ref) { auto const loop_stride = cuco::detail::grid_stride() / CGSize; auto idx = cuco::detail::global_thread_id() / CGSize; while (idx < n) { - typename Ref::value_type const erase_element{*(first + idx)}; + typename std::iterator_traits::value_type const& erase_element{*(first + idx)}; if constexpr (CGSize == 1) { - // ref.insert(insert_pair); + ref.erase(erase_element); } else { - // auto const tile = - // cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); - // ref.insert(tile, insert_pair); + auto const tile = + cooperative_groups::tiled_partition(cooperative_groups::this_thread_block()); + ref.erase(tile, erase_element); } idx += loop_stride; } diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 73037f424..57a9c4ebc 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -487,6 +487,15 @@ class open_addressing_ref_impl { } } + /** + * @brief Erases an element. + * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * + * @param value The element to erase + * + * @return True if the given element is successfully erased + */ template __device__ bool erase(Value const& value) noexcept { @@ -506,18 +515,12 @@ class open_addressing_ref_impl { // Key exists, return true if successfully deleted if (eq_res == detail::equal_result::EQUAL) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); - auto const erased_slot = [&]() { - if constexpr (this->has_payload) { - return cuco::pair{this->erased_key_sentinel(), empty_slot_sentinel_.second}; - } else { - return this->erased_key_sentinel(); - } - }(); switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, slot_content, - erased_slot)) { - case insert_result::CONTINUE: continue; + this->erased_slot_sentinel())) { case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; } } } @@ -525,6 +528,64 @@ class open_addressing_ref_impl { } } + /** + * @brief Erases an element. + * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * + * @param group The Cooperative Group used to perform group erase + * @param value The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(cooperative_groups::thread_block_tile const& group, + Value const& value) noexcept + { + auto const key = this->extract_key(value); + auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); + + while (true) { + auto const window_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_window_index] = [&]() { + for (auto i = 0; i < window_size; ++i) { + switch (this->predicate_(this->extract_key(window_slots[i]), key)) { + case detail::equal_result::AVAILABLE: + return window_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EQUAL: + return window_probing_results{detail::equal_result::EQUAL, i}; + default: continue; + } + } + // returns dummy index `-1` for UNEQUAL + return window_probing_results{detail::equal_result::UNEQUAL, -1}; + }(); + + auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_contains_equal) { + auto const src_lane = __ffs(group_contains_equal) - 1; + auto const status = + (group.thread_rank() == src_lane) + ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + window_slots[src_lane], + this->erased_slot_sentinel()) + : insert_result::CONTINUE; + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } else if (group.any(state == detail::equal_result::AVAILABLE)) { + // Key doesn't exist, return false + return false; + } else { + ++probing_iter; + } + } + } + /** * @brief Indicates whether the probe key `key` was inserted into the container. * @@ -811,6 +872,20 @@ class open_addressing_ref_impl { return value.second; } + /** + * @brief Gets the sentinel used to represent an erased slot. + * + * @return The sentinel value used to represent an erased slot + */ + [[nodiscard]] __device__ constexpr value_type const& erased_slot_sentinel() const noexcept + { + if constexpr (this->has_payload) { + return cuco::pair{this->erased_key_sentinel(), this->extract_payload()}; + } else { + return this->erased_key_sentinel(); + } + } + /** * @brief Inserts the specified element with one single CAS operation. *