diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 6585efcf6..b04f3187c 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -222,12 +222,14 @@ void initialize_buckets(Table** table, BaseAllocator* allocator, uint32_t reserve_size = bucket_max_size < CACHE_LINE_SIZE ? CACHE_LINE_SIZE : bucket_max_size; bucket_memory_size += reserve_size * sizeof(uint8_t); + uint8_t* address = nullptr; + allocator->alloc(MemoryType::Device, (void**)&(address), + bucket_memory_size * (end - start)); + (*table)->buckets_address.push_back(address); for (int i = start; i < end; i++) { - uint8_t* address = nullptr; - allocator->alloc(MemoryType::Device, (void**)&(address), - bucket_memory_size); - allocate_bucket_others<<<1, 1>>>((*table)->buckets, i, address, - reserve_size, bucket_max_size); + allocate_bucket_others<<<1, 1>>>( + (*table)->buckets, i, address + (bucket_memory_size * (i - start)), + reserve_size, bucket_max_size); } CUDA_CHECK(cudaDeviceSynchronize()); @@ -365,17 +367,9 @@ void double_capacity(Table** table, BaseAllocator* allocator) { /* free all of the resource of a Table. */ template void destroy_table(Table** table, BaseAllocator* allocator) { - uint8_t** d_address = nullptr; - CUDA_CHECK(cudaMalloc((void**)&d_address, sizeof(uint8_t*))); - for (int i = 0; i < (*table)->buckets_num; i++) { - uint8_t* h_address; - get_bucket_others_address - <<<1, 1>>>((*table)->buckets, i, d_address); - CUDA_CHECK(cudaMemcpy(&h_address, d_address, sizeof(uint8_t*), - cudaMemcpyDeviceToHost)); - allocator->free(MemoryType::Device, h_address); - } - CUDA_CHECK(cudaFree(d_address)); + for (auto addr : (*table)->buckets_address) { + allocator->free(MemoryType::Device, addr); + } for (int i = 0; i < (*table)->num_of_memory_slices; i++) { if (is_on_device((*table)->slices[i])) { diff --git a/include/merlin/types.cuh b/include/merlin/types.cuh index 67b1da03b..96d558eee 100644 --- a/include/merlin/types.cuh +++ b/include/merlin/types.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace nv { namespace merlin { @@ -161,6 +162,7 @@ struct Table { int slots_number = 0; // unused int device_id = 0; // Device id int tile_size; + std::vector buckets_address; }; template