Skip to content

Commit

Permalink
Templatorize the erase_if and export_batch_if API
Browse files Browse the repository at this point in the history
  • Loading branch information
Lifann authored and rhdong committed May 26, 2023
1 parent 7a22173 commit f7d4fe7
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 118 deletions.
11 changes: 7 additions & 4 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1774,16 +1774,18 @@ __global__ void remove_kernel(const Table<K, V, M>* __restrict table,
}

/* Remove specified keys which match the Predict. */
template <class K, class V, class M, uint32_t TILE_SIZE = 1>
template <class K, class V, class M,
template <typename, typename> class PredFunctor,
uint32_t TILE_SIZE = 1>
__global__ void remove_kernel(const Table<K, V, M>* __restrict table,
const EraseIfPredictInternal<K, M> pred,
const K pattern, const M threshold,
size_t* __restrict count,
Bucket<K, V, M>* __restrict buckets,
int* __restrict buckets_size,
const size_t bucket_max_size,
const size_t buckets_num, size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
PredFunctor<K, M> pred;

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
Expand Down Expand Up @@ -1891,9 +1893,9 @@ __global__ void dump_kernel(const Table<K, V, M>* __restrict table, K* d_key,
}

/* Dump with meta. */
template <class K, class V, class M>
template <class K, class V, class M,
template <typename, typename> class PredFunctor>
__global__ void dump_kernel(const Table<K, V, M>* __restrict table,
const EraseIfPredictInternal<K, M> pred,
const K pattern, const M threshold, K* d_key,
V* __restrict d_val, M* __restrict d_meta,
const size_t offset, const size_t search_length,
Expand All @@ -1907,6 +1909,7 @@ __global__ void dump_kernel(const Table<K, V, M>* __restrict table,
M* block_result_meta = (M*)&(block_result_val[blockDim.x * dim]);
__shared__ size_t block_acc;
__shared__ size_t global_acc;
PredFunctor<K, M> pred;

const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;

Expand Down
101 changes: 52 additions & 49 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,28 @@ struct HashTableOptions {
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool erase_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0xFFFF000000000000 == pattern) &&
* (meta < threshold));
* }
* struct EraseIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0xFFFF000000000000 == pattern) &&
* (meta < threshold));
* }
* };
* ```
*
* Example for export_batch_if:
* ```
* template <class K, class M>
* __forceinline__ __device__ bool export_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* struct ExportIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* };
* ```
*/
template <class K, class M>
Expand Down Expand Up @@ -1023,21 +1027,24 @@ class HashTable {
* @brief Erases all elements that satisfy the predicate @p pred from the
* hash table.
*
* The value for @p pred should be a function with type `Pred` defined like
* the following example:
* @tparam PredFunctor The predicate template <typename K, typename M>
* function with operator signature (bool*)(const K&, const M&, const K&,
* const threshold) that returns `true` if the element should be erased. The
* value for @p pred should be a function with type `Pred` defined like the
* following example:
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool erase_if_pred(const K& key,
* const M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0x1 == pattern) && (meta < threshold));
* }
* struct EraseIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0x1 == pattern) && (meta < threshold));
* }
* };
* ```
*
* @param pred The predicate function with type Pred that returns `true` if
* the element should be erased.
* @param pattern The third user-defined argument to @p pred with key_type
* type.
* @param threshold The fourth user-defined argument to @p pred with meta_type
Expand All @@ -1047,27 +1054,24 @@ class HashTable {
* @return The number of elements removed.
*
*/
size_type erase_if(const Pred& pred, const key_type& pattern,
const meta_type& threshold, cudaStream_t stream = 0) {
template <template <typename, typename> class PredFunctor>
size_type erase_if(const key_type& pattern, const meta_type& threshold,
cudaStream_t stream = 0) {
write_read_lock lock(mutex_);

auto dev_ws{dev_mem_pool_->get_workspace<1>(sizeof(size_type), stream)};
auto d_count{dev_ws.get<size_type*>(0)};

CUDA_CHECK(cudaMemsetAsync(d_count, 0, sizeof(size_type), stream));

Pred h_pred;
CUDA_CHECK(cudaMemcpyFromSymbolAsync(&h_pred, pred, sizeof(Pred), 0,
cudaMemcpyDeviceToHost, stream));

{
const size_t block_size = options_.block_size;
const size_t N = table_->buckets_num;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

remove_kernel<key_type, value_type, meta_type>
remove_kernel<key_type, value_type, meta_type, PredFunctor>
<<<grid_size, block_size, 0, stream>>>(
table_, h_pred, pattern, threshold, d_count, table_->buckets,
table_, pattern, threshold, d_count, table_->buckets,
table_->buckets_size, table_->bucket_max_size,
table_->buckets_num, N);
}
Expand Down Expand Up @@ -1169,6 +1173,9 @@ class HashTable {

/**
* @brief Exports a certain number of the key-value-meta tuples which match
*
* @tparam PredFunctor A functor with template <K, M> defined an operator
* with signature: __device__ (bool*)(const K&, M&, const K&, const M&).
* specified condition from the hash table.
*
* @param n The maximum number of exported pairs.
Expand All @@ -1177,17 +1184,16 @@ class HashTable {
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool export_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
*
* return meta > threshold;
* }
* struct ExportIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* };
* ```
*
* @param pred The predicate function with type Pred that returns `true` if
* the element should be exported.
* @param pattern The third user-defined argument to @p pred with key_type
* type.
* @param threshold The fourth user-defined argument to @p pred with meta_type
Expand All @@ -1209,9 +1215,10 @@ class HashTable {
* memory. Reducing the value for @p n is currently required if this exception
* occurs.
*/
void export_batch_if(Pred& pred, const key_type& pattern,
const meta_type& threshold, size_type n,
const size_type offset, size_type* d_counter,
template <template <typename, typename> class PredFunctor>
void export_batch_if(const key_type& pattern, const meta_type& threshold,
size_type n, const size_type offset,
size_type* d_counter,
key_type* keys, // (n)
value_type* values, // (n, DIM)
meta_type* metas = nullptr, // (n)
Expand All @@ -1235,13 +1242,9 @@ class HashTable {
const size_t shared_size = kvm_size * block_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size);

Pred h_pred;
CUDA_CHECK(cudaMemcpyFromSymbolAsync(&h_pred, pred, sizeof(Pred), 0,
cudaMemcpyDeviceToHost, stream));

dump_kernel<key_type, value_type, meta_type>
dump_kernel<key_type, value_type, meta_type, PredFunctor>
<<<grid_size, block_size, shared_size, stream>>>(
table_, h_pred, pattern, threshold, keys, values, metas, offset, n,
table_, pattern, threshold, keys, values, metas, offset, n,
d_counter);

CudaCheckError();
Expand Down
42 changes: 20 additions & 22 deletions tests/find_or_insert_ptr_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,22 @@ using Table = nv::merlin::HashTable<K, V, M>;
using TableOptions = nv::merlin::HashTableOptions;

template <class K, class M>
__forceinline__ __device__ bool erase_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}

template <class K, class M>
__device__ Table::Pred EraseIfPred = erase_if_pred<K, M>;

template <class K, class M>
__forceinline__ __device__ bool export_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
struct EraseIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}
};

template <class K, class M>
__device__ Table::Pred ExportIfPred = export_if_pred<K, M>;
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
};

void test_basic(size_t max_hbm_for_vectors) {
constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL;
Expand Down Expand Up @@ -546,8 +544,8 @@ void test_erase_if_pred(size_t max_hbm_for_vectors) {

K pattern = 100;
M threshold = 0;
size_t erase_num =
table->erase_if(EraseIfPred<K, M>, pattern, threshold, stream);
size_t erase_num = table->template erase_if<EraseIfPredFunctor>(
pattern, threshold, stream);
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ((erase_num + total_size), BUCKET_MAX_SIZE);
Expand Down Expand Up @@ -1212,9 +1210,9 @@ void test_export_batch_if(size_t max_hbm_for_vectors) {
K pattern = 100;
M threshold = h_metas[size_t(KEY_NUM / 2)];

table->export_batch_if(ExportIfPred<K, M>, pattern, threshold,
table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);
table->template export_batch_if<ExportIfPredFunctor>(
pattern, threshold, table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);

CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpy(&h_dump_counter, d_dump_counter, sizeof(size_t),
Expand Down Expand Up @@ -2822,4 +2820,4 @@ TEST(FindOrInsertPtrTest, test_find_or_insert_values_check) {
test_find_or_insert_values_check(16);
// TODO(rhdong): Add back when diff error issue fixed in hybrid mode.
// test_insert_or_assign_values_check(0);
}
}
42 changes: 20 additions & 22 deletions tests/find_or_insert_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,22 @@ using Table = nv::merlin::HashTable<K, V, M>;
using TableOptions = nv::merlin::HashTableOptions;

template <class K, class M>
__forceinline__ __device__ bool erase_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}

template <class K, class M>
__device__ Table::Pred EraseIfPred = erase_if_pred<K, M>;

template <class K, class M>
__forceinline__ __device__ bool export_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
struct EraseIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}
};

template <class K, class M>
__device__ Table::Pred ExportIfPred = export_if_pred<K, M>;
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
};

void test_basic(size_t max_hbm_for_vectors) {
constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL;
Expand Down Expand Up @@ -470,8 +468,8 @@ void test_erase_if_pred(size_t max_hbm_for_vectors) {

K pattern = 100;
M threshold = 0;
size_t erase_num =
table->erase_if(EraseIfPred<K, M>, pattern, threshold, stream);
size_t erase_num = table->template erase_if<EraseIfPredFunctor>(
pattern, threshold, stream);
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ((erase_num + total_size), BUCKET_MAX_SIZE);
Expand Down Expand Up @@ -1061,9 +1059,9 @@ void test_export_batch_if(size_t max_hbm_for_vectors) {
K pattern = 100;
M threshold = h_metas[size_t(KEY_NUM / 2)];

table->export_batch_if(ExportIfPred<K, M>, pattern, threshold,
table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);
table->template export_batch_if<ExportIfPredFunctor>(
pattern, threshold, table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);

CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpy(&h_dump_counter, d_dump_counter, sizeof(size_t),
Expand Down Expand Up @@ -2499,4 +2497,4 @@ TEST(FindOrInsertTest, test_find_or_insert_values_check) {
test_find_or_insert_values_check(16);
// TODO(rhdong): Add back when diff error issue fixed in hybrid mode.
// test_insert_or_assign_values_check(0);
}
}
Loading

0 comments on commit f7d4fe7

Please sign in to comment.