From 84aebec5d732b6269cb73e86f5a386e591f969a8 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Wed, 23 Aug 2023 11:13:44 -0700 Subject: [PATCH] Fix a bug in CG probing --- .../cuco/detail/static_map/static_map_ref.inl | 25 +++- tests/static_map/insert_or_assign_test.cu | 114 ++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) create mode 100644 tests/static_map/insert_or_assign_test.cu diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 229a08654..cd2d22081 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -272,7 +272,14 @@ class operator_impl< for (auto& slot_content : window_slots) { auto const eq_res = ref_.predicate_(slot_content, key); - // If the key is already in the container, return false + // If the key is already in the container, update the payload and return + if (eq_res == detail::equal_result::EQUAL) { + auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); + ref_.impl_.atomic_store( + &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, + value.second); + return; + } if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); if (attempt_insert_or_assign( @@ -302,7 +309,7 @@ class operator_impl< auto const key = value.first; auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); - auto probing_iter = probing_scheme(key, storage_ref.window_extent()); + auto probing_iter = probing_scheme(group, key, storage_ref.window_extent()); while (true) { auto const window_slots = storage_ref[*probing_iter]; @@ -321,6 +328,18 @@ class operator_impl< return detail::window_probing_results{detail::equal_result::UNEQUAL, -1}; }(); + auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_contains_equal) { + auto const src_lane = __ffs(group_contains_equal) - 1; + if (group.thread_rank() == src_lane) { + ref_.impl_.atomic_store( + &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, + value.second); + } + group.sync(); + return; + } + auto const group_contains_empty = group.ballot(state == detail::equal_result::EMPTY); if (group_contains_empty) { auto const src_lane = __ffs(group_contains_empty) - 1; @@ -328,7 +347,7 @@ class operator_impl< (group.thread_rank() == src_lane) ? attempt_insert_or_assign( (storage_ref.data() + *probing_iter)->data() + intra_window_index, value) - : false; + : true; // Exit if inserted or assigned if (group.shfl(status, src_lane)) { return; } diff --git a/tests/static_map/insert_or_assign_test.cu b/tests/static_map/insert_or_assign_test.cu new file mode 100644 index 000000000..90c6553ce --- /dev/null +++ b/tests/static_map/insert_or_assign_test.cu @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include + +using size_type = std::size_t; + +template +__inline__ void test_insert_or_assign(Map& map, size_type num_keys) +{ + using Key = typename Map::key_type; + using Value = typename Map::mapped_type; + + // Insert pairs + auto pairs_begin = + thrust::make_transform_iterator(thrust::counting_iterator(0), + [] __device__(auto i) { return cuco::pair(i, i); }); + + auto const initial_size = map.insert(pairs_begin, pairs_begin + num_keys); + REQUIRE(initial_size == num_keys); // all keys should be inserted + + // Query pairs have the same keys but different payloads + auto query_pairs_begin = thrust::make_transform_iterator( + thrust::counting_iterator(0), + [] __device__(auto i) { return cuco::pair(i, i * 2); }); + + map.insert_or_assign(query_pairs_begin, query_pairs_begin + num_keys); + + auto const updated_size = map.size(); + // all keys are present in the map so the size shouldn't change + REQUIRE(updated_size == initial_size); + + thrust::device_vector d_keys(num_keys); + thrust::device_vector d_values(num_keys); + map.retrieve_all(d_keys.begin(), d_values.begin()); + + auto gold_values_begin = thrust::make_transform_iterator(thrust::counting_iterator(0), + [] __device__(auto i) { return i * 2; }); + + thrust::sort(thrust::device, d_values.begin(), d_values.end()); + REQUIRE(cuco::test::equal( + d_values.begin(), d_values.end(), gold_values_begin, thrust::equal_to{})); +} + +TEMPLATE_TEST_CASE_SIG( + "Insert or assign", + "", + ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), + Key, + Value, + Probe, + CGSize), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_type num_keys{400}; + + using probe = + std::conditional_t>, + cuco::experimental::double_hashing, + cuco::murmurhash3_32>>; + + auto map = cuco::experimental::static_map, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::experimental::storage<2>>{ + num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; + + test_insert_or_assign(map, num_keys); +}