Skip to content

Commit

Permalink
Add multimap count and conditional insert (#571)
Browse files Browse the repository at this point in the history
This PR adds `count`, `insert_if` and `insert_if_async` APIs to the new
multimap.
  • Loading branch information
PointKernel authored Aug 27, 2024
1 parent 9fe6c82 commit 6510352
Show file tree
Hide file tree
Showing 11 changed files with 422 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ class open_addressing_impl {
[[nodiscard]] size_type count(InputIt first,
InputIt last,
Ref container_ref,
cuda::stream_ref stream) const noexcept
cuda::stream_ref stream) const
{
auto constexpr is_outer = false;
return this->count<is_outer>(first, last, container_ref, stream);
Expand Down
57 changes: 54 additions & 3 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ template <class Key,
class Allocator,
class Storage>
template <typename InputIt>
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert(
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert(
InputIt first, InputIt last, cuda::stream_ref stream)
{
return impl_->insert(first, last, ref(op::insert), stream);
this->insert_async(first, last, stream);
stream.wait();
}

template <class Key,
Expand All @@ -174,6 +174,41 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
impl_->insert_async(first, last, ref(op::insert), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate>
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert_if(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda::stream_ref stream)
{
return impl_->insert_if(first, last, stencil, pred, ref(op::insert), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
insert_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
cuda::stream_ref stream) noexcept
{
impl_->insert_if_async(first, last, stencil, pred, ref(op::insert), stream);
}

template <class Key,
class T,
class Extent,
Expand Down Expand Up @@ -249,6 +284,22 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
impl_->contains_if_async(first, last, stencil, pred, output_begin, ref(op::contains), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt>
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::count(
InputIt first, InputIt last, cuda::stream_ref stream) const
{
return impl_->count(first, last, ref(op::count), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
57 changes: 57 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -466,5 +466,62 @@ class operator_impl<
return ref_.impl_.contains(group, key);
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::count_tag,
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type =
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using size_type = typename base_type::size_type;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief Counts the occurrence of a given key contained in multimap
*
* @tparam ProbeKey Input type
*
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(key);
}

/**
* @brief Counts the occurrence of a given key contained in multimap
*
* @tparam ProbeKey Probe key type
*
* @param group The Cooperative Group used to perform group count
* @param key The key to count for
*
* @return Number of occurrences found by the current thread
*/
template <typename ProbeKey>
__device__ size_type count(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key) const noexcept
{
auto const& ref_ = static_cast<ref_type const&>(*this);
return ref_.impl_.count(group, key);
}
};

} // namespace detail
} // namespace cuco
8 changes: 4 additions & 4 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ template <class Key,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert_if(
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert_if(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda::stream_ref stream)
{
this->insert_if_async(first, last, stencil, pred, stream);
stream.wait();
return impl_->insert_if(first, last, stencil, pred, ref(op::insert), stream);
}

template <class Key,
Expand Down Expand Up @@ -287,7 +287,7 @@ template <class Key,
template <typename InputIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::count(
InputIt first, InputIt last, cuda::stream_ref stream) const noexcept
InputIt first, InputIt last, cuda::stream_ref stream) const
{
return impl_->count(first, last, ref(op::count), stream);
}
Expand Down
110 changes: 55 additions & 55 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,61 @@ class static_map {
template <typename InputIt>
void insert_async(InputIt first, InputIt last, cuda::stream_ref stream = {}) noexcept;

/**
* @brief Inserts keys in the range `[first, last)` if `pred` of the corresponding stencil returns
* true.
*
* @note The key `*(first + i)` is inserted if `pred( *(stencil + i) )` returns true.
* @note This function synchronizes the given stream and returns the number of successful
* insertions. For asynchronous execution use `insert_if_async`.
*
* @tparam InputIt Device accessible random access iterator whose `value_type` is
* convertible to the container's `value_type`
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
*
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param stream CUDA stream used for the operation
*
* @return Number of successful insertions
*/
template <typename InputIt, typename StencilIt, typename Predicate>
size_type insert_if(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda::stream_ref stream = {});

/**
* @brief Asynchronously inserts keys in the range `[first, last)` if `pred` of the corresponding
* stencil returns true.
*
* @note The key `*(first + i)` is inserted if `pred( *(stencil + i) )` returns true.
*
* @tparam InputIt Device accessible random access iterator whose `value_type` is
* convertible to the container's `value_type`
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
*
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param stream CUDA stream used for the operation
*/
template <typename InputIt, typename StencilIt, typename Predicate>
void insert_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
cuda::stream_ref stream = {}) noexcept;

/**
* @brief Asynchronously inserts all elements in the range `[first, last)`.
*
Expand Down Expand Up @@ -370,61 +425,6 @@ class static_map {
InsertedIt inserted_begin,
cuda::stream_ref stream = {});

/**
* @brief Inserts keys in the range `[first, last)` if `pred` of the corresponding stencil returns
* true.
*
* @note The key `*(first + i)` is inserted if `pred( *(stencil + i) )` returns true.
* @note This function synchronizes the given stream and returns the number of successful
* insertions. For asynchronous execution use `insert_if_async`.
*
* @tparam InputIt Device accessible random access iterator whose `value_type` is
* convertible to the container's `value_type`
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
*
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param stream CUDA stream used for the operation
*
* @return Number of successful insertions
*/
template <typename InputIt, typename StencilIt, typename Predicate>
size_type insert_if(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda::stream_ref stream = {});

/**
* @brief Asynchronously inserts keys in the range `[first, last)` if `pred` of the corresponding
* stencil returns true.
*
* @note The key `*(first + i)` is inserted if `pred( *(stencil + i) )` returns true.
*
* @tparam InputIt Device accessible random access iterator whose `value_type` is
* convertible to the container's `value_type`
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
*
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param stream CUDA stream used for the operation
*/
template <typename InputIt, typename StencilIt, typename Predicate>
void insert_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
cuda::stream_ref stream = {}) noexcept;

/**
* @brief For any key-value pair `{k, v}` in the range `[first, last)`, if a key equivalent to `k`
* already exists in the container, assigns `v` to the mapped_type corresponding to the key `k`.
Expand Down
Loading

0 comments on commit 6510352

Please sign in to comment.