Skip to content

Commit

Permalink
Fix a bug in CG probing
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Aug 23, 2023
1 parent e2e26b3 commit 84aebec
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 3 deletions.
25 changes: 22 additions & 3 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,14 @@ class operator_impl<
for (auto& slot_content : window_slots) {
auto const eq_res = ref_.predicate_(slot_content, key);

// If the key is already in the container, return false
// If the key is already in the container, update the payload and return
if (eq_res == detail::equal_result::EQUAL) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
return;
}
if (eq_res == detail::equal_result::EMPTY) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
if (attempt_insert_or_assign(
Expand Down Expand Up @@ -302,7 +309,7 @@ class operator_impl<
auto const key = value.first;
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.window_extent());
auto probing_iter = probing_scheme(group, key, storage_ref.window_extent());

while (true) {
auto const window_slots = storage_ref[*probing_iter];
Expand All @@ -321,14 +328,26 @@ class operator_impl<
return detail::window_probing_results{detail::equal_result::UNEQUAL, -1};
}();

auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL);
if (group_contains_equal) {
auto const src_lane = __ffs(group_contains_equal) - 1;
if (group.thread_rank() == src_lane) {
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
}
group.sync();
return;
}

auto const group_contains_empty = group.ballot(state == detail::equal_result::EMPTY);
if (group_contains_empty) {
auto const src_lane = __ffs(group_contains_empty) - 1;
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert_or_assign(
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)
: false;
: true;

// Exit if inserted or assigned
if (group.shfl(status, src_lane)) { return; }
Expand Down
114 changes: 114 additions & 0 deletions tests/static_map/insert_or_assign_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <utils.hpp>

#include <cuco/static_map.cuh>

#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/sort.h>

#include <catch2/catch_template_test_macros.hpp>

using size_type = std::size_t;

template <typename Map>
__inline__ void test_insert_or_assign(Map& map, size_type num_keys)
{
using Key = typename Map::key_type;
using Value = typename Map::mapped_type;

// Insert pairs
auto pairs_begin =
thrust::make_transform_iterator(thrust::counting_iterator<size_type>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i); });

auto const initial_size = map.insert(pairs_begin, pairs_begin + num_keys);
REQUIRE(initial_size == num_keys); // all keys should be inserted

// Query pairs have the same keys but different payloads
auto query_pairs_begin = thrust::make_transform_iterator(
thrust::counting_iterator<size_type>(0),
[] __device__(auto i) { return cuco::pair<Key, Value>(i, i * 2); });

map.insert_or_assign(query_pairs_begin, query_pairs_begin + num_keys);

auto const updated_size = map.size();
// all keys are present in the map so the size shouldn't change
REQUIRE(updated_size == initial_size);

thrust::device_vector<Key> d_keys(num_keys);
thrust::device_vector<Key> d_values(num_keys);
map.retrieve_all(d_keys.begin(), d_values.begin());

auto gold_values_begin = thrust::make_transform_iterator(thrust::counting_iterator<size_type>(0),
[] __device__(auto i) { return i * 2; });

thrust::sort(thrust::device, d_values.begin(), d_values.end());
REQUIRE(cuco::test::equal(
d_values.begin(), d_values.end(), gold_values_begin, thrust::equal_to<Value>{}));
}

TEMPLATE_TEST_CASE_SIG(
"Insert or assign",
"",
((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize),
Key,
Value,
Probe,
CGSize),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
constexpr size_type num_keys{400};

using probe =
std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
cuco::experimental::linear_probing<CGSize, cuco::murmurhash3_32<Key>>,
cuco::experimental::double_hashing<CGSize,
cuco::murmurhash3_32<Key>,
cuco::murmurhash3_32<Key>>>;

auto map = cuco::experimental::static_map<Key,
Value,
cuco::experimental::extent<size_type>,
cuda::thread_scope_device,
thrust::equal_to<Key>,
probe,
cuco::cuda_allocator<std::byte>,
cuco::experimental::storage<2>>{
num_keys, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};

test_insert_or_assign(map, num_keys);
}

0 comments on commit 84aebec

Please sign in to comment.