From 93e2c8587cf304eaab6fa5b112e96a4fc3feb4da Mon Sep 17 00:00:00 2001 From: Lifann Date: Wed, 14 Aug 2024 00:19:45 +0800 Subject: [PATCH] opt(export_batch_if): Optimize the export_batch_if in cond to reduce memory wavefronts --- CMakeLists.txt | 6 +- include/merlin/core_kernels.cuh | 56 ++++++++++++++ include/merlin_hashtable.cuh | 46 +++++++---- tests/export_batch_if_test.cc.cu | 127 +++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 14 deletions(-) create mode 100644 tests/export_batch_if_test.cc.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 34016ca8..c159fe2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,4 +163,8 @@ TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main) add_executable(reserved_keys_test tests/reserved_keys_test.cc.cu) target_compile_features(reserved_keys_test PUBLIC cxx_std_14) set_target_properties(reserved_keys_test PROPERTIES CUDA_ARCHITECTURES OFF) -TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main) \ No newline at end of file +TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main) + +add_executable(export_batch_if_test tests/export_batch_if_test.cc.cu) +target_compile_features(export_batch_if_test PUBLIC cxx_std_14) +set_target_properties(export_batch_if_test PROPERTIES CUDA_ARCHITECTURES OFF) diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 5ee80e76..dd81b814 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -910,5 +910,61 @@ __global__ void dump_kernel(const Table* __restrict table, } } +template class PredFunctor, int TILE_SIZE> +__global__ void dump_kernel_v2(const Table* __restrict table, + Bucket* buckets, const K pattern, + const S threshold, K* d_key, V* __restrict d_val, + S* __restrict d_score, const size_t offset, + const size_t search_length, + size_t* d_dump_counter) { + const size_t bucket_max_size = table->bucket_max_size; + int dim = table->dim; + auto g = cg::tiled_partition(cg::this_thread_block()); + + PredFunctor pred; + size_t tid = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + + for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) { + size_t bkt_idx = (ii + offset) / bucket_max_size; + int key_idx = (ii + offset) % bucket_max_size; + int leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE; + Bucket* bucket = &(buckets[bkt_idx]); + + const K key = + (bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed); + S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed); + bool match = + (!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold); + unsigned int vote = g.ballot(match); + int tile_cnt = __popc(vote); + int tile_offset = 0; + if (g.thread_rank() == 0) { + tile_offset = static_cast( + atomicAdd(d_dump_counter, static_cast(tile_cnt))); + } + tile_offset = g.shfl(tile_offset, 0); + + if (match) { + d_key[tile_offset + key_idx] = key; + if (d_score) { + d_score[tile_offset + key_idx] = score; + } + } + +#pragma unroll + for (int r = 0; r < TILE_SIZE; r++) { + bool cur_match = vote >> r & 1; + if (cur_match) { + int cur_idx = leading_key_idx + r; + for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) { + d_val[(tile_offset + cur_idx) * dim + j] = + bucket->vectors[cur_idx * dim + j]; + } + } + } + } +} + } // namespace merlin } // namespace nv diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 2d631930..29782dca 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -917,6 +917,7 @@ class HashTable : public HashTableBase { cudaDeviceProp deviceProp; CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id)); shared_mem_size_ = deviceProp.sharedMemPerBlock; + sm_cnt_ = deviceProp.multiProcessorCount; create_table( &table_, allocator_, options_.dim, options_.init_capacity, options_.max_capacity, options_.max_hbm_for_vectors, @@ -2611,22 +2612,40 @@ class HashTable : public HashTableBase { return; } n = std::min(table_->capacity - offset, n); + if (n == 0) { + return; + } - const size_t score_size = scores ? sizeof(score_type) : 0; - const size_t kvm_size = - sizeof(key_type) + sizeof(value_type) * dim() + score_size; - const size_t block_size = std::min(shared_mem_size_ / 2 / kvm_size, 1024UL); - MERLIN_CHECK( - block_size > 0, - "[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!"); + bool match_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 && + options_.max_bucket_size >= TILE_SIZE && + offset % TILE_SIZE == 0 && n % TILE_SIZE == 0; - const size_t shared_size = kvm_size * block_size; - const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size); + if (match_fast_cond) { + int grid_size = std::min(sm_cnt_, static_cast(SAFE_GET_GRID_SIZE( + n, options_.block_size))); + const int TILE_SIZE = 8; - dump_kernel - <<>>( - d_table_, table_->buckets, pattern, threshold, keys, values, scores, - offset, n, d_counter); + dump_kernel_v2 + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } else { + const size_t score_size = scores ? sizeof(score_type) : 0; + const size_t kvm_size = + sizeof(key_type) + sizeof(value_type) * dim() + score_size; + const size_t block_size = + std::min(shared_mem_size_ / 2 / kvm_size, 1024UL); + MERLIN_CHECK( + block_size > 0, + "[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!"); + + const size_t shared_size = kvm_size * block_size; + const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size); + dump_kernel + <<>>( + d_table_, table_->buckets, pattern, threshold, keys, values, + scores, offset, n, d_counter); + } CudaCheckError(); } @@ -3037,6 +3056,7 @@ class HashTable : public HashTableBase { TableCore* table_ = nullptr; TableCore* d_table_ = nullptr; size_t shared_mem_size_ = 0; + int sm_cnt_ = 0; std::atomic_bool reach_max_capacity_{false}; bool initialized_ = false; mutable group_shared_mutex mutex_; diff --git a/tests/export_batch_if_test.cc.cu b/tests/export_batch_if_test.cc.cu new file mode 100644 index 00000000..048df722 --- /dev/null +++ b/tests/export_batch_if_test.cc.cu @@ -0,0 +1,127 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "merlin/types.cuh" +#include "merlin_hashtable.cuh" +#include "test_util.cuh" + +using K = uint64_t; +using V = float; +using S = uint64_t; +using i64 = int64_t; +using u64 = uint64_t; +using f32 = float; +using EvictStrategy = nv::merlin::EvictStrategy; +using TableOptions = nv::merlin::HashTableOptions; + +template +struct ExportIfPredFunctor { + __forceinline__ __device__ bool operator()(const K& key, S& score, + const K& pattern, + const S& threshold) { + return score < threshold; + } +}; + +void test_export_batch_if() { + constexpr uint64_t CAP = 1024ul; + size_t n = 256; + size_t n0 = 127; + size_t n1 = 128; + size_t n2 = 163; + size_t dim = 32; + size_t table_size = 0; + i64 pattern = 0; + u64 threshold = 40; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + TableOptions options; + options.init_capacity = CAP; + options.max_capacity = CAP; + options.dim = dim; + options.max_hbm_for_vectors = nv::merlin::GB(100); + using Table = + nv::merlin::HashTable; + + std::unique_ptr table = std::make_unique
(); + table->init(options); + + test_util::KVMSBuffer buffer0; + buffer0.Reserve(n0, dim, stream); + buffer0.ToRange(0, 1, stream); + buffer0.Setscore((u64)15, stream); + table->insert_or_assign(n0, buffer0.keys_ptr(), buffer0.values_ptr(), + buffer0.scores_ptr(), stream, true, false); + table_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + MERLIN_EXPECT_TRUE(table_size == n0, "Invalid table size."); + + test_util::KVMSBuffer buffer1; + buffer1.Reserve(n1, dim, stream); + buffer1.ToRange(n0, 1, stream); + buffer1.Setscore((u64)30, stream); + table->insert_or_assign(n1, buffer1.keys_ptr(), buffer1.values_ptr(), + buffer1.scores_ptr(), stream, true, false); + table_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + MERLIN_EXPECT_TRUE(table_size == n0 + n1, "Invalid table size."); + + test_util::KVMSBuffer buffer2; + buffer2.Reserve(n2, dim, stream); + buffer2.ToRange(n0 + n1, 1, stream); + buffer2.Setscore((u64)45, stream); + table->insert_or_assign(n2, buffer2.keys_ptr(), buffer2.values_ptr(), + buffer2.scores_ptr(), stream, true, false); + table_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + MERLIN_EXPECT_TRUE(table_size == n0 + n1 + n2, "Invalid table size."); + + test_util::KVMSBuffer buffer_out; + buffer_out.Reserve(CAP, dim, stream); + buffer_out.ToZeros(stream); + + size_t* d_cnt = nullptr; + size_t h_cnt = 0; + CUDA_CHECK(cudaMallocAsync(&d_cnt, sizeof(size_t), stream)); + CUDA_CHECK(cudaMemsetAsync(d_cnt, 0, sizeof(size_t), stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + table->export_batch_if( + pattern, threshold, static_cast(CAP), 0, d_cnt, + buffer_out.keys_ptr(), buffer_out.values_ptr(), buffer_out.scores_ptr(), + stream); + CUDA_CHECK(cudaMemcpyAsync(&h_cnt, d_cnt, sizeof(size_t), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + MERLIN_EXPECT_TRUE(h_cnt == n0 + n1, "export_batch_if get invalid cnt."); + + buffer_out.SyncData(false, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + std::unordered_map record; + for (size_t i = 0; i < h_cnt; i++) { + i64 key = buffer_out.keys_ptr(false)[i]; + u64 score = buffer_out.scores_ptr(false)[i]; + MERLIN_EXPECT_TRUE(key == static_cast(score), ""); + record[key] = score; + for (int j = 0; j < dim; j++) { + f32 value = buffer_out.values_ptr(false)[i * dim + j]; + MERLIN_EXPECT_TRUE(key == static_cast(value), ""); + } + } + MERLIN_EXPECT_TRUE(record.size() == n0 + n1 + n2, ""); + printf("done\n"); +} + +int main() { + test_export_batch_if(); + return 0; +}