From 9b79bc9c1e19179b2c1bbfe9629cfbd39170881c Mon Sep 17 00:00:00 2001 From: Lifann Date: Wed, 14 Aug 2024 00:19:45 +0800 Subject: [PATCH] opt(export_batch_if): Optimize the export_batch_if in cond to reduce memory wavefronts --- CMakeLists.txt | 6 +- include/merlin/core_kernels.cuh | 162 ++++++++++++++++++++++++++ include/merlin_hashtable.cuh | 108 +++++++++++++++--- tests/export_batch_if_test.cc.cu | 188 +++++++++++++++++++++++++++++++ tests/test_util.cuh | 8 +- 5 files changed, 453 insertions(+), 19 deletions(-) create mode 100644 tests/export_batch_if_test.cc.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 34016ca8..c159fe2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,4 +163,8 @@ TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main) add_executable(reserved_keys_test tests/reserved_keys_test.cc.cu) target_compile_features(reserved_keys_test PUBLIC cxx_std_14) set_target_properties(reserved_keys_test PROPERTIES CUDA_ARCHITECTURES OFF) -TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main) \ No newline at end of file +TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main) + +add_executable(export_batch_if_test tests/export_batch_if_test.cc.cu) +target_compile_features(export_batch_if_test PUBLIC cxx_std_14) +set_target_properties(export_batch_if_test PROPERTIES CUDA_ARCHITECTURES OFF) diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 5ee80e76..800d9331 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -910,5 +910,167 @@ __global__ void dump_kernel(const Table* __restrict table, } } +template class PredFunctor, int TILE_SIZE> +__global__ void dump_kernel_v2(const Table* __restrict table, + Bucket* buckets, const K pattern, + const S threshold, K* d_key, V* __restrict d_val, + S* __restrict d_score, const size_t offset, + const size_t search_length, + size_t* d_dump_counter) { + const size_t bucket_max_size = table->bucket_max_size; + int dim = table->dim; + auto g = cg::tiled_partition(cg::this_thread_block()); + + PredFunctor pred; + size_t tid = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + + for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) { + size_t bkt_idx = (ii + offset) / bucket_max_size; + size_t key_idx = (ii + offset) % bucket_max_size; + size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE; + Bucket* bucket = &(buckets[bkt_idx]); + + const K key = + (bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed); + S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + + bool match = + (!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold); + unsigned int vote = g.ballot(match); + int tile_cnt = __popc(vote); + size_t tile_offset = 0; + if (g.thread_rank() == 0) { + tile_offset = atomicAdd(d_dump_counter, static_cast(tile_cnt)); + } + tile_offset = g.shfl(tile_offset, 0); + int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE)); + + if (match) { + d_key[tile_offset + bias_g] = key; + if (d_score) { + d_score[tile_offset + bias_g] = score; + } + } + +#pragma unroll + for (int r = 0; r < TILE_SIZE; r++) { + unsigned int biased_vote = vote >> r; + bool cur_match = biased_vote & 1; + if (cur_match) { + int bias = tile_cnt - __popc(biased_vote); + size_t cur_idx = leading_key_idx + r; + + for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) { + d_val[(tile_offset + bias) * dim + j] = + bucket->vectors[cur_idx * dim + j]; + } + } + } + } +} + +template class PredFunctor, int TILE_SIZE> +__global__ void dump_kernel_v2_vectorized(const Table* __restrict table, + Bucket* buckets, const K pattern, + const S threshold, K* d_key, V* __restrict d_val, + S* __restrict d_score, const size_t offset, + const size_t search_length, + size_t* d_dump_counter) { + const size_t bucket_max_size = table->bucket_max_size; + int dim = table->dim; + auto g = cg::tiled_partition(cg::this_thread_block()); + + PredFunctor pred; + size_t tid = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + + for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) { + size_t bkt_idx = (ii + offset) / bucket_max_size; + size_t key_idx = (ii + offset) % bucket_max_size; + size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE; + Bucket* bucket = &(buckets[bkt_idx]); + + const K key = + (bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed); + S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + + bool match = + (!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold); + unsigned int vote = g.ballot(match); + int tile_cnt = __popc(vote); + size_t tile_offset = 0; + if (g.thread_rank() == 0) { + tile_offset = atomicAdd(d_dump_counter, static_cast(tile_cnt)); + } + tile_offset = g.shfl(tile_offset, 0); + int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE)); + + if (match) { + d_key[tile_offset + bias_g] = key; + if (d_score) { + d_score[tile_offset + bias_g] = score; + } + } + +#pragma unroll + for (int r = 0; r < TILE_SIZE; r++) { + unsigned int biased_vote = vote >> r; + bool cur_match = biased_vote & 1; + if (cur_match) { + int bias = tile_cnt - __popc(biased_vote); + size_t cur_idx = leading_key_idx + r; + + float4* d_val_fp4 = reinterpret_cast(d_val); + float4* vec_fp4 = reinterpret_cast(bucket->vectors); + int d4 = dim / 4; + for (int j = g.thread_rank(); j < d4; j += TILE_SIZE) { + d_val_fp4[(tile_offset + bias) * d4 + j] = vec_fp4[cur_idx * d4 + j]; + } + } + } + } +} + +template class PredFunctor> +__global__ void size_if_kernel(const Table* __restrict table, + Bucket* buckets, const K pattern, + const S threshold, size_t* d_counter) { + extern __shared__ unsigned char s[]; + KVM* const block_tuples{reinterpret_cast*>(s)}; + + const size_t bucket_max_size{table->bucket_max_size}; + + size_t local_acc = 0; + __shared__ size_t block_acc; + PredFunctor pred; + + const size_t tid{blockIdx.x * blockDim.x + threadIdx.x}; + + if (threadIdx.x == 0) { + block_acc = 0; + } + __syncthreads(); + + for (size_t i = tid; i < table->capacity; i += blockDim.x * gridDim.x) { + Bucket* const bucket{&buckets[i / bucket_max_size]}; + + const int key_idx{static_cast(i % bucket_max_size)}; + const K key{(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed)}; + S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + + if ((!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold)) { + ++local_acc; + } + } + atomicAdd(&block_acc, local_acc); + __syncthreads(); + + if (threadIdx.x == 0) { + atomicAdd(d_counter, block_acc); + } +} + } // namespace merlin } // namespace nv diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 2d631930..1d7c8fe1 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -917,6 +917,8 @@ class HashTable : public HashTableBase { cudaDeviceProp deviceProp; CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id)); shared_mem_size_ = deviceProp.sharedMemPerBlock; + sm_cnt_ = deviceProp.multiProcessorCount; + max_threads_per_block_ = deviceProp.maxThreadsPerBlock; create_table( &table_, allocator_, options_.dim, options_.init_capacity, options_.max_capacity, options_.max_hbm_for_vectors, @@ -2611,22 +2613,76 @@ class HashTable : public HashTableBase { return; } n = std::min(table_->capacity - offset, n); + if (n == 0) { + return; + } - const size_t score_size = scores ? sizeof(score_type) : 0; - const size_t kvm_size = - sizeof(key_type) + sizeof(value_type) * dim() + score_size; - const size_t block_size = std::min(shared_mem_size_ / 2 / kvm_size, 1024UL); - MERLIN_CHECK( - block_size > 0, - "[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!"); - - const size_t shared_size = kvm_size * block_size; - const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size); + bool match_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 && + options_.max_bucket_size >= TILE_SIZE && + offset % TILE_SIZE == 0 && n % TILE_SIZE == 0; + + if (match_fast_cond) { + int grid_size = std::min( + sm_cnt_ * max_threads_per_block_ / options_.block_size, + static_cast(SAFE_GET_GRID_SIZE(n, options_.block_size))); + if (sizeof(V) == sizeof(float) && dim() >= 32 && dim() % 4 == 0) { + if (dim() >= 128) { + const int TILE_SIZE = 32; + dump_kernel_v2_vectorized + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } else if (dim() >= 64) { + const int TILE_SIZE = 16; + dump_kernel_v2_vectorized + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } else { + const int TILE_SIZE = 8; + dump_kernel_v2_vectorized + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } + } else { + if (dim() >= 32) { + const int TILE_SIZE = 32; + dump_kernel_v2 + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } else if (dim() >= 16) { + const int TILE_SIZE = 16; + dump_kernel_v2 + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } else { + const int TILE_SIZE = 8; + dump_kernel_v2 + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } + } + } else { + const size_t score_size = scores ? sizeof(score_type) : 0; + const size_t kvm_size = + sizeof(key_type) + sizeof(value_type) * dim() + score_size; + const size_t block_size = + std::min(shared_mem_size_ / 2 / kvm_size, 1024UL); + MERLIN_CHECK( + block_size > 0, + "[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!"); - dump_kernel - <<>>( - d_table_, table_->buckets, pattern, threshold, keys, values, scores, - offset, n, d_counter); + const size_t shared_size = kvm_size * block_size; + const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size); + dump_kernel + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } CudaCheckError(); } @@ -2668,6 +2724,28 @@ class HashTable : public HashTableBase { return h_size; } + /** + * @brief Returns the number of keys if meet PredFunctor. + * + * @param stream The CUDA stream that is used to execute the operation. + * @return The table size match condiction of PredFunctor. + */ + template