From 85989b6655df55a72dcb725acc2abf7c4fe7425a Mon Sep 17 00:00:00 2001 From: Lifann Date: Fri, 3 Jan 2025 22:23:01 +0800 Subject: [PATCH] fix: Value error when dump table with none TILE_SIZE integer multiples of offset and search length --- include/merlin_hashtable.cuh | 47 +++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 1d7c8fe1..f898c824 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -2617,56 +2617,59 @@ class HashTable : public HashTableBase { return; } - bool match_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 && + bool basic_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 && options_.max_bucket_size >= TILE_SIZE && offset % TILE_SIZE == 0 && n % TILE_SIZE == 0; + bool use_fast_mode = false; - if (match_fast_cond) { + if (basic_fast_cond) { int grid_size = std::min( sm_cnt_ * max_threads_per_block_ / options_.block_size, static_cast(SAFE_GET_GRID_SIZE(n, options_.block_size))); if (sizeof(V) == sizeof(float) && dim() >= 32 && dim() % 4 == 0) { - if (dim() >= 128) { - const int TILE_SIZE = 32; - dump_kernel_v2_vectorized + if (dim() >= 128 && offset % 32 == 0 && n % 32 == 0) { + dump_kernel_v2_vectorized <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, scores, offset, n, d_counter); - } else if (dim() >= 64) { - const int TILE_SIZE = 16; - dump_kernel_v2_vectorized + use_fast_mode = true; + } else if (dim() >= 64 && offset % 16 == 0 && n % 16 == 0) { + dump_kernel_v2_vectorized <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, scores, offset, n, d_counter); - } else { - const int TILE_SIZE = 8; - dump_kernel_v2_vectorized + use_fast_mode = true; + } else if (dim() >= 8 && offset % 8 == 0 && n % 8 == 0) { + dump_kernel_v2_vectorized <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, scores, offset, n, d_counter); + use_fast_mode = true; } - } else { - if (dim() >= 32) { - const int TILE_SIZE = 32; - dump_kernel_v2 + } + if (!use_fast_mode) { + if (dim() >= 32 && offset % 32 == 0 && n % 32 == 0) { + dump_kernel_v2 <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, scores, offset, n, d_counter); - } else if (dim() >= 16) { - const int TILE_SIZE = 16; - dump_kernel_v2 + use_fast_mode = true; + } else if (dim() >= 16 && offset % 16 == 0 && n % 16 == 0) { + dump_kernel_v2 <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, scores, offset, n, d_counter); - } else { - const int TILE_SIZE = 8; - dump_kernel_v2 + use_fast_mode = true; + } else if (dim() >= 8 && offset % 8 == 0 && n % 8 == 0) { + dump_kernel_v2 <<>>( d_table_, table_->buckets, pattern, threshold, keys, values, scores, offset, n, d_counter); + use_fast_mode = true; } } - } else { + } + if (!use_fast_mode) { const size_t score_size = scores ? sizeof(score_type) : 0; const size_t kvm_size = sizeof(key_type) + sizeof(value_type) * dim() + score_size;