Skip to content

Commit

Permalink
[Fix] make thrust temp memory allocation managed by HKV allocator.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Oct 15, 2023
1 parent f4ecb2c commit 3b8ed85
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
31 changes: 31 additions & 0 deletions include/merlin/allocator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <stdlib.h>
#include <thrust/device_malloc_allocator.h>
#include "debug.hpp"
#include "utils.cuh"

Expand Down Expand Up @@ -123,5 +124,35 @@ class DefaultAllocator : public virtual BaseAllocator {
}
};

template <typename T>
struct ThrustAllocator : thrust::device_malloc_allocator<T> {
public:
typedef thrust::device_malloc_allocator<T> super_t;
typedef typename super_t::pointer pointer;
typedef typename super_t::size_type size_type;

public:
pointer allocate(size_type n) {
void* ptr = nullptr;
MERLIN_CHECK(
allocator_ != nullptr,
"[ThrustAllocator] set_allocator should be called in advance!");
allocator_->alloc(MemoryType::Device, &ptr, sizeof(T) * n);
return pointer(reinterpret_cast<T*>(ptr));
}

void deallocate(pointer p, size_type n) {
MERLIN_CHECK(
allocator_ != nullptr,
"[ThrustAllocator] set_allocator should be called in advance!");
allocator_->free(MemoryType::Device, reinterpret_cast<void*>(p.get()));
}

void set_allocator(BaseAllocator* allocator) { allocator_ = allocator; }

public:
BaseAllocator* allocator_ = nullptr;
};

} // namespace merlin
} // namespace nv
18 changes: 12 additions & 6 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ class HashTable {
default_allocator_ = (allocator == nullptr);
allocator_ = (allocator == nullptr) ? (new DefaultAllocator()) : allocator;

thrust_allocator_.set_allocator(allocator_);

if (options_.device_id >= 0) {
CUDA_CHECK(cudaSetDevice(options_.device_id));
} else {
Expand Down Expand Up @@ -295,6 +297,7 @@ class HashTable {
options_.host_memory_pool, allocator_);

CUDA_CHECK(cudaDeviceSynchronize());

initialized_ = true;
CudaCheckError();
}
Expand Down Expand Up @@ -1349,8 +1352,9 @@ class HashTable {
reinterpret_cast<uintptr_t*>(src));
thrust::device_ptr<int> dst_offset_ptr(dst_offset);

thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n,
dst_offset_ptr, thrust::less<uintptr_t>());
thrust::sort_by_key(thrust_par(thrust_allocator_).on(stream), src_ptr,
src_ptr + n, dst_offset_ptr,
thrust::less<uintptr_t>());

const size_t block_size = options_.io_block_size;
const size_t N = n * dim();
Expand Down Expand Up @@ -1756,8 +1760,9 @@ class HashTable {

for (size_type start_i = 0; start_i < N; start_i += step) {
size_type end_i = std::min(start_i + step, N);
h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i,
size_ptr + end_i, 0, thrust::plus<int>());
h_size += thrust::reduce(thrust_par(thrust_allocator_).on(stream),
size_ptr + start_i, size_ptr + end_i, 0,
thrust::plus<int>());
}

CudaCheckError();
Expand Down Expand Up @@ -2075,8 +2080,8 @@ class HashTable {

thrust::device_ptr<int> size_ptr(table_->buckets_size);

int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0,
thrust::plus<int>());
int size = thrust::reduce(thrust_par(thrust_allocator_).on(stream),
size_ptr, size_ptr + N, 0, thrust::plus<int>());

CudaCheckError();
return static_cast<float>((delta * 1.0) / (capacity() * 1.0) +
Expand Down Expand Up @@ -2139,6 +2144,7 @@ class HashTable {
std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
std::unique_ptr<HostMemoryPool> host_mem_pool_;
allocator_type* allocator_;
ThrustAllocator<uint8_t> thrust_allocator_;
bool default_allocator_ = true;
};

Expand Down

0 comments on commit 3b8ed85

Please sign in to comment.