Skip to content

Commit

Permalink
[Fix] make thrust temp memory allocation managed by HKV allocator.
Browse files Browse the repository at this point in the history
- Refactoring: clean the warning on low version of CUDA
  • Loading branch information
rhdong committed Oct 24, 2023
1 parent f4ecb2c commit 5301c4a
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 21 deletions.
31 changes: 31 additions & 0 deletions include/merlin/allocator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <stdlib.h>
#include <thrust/device_malloc_allocator.h>
#include "debug.hpp"
#include "utils.cuh"

Expand Down Expand Up @@ -123,5 +124,35 @@ class DefaultAllocator : public virtual BaseAllocator {
}
};

template <typename T>
struct ThrustAllocator : thrust::device_malloc_allocator<T> {
public:
typedef thrust::device_malloc_allocator<T> super_t;
typedef typename super_t::pointer pointer;
typedef typename super_t::size_type size_type;

public:
pointer allocate(size_type n) {
void* ptr = nullptr;
MERLIN_CHECK(
allocator_ != nullptr,
"[ThrustAllocator] set_allocator should be called in advance!");
allocator_->alloc(MemoryType::Device, &ptr, sizeof(T) * n);
return pointer(reinterpret_cast<T*>(ptr));
}

void deallocate(pointer p, size_type n) {
MERLIN_CHECK(
allocator_ != nullptr,
"[ThrustAllocator] set_allocator should be called in advance!");
allocator_->free(MemoryType::Device, reinterpret_cast<void*>(p.get()));
}

void set_allocator(BaseAllocator* allocator) { allocator_ = allocator; }

public:
BaseAllocator* allocator_ = nullptr;
};

} // namespace merlin
} // namespace nv
2 changes: 2 additions & 0 deletions include/merlin/core_kernels/find_or_insert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,7 @@ struct KernelSelector_FindOrInsert {
}
};

#if defined(CUDART_VERSION) && (CUDART_VERSION >= 11030)
auto launch_TLPv2 = [&]() {
if (total_value_size % sizeof(byte16) == 0) {
using VecV = byte16;
Expand All @@ -1244,6 +1245,7 @@ struct KernelSelector_FindOrInsert {
params, stream);
}
};
#endif

