diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 186746e92..64f37d429 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -1774,9 +1774,10 @@ __global__ void remove_kernel(const Table* __restrict table, } /* Remove specified keys which match the Predict. */ -template +template class PredFunctor, + uint32_t TILE_SIZE = 1> __global__ void remove_kernel(const Table* __restrict table, - const EraseIfPredictInternal pred, const K pattern, const M threshold, size_t* __restrict count, Bucket* __restrict buckets, @@ -1784,6 +1785,7 @@ __global__ void remove_kernel(const Table* __restrict table, const size_t bucket_max_size, const size_t buckets_num, size_t N) { auto g = cg::tiled_partition(cg::this_thread_block()); + PredFunctor pred; for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N; t += blockDim.x * gridDim.x) { @@ -1891,9 +1893,9 @@ __global__ void dump_kernel(const Table* __restrict table, K* d_key, } /* Dump with meta. */ -template +template class PredFunctor> __global__ void dump_kernel(const Table* __restrict table, - const EraseIfPredictInternal pred, const K pattern, const M threshold, K* d_key, V* __restrict d_val, M* __restrict d_meta, const size_t offset, const size_t search_length, @@ -1907,6 +1909,7 @@ __global__ void dump_kernel(const Table* __restrict table, M* block_result_meta = (M*)&(block_result_val[blockDim.x * dim]); __shared__ size_t block_acc; __shared__ size_t global_acc; + PredFunctor pred; const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 52ad6bad2..9152139d7 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -87,24 +87,28 @@ struct HashTableOptions { * * ``` * template - * __forceinline__ __device__ bool erase_if_pred(const K& key, - * M& meta, - * const K& pattern, - * const M& threshold) { - * return ((key & 0xFFFF000000000000 == pattern) && - * (meta < threshold)); - * } + * struct EraseIfPredFunctor { + * __forceinline__ __device__ bool operator()(const K& key, + * M& meta, + * const K& pattern, + * const M& threshold) { + * return ((key & 0xFFFF000000000000 == pattern) && + * (meta < threshold)); + * } + * }; * ``` * * Example for export_batch_if: * ``` * template - * __forceinline__ __device__ bool export_if_pred(const K& key, - * M& meta, - * const K& pattern, - * const M& threshold) { - * return meta >= threshold; - * } + * struct ExportIfPredFunctor { + * __forceinline__ __device__ bool operator()(const K& key, + * M& meta, + * const K& pattern, + * const M& threshold) { + * return meta >= threshold; + * } + * }; * ``` */ template @@ -1023,21 +1027,24 @@ class HashTable { * @brief Erases all elements that satisfy the predicate @p pred from the * hash table. * - * The value for @p pred should be a function with type `Pred` defined like - * the following example: + * @tparam PredFunctor The predicate template + * function with operator signature (bool*)(const K&, const M&, const K&, + * const threshold) that returns `true` if the element should be erased. The + * value for @p pred should be a function with type `Pred` defined like the + * following example: * * ``` * template - * __forceinline__ __device__ bool erase_if_pred(const K& key, - * const M& meta, - * const K& pattern, - * const M& threshold) { - * return ((key & 0x1 == pattern) && (meta < threshold)); - * } + * struct EraseIfPredFunctor { + * __forceinline__ __device__ bool operator()(const K& key, + * M& meta, + * const K& pattern, + * const M& threshold) { + * return ((key & 0x1 == pattern) && (meta < threshold)); + * } + * }; * ``` * - * @param pred The predicate function with type Pred that returns `true` if - * the element should be erased. * @param pattern The third user-defined argument to @p pred with key_type * type. * @param threshold The fourth user-defined argument to @p pred with meta_type @@ -1047,8 +1054,9 @@ class HashTable { * @return The number of elements removed. * */ - size_type erase_if(const Pred& pred, const key_type& pattern, - const meta_type& threshold, cudaStream_t stream = 0) { + template