Skip to content

Commit

Permalink
Expose hash_function member function (#587)
Browse files Browse the repository at this point in the history
Close #582 

This PR exposes `hash_function` member function for cuco hash tables.

---------

Co-authored-by: Daniel Jünger <[email protected]>
  • Loading branch information
PointKernel and sleeepyjack authored Aug 27, 2024
1 parent 6510352 commit a20460c
Show file tree
Hide file tree
Showing 20 changed files with 259 additions and 3 deletions.
11 changes: 11 additions & 0 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class open_addressing_impl {

using storage_ref_type = typename storage_type::ref_type; ///< Non-owning window storage ref type
using probing_scheme_type = ProbingScheme; ///< Probe scheme type
using hasher = typename probing_scheme_type::hasher; ///< Hash function type

/**
* @brief Constructs a statically-sized open addressing data structure with the specified initial
Expand Down Expand Up @@ -933,6 +934,16 @@ class open_addressing_impl {
return probing_scheme_;
}

/**
* @brief Gets the function(s) used to hash keys
*
* @return The function(s) used to hash keys
*/
[[nodiscard]] constexpr hasher hash_function() const noexcept
{
return this->probing_scheme().hash_function();
}

/**
* @brief Gets the container allocator.
*
Expand Down
14 changes: 13 additions & 1 deletion include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class open_addressing_ref_impl {
public:
using key_type = Key; ///< Key type
using probing_scheme_type = ProbingScheme; ///< Type of probing scheme
using hasher = typename probing_scheme_type::hasher; ///< Hash function type
using storage_ref_type = StorageRef; ///< Type of storage ref
using window_type = typename storage_ref_type::window_type; ///< Window type
using value_type = typename storage_ref_type::value_type; ///< Storage element type
Expand Down Expand Up @@ -233,11 +234,22 @@ class open_addressing_ref_impl {
*
* @return The probing scheme used for the container
*/
[[nodiscard]] __device__ constexpr probing_scheme_type const& probing_scheme() const noexcept
[[nodiscard]] __host__ __device__ constexpr probing_scheme_type const& probing_scheme()
const noexcept
{
return probing_scheme_;
}

/**
* @brief Gets the function(s) used to hash keys
*
* @return The function(s) used to hash keys
*/
[[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept
{
return this->probing_scheme().hash_function();
}

/**
* @brief Gets the non-owning storage ref.
*
Expand Down
15 changes: 15 additions & 0 deletions include/cuco/detail/probing_scheme/probing_scheme_impl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ __host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
upper_bound};
}

template <int32_t CGSize, typename Hash>
__host__ __device__ constexpr linear_probing<CGSize, Hash>::hasher
linear_probing<CGSize, Hash>::hash_function() const noexcept
{
return hash_;
}

template <int32_t CGSize, typename Hash1, typename Hash2>
__host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::double_hashing(
Hash1 const& hash1, Hash2 const& hash2)
Expand Down Expand Up @@ -192,4 +199,12 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operato
cg_size),
upper_bound}; // TODO use fast_int operator
}

template <int32_t CGSize, typename Hash1, typename Hash2>
__host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::hasher
double_hashing<CGSize, Hash1, Hash2>::hash_function() const noexcept
{
return {hash1_, hash2_};
}

} // namespace cuco
15 changes: 15 additions & 0 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,21 @@ static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
return impl_->key_eq();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::hasher
static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::hash_function()
const noexcept
{
return impl_->hash_function();
}

template <class Key,
class T,
class Extent,
Expand Down
20 changes: 20 additions & 0 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,26 @@ static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>
return this->impl_.key_eq();
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
__host__ __device__ constexpr static_map_ref<Key,
T,
Scope,
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::hasher
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::hash_function()
const noexcept
{
return impl_.hash_function();
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
Expand Down
16 changes: 16 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,22 @@ constexpr static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Alloca
return impl_->key_eq();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
hasher
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
hash_function() const noexcept
{
return impl_->hash_function();
}

template <class Key,
class T,
class Extent,
Expand Down
20 changes: 20 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operator
return impl_.key_eq();
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
__host__ __device__ constexpr static_multimap_ref<Key,
T,
Scope,
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::hasher
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
hash_function() const noexcept
{
return impl_.hash_function();
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
Expand Down
14 changes: 14 additions & 0 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,20 @@ constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator
return impl_->key_eq();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::hasher
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::hash_function()
const noexcept
{
return impl_->hash_function();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
18 changes: 18 additions & 0 deletions include/cuco/detail/static_multiset/static_multiset_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators..
return this->impl_.key_eq();
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
__host__ __device__ constexpr static_multiset_ref<Key,
Scope,
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::hasher
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::hash_function()
const noexcept
{
return impl_.hash_function();
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
Expand Down
14 changes: 14 additions & 0 deletions include/cuco/detail/static_set/static_set.inl
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,20 @@ static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::key
return impl_->key_eq();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::hasher
static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::hash_function()
const noexcept
{
return impl_->hash_function();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
18 changes: 18 additions & 0 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::k
return this->impl_.key_eq();
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
__host__ __device__ constexpr static_set_ref<Key,
Scope,
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::hasher
static_set_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::hash_function()
const noexcept
{
return impl_.hash_function();
}

template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
Expand Down
23 changes: 21 additions & 2 deletions include/cuco/probing_scheme.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cuco/detail/probing_scheme/probing_scheme_base.cuh>
#include <cuco/pair.cuh>

#include <cuda/std/tuple>
#include <cuda/std/type_traits>

#include <cooperative_groups.h>
Expand All @@ -37,10 +38,12 @@ namespace cuco {
*/
template <int32_t CGSize, typename Hash>
class linear_probing : private detail::probing_scheme_base<CGSize> {
public:
using probing_scheme_base_type =
detail::probing_scheme_base<CGSize>; ///< The base probe scheme type

public:
using probing_scheme_base_type::cg_size;
using hasher = Hash; ///< Hash function type

/**
*@brief Constructs linear probing scheme with the hasher callable.
Expand Down Expand Up @@ -93,6 +96,13 @@ class linear_probing : private detail::probing_scheme_base<CGSize> {
ProbeKey const& probe_key,
Extent upper_bound) const noexcept;

/**
* @brief Gets the function used to hash keys
*
* @return The function used to hash keys
*/
__host__ __device__ constexpr hasher hash_function() const noexcept;

private:
Hash hash_;
};
Expand All @@ -113,10 +123,12 @@ class linear_probing : private detail::probing_scheme_base<CGSize> {
*/
template <int32_t CGSize, typename Hash1, typename Hash2 = Hash1>
class double_hashing : private detail::probing_scheme_base<CGSize> {
public:
using probing_scheme_base_type =
detail::probing_scheme_base<CGSize>; ///< The base probe scheme type

public:
using probing_scheme_base_type::cg_size;
using hasher = cuda::std::tuple<Hash1, Hash2>; ///< Hash function type

/**
*@brief Constructs double hashing probing scheme with the two hasher callables.
Expand Down Expand Up @@ -195,6 +207,13 @@ class double_hashing : private detail::probing_scheme_base<CGSize> {
ProbeKey const& probe_key,
Extent upper_bound) const noexcept;

/**
* @brief Gets the functions used to hash keys
*
* @return The functions used to hash keys
*/
__host__ __device__ constexpr hasher hash_function() const noexcept;

private:
Hash1 hash1_;
Hash2 hash2_;
Expand Down
8 changes: 8 additions & 0 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class static_map {
/// Non-owning window storage ref type
using storage_ref_type = typename impl_type::storage_ref_type;
using probing_scheme_type = typename impl_type::probing_scheme_type; ///< Probing scheme type
using hasher = typename probing_scheme_type::hasher; ///< Hash function type

using mapped_type = T; ///< Payload type
template <typename... Operators>
Expand Down Expand Up @@ -959,6 +960,13 @@ class static_map {
*/
[[nodiscard]] constexpr key_equal key_eq() const noexcept;

/**
* @brief Gets the function(s) used to hash keys
*
* @return The function(s) used to hash keys
*/
[[nodiscard]] constexpr hasher hash_function() const noexcept;

/**
* @brief Get device ref with operators.
*
Expand Down
8 changes: 8 additions & 0 deletions include/cuco/static_map_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class static_map_ref
using key_type = Key; ///< Key type
using mapped_type = T; ///< Mapped type
using probing_scheme_type = ProbingScheme; ///< Type of probing scheme
using hasher = typename probing_scheme_type::hasher; ///< Hash function type
using storage_ref_type = StorageRef; ///< Type of storage ref
using window_type = typename storage_ref_type::window_type; ///< Window type
using value_type = typename storage_ref_type::value_type; ///< Storage element type
Expand Down Expand Up @@ -190,6 +191,13 @@ class static_map_ref
*/
[[nodiscard]] __host__ __device__ constexpr key_equal key_eq() const noexcept;

/**
* @brief Gets the function(s) used to hash keys
*
* @return The function(s) used to hash keys
*/
[[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept;

/**
* @brief Returns a const_iterator to one past the last slot.
*
Expand Down
Loading

0 comments on commit a20460c

Please sign in to comment.