auto launch_Pipeline = [&]() {
if (total_value_size % sizeof(byte16) == 0) {
Expand Down
4 changes: 0 additions & 4 deletions include/merlin/core_kernels/update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ __global__ void tlp_update_kernel_with_io(
uint32_t tx = threadIdx.x;
uint32_t kv_idx = blockIdx.x * blockDim.x + tx;
K key{static_cast<K>(EMPTY_KEY)};
S score{static_cast<S>(EMPTY_SCORE)};
OccupyResult occupy_result{OccupyResult::INITIAL};
VecD_Comp target_digests{0};
VecV* bucket_values_ptr{nullptr};
K* bucket_keys_ptr{nullptr};
uint32_t key_pos = {0};
if (kv_idx < n) {
key = keys[kv_idx];
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);

if (!IS_RESERVED_KEY(key)) {
const K hashed_key = Murmur3HashDevice(key);
Expand Down Expand Up @@ -767,15 +765,13 @@ __global__ void tlp_update_kernel_hybrid(
uint32_t tx = threadIdx.x;
uint32_t kv_idx = blockIdx.x * blockDim.x + tx;
K key{static_cast<K>(EMPTY_KEY)};
S score{static_cast<S>(EMPTY_SCORE)};
OccupyResult occupy_result{OccupyResult::INITIAL};
VecD_Comp target_digests{0};
V* bucket_values_ptr{nullptr};
K* bucket_keys_ptr{nullptr};
uint32_t key_pos = {0};
if (kv_idx < n) {
key = keys[kv_idx];
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);
if (src_offset) src_offset[kv_idx] = kv_idx;
if (!IS_RESERVED_KEY(key)) {
const K hashed_key = Murmur3HashDevice(key);
Expand Down
2 changes: 0 additions & 2 deletions include/merlin/core_kernels/update_score.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,12 @@ __global__ void tlp_update_score_kernel(Bucket<K, V, S>* __restrict__ buckets,
uint32_t tx = threadIdx.x;
uint32_t kv_idx = blockIdx.x * blockDim.x + tx;
K key{static_cast<K>(EMPTY_KEY)};
S score{static_cast<S>(EMPTY_SCORE)};
OccupyResult occupy_result{OccupyResult::INITIAL};
VecD_Comp target_digests{0};
K* bucket_keys_ptr{nullptr};
uint32_t key_pos = {0};
if (kv_idx < n) {
key = keys[kv_idx];
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);

if (!IS_RESERVED_KEY(key)) {
const K hashed_key = Murmur3HashDevice(key);
Expand Down
2 changes: 2 additions & 0 deletions include/merlin/core_kernels/upsert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ struct KernelSelector_Upsert {
}
};

#if defined(CUDART_VERSION) && (CUDART_VERSION >= 11030)
auto launch_TLPv2 = [&]() {
if (total_value_size % sizeof(byte16) == 0) {
using VecV = byte16;
Expand All @@ -1169,6 +1170,7 @@ struct KernelSelector_Upsert {
stream);
}
};
#endif

auto launch_Pipeline = [&]() {
if (total_value_size % sizeof(byte16) == 0) {
Expand Down
2 changes: 2 additions & 0 deletions include/merlin/core_kernels/upsert_and_evict.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,7 @@ struct KernelSelector_UpsertAndEvict {
}
};

#if defined(CUDART_VERSION) && (CUDART_VERSION >= 11030)
auto launch_TLPv2 = [&]() {
if (total_value_size % sizeof(byte16) == 0) {
using VecV = byte16;
Expand All @@ -1315,6 +1316,7 @@ struct KernelSelector_UpsertAndEvict {
params, stream);
}
};
#endif

auto launch_Pipeline = [&]() {
if (total_value_size % sizeof(byte16) == 0) {
Expand Down
39 changes: 24 additions & 15 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ class HashTable {
default_allocator_ = (allocator == nullptr);
allocator_ = (allocator == nullptr) ? (new DefaultAllocator()) : allocator;

thrust_allocator_.set_allocator(allocator_);

if (options_.device_id >= 0) {
CUDA_CHECK(cudaSetDevice(options_.device_id));
} else {
Expand Down Expand Up @@ -295,6 +297,7 @@ class HashTable {
options_.host_memory_pool, allocator_);

CUDA_CHECK(cudaDeviceSynchronize());

initialized_ = true;
CudaCheckError();
}
Expand Down Expand Up @@ -426,8 +429,9 @@ class HashTable {
reinterpret_cast<uintptr_t*>(d_dst));
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
d_src_offset_ptr, thrust::less<uintptr_t>());
thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), d_dst_ptr,
d_dst_ptr + n, d_src_offset_ptr,
thrust::less<uintptr_t>());
}

if (filter_condition) {
Expand Down Expand Up @@ -780,8 +784,9 @@ class HashTable {
reinterpret_cast<uintptr_t*>(dst));
thrust::device_ptr<int> src_offset_ptr(src_offset);

thrust::sort_by_key(thrust_par.on(stream), dst_ptr, dst_ptr + n,
src_offset_ptr, thrust::less<uintptr_t>());
thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), dst_ptr,
dst_ptr + n, src_offset_ptr,
thrust::less<uintptr_t>());
}

{
Expand Down Expand Up @@ -913,9 +918,9 @@ class HashTable {
reinterpret_cast<uintptr_t*>(d_table_value_addrs));
thrust::device_ptr<int> param_key_index_ptr(param_key_index);

thrust::sort_by_key(thrust_par.on(stream), table_value_ptr,
table_value_ptr + n, param_key_index_ptr,
thrust::less<uintptr_t>());
thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream),
table_value_ptr, table_value_ptr + n,
param_key_index_ptr, thrust::less<uintptr_t>());
}

if (filter_condition) {
Expand Down Expand Up @@ -1152,8 +1157,9 @@ class HashTable {
reinterpret_cast<uintptr_t*>(d_dst));
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
d_src_offset_ptr, thrust::less<uintptr_t>());
thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), d_dst_ptr,
d_dst_ptr + n, d_src_offset_ptr,
thrust::less<uintptr_t>());
}

if (filter_condition) {
Expand Down Expand Up @@ -1349,8 +1355,9 @@ class HashTable {
reinterpret_cast<uintptr_t*>(src));
thrust::device_ptr<int> dst_offset_ptr(dst_offset);

thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n,
dst_offset_ptr, thrust::less<uintptr_t>());
thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), src_ptr,
src_ptr + n, dst_offset_ptr,
thrust::less<uintptr_t>());

const size_t block_size = options_.io_block_size;
const size_t N = n * dim();
Expand Down Expand Up @@ -1756,8 +1763,9 @@ class HashTable {

for (size_type start_i = 0; start_i < N; start_i += step) {
size_type end_i = std::min(start_i + step, N);
h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i,
size_ptr + end_i, 0, thrust::plus<int>());
h_size += thrust::reduce(thrust_par(thrust_allocator_).on(stream),
size_ptr + start_i, size_ptr + end_i, 0,
thrust::plus<int>());
}

CudaCheckError();
Expand Down Expand Up @@ -2075,8 +2083,8 @@ class HashTable {

thrust::device_ptr<int> size_ptr(table_->buckets_size);

int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0,
thrust::plus<int>());
int size = thrust::reduce(thrust_par(thrust_allocator_).on(stream),
size_ptr, size_ptr + N, 0, thrust::plus<int>());

CudaCheckError();
return static_cast<float>((delta * 1.0) / (capacity() * 1.0) +
Expand Down Expand Up @@ -2139,6 +2147,7 @@ class HashTable {
std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
std::unique_ptr<HostMemoryPool> host_mem_pool_;
allocator_type* allocator_;
ThrustAllocator<uint8_t> thrust_allocator_;
bool default_allocator_ = true;
};

Expand Down

0 comments on commit 5301c4a

Please sign in to comment.