diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index df2b0b03a..28f11fe3a 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -900,33 +900,15 @@ class open_addressing_ref_impl { value_type const& expected, Value const& desired) noexcept { - auto old = - compare_and_swap(address, this->empty_slot_sentinel_, static_cast(desired)); - auto* old_ptr = reinterpret_cast(&old); - auto const inserted = [&]() { - if constexpr (this->has_payload) { - // If it's a map implementation, compare keys only - return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); - } else { - // If it's a set implementation, compare the whole slot content - return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); - } - }(); - if (inserted) { + auto old = compare_and_swap(address, expected, static_cast(desired)); + auto* old_ptr = reinterpret_cast(&old); + if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) { return insert_result::SUCCESS; } else { - // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - auto const res = [&]() { - if constexpr (this->has_payload) { - // If it's a map implementation, compare keys only - return this->predicate_.equal_to(old_ptr->first, desired.first); - } else { - // If it's a set implementation, compare the whole slot content - return this->predicate_.equal_to(*old_ptr, desired); - } - }(); - return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE - : insert_result::CONTINUE; + return this->predicate_.equal_to(this->extract_key(*old_ptr), this->extract_key(desired)) == + detail::equal_result::EQUAL + ? insert_result::DUPLICATE + : insert_result::CONTINUE; } } @@ -948,8 +930,8 @@ class open_addressing_ref_impl { { using mapped_type = decltype(this->empty_slot_sentinel_.second); - auto const expected_key = this->empty_slot_sentinel_.first; - auto const expected_payload = this->empty_slot_sentinel_.second; + auto const expected_key = expected.first; + auto const expected_payload = expected.second; auto old_key = compare_and_swap(&address->first, expected_key, static_cast(desired.first)); @@ -996,7 +978,7 @@ class open_addressing_ref_impl { { using mapped_type = decltype(this->empty_slot_sentinel_.second); - auto const expected_key = this->empty_slot_sentinel_.first; + auto const expected_key = expected.first; auto old_key = compare_and_swap(&address->first, expected_key, static_cast(desired.first));