diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index dd81b814..5bacd590 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -927,38 +927,160 @@ __global__ void dump_kernel_v2(const Table* __restrict table, for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) { size_t bkt_idx = (ii + offset) / bucket_max_size; - int key_idx = (ii + offset) % bucket_max_size; - int leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE; + size_t key_idx = (ii + offset) % bucket_max_size; + size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE; Bucket* bucket = &(buckets[bkt_idx]); const K key = (bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed); S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + bool match = (!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold); unsigned int vote = g.ballot(match); int tile_cnt = __popc(vote); - int tile_offset = 0; + size_t tile_offset = 0; if (g.thread_rank() == 0) { - tile_offset = static_cast( - atomicAdd(d_dump_counter, static_cast(tile_cnt))); + tile_offset = atomicAdd(d_dump_counter, static_cast(tile_cnt)); } tile_offset = g.shfl(tile_offset, 0); + int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE)); + + if (match) { + d_key[tile_offset + bias_g] = key; + if (d_score) { + d_score[tile_offset + bias_g] = score; + } + } + +#pragma unroll + for (int r = 0; r < TILE_SIZE; r++) { + unsigned int biased_vote = vote >> r; + bool cur_match = biased_vote & 1; + if (cur_match) { + int bias = tile_cnt - __popc(biased_vote); + size_t cur_idx = leading_key_idx + r; + + for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) { + d_val[(tile_offset + bias) * dim + j] = + bucket->vectors[cur_idx * dim + j]; + } + } + } + } +} + +template class PredFunctor> +__global__ void size_if_kernel(const Table* __restrict table, + Bucket* buckets, const K pattern, + const S threshold, size_t* d_counter) { + extern __shared__ unsigned char s[]; + KVM* const block_tuples{reinterpret_cast*>(s)}; + + const size_t bucket_max_size{table->bucket_max_size}; + + size_t local_acc = 0; + __shared__ size_t block_acc; + PredFunctor pred; + + const size_t tid{blockIdx.x * blockDim.x + threadIdx.x}; + + if (threadIdx.x == 0) { + block_acc = 0; + } + __syncthreads(); + + for (size_t i = tid; i < table->capacity; i += blockDim.x * gridDim.x) { + Bucket* const bucket{&buckets[i / bucket_max_size]}; + + const int key_idx{static_cast(i % bucket_max_size)}; + const K key{(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed)}; + S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + + if ((!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold)) { + ++local_acc; + } + } + atomicAdd(&block_acc, local_acc); + __syncthreads(); + + if (threadIdx.x == 0) { + atomicAdd(d_counter, block_acc); + } +} + +template class PredFunctor, int TILE_SIZE> +__global__ void dump_kernel_v3(const Table* __restrict table, + Bucket* buckets, const K pattern, + const S threshold, K* d_key, V* __restrict d_val, + S* __restrict d_score, const size_t offset, + const size_t search_length, + size_t* d_dump_counter) { + const size_t bucket_max_size = table->bucket_max_size; + int dim = table->dim; + auto g = cg::tiled_partition(cg::this_thread_block()); + + PredFunctor pred; + + __shared__ int block_cnt; + __shared__ size_t block_offset; + + size_t tid = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + + for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) { + size_t bkt_idx = (ii + offset) / bucket_max_size; + size_t key_idx = (ii + offset) % bucket_max_size; + size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE; + Bucket* bucket = &(buckets[bkt_idx]); + + const K key = + (bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed); + S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + + if (threadIdx.x == 0) { + block_cnt = 0; + } + __syncthreads(); + + bool match = + (!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold); + unsigned int vote = g.ballot(match); + int tile_cnt = __popc(vote); + + int in_block_tile_offset = 0; + if (g.thread_rank() == 0) { + in_block_tile_offset = + atomicAdd(reinterpret_cast(&block_cnt), tile_cnt); + } + in_block_tile_offset = g.shfl(in_block_tile_offset, 0); + __syncthreads(); + + if (threadIdx.x == 0) { + block_offset = atomicAdd(d_dump_counter, static_cast(block_cnt)); + } + __syncthreads(); + + int tile_offset = block_offset + in_block_tile_offset; + int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE)); if (match) { - d_key[tile_offset + key_idx] = key; + d_key[tile_offset + bias_g] = key; if (d_score) { - d_score[tile_offset + key_idx] = score; + d_score[tile_offset + bias_g] = score; } } #pragma unroll for (int r = 0; r < TILE_SIZE; r++) { - bool cur_match = vote >> r & 1; + unsigned int biased_vote = vote >> r; + bool cur_match = biased_vote & 1; if (cur_match) { + int bias = tile_cnt - __popc(biased_vote); int cur_idx = leading_key_idx + r; for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) { - d_val[(tile_offset + cur_idx) * dim + j] = + d_val[(tile_offset + bias) * dim + j] = bucket->vectors[cur_idx * dim + j]; } } diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 29782dca..802a424f 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -918,6 +918,7 @@ class HashTable : public HashTableBase { CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id)); shared_mem_size_ = deviceProp.sharedMemPerBlock; sm_cnt_ = deviceProp.multiProcessorCount; + max_threads_per_block_ = deviceProp.maxThreadsPerBlock; create_table( &table_, allocator_, options_.dim, options_.init_capacity, options_.max_capacity, options_.max_hbm_for_vectors, @@ -2621,10 +2622,10 @@ class HashTable : public HashTableBase { offset % TILE_SIZE == 0 && n % TILE_SIZE == 0; if (match_fast_cond) { - int grid_size = std::min(sm_cnt_, static_cast(SAFE_GET_GRID_SIZE( - n, options_.block_size))); - const int TILE_SIZE = 8; - + int grid_size = std::min( + sm_cnt_ * max_threads_per_block_ / options_.block_size, + static_cast(SAFE_GET_GRID_SIZE(n, options_.block_size))); + const int TILE_SIZE = 32; dump_kernel_v2 <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, @@ -2687,6 +2688,28 @@ class HashTable : public HashTableBase { return h_size; } + /** + * @brief Returns the number of keys if meet PredFunctor. + * + * @param stream The CUDA stream that is used to execute the operation. + * @return The table size match condiction of PredFunctor. + */ + template