Skip to content

Commit

Permalink
[Feat] remove all of the cudaMallocManaged
Browse files Browse the repository at this point in the history
- BTW, remove the logic of the constant memory Table
  • Loading branch information
rhdong committed Jul 27, 2023
1 parent c94e641 commit f7d36f7
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 217 deletions.
134 changes: 73 additions & 61 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <cstdlib>
#include <cstring>
#include "merlin/core_kernels/find_or_insert.cuh"
#include "merlin/core_kernels/find_ptr_or_insert.cuh"
#include "merlin/core_kernels/kernel_utils.cuh"
Expand All @@ -28,37 +30,6 @@
namespace nv {
namespace merlin {

/* For improving performance consideration, allocating up to 64 table structures
* in constant memory is supported. To close this function, please set
* `TableOption::use_constant_memory` to `false`.
*/
constexpr int MAX_CONSTANT_TABLE = 64;
static std::mutex constant_table_mutex;
static uint64_t constant_table_flag = 0;

__constant__ char
c_table_[sizeof(Table<uint64_t, float, uint64_t>) * MAX_CONSTANT_TABLE];

template <class T = uint64_t>
int allocate_constant_table() {
std::lock_guard<std::mutex> guard(constant_table_mutex);
if (constant_table_flag == std::numeric_limits<uint64_t>::max()) return -1;
int table_index = 0;
while (constant_table_flag & (1l << table_index)) {
table_index++;
}

constant_table_flag = constant_table_flag | (1l << table_index);

return table_index;
}

template <class T = uint64_t>
void release_constant_table(int table_index) {
std::lock_guard<std::mutex> guard(constant_table_mutex);
if (table_index < 0 || table_index >= MAX_CONSTANT_TABLE) return;
constant_table_flag = constant_table_flag & (~(1l << table_index));
}

template <class S>
__global__ void create_locks(S* __restrict mutex, const size_t start,
Expand Down Expand Up @@ -108,6 +79,30 @@ __global__ void create_atomic_scores(Bucket<K, V, S>* __restrict buckets,
}
}

template <class K, class V, class S>
__global__ void allocate_bucket_vectors(Bucket<K, V, S>* __restrict buckets,
const size_t index, V* address) {
buckets[index].vectors = address;
}

template <class K, class V, class S>
__global__ void allocate_bucket_others(Bucket<K, V, S>* __restrict buckets,
const int index, uint8_t* address,
const uint32_t reserve_size,
const size_t bucket_max_size) {
buckets[index].digests_ = address;
buckets[index].keys_ =
reinterpret_cast<AtomicKey<K>*>(buckets[index].digests_ + reserve_size);
buckets[index].scores_ =
reinterpret_cast<AtomicScore<S>*>(buckets[index].keys_ + bucket_max_size);
}

template <class K, class V, class S>
__global__ void get_bucket_others_address(Bucket<K, V, S>* __restrict buckets,
const int index, uint8_t** address) {
*address = buckets[index].digests_;
}

/* Initialize the buckets with index from start to end. */
template <class K, class V, class S>
void initialize_buckets(Table<K, V, S>** table, const size_t start,
Expand All @@ -130,7 +125,7 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,
((*table)->bucket_max_size * sizeof(V) * (*table)->dim);
size_t num_of_allocated_buckets = 0;

realloc_managed<V**>(
realloc_host<V**>(
&((*table)->slices), (*table)->num_of_memory_slices * sizeof(V*),
((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*));

Expand All @@ -152,16 +147,23 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,
}
for (int j = 0; j < num_of_buckets_in_one_slice; j++) {
if ((*table)->is_pure_hbm) {
(*table)->buckets[start + num_of_allocated_buckets + j].vectors =
size_t index = start + num_of_allocated_buckets + j;
V* address =
(*table)->slices[i] + j * (*table)->bucket_max_size * (*table)->dim;
allocate_bucket_vectors<K, V, S>
<<<1, 1>>>((*table)->buckets, index, address);
CUDA_CHECK(cudaDeviceSynchronize());
} else {
V* h_ptr =
(*table)->slices[i] + j * (*table)->bucket_max_size * (*table)->dim;
CUDA_CHECK(cudaHostGetDevicePointer(
&((*table)->buckets[start + num_of_allocated_buckets + j].vectors),
h_ptr, 0));
V* address = nullptr;
CUDA_CHECK(cudaHostGetDevicePointer(&address, h_ptr, 0));
size_t index = start + num_of_allocated_buckets + j;
allocate_bucket_vectors<K, V, S>
<<<1, 1>>>((*table)->buckets, index, address);
}
}
CUDA_CHECK(cudaDeviceSynchronize());
num_of_allocated_buckets += num_of_buckets_in_one_slice;
}

Expand All @@ -175,13 +177,12 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,
bucket_max_size < CACHE_LINE_SIZE ? CACHE_LINE_SIZE : bucket_max_size;
bucket_memory_size += reserve_size * sizeof(uint8_t);
for (int i = start; i < end; i++) {
CUDA_CHECK(
cudaMalloc(&((*table)->buckets[i].digests_), bucket_memory_size));
(*table)->buckets[i].keys_ = reinterpret_cast<AtomicKey<K>*>(
(*table)->buckets[i].digests_ + reserve_size);
(*table)->buckets[i].scores_ = reinterpret_cast<AtomicScore<S>*>(
(*table)->buckets[i].keys_ + bucket_max_size);
uint8_t* address = nullptr;
CUDA_CHECK(cudaMalloc(&address, bucket_memory_size));
allocate_bucket_others<K, V, S><<<1, 1>>>((*table)->buckets, i, address,
reserve_size, bucket_max_size);
}
CUDA_CHECK(cudaDeviceSynchronize());

{
const size_t block_size = 512;
Expand All @@ -205,6 +206,7 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,
create_atomic_scores<K, V, S><<<grid_size, block_size>>>(
(*table)->buckets, start, end, (*table)->bucket_max_size);
}
CUDA_CHECK(cudaDeviceSynchronize());
CudaCheckError();
}

Expand Down Expand Up @@ -248,8 +250,9 @@ void create_table(Table<K, V, S>** table, const size_t dim,
const size_t max_hbm_for_vectors = 0,
const size_t bucket_max_size = 128,
const size_t tile_size = 32, const bool primary = true) {
CUDA_CHECK(cudaMallocManaged((void**)table, sizeof(Table<K, V, S>)));
CUDA_CHECK(cudaMemset(*table, 0, sizeof(Table<K, V, S>)));
(*table) =
reinterpret_cast<Table<K, V, S>*>(std::malloc(sizeof(Table<K, V, S>)));
std::memset(*table, 0, sizeof(Table<K, V, S>));
(*table)->dim = dim;
(*table)->bucket_max_size = bucket_max_size;
(*table)->max_size = std::max(init_size, max_size);
Expand Down Expand Up @@ -284,8 +287,7 @@ void create_table(Table<K, V, S>** table, const size_t dim,
CUDA_CHECK(cudaMemset((*table)->buckets_size, 0,
(*table)->buckets_num * sizeof(int)));

CUDA_CHECK(
cudaMallocManaged((void**)&((*table)->buckets),
CUDA_CHECK(cudaMalloc((void**)&((*table)->buckets),
(*table)->buckets_num * sizeof(Bucket<K, V, S>)));
CUDA_CHECK(cudaMemset((*table)->buckets, 0,
(*table)->buckets_num * sizeof(Bucket<K, V, S>)));
Expand All @@ -303,7 +305,7 @@ void double_capacity(Table<K, V, S>** table) {
realloc<int*>(&((*table)->buckets_size), (*table)->buckets_num * sizeof(int),
(*table)->buckets_num * sizeof(int) * 2);

realloc_managed<Bucket<K, V, S>*>(
realloc<Bucket<K, V, S>*>(
&((*table)->buckets), (*table)->buckets_num * sizeof(Bucket<K, V, S>),
(*table)->buckets_num * sizeof(Bucket<K, V, S>) * 2);

Expand All @@ -317,9 +319,17 @@ void double_capacity(Table<K, V, S>** table) {
/* free all of the resource of a Table. */
template <class K, class V, class S>
void destroy_table(Table<K, V, S>** table) {
uint8_t** d_address = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_address, sizeof(uint8_t*)));
for (int i = 0; i < (*table)->buckets_num; i++) {
CUDA_CHECK(cudaFree((*table)->buckets[i].digests_));
uint8_t* h_address;
get_bucket_others_address<K, V, S>
<<<1, 1>>>((*table)->buckets, i, d_address);
CUDA_CHECK(cudaMemcpy(&h_address, d_address, sizeof(uint8_t*),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaFree(h_address));
}
CUDA_CHECK(cudaFree(d_address));

for (int i = 0; i < (*table)->num_of_memory_slices; i++) {
if (is_on_device((*table)->slices[i])) {
Expand All @@ -335,11 +345,11 @@ void destroy_table(Table<K, V, S>** table) {
release_locks<Mutex>
<<<grid_size, block_size>>>((*table)->locks, 0, (*table)->buckets_num);
}
CUDA_CHECK(cudaFree((*table)->slices));
std::free((*table)->slices);
CUDA_CHECK(cudaFree((*table)->buckets_size));
CUDA_CHECK(cudaFree((*table)->buckets));
CUDA_CHECK(cudaFree((*table)->locks));
CUDA_CHECK(cudaFree(*table));
std::free(*table);
CUDA_CHECK(cudaDeviceSynchronize());
CudaCheckError();
}
Expand Down Expand Up @@ -465,8 +475,8 @@ __forceinline__ __device__ void move_key_to_new_bucket(

template <class K, class V, class S, uint32_t TILE_SIZE = 4>
__global__ void rehash_kernel_for_fast_mode(
const Table<K, V, S>* __restrict table, size_t N) {
Bucket<K, V, S>* buckets = table->buckets;
const Table<K, V, S>* __restrict table, Bucket<K, V, S>* buckets,
size_t N) {
int* __restrict buckets_size = table->buckets_size;
const size_t bucket_max_size = table->bucket_max_size;
const size_t buckets_num = table->buckets_num;
Expand Down Expand Up @@ -746,14 +756,15 @@ __global__ void accum_kernel(

/* Clear all key-value in the table. */
template <class K, class V, class S>
__global__ void clear_kernel(Table<K, V, S>* __restrict table, size_t N) {
__global__ void clear_kernel(Table<K, V, S>* __restrict table,
Bucket<K, V, S>* buckets, size_t N) {
size_t tid = (blockIdx.x * blockDim.x) + threadIdx.x;
const size_t bucket_max_size = table->bucket_max_size;

for (size_t t = tid; t < N; t += blockDim.x * gridDim.x) {
int key_idx = t % bucket_max_size;
int bkt_idx = t / bucket_max_size;
Bucket<K, V, S>* bucket = &(table->buckets[bkt_idx]);
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

bucket->digests(key_idx)[0] = empty_digest<K>();
(bucket->keys(key_idx))
Expand Down Expand Up @@ -894,7 +905,8 @@ inline std::tuple<size_t, size_t> dump_kernel_shared_memory_size(
}

template <class K, class V, class S>
__global__ void dump_kernel(const Table<K, V, S>* __restrict table, K* d_key,
__global__ void dump_kernel(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* buckets, K* d_key,
V* __restrict d_val, S* __restrict d_score,
const size_t offset, const size_t search_length,
size_t* d_dump_counter) {
Expand All @@ -915,8 +927,7 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table, K* d_key,
__syncthreads();

if (tid < search_length) {
Bucket<K, V, S>* const bucket{
&table->buckets[(tid + offset) / bucket_max_size]};
Bucket<K, V, S>* const bucket{&buckets[(tid + offset) / bucket_max_size]};

const int key_idx{static_cast<int>((tid + offset) % bucket_max_size)};
const K key{(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed)};
Expand Down Expand Up @@ -953,9 +964,10 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table, K* d_key,
template <class K, class V, class S,
template <typename, typename> class PredFunctor>
__global__ void dump_kernel(const Table<K, V, S>* __restrict table,
const K pattern, const S threshold, K* d_key,
V* __restrict d_val, S* __restrict d_score,
const size_t offset, const size_t search_length,
Bucket<K, V, S>* buckets, const K pattern,
const S threshold, K* d_key, V* __restrict d_val,
S* __restrict d_score, const size_t offset,
const size_t search_length,
size_t* d_dump_counter) {
extern __shared__ unsigned char s[];
const size_t bucket_max_size = table->bucket_max_size;
Expand All @@ -978,7 +990,7 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table,
if (tid < search_length) {
int bkt_idx = (tid + offset) / bucket_max_size;
int key_idx = (tid + offset) % bucket_max_size;
Bucket<K, V, S>* bucket = &(table->buckets[bkt_idx]);
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

const K key =
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
Expand Down
39 changes: 20 additions & 19 deletions include/merlin/core_kernels/find_or_insert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ namespace merlin {
*/
template <class K, class V, class S, uint32_t TILE_SIZE = 4>
__global__ void find_or_insert_kernel_with_io(
const Table<K, V, S>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V* __restrict values, S* __restrict scores, const size_t N) {
const Table<K, V, S>* __restrict table, Bucket<K, V, S>* buckets,
const size_t bucket_max_size, const size_t buckets_num, const size_t dim,
const K* __restrict keys, V* __restrict values, S* __restrict scores,
const size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int* buckets_size = table->buckets_size;

Expand All @@ -52,8 +53,8 @@ __global__ void find_or_insert_kernel_with_io(
K evicted_key;

Bucket<K, V, S>* bucket =
get_key_position<K>(table->buckets, find_or_insert_key, bkt_idx,
start_idx, buckets_num, bucket_max_size);
get_key_position<K>(buckets, find_or_insert_key, bkt_idx, start_idx,
buckets_num, bucket_max_size);

OccupyResult occupy_result{OccupyResult::INITIAL};
const int bucket_size = buckets_size[bkt_idx];
Expand Down Expand Up @@ -110,24 +111,24 @@ struct SelectFindOrInsertKernelWithIO {
const size_t buckets_num, const size_t dim,
cudaStream_t& stream, const size_t& n,
const Table<K, V, S>* __restrict table,
const K* __restrict keys, V* __restrict values,
S* __restrict scores) {
Bucket<K, V, S>* buckets, const K* __restrict keys,
V* __restrict values, S* __restrict scores) {
if (load_factor <= 0.75) {
const unsigned int tile_size = 4;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_or_insert_kernel_with_io<K, V, S, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, N);
<<<grid_size, block_size, 0, stream>>>(table, buckets,
bucket_max_size, buckets_num,
dim, keys, values, scores, N);
} else {
const unsigned int tile_size = 32;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_or_insert_kernel_with_io<K, V, S, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, N);
<<<grid_size, block_size, 0, stream>>>(table, buckets,
bucket_max_size, buckets_num,
dim, keys, values, scores, N);
}
return;
}
Expand All @@ -137,10 +138,10 @@ struct SelectFindOrInsertKernelWithIO {
*/
template <class K, class V, class S, uint32_t TILE_SIZE = 4>
__global__ void find_or_insert_kernel(
const Table<K, V, S>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V** __restrict vectors, S* __restrict scores, bool* __restrict found,
int* __restrict keys_index, const size_t N) {
const Table<K, V, S>* __restrict table, Bucket<K, V, S>* buckets,
const size_t bucket_max_size, const size_t buckets_num, const size_t dim,
const K* __restrict keys, V** __restrict vectors, S* __restrict scores,
bool* __restrict found, int* __restrict keys_index, const size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int* buckets_size = table->buckets_size;

Expand All @@ -162,8 +163,8 @@ __global__ void find_or_insert_kernel(
K evicted_key;

Bucket<K, V, S>* bucket =
get_key_position<K>(table->buckets, find_or_insert_key, bkt_idx,
start_idx, buckets_num, bucket_max_size);
get_key_position<K>(buckets, find_or_insert_key, bkt_idx, start_idx,
buckets_num, bucket_max_size);

if (g.thread_rank() == 0) {
*(keys_index + key_idx) = key_idx;
Expand Down
Loading

0 comments on commit f7d36f7

Please sign in to comment.