diff --git a/.gitignore b/.gitignore index 122c321f..eba684ea 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ .idea .vscode build - +.clwb +cmake-build-debug/ docs/build docs/source/README.md docs/source/CONTRIBUTING.md diff --git a/CMakeLists.txt b/CMakeLists.txt index 77190094..91fd632e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,3 +159,13 @@ add_executable(find_with_missed_keys_test tests/find_with_missed_keys_test.cc.cu target_compile_features(find_with_missed_keys_test PUBLIC cxx_std_14) set_target_properties(find_with_missed_keys_test PROPERTIES CUDA_ARCHITECTURES OFF) TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main) + +add_executable(reserved_bucket_test tests/reserved_bucket_test.cc.cu) +target_compile_features(reserved_bucket_test PUBLIC cxx_std_14) +set_target_properties(reserved_bucket_test PROPERTIES CUDA_ARCHITECTURES OFF) +TARGET_LINK_LIBRARIES(reserved_bucket_test gtest_main) + +add_executable(key_option_test tests/key_option_test.cc.cu) +target_compile_features(key_option_test PUBLIC cxx_std_14) +set_target_properties(key_option_test PROPERTIES CUDA_ARCHITECTURES OFF) +TARGET_LINK_LIBRARIES(key_option_test gtest_main) \ No newline at end of file diff --git a/README.md b/README.md index 798819e3..5f1a7866 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,11 @@ cd HierarchicalKV && mkdir -p build && cd build cmake -DCMAKE_BUILD_TYPE=Release -Dsm=80 .. && make -j ``` +For Debug: +```shell +cmake -DCMAKE_BUILD_TYPE=Debug -Dsm=80 .. && make -j +``` + For Benchmark: ```shell ./merlin_hashtable_benchmark diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index d8a85002..bb5e43fe 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -24,6 +24,7 @@ #include "core_kernels/kernel_utils.cuh" #include "core_kernels/lookup.cuh" #include "core_kernels/lookup_ptr.cuh" +#include "core_kernels/reserved_bucket.cuh" #include "core_kernels/update.cuh" #include "core_kernels/update_score.cuh" #include "core_kernels/update_values.cuh" @@ -341,6 +342,10 @@ void create_table(Table** table, BaseAllocator* allocator, (*table)->buckets_num * sizeof(Bucket))); initialize_buckets(table, allocator, 0, (*table)->buckets_num); + + ReservedBucket::initialize( + &(*table)->reserved_bucket, allocator, (*table)->dim); + CudaCheckError(); } @@ -632,7 +637,8 @@ __global__ void remove_kernel(const Table* __restrict table, Bucket* __restrict buckets, int* __restrict buckets_size, const size_t bucket_max_size, - const size_t buckets_num, size_t N) { + const size_t buckets_num, + size_t N) { auto g = cg::tiled_partition(cg::this_thread_block()); int rank = g.thread_rank(); @@ -640,7 +646,10 @@ __global__ void remove_kernel(const Table* __restrict table, t += blockDim.x * gridDim.x) { int key_idx = t / TILE_SIZE; K find_key = keys[key_idx]; - if (IS_RESERVED_KEY(find_key)) continue; + if (IS_RESERVED_KEY(find_key)) { + table->reserved_bucket->erase(find_key, table->dim); + continue; + } int key_pos = -1; @@ -700,7 +709,8 @@ __global__ void remove_kernel(const Table* __restrict table, Bucket* __restrict buckets, int* __restrict buckets_size, const size_t bucket_max_size, - const size_t buckets_num, size_t N) { + const size_t buckets_num, + size_t N) { auto g = cg::tiled_partition(cg::this_thread_block()); PredFunctor pred; @@ -719,8 +729,8 @@ __global__ void remove_kernel(const Table* __restrict table, bucket->keys(key_offset)->load(cuda::std::memory_order_relaxed); current_score = bucket->scores(key_offset)->load(cuda::std::memory_order_relaxed); - if (!IS_RESERVED_KEY(current_key)) { - if (pred(current_key, current_score, pattern, threshold)) { + if (pred(current_key, current_score, pattern, threshold)) { + if (!IS_RESERVED_KEY(current_key)) { atomicAdd(count, 1); key_pos = key_offset; bucket->digests(key_pos)[0] = empty_digest(); @@ -732,6 +742,7 @@ __global__ void remove_kernel(const Table* __restrict table, cuda::std::memory_order_relaxed); atomicSub(&buckets_size[bkt_idx], 1); } else { + table->reserved_bucket->erase(current_key, table->dim); key_offset++; } } else { diff --git a/include/merlin/core_kernels/accum_or_assign.cuh b/include/merlin/core_kernels/accum_or_assign.cuh index 7f557f59..83102538 100644 --- a/include/merlin/core_kernels/accum_or_assign.cuh +++ b/include/merlin/core_kernels/accum_or_assign.cuh @@ -98,7 +98,9 @@ __global__ void accum_or_assign_kernel_with_io( const K insert_key = keys[key_idx]; - if (IS_RESERVED_KEY(insert_key)) continue; + if (IS_RESERVED_KEY(insert_key)) { + continue; + } const S insert_score = ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch); diff --git a/include/merlin/core_kernels/reserved_bucket.cuh b/include/merlin/core_kernels/reserved_bucket.cuh new file mode 100644 index 00000000..cfa4c7b3 --- /dev/null +++ b/include/merlin/core_kernels/reserved_bucket.cuh @@ -0,0 +1,267 @@ +/* +* Copyright (c) 2024, NVIDIA CORPORATION. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http:///www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#pragma once +#include "../allocator.cuh" +#include "../types.cuh" + +namespace nv { +namespace merlin { + +#define RESERVED_BUCKET_SIZE 4 +#define RESERVED_BUCKET_MASK 3 + +template struct ReservedBucket; + +template +__global__ static void rb_size_kernel(ReservedBucket* reserved_bucket, size_t* size); + +template +struct ReservedBucket { + cuda::atomic locks[RESERVED_BUCKET_SIZE]; + bool keys[RESERVED_BUCKET_SIZE]; + static void initialize(ReservedBucket** reserved_bucket, + BaseAllocator* allocator, size_t dim) { + size_t total_size = sizeof (ReservedBucket); + total_size += sizeof(V) * RESERVED_BUCKET_SIZE * dim; + void* memory_block; + allocator->alloc(MemoryType::Device, &memory_block, total_size); + CUDA_CHECK(cudaMemset(memory_block, 0, total_size)); + *reserved_bucket = static_cast*>(memory_block); + } + + __forceinline__ __device__ V* get_vector(K key, size_t dim) { + V* vector = reinterpret_cast(keys + RESERVED_BUCKET_SIZE); + size_t index = key & RESERVED_BUCKET_MASK; + return vector + index * dim; + } + + __forceinline__ __device__ bool contains(K key) { + size_t index = key & RESERVED_BUCKET_MASK; + return keys[index]; + } + + __forceinline__ __device__ void set_key(K key, bool value = true) { + size_t index = key & RESERVED_BUCKET_MASK; + keys[index] = value; + } + + // since reserved bucket key should always exist + // insert_or_assign insert_and_evict assign all equal to write_vector + __forceinline__ __device__ void write_vector( + K key, size_t dim, const V* data) { + V* vectors = get_vector(key, dim); + set_key(key); + for (int i = 0; i < dim; i++) { + vectors[i] = data[i]; + printf("vectors[%d] = %f %f \n", i, vectors[i], data[i]); + } + } + + __forceinline__ __device__ void read_vector( + K key, size_t dim, V* out_data) { + V* vectors = get_vector(key, dim); + for (int i = 0; i < dim; i++) { + out_data[i] = vectors[i]; + printf("out_data[%d] = %f %f \n", i, out_data[i], vectors[i]); + } + } + + __forceinline__ __device__ void erase(K key, size_t dim) { + V* vectors = get_vector(key, dim); + set_key(key, false); + for (int i = 0; i < dim; i++) { + vectors[i] = 0; + } + } + + // Search for the specified keys and return the pointers of values. + __forceinline__ __device__ bool find(K key, size_t dim, V** values) { + if (contains(key)) { + V* vectors = get_vector(key, dim); + *values = vectors; + return true; + } else { + return false; + } + } + + // Search for the specified keys and Insert them firstly when missing. + __forceinline__ __device__ bool find_or_insert(K key, size_t dim, V* values) { + if (contains(key)) { + return true; + } else { + write_vector(key, dim, values); + set_key(key); + return false; + } + } + + // Search for the specified keys and return the pointers of values. + // Insert them firstly when missing. + __forceinline__ __device__ bool find_or_insert( + K key, size_t dim, V** values) { + if (contains(key)) { + V* vectors = get_vector(key, dim); + *values = vectors; + return true; + } else { + write_vector(key, dim, *values); + set_key(key); + return false; + } + } + __forceinline__ __device__ void accum_or_assign( + K key, bool is_accum, size_t dim, const V* values) { + if (is_accum) { + V* vectors = get_vector(key, dim); + for (int i = 0; i < dim; i++) { + vectors[i] += values[i]; + } + } else { + write_vector(key, dim, values); + } + set_key(key); + } + + /* + * @brief Exports reserved bucket to key-value tuples + * @param n The maximum number of exported pairs. + * @param offset The position of the key to search. + * @param keys The keys to dump from GPU-accessible memory with shape (n). + * @param values The values to dump from GPU-accessible memory with shape + * (n, DIM). + * @return The number of elements dumped. + */ + __forceinline__ __device__ size_t export_batch( + size_t n, const size_t offset, + K* keys, size_t dim, V* values, size_t batch_size) { + if (offset >= size()) { + return 0; + } + + size_t count = 0; + V* vector = reinterpret_cast(keys + RESERVED_BUCKET_SIZE); + for (int i = offset; i < RESERVED_BUCKET_SIZE && offset < n; i++) { + vector += i * dim; + offset++; + if (keys[i]) { + for (int j = 0; j < dim; j++) { + values[i * dim + j] = vector[j]; + } + } + } + return count; + } + + /** + * @brief Returns the reserved bucket size. + */ + __forceinline__ __device__ size_t size() { + size_t count = 0; + for (int i = 0; i < RESERVED_BUCKET_SIZE; i++) { + if (keys[i]) { + count++; + } + } + return count; + } + + size_t size_host() { + size_t * d_size; + cudaMalloc(&d_size, sizeof(int)); + rb_size_kernel<<<1, 1>>>(this, d_size); + CUDA_CHECK(cudaDeviceSynchronize()); + int h_size; + cudaMemcpy(&h_size, d_size, sizeof(int), cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaFree(d_size)); + return h_size; + } + /** + * @brief Removes all of the elements in the reserved bucket with no release + * object. + */ + __forceinline__ __device__ void clear(size_t dim) { + size_t total_size = sizeof (ReservedBucket); + total_size += sizeof(V) * RESERVED_BUCKET_SIZE * dim; + CUDA_CHECK(cudaMemset(this, 0, total_size)); + } +}; + +template +__global__ static void rb_size_kernel(ReservedBucket* reserved_bucket, size_t* size) { + *size = reserved_bucket->size(); +} + +template +__global__ void rb_write_vector_kernel(ReservedBucket* reserved_bucket, + K key, size_t dim, const V* data) { + reserved_bucket->write_vector(key, dim, data); +} + +template +__global__ void rb_read_vector_kernel(ReservedBucket* reserved_bucket, + K key, size_t dim, V* out_data) { + reserved_bucket->read_vector(key, dim, out_data); +} + +template +__global__ void rb_erase_kernel(ReservedBucket* reserved_bucket, + K key, size_t dim) { + reserved_bucket->erase(key, dim); +} + +template +__global__ void rb_clear_kernel(ReservedBucket* reserved_bucket, size_t dim) { + reserved_bucket->clear(dim); +} + +template +__global__ void rb_find_or_insert_kernel( + ReservedBucket* reserved_bucket, + K key, size_t dim, const V* data, bool* is_found) { + *is_found = reserved_bucket->find_or_insert(key, dim, data); +} + +template __global__ void rb_find_or_insert_kernel( + ReservedBucket* reserved_bucket, K key, size_t dim, bool* is_found, V** values) { + *is_found = reserved_bucket->find_or_insert(key, dim, values); +} + +template +__global__ void rb_accum_or_assign_kernel( + ReservedBucket* reserved_bucket, + K key, bool is_accum, + size_t dim, const V* data) { + printf("rb_accum_or_assign_kernel\n"); + reserved_bucket->accum_or_assign(key, is_accum, dim, data); +} + +template __global__ void rb_find_kernel( + ReservedBucket* reserved_bucket, K key, size_t dim, + bool* found, V** values) { + *found = reserved_bucket->find(key, dim, values); +} + +template __global__ void rb_export_batch_kernel( + ReservedBucket* reserved_bucket, size_t n, size_t offset, K* keys, + size_t dim, V* values, size_t batch_size) { + reserved_bucket->export_batch(n, offset, keys, dim, values, batch_size); +} + + +} // namespace merlin +} // namespace nv \ No newline at end of file diff --git a/include/merlin/types.cuh b/include/merlin/types.cuh index 176aa208..be5b8383 100644 --- a/include/merlin/types.cuh +++ b/include/merlin/types.cuh @@ -140,6 +140,9 @@ class Lock { using Mutex = Lock; +template +struct ReservedBucket; + template struct Table { Bucket* buckets; @@ -163,6 +166,7 @@ struct Table { int device_id = 0; // Device id int tile_size; std::vector buckets_address; + ReservedBucket* reserved_bucket; }; template diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 9586f16a..868d4a39 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -360,7 +360,7 @@ class HashTableBase { * * @param ignore_evict_strategy A boolean option indicating whether if * the accum_or_assign ignores the evict strategy of table with current - * scores anyway. If true, it does not check whether the scores confroms to + * scores anyway. If true, it does not check whether the scores confront to * the evict strategy. If false, it requires the scores follow the evict * strategy of table. */ @@ -2641,7 +2641,7 @@ class HashTable : public HashTableBase { size_ptr + start_i, size_ptr + end_i, 0, thrust::plus()); } - + h_size += table_->reserved_bucket->size_host(); CudaCheckError(); return h_size; } diff --git a/tests/key_option_test.cc.cu b/tests/key_option_test.cc.cu new file mode 100644 index 00000000..15a86a9f --- /dev/null +++ b/tests/key_option_test.cc.cu @@ -0,0 +1,70 @@ +#include +#include + +struct KeyOptions { + uint64_t EMPTY_KEY; + uint64_t RECLAIM_KEY; + uint64_t LOCKED_KEY; + + __host__ __device__ KeyOptions(uint64_t emptyKey, uint64_t reclaimKey, uint64_t lockedKey) + : EMPTY_KEY(emptyKey), RECLAIM_KEY(reclaimKey), LOCKED_KEY(lockedKey) {} + + virtual __host__ __device__ bool isReservedKey(uint64_t key) const = 0; + virtual __host__ __device__ bool isVacantKey(uint64_t key) const = 0; + virtual ~KeyOptions() {} +}; + +constexpr uint64_t RESERVED_KEY_MASK = UINT64_C(0xFFFFFFFFFFFFFFFC); +constexpr uint64_t VACANT_KEY_MASK = UINT64_C(0xFFFFFFFFFFFFFFFE); + +class DefaultKeyOptions : public KeyOptions { + public: + __host__ __device__ DefaultKeyOptions(uint64_t emptyKey = UINT64_C(0xFFFFFFFFFFFFFFFF), + uint64_t reclaimKey = UINT64_C(0xFFFFFFFFFFFFFFFE), + uint64_t lockedKey = UINT64_C(0xFFFFFFFFFFFFFFFD)) + : KeyOptions(emptyKey, reclaimKey, lockedKey) {} + + __host__ __device__ bool isReservedKey(uint64_t key) const override { + return (RESERVED_KEY_MASK & (key)) == RESERVED_KEY_MASK; + } + + __host__ __device__ bool isVacantKey(uint64_t key) const override { + return (VACANT_KEY_MASK & (key)) == VACANT_KEY_MASK; + } +}; + +class CustomKeyOptions : public DefaultKeyOptions { + public: + __host__ __device__ CustomKeyOptions(uint64_t emptyKey = UINT64_C(0xFFFFFFFFFFFFFFFA)) // Custom EMPTY_KEY + : DefaultKeyOptions(emptyKey, UINT64_C(0xFFFFFFFFFFFFFFFE), UINT64_C(0xFFFFFFFFFFFFFFFD)) {} + + __host__ __device__ bool isReservedKey(uint64_t key) const override { + return key == EMPTY_KEY || key == RECLAIM_KEY || key == LOCKED_KEY; + } + + __host__ __device__ bool isVacantKey(uint64_t key) const override { + return key == EMPTY_KEY || key == RECLAIM_KEY; + } +}; + +void testKeyOptions() { + DefaultKeyOptions opts; + if (opts.isReservedKey(UINT64_C(0xFFFFFFFFFFFFFFFF))) { + printf("Empty key is reserved\n"); + } + if (!opts.isReservedKey(UINT64_C(0x1))) { + printf("Non-reserved key is not reserved\n"); + } + + CustomKeyOptions customOpts; + if (!customOpts.isReservedKey(UINT64_C(0xFFFFFFFFFFFFFFFF))) { + printf("Empty key is no longer treated as reserved\n"); + } + if (customOpts.isReservedKey(UINT64_C(0xFFFFFFFFFFFFFFFA))) { + printf("New empty key is reserved\n"); + } +} + +TEST(KeyOptionsTest, testKeyOptions) { + testKeyOptions(); +} \ No newline at end of file diff --git a/tests/merlin_hashtable_test.cc.cu b/tests/merlin_hashtable_test.cc.cu index ed80be3c..92171e54 100644 --- a/tests/merlin_hashtable_test.cc.cu +++ b/tests/merlin_hashtable_test.cc.cu @@ -186,18 +186,24 @@ void test_basic(size_t max_hbm_for_vectors) { for (int i = 0; i < TEST_TIMES; i++) { std::unique_ptr table = std::make_unique
(); table->init(options); - + std::cout << "before count" << std::endl; + size_t count = table->bucket_count(); + std::cout << "count " << count << std::endl; ASSERT_EQ(table->bucket_count(), 524287); // 1 + (INIT_CAPACITY / options.bucket_max_size) + + std::cout << "before size" << std::endl; total_size = table->size(stream); CUDA_CHECK(cudaStreamSynchronize(stream)); + std::cout << "size " << total_size << std::endl; ASSERT_EQ(total_size, 0); + std::cout << "before assign" << std::endl; table->insert_or_assign(KEY_NUM, d_keys, d_vectors, d_scores, stream); CUDA_CHECK(cudaStreamSynchronize(stream)); - total_size = table->size(stream); CUDA_CHECK(cudaStreamSynchronize(stream)); + std::cout << "total_size " << total_size << " " << KEY_NUM <(224, true); diff --git a/tests/reserved_bucket_test.cc.cu b/tests/reserved_bucket_test.cc.cu new file mode 100644 index 00000000..8afe17a2 --- /dev/null +++ b/tests/reserved_bucket_test.cc.cu @@ -0,0 +1,400 @@ +/* +* Copyright (c) 2024, NVIDIA CORPORATION. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +//#include "merlin/core_kernels/reserved_bucket.cuh" +#include "test_util.cuh" +#include +#include +#include + +#include "merlin/allocator.cuh" +#include "merlin/types.cuh" + +using namespace nv::merlin; + +typedef uint64_t K; +typedef float V; + +#define RESERVED_BUCKET_SIZE 4 +#define RESERVED_BUCKET_MASK 3 + +template struct RdBucket; + +template +__global__ static void rb_size_kernel(RdBucket* reserved_bucket, size_t* size); + +template +struct RdBucket { + cuda::atomic locks[RESERVED_BUCKET_SIZE]; + bool keys[RESERVED_BUCKET_SIZE]; + static void initialize(RdBucket** reserved_bucket, + BaseAllocator* allocator, size_t dim) { + size_t total_size = sizeof (RdBucket); + total_size += sizeof(V) * RESERVED_BUCKET_SIZE * dim; + void* memory_block; + allocator->alloc(MemoryType::Device, &memory_block, total_size); + CUDA_CHECK(cudaMemset(memory_block, 0, total_size)); + *reserved_bucket = static_cast*>(memory_block); + } + + __forceinline__ __device__ V* get_vector(K key, size_t dim) { + V* vector = reinterpret_cast(keys + RESERVED_BUCKET_SIZE); + size_t index = key & RESERVED_BUCKET_MASK; + return vector + index * dim; + } + + __forceinline__ __device__ bool contains(K key) { + size_t index = key & RESERVED_BUCKET_MASK; + return keys[index]; + } + + __forceinline__ __device__ void set_key(K key, bool value = true) { + size_t index = key & RESERVED_BUCKET_MASK; + keys[index] = value; + } + + // since reserved bucket key should always exist + // insert_or_assign insert_and_evict assign all equal to write_vector + __forceinline__ __device__ void write_vector( + K key, size_t dim, const V* data) { + V* vectors = get_vector(key, dim); + set_key(key); + for (int i = 0; i < dim; i++) { + vectors[i] = data[i]; + printf("vectors[%d] = %f %f \n", i, vectors[i], data[i]); + } + } + + __forceinline__ __device__ void read_vector( + K key, size_t dim, V* out_data) { + V* vectors = get_vector(key, dim); + for (int i = 0; i < dim; i++) { + out_data[i] = vectors[i]; + printf("out_data[%d] = %f %f \n", i, out_data[i], vectors[i]); + } + } + + __forceinline__ __device__ void erase(K key, size_t dim) { + V* vectors = get_vector(key, dim); + set_key(key, false); + for (int i = 0; i < dim; i++) { + vectors[i] = 0; + } + } + + // Search for the specified keys and return the pointers of values. + __forceinline__ __device__ bool find(K key, size_t dim, V** values) { + if (contains(key)) { + V* vectors = get_vector(key, dim); + *values = vectors; + return true; + } else { + return false; + } + } + + // Search for the specified keys and Insert them firstly when missing. + __forceinline__ __device__ bool find_or_insert(K key, size_t dim, V* values) { + if (contains(key)) { + return true; + } else { + write_vector(key, dim, values); + set_key(key); + return false; + } + } + + // Search for the specified keys and return the pointers of values. + // Insert them firstly when missing. + __forceinline__ __device__ bool find_or_insert( + K key, size_t dim, V** values) { + if (contains(key)) { + V* vectors = get_vector(key, dim); + *values = vectors; + return true; + } else { + write_vector(key, dim, *values); + set_key(key); + return false; + } + } + __forceinline__ __device__ void accum_or_assign( + K key, bool is_accum, size_t dim, const V* values) { + if (is_accum) { + V* vectors = get_vector(key, dim); + for (int i = 0; i < dim; i++) { + vectors[i] += values[i]; + } + } else { + write_vector(key, dim, values); + } + set_key(key); + } + + /* + * @brief Exports reserved bucket to key-value tuples + * @param n The maximum number of exported pairs. + * @param offset The position of the key to search. + * @param keys The keys to dump from GPU-accessible memory with shape (n). + * @param values The values to dump from GPU-accessible memory with shape + * (n, DIM). + * @return The number of elements dumped. + */ + __forceinline__ __device__ size_t export_batch( + size_t n, const size_t offset, + K* keys, size_t dim, V* values, size_t batch_size) { + if (offset >= size()) { + return 0; + } + + size_t count = 0; + V* vector = reinterpret_cast(keys + RESERVED_BUCKET_SIZE); + for (int i = offset; i < RESERVED_BUCKET_SIZE && offset < n; i++) { + vector += i * dim; + offset++; + if (keys[i]) { + for (int j = 0; j < dim; j++) { + values[i * dim + j] = vector[j]; + } + } + } + return count; + } + + /** + * @brief Returns the reserved bucket size. + */ + __forceinline__ __device__ size_t size() { + size_t count = 0; + for (int i = 0; i < RESERVED_BUCKET_SIZE; i++) { + if (keys[i]) { + count++; + } + } + return count; + } + + size_t size_host() { + size_t * d_size; + cudaMalloc(&d_size, sizeof(int)); + rb_size_kernel<<<1, 1>>>(this, d_size); + CUDA_CHECK(cudaDeviceSynchronize()); + int h_size; + cudaMemcpy(&h_size, d_size, sizeof(int), cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaFree(d_size)); + return h_size; + } + /** + * @brief Removes all of the elements in the reserved bucket with no release + * object. + */ + __forceinline__ __device__ void clear(size_t dim) { + size_t total_size = sizeof (RdBucket); + total_size += sizeof(V) * RESERVED_BUCKET_SIZE * dim; + CUDA_CHECK(cudaMemset(this, 0, total_size)); + } +}; + +template +__global__ static void rb_size_kernel(RdBucket* reserved_bucket, size_t* size) { + *size = reserved_bucket->size(); +} + +template +__global__ void rb_write_vector_kernel(RdBucket* reserved_bucket, + K key, size_t dim, const V* data) { + reserved_bucket->write_vector(key, dim, data); +} + +template +__global__ void rb_read_vector_kernel(RdBucket* reserved_bucket, + K key, size_t dim, V* out_data) { + reserved_bucket->read_vector(key, dim, out_data); +} + +template +__global__ void rb_erase_kernel(RdBucket* reserved_bucket, + K key, size_t dim) { + reserved_bucket->erase(key, dim); +} + +template +__global__ void rb_clear_kernel(RdBucket* reserved_bucket, size_t dim) { + reserved_bucket->clear(dim); +} + +template +__global__ void rb_find_or_insert_kernel( + RdBucket* reserved_bucket, + K key, size_t dim, const V* data, bool* is_found) { + *is_found = reserved_bucket->find_or_insert(key, dim, data); +} + +template __global__ void rb_find_or_insert_kernel( + RdBucket* reserved_bucket, K key, size_t dim, bool* is_found, V** values) { + *is_found = reserved_bucket->find_or_insert(key, dim, values); +} + +template +__global__ void rb_accum_or_assign_kernel( + RdBucket* reserved_bucket, + K key, bool is_accum, + size_t dim, const V* data) { + printf("rb_accum_or_assign_kernel\n"); + reserved_bucket->accum_or_assign(key, is_accum, dim, data); +} + +template __global__ void rb_find_kernel( + RdBucket* reserved_bucket, K key, size_t dim, + bool* found, V** values) { + *found = reserved_bucket->find(key, dim, values); +} + +template __global__ void rb_export_batch_kernel( + RdBucket* reserved_bucket, size_t n, size_t offset, K* keys, + size_t dim, V* values, size_t batch_size) { + reserved_bucket->export_batch(n, offset, keys, dim, values, batch_size); +} + +#include +void print_vector(const float* vector, size_t dim) { + std::cout << "Vector contents: ["; + for (size_t i = 0; i < dim; i++) { + std::cout << vector[i] << (i < dim - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; +} + +#define ASSERT_EQUAL(x, y, index) \ + do { \ + if ((x) != (y)) { \ + std::cerr << "Assertion failed: (" << #x << " == " << #y \ + << "), in file " << __FILE__ << ", line " << __LINE__ \ + << ", index " << (index) << ".\n" \ + << "Values: " << (x) << " != " << (y) << std::endl; \ + std::abort(); \ + } \ + } while (false) + +bool find_key(RdBucket* bucket, K key, size_t dim, V* values) { + bool* d_is_found; + cudaMalloc(&d_is_found, sizeof(bool)); + V** d_values; + cudaMalloc(&d_values, sizeof(V*)); + rb_find_kernel<<<1, 1>>>(bucket, key, dim, d_is_found, d_values); + CUDA_CHECK(cudaDeviceSynchronize()); + bool h_is_found; + CUDA_CHECK(cudaMemcpy(&h_is_found, d_is_found, sizeof(bool), cudaMemcpyDeviceToHost)); + std::cout << "found " << h_is_found < default_allocator(new DefaultAllocator()); + BaseAllocator* allocator = default_allocator.get(); + int num_devices; + CUDA_CHECK(cudaGetDeviceCount(&num_devices)); + MERLIN_CHECK(num_devices > 0, + "Need at least one CUDA capable device for running this test."); + std::cout << "enter " << key % RESERVED_BUCKET_SIZE << std::endl; + RdBucket* bucket; + size_t dim = 10; + RdBucket::initialize(&bucket, allocator, dim); + + V* test_vector; + V* out_vector; + CUDA_CHECK(cudaMalloc(&test_vector, dim * sizeof(V))); + CUDA_CHECK(cudaMalloc(&out_vector, dim * sizeof(V))); + + V host_vector[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + CUDA_CHECK(cudaMemcpy(test_vector, host_vector, + dim * sizeof(V), cudaMemcpyHostToDevice)); + + CudaCheckError(); + + rb_write_vector_kernel<<<1, 1>>>(bucket, key, dim, test_vector); + + CUDA_CHECK(cudaDeviceSynchronize()); + assert(bucket->size_host() == 1); + V host_out_vector[10]; + bool found = find_key(bucket, key, dim, host_out_vector); + assert(found); + array_eq(host_vector, host_out_vector, dim); + + rb_read_vector_kernel<<<1, 1>>>(bucket, key, dim, out_vector); + + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy( + host_out_vector, out_vector, dim * sizeof(V), cudaMemcpyDeviceToHost)); + + print_vector(host_out_vector, dim); + array_eq(host_vector, host_out_vector, dim); + + rb_erase_kernel<<<1, 1>>>(bucket, key, dim); + CUDA_CHECK(cudaDeviceSynchronize()); + rb_read_vector_kernel<<<1, 1>>>(bucket, key, dim, out_vector); + CUDA_CHECK(cudaMemcpy( + host_out_vector, out_vector, dim * sizeof(V), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaDeviceSynchronize()); + for (int i = 0; i < dim; i++) { + ASSERT_EQUAL(host_out_vector[i], 0, i); + } + assert(bucket->size_host() == 0); + assert(!find_key(bucket, key, dim, host_out_vector)); + + rb_accum_or_assign_kernel<<<1, 1>>>(bucket, key, false, dim, test_vector); + CUDA_CHECK(cudaDeviceSynchronize()); + + assert(bucket->size_host() == 1); + assert(find_key(bucket, key, dim, host_out_vector)); + array_eq(host_vector, host_out_vector, dim); + + rb_accum_or_assign_kernel<<<1, 1>>>(bucket, key, true, dim, test_vector); + CUDA_CHECK(cudaDeviceSynchronize()); + assert(bucket->size_host() == 1); + assert(find_key(bucket, key, dim, host_out_vector)); + + V host_vector2[10] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + array_eq(host_vector2, host_out_vector, dim); + std::cout << "All GPU tests passed!" << std::endl; + + CUDA_CHECK(cudaFree(test_vector)); + CUDA_CHECK(cudaFree(out_vector)); + CudaCheckError(); +} + +TEST(RdBucketTest, test_reserved_bucket_gpu) { + test_reserved_bucket_gpu(EMPTY_KEY); + test_reserved_bucket_gpu(RECLAIM_KEY); + test_reserved_bucket_gpu(LOCKED_KEY); + test_reserved_bucket_gpu(RESERVED_KEY_MASK); +} \ No newline at end of file diff --git a/tests/test_util.cuh b/tests/test_util.cuh index bdd75a35..2633a9df 100644 --- a/tests/test_util.cuh +++ b/tests/test_util.cuh @@ -27,6 +27,7 @@ #include #include #include +#include "merlin/types.cuh" #include "merlin/utils.cuh" #include "merlin_hashtable.cuh" @@ -42,6 +43,8 @@ exit(-1); \ } +using namespace nv::merlin; + namespace test_util { template @@ -128,6 +131,10 @@ void create_random_keys(K* h_keys, S* h_scores, int KEY_NUM, std::mt19937_64 eng(rd()); std::uniform_int_distribution distr; int i = 0; + numbers.insert(EMPTY_KEY); + numbers.insert(RECLAIM_KEY); + numbers.insert(LOCKED_KEY); + numbers.insert(RESERVED_KEY_MASK); while (numbers.size() < KEY_NUM) { numbers.insert(distr(eng)); @@ -148,6 +155,11 @@ void create_random_keys(K* h_keys, S* h_scores, V* h_vectors, int KEY_NUM, std::uniform_int_distribution distr; int i = 0; + numbers.insert(EMPTY_KEY); + numbers.insert(RECLAIM_KEY); + numbers.insert(LOCKED_KEY); + numbers.insert(RESERVED_KEY_MASK); + while (numbers.size() < KEY_NUM) { numbers.insert(distr(eng) % range); }