Skip to content

Commit

Permalink
fix: Value error when dump table with none TILE_SIZE integer multiple…
Browse files Browse the repository at this point in the history
…s of offset and search length
  • Loading branch information
Lifann committed Jan 3, 2025
1 parent 989c5cf commit 85989b6
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2617,56 +2617,59 @@ class HashTable : public HashTableBase<K, V, S> {
return;
}

bool match_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 &&
bool basic_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 &&
options_.max_bucket_size >= TILE_SIZE &&
offset % TILE_SIZE == 0 && n % TILE_SIZE == 0;
bool use_fast_mode = false;

if (match_fast_cond) {
if (basic_fast_cond) {
int grid_size = std::min(
sm_cnt_ * max_threads_per_block_ / options_.block_size,
static_cast<int>(SAFE_GET_GRID_SIZE(n, options_.block_size)));
if (sizeof(V) == sizeof(float) && dim() >= 32 && dim() % 4 == 0) {
if (dim() >= 128) {
const int TILE_SIZE = 32;
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
if (dim() >= 128 && offset % 32 == 0 && n % 32 == 0) {
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, 32>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else if (dim() >= 64) {
const int TILE_SIZE = 16;
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 64 && offset % 16 == 0 && n % 16 == 0) {
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, 16>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else {
const int TILE_SIZE = 8;
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 8 && offset % 8 == 0 && n % 8 == 0) {
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, 8>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
use_fast_mode = true;
}
} else {
if (dim() >= 32) {
const int TILE_SIZE = 32;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
}
if (!use_fast_mode) {
if (dim() >= 32 && offset % 32 == 0 && n % 32 == 0) {
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, 32>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else if (dim() >= 16) {
const int TILE_SIZE = 16;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 16 && offset % 16 == 0 && n % 16 == 0) {
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, 16>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else {
const int TILE_SIZE = 8;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 8 && offset % 8 == 0 && n % 8 == 0) {
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, 8>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
use_fast_mode = true;
}
}
} else {
}
if (!use_fast_mode) {
const size_t score_size = scores ? sizeof(score_type) : 0;
const size_t kvm_size =
sizeof(key_type) + sizeof(value_type) * dim() + score_size;
Expand Down

0 comments on commit 85989b6

Please sign in to comment.