Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable reserve key #190

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
.idea
.vscode
build

.clwb
cmake-build-debug/
docs/build
docs/source/README.md
docs/source/CONTRIBUTING.md
Expand Down
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,13 @@ add_executable(find_with_missed_keys_test tests/find_with_missed_keys_test.cc.cu
target_compile_features(find_with_missed_keys_test PUBLIC cxx_std_14)
set_target_properties(find_with_missed_keys_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main)

add_executable(reserved_bucket_test tests/reserved_bucket_test.cc.cu)
target_compile_features(reserved_bucket_test PUBLIC cxx_std_14)
set_target_properties(reserved_bucket_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(reserved_bucket_test gtest_main)

add_executable(key_option_test tests/key_option_test.cc.cu)
target_compile_features(key_option_test PUBLIC cxx_std_14)
set_target_properties(key_option_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(key_option_test gtest_main)
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ cd HierarchicalKV && mkdir -p build && cd build
cmake -DCMAKE_BUILD_TYPE=Release -Dsm=80 .. && make -j
```

For Debug:
```shell
cmake -DCMAKE_BUILD_TYPE=Debug -Dsm=80 .. && make -j
```

For Benchmark:
```shell
./merlin_hashtable_benchmark
Expand Down
21 changes: 16 additions & 5 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "core_kernels/kernel_utils.cuh"
#include "core_kernels/lookup.cuh"
#include "core_kernels/lookup_ptr.cuh"
#include "core_kernels/reserved_bucket.cuh"
#include "core_kernels/update.cuh"
#include "core_kernels/update_score.cuh"
#include "core_kernels/update_values.cuh"
Expand Down Expand Up @@ -341,6 +342,10 @@ void create_table(Table<K, V, S>** table, BaseAllocator* allocator,
(*table)->buckets_num * sizeof(Bucket<K, V, S>)));

initialize_buckets<K, V, S>(table, allocator, 0, (*table)->buckets_num);

ReservedBucket<K, V>::initialize(
&(*table)->reserved_bucket, allocator, (*table)->dim);

CudaCheckError();
}

Expand Down Expand Up @@ -632,15 +637,19 @@ __global__ void remove_kernel(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* __restrict buckets,
int* __restrict buckets_size,
const size_t bucket_max_size,
const size_t buckets_num, size_t N) {
const size_t buckets_num,
size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int rank = g.thread_rank();

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
int key_idx = t / TILE_SIZE;
K find_key = keys[key_idx];
if (IS_RESERVED_KEY(find_key)) continue;
if (IS_RESERVED_KEY(find_key)) {
table->reserved_bucket->erase(find_key, table->dim);
continue;
}

int key_pos = -1;

Expand Down Expand Up @@ -700,7 +709,8 @@ __global__ void remove_kernel(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* __restrict buckets,
int* __restrict buckets_size,
const size_t bucket_max_size,
const size_t buckets_num, size_t N) {
const size_t buckets_num,
size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
PredFunctor<K, S> pred;

Expand All @@ -719,8 +729,8 @@ __global__ void remove_kernel(const Table<K, V, S>* __restrict table,
bucket->keys(key_offset)->load(cuda::std::memory_order_relaxed);
current_score =
bucket->scores(key_offset)->load(cuda::std::memory_order_relaxed);
if (!IS_RESERVED_KEY(current_key)) {
if (pred(current_key, current_score, pattern, threshold)) {
if (pred(current_key, current_score, pattern, threshold)) {
if (!IS_RESERVED_KEY(current_key)) {
atomicAdd(count, 1);
key_pos = key_offset;
bucket->digests(key_pos)[0] = empty_digest<K>();
Expand All @@ -732,6 +742,7 @@ __global__ void remove_kernel(const Table<K, V, S>* __restrict table,
cuda::std::memory_order_relaxed);
atomicSub(&buckets_size[bkt_idx], 1);
} else {
table->reserved_bucket->erase(current_key, table->dim);
key_offset++;
}
} else {
Expand Down
4 changes: 3 additions & 1 deletion include/merlin/core_kernels/accum_or_assign.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ __global__ void accum_or_assign_kernel_with_io(

const K insert_key = keys[key_idx];

if (IS_RESERVED_KEY(insert_key)) continue;
if (IS_RESERVED_KEY(insert_key)) {
continue;
}

const S insert_score =
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
Expand Down
267 changes: 267 additions & 0 deletions include/merlin/core_kernels/reserved_bucket.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http:///www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include "../allocator.cuh"
#include "../types.cuh"

namespace nv {
namespace merlin {

#define RESERVED_BUCKET_SIZE 4
#define RESERVED_BUCKET_MASK 3

template <class K, class V> struct ReservedBucket;

template <class K, class V>
__global__ static void rb_size_kernel(ReservedBucket<K, V>* reserved_bucket, size_t* size);

template <class K, class V>
struct ReservedBucket {
cuda::atomic<bool, cuda::thread_scope_device> locks[RESERVED_BUCKET_SIZE];
bool keys[RESERVED_BUCKET_SIZE];
static void initialize(ReservedBucket<K, V>** reserved_bucket,
BaseAllocator* allocator, size_t dim) {
size_t total_size = sizeof (ReservedBucket<K, V>);
total_size += sizeof(V) * RESERVED_BUCKET_SIZE * dim;
void* memory_block;
allocator->alloc(MemoryType::Device, &memory_block, total_size);
CUDA_CHECK(cudaMemset(memory_block, 0, total_size));
*reserved_bucket = static_cast<ReservedBucket<K, V>*>(memory_block);
}

__forceinline__ __device__ V* get_vector(K key, size_t dim) {
V* vector = reinterpret_cast<V*>(keys + RESERVED_BUCKET_SIZE);
size_t index = key & RESERVED_BUCKET_MASK;
return vector + index * dim;
}

__forceinline__ __device__ bool contains(K key) {
size_t index = key & RESERVED_BUCKET_MASK;
return keys[index];
}

__forceinline__ __device__ void set_key(K key, bool value = true) {
size_t index = key & RESERVED_BUCKET_MASK;
keys[index] = value;
}

// since reserved bucket key should always exist
// insert_or_assign insert_and_evict assign all equal to write_vector
__forceinline__ __device__ void write_vector(
K key, size_t dim, const V* data) {
V* vectors = get_vector(key, dim);
set_key(key);
for (int i = 0; i < dim; i++) {
vectors[i] = data[i];
printf("vectors[%d] = %f %f \n", i, vectors[i], data[i]);
}
}

__forceinline__ __device__ void read_vector(
K key, size_t dim, V* out_data) {
V* vectors = get_vector(key, dim);
for (int i = 0; i < dim; i++) {
out_data[i] = vectors[i];
printf("out_data[%d] = %f %f \n", i, out_data[i], vectors[i]);
}
}

__forceinline__ __device__ void erase(K key, size_t dim) {
V* vectors = get_vector(key, dim);
set_key(key, false);
for (int i = 0; i < dim; i++) {
vectors[i] = 0;
}
}

// Search for the specified keys and return the pointers of values.
__forceinline__ __device__ bool find(K key, size_t dim, V** values) {
if (contains(key)) {
V* vectors = get_vector(key, dim);
*values = vectors;
return true;
} else {
return false;
}
}

// Search for the specified keys and Insert them firstly when missing.
__forceinline__ __device__ bool find_or_insert(K key, size_t dim, V* values) {
if (contains(key)) {
return true;
} else {
write_vector(key, dim, values);
set_key(key);
return false;
}
}

// Search for the specified keys and return the pointers of values.
// Insert them firstly when missing.
__forceinline__ __device__ bool find_or_insert(
K key, size_t dim, V** values) {
if (contains(key)) {
V* vectors = get_vector(key, dim);
*values = vectors;
return true;
} else {
write_vector(key, dim, *values);
set_key(key);
return false;
}
}
__forceinline__ __device__ void accum_or_assign(
K key, bool is_accum, size_t dim, const V* values) {
if (is_accum) {
V* vectors = get_vector(key, dim);
for (int i = 0; i < dim; i++) {
vectors[i] += values[i];
}
} else {
write_vector(key, dim, values);
}
set_key(key);
}

/*
* @brief Exports reserved bucket to key-value tuples
* @param n The maximum number of exported pairs.
* @param offset The position of the key to search.
* @param keys The keys to dump from GPU-accessible memory with shape (n).
* @param values The values to dump from GPU-accessible memory with shape
* (n, DIM).
* @return The number of elements dumped.
*/
__forceinline__ __device__ size_t export_batch(
size_t n, const size_t offset,
K* keys, size_t dim, V* values, size_t batch_size) {
if (offset >= size()) {
return 0;
}

size_t count = 0;
V* vector = reinterpret_cast<V*>(keys + RESERVED_BUCKET_SIZE);
for (int i = offset; i < RESERVED_BUCKET_SIZE && offset < n; i++) {
vector += i * dim;
offset++;
if (keys[i]) {
for (int j = 0; j < dim; j++) {
values[i * dim + j] = vector[j];
}
}
}
return count;
}

/**
* @brief Returns the reserved bucket size.
*/
__forceinline__ __device__ size_t size() {
size_t count = 0;
for (int i = 0; i < RESERVED_BUCKET_SIZE; i++) {
if (keys[i]) {
count++;
}
}
return count;
}

size_t size_host() {
size_t * d_size;
cudaMalloc(&d_size, sizeof(int));
rb_size_kernel<<<1, 1>>>(this, d_size);
CUDA_CHECK(cudaDeviceSynchronize());
int h_size;
cudaMemcpy(&h_size, d_size, sizeof(int), cudaMemcpyDeviceToHost);
CUDA_CHECK(cudaFree(d_size));
return h_size;
}
/**
* @brief Removes all of the elements in the reserved bucket with no release
* object.
*/
__forceinline__ __device__ void clear(size_t dim) {
size_t total_size = sizeof (ReservedBucket<K, V>);
total_size += sizeof(V) * RESERVED_BUCKET_SIZE * dim;
CUDA_CHECK(cudaMemset(this, 0, total_size));
}
};

template <class K, class V>
__global__ static void rb_size_kernel(ReservedBucket<K, V>* reserved_bucket, size_t* size) {
*size = reserved_bucket->size();
}

template <class K, class V>
__global__ void rb_write_vector_kernel(ReservedBucket<K, V>* reserved_bucket,
K key, size_t dim, const V* data) {
reserved_bucket->write_vector(key, dim, data);
}

template <class K, class V>
__global__ void rb_read_vector_kernel(ReservedBucket<K, V>* reserved_bucket,
K key, size_t dim, V* out_data) {
reserved_bucket->read_vector(key, dim, out_data);
}

template <class K, class V>
__global__ void rb_erase_kernel(ReservedBucket<K, V>* reserved_bucket,
K key, size_t dim) {
reserved_bucket->erase(key, dim);
}

template <class K, class V>
__global__ void rb_clear_kernel(ReservedBucket<K, V>* reserved_bucket, size_t dim) {
reserved_bucket->clear(dim);
}

template <class K, class V>
__global__ void rb_find_or_insert_kernel(
ReservedBucket<K, V>* reserved_bucket,
K key, size_t dim, const V* data, bool* is_found) {
*is_found = reserved_bucket->find_or_insert(key, dim, data);
}

template <class K, class V> __global__ void rb_find_or_insert_kernel(
ReservedBucket<K, V>* reserved_bucket, K key, size_t dim, bool* is_found, V** values) {
*is_found = reserved_bucket->find_or_insert(key, dim, values);
}

template <class K, class V>
__global__ void rb_accum_or_assign_kernel(
ReservedBucket<K, V>* reserved_bucket,
K key, bool is_accum,
size_t dim, const V* data) {
printf("rb_accum_or_assign_kernel\n");
reserved_bucket->accum_or_assign(key, is_accum, dim, data);
}

template <class K, class V> __global__ void rb_find_kernel(
ReservedBucket<K, V>* reserved_bucket, K key, size_t dim,
bool* found, V** values) {
*found = reserved_bucket->find(key, dim, values);
}

template <class K, class V> __global__ void rb_export_batch_kernel(
ReservedBucket<K, V>* reserved_bucket, size_t n, size_t offset, K* keys,
size_t dim, V* values, size_t batch_size) {
reserved_bucket->export_batch(n, offset, keys, dim, values, batch_size);
}


} // namespace merlin
} // namespace nv
Loading