diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 4bca3173..66efdbb6 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -222,14 +222,12 @@ void initialize_buckets(Table** table, BaseAllocator* allocator, uint32_t reserve_size = bucket_max_size < CACHE_LINE_SIZE ? CACHE_LINE_SIZE : bucket_max_size; bucket_memory_size += reserve_size * sizeof(uint8_t); - uint8_t* address = nullptr; - allocator->alloc(MemoryType::Device, (void**)&(address), - bucket_memory_size * (end - start)); - (*table)->buckets_address.push_back(address); for (int i = start; i < end; i++) { - allocate_bucket_others<<<1, 1>>>( - (*table)->buckets, i, address + (bucket_memory_size * (i - start)), - reserve_size, bucket_max_size); + uint8_t* address = nullptr; + allocator->alloc(MemoryType::Device, (void**)&(address), + bucket_memory_size); + allocate_bucket_others<<<1, 1>>>((*table)->buckets, i, address, + reserve_size, bucket_max_size); } CUDA_CHECK(cudaDeviceSynchronize()); @@ -367,9 +365,17 @@ void double_capacity(Table** table, BaseAllocator* allocator) { /* free all of the resource of a Table. */ template void destroy_table(Table** table, BaseAllocator* allocator) { - for (auto addr : (*table)->buckets_address) { - allocator->free(MemoryType::Device, addr); - } + uint8_t** d_address = nullptr; + CUDA_CHECK(cudaMalloc((void**)&d_address, sizeof(uint8_t*))); + for (int i = 0; i < (*table)->buckets_num; i++) { + uint8_t* h_address; + get_bucket_others_address + <<<1, 1>>>((*table)->buckets, i, d_address); + CUDA_CHECK(cudaMemcpy(&h_address, d_address, sizeof(uint8_t*), + cudaMemcpyDeviceToHost)); + allocator->free(MemoryType::Device, h_address); + } + CUDA_CHECK(cudaFree(d_address)); for (int i = 0; i < (*table)->num_of_memory_slices; i++) { if (is_on_device((*table)->slices[i])) { diff --git a/include/merlin/types.cuh b/include/merlin/types.cuh index c8ac799f..6f0bcc15 100644 --- a/include/merlin/types.cuh +++ b/include/merlin/types.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include "debug.hpp" namespace nv { @@ -215,7 +214,6 @@ struct Table { int slots_number = 0; // unused int device_id = 0; // Device id int tile_size; - std::vector buckets_address; }; template