diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index c811614e3..d42bb48ab 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -1029,8 +1029,9 @@ struct SelectUpsertKernelWithIO { */ template __global__ void upsert_kernel(const Table* __restrict table, - const K* __restrict keys, V** __restrict vectors, - const M* __restrict metas, + const K* __restrict keys, + K* __restrict evicted_keys, + V** __restrict vectors, const M* __restrict metas, int* __restrict src_offset, size_t N) { Bucket* buckets = table->buckets; int* buckets_size = table->buckets_size; @@ -1168,6 +1169,10 @@ __global__ void upsert_kernel(const Table* __restrict table, // override_result == OverrideResult::SUCCESS if (rank == src_lane) { + if (evicted_keys) { + evicted_keys[key_idx] = + bucket->keys[key_pos].load(cuda::std::memory_order_relaxed); + } bucket->keys[key_pos].store(insert_key, cuda::std::memory_order_relaxed); *(vectors + key_idx) = (bucket->vectors + key_pos * dim); diff --git a/include/merlin/external_storage.cuh b/include/merlin/external_storage.cuh new file mode 100644 index 000000000..9dfb77fb1 --- /dev/null +++ b/include/merlin/external_storage.cuh @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include "merlin/memory_pool.cuh" + +namespace nv { +namespace merlin { + +template +class ExternalStorage { + public: + using size_type = size_t; + using key_type = Key; + using value_type = Value; + + using dev_mem_pool_type = MemoryPool>; + using host_mem_pool_type = MemoryPool>; + + const size_type value_dim; + + ExternalStorage() = delete; + + /** + * Constructs external storage object. + * + * @param value_dim The dimensionality of the values. In other words, each + * value stored is exactly `value_dim * sizeof(value_type)` bytes large. + */ + ExternalStorage(const size_type value_dim) : value_dim{value_dim} {} + + /** + * @brief Inserts key/value pairs into the external storage that are about to + * be evicted from the Merlin hashtable. If a key/value pair already exists, + * overwrites the current value. + * + * @param dev_mem_pool Memory pool for temporarily allocating device memory. + * @param host_mem_pool Memory pool for temporarily allocating host memory. + * @param hkvs_is_pure_hbm True if the Merlin hashtable store is currently + * operating in pure HBM mode, false otherwise. In pure HBM mode, all `values` + * pointers are GUARANTEED to point to device memory. + * @param n Number of key/value slots provided in other arguments. + * @param d_masked_keys Device pointer to an (n)-sized array of keys. + * Key-Value slots that should be ignored have the key set to `EMPTY_KEY`. + * @param d_values Device pointer to an (n)-sized array containing pointers to + * respectively a memory location where the current values for a key are + * stored. Each pointer points to a vector of length `value_dim`. Pointers + * *can* be set to `nullptr` for slots where the corresponding key equated to + * the `EMPTY_KEY`. The memory locations can be device or host memory (see + * also `hkvs_is_pure_hbm`). + * @param stream Stream that MUST be used for queuing asynchronous CUDA + * operations. If only the input arguments or resources obtained from + * respectively `dev_mem_pool` and `host_mem_pool` are used for such + * operations, it is not necessary to synchronize the stream prior to + * returning from the function. + */ + virtual void insert_or_assign(dev_mem_pool_type& dev_mem_pool, + host_mem_pool_type& host_mem_pool, + bool hkvs_is_pure_hbm, size_type n, + const key_type* d_masked_keys, // (n) + const value_type* const* d_values, // (n) + cudaStream_t stream) = 0; + + /** + * @brief Attempts to find the supplied `d_keys` if the corresponding + * `d_founds`-flag is `false` and fills the stored into the supplied memory + * locations (i.e. in `d_values`). + * + * @param dev_mem_pool Memory pool for temporarily allocating device memory. + * @param host_mem_pool Memory pool for temporarily allocating host memory. + * @param n Number of key/value slots provided in other arguments. + * @param d_keys Device pointer to an (n)-sized array of keys. + * @param d_values Device pointer to an (n * value_dim)-sized array to store + * the retrieved `d_values`. For slots where the corresponding `d_founds`-flag + * is not `false`, the value may already have been assigned and, thus, MUST + * not be altered. + * @param d_founds Device pointer to an (n)-sized array which indicates + * whether the corresponding `d_values` slot is already filled or not. So, if + * and only if `d_founds` is still false, the implementation shall attempt to + * retrieve and fill in the value for the corresponding key. If a key/value + * was retrieved successfully from external storage, the implementation MUST + * also set `d_founds` to `true`. + * @param stream Stream that MUST be used for queuing asynchronous CUDA + * operations. If only the input arguments or resources obtained from + * respectively `dev_mem_pool` and `host_mem_pool` are used for such + * operations, it is not necessary to synchronize the stream prior to + * returning from the function. + */ + virtual void find(dev_mem_pool_type& dev_mem_pool, + host_mem_pool_type& host_mem_pool, size_type n, + const key_type* d_keys, // (n) + value_type* d_values, // (n * value_dim) + bool* d_founds, // (n) + cudaStream_t stream) = 0; + + /** + * @brief Attempts to erase the entries associated with the supplied `d_keys`. + * For keys do not exist nothing happens. It is permissible for this function + * to be implemented asynchronously (i.e., to return before the actual + * deletion has happened). + * + * @param dev_mem_pool Memory pool for temporarily allocating device memory. + * @param host_mem_pool Memory pool for temporarily allocating host memory. + * @param n Number of keys provided in `d_keys` arguments. + * @param d_keys Device pointer to an (n)-sized array of keys. This pointer is + * only guarnteed to be valid for the duration of the call. If easure is + * implemented asynchronously, you must make a copy and manage its lifetime + * yourself. + */ + virtual void erase_async(dev_mem_pool_type& dev_mem_pool, + host_mem_pool_type& host_mem_pool, size_type n, + const key_type* d_keys, cudaStream_t stream) = 0; +}; + +} // namespace merlin +} // namespace nv diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 5f4ffda5f..ceee21c97 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -26,6 +26,7 @@ #include #include #include "merlin/core_kernels.cuh" +#include "merlin/external_storage.cuh" #include "merlin/flexible_buffer.cuh" #include "merlin/memory_pool.cuh" #include "merlin/types.cuh" @@ -152,6 +153,8 @@ class HashTable { using DeviceMemoryPool = MemoryPool>; using HostMemoryPool = MemoryPool>; + using external_storage_type = ExternalStorage; + #if THRUST_VERSION >= 101600 static constexpr auto thrust_par = thrust::cuda::par_nosync; #else @@ -169,6 +172,8 @@ class HashTable { * table object. */ ~HashTable() { + unlink_external_storage(); + if (initialized_) { CUDA_CHECK(cudaDeviceSynchronize()); @@ -299,12 +304,17 @@ class HashTable { load_factor, options_.block_size, stream, n, c_table_index_, d_table_, keys, reinterpret_cast(values), metas); } else { - const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))}; + const size_type dev_ws_base_size{n * (sizeof(value_type*) + sizeof(int))}; + const size_type dev_ws_size{dev_ws_base_size + + (ext_store_ ? n : 0) * sizeof(key_type)}; auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)}; auto d_dst{dev_ws.get(0)}; auto d_src_offset{reinterpret_cast(d_dst + n)}; + auto d_evicted_keys{reinterpret_cast(d_src_offset + n)}; - CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream)); + CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_base_size, stream)); + CUDA_CHECK(cudaMemsetAsync(d_evicted_keys, 0xFF, + dev_ws_size - dev_ws_base_size, stream)); { const size_t block_size = options_.block_size; @@ -312,8 +322,15 @@ class HashTable { const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); upsert_kernel - <<>>(d_table_, keys, d_dst, metas, - d_src_offset, N); + <<>>( + d_table_, keys, ext_store_ ? d_evicted_keys : nullptr, d_dst, + metas, d_src_offset, N); + } + + if (ext_store_) { + ext_store_->insert_or_assign( + *dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm, n, + d_evicted_keys, reinterpret_cast(d_dst), stream); } { @@ -326,16 +343,17 @@ class HashTable { } if (options_.io_by_cpu) { - const size_type host_ws_size{dev_ws_size + + const size_type host_ws_size{dev_ws_base_size + n * sizeof(value_type) * dim()}; auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)}; auto h_dst{host_ws.get(0)}; auto h_src_offset{reinterpret_cast(h_dst + n)}; auto h_values{reinterpret_cast(h_src_offset + n)}; - CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size, + CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_base_size, cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size, + CUDA_CHECK(cudaMemcpyAsync(h_values, values, + host_ws_size - dev_ws_base_size, cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -547,6 +565,11 @@ class HashTable { } } + if (ext_store_) { + ext_store_->find(*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds, + stream); + } + CudaCheckError(); } @@ -576,6 +599,10 @@ class HashTable { table_->bucket_max_size, table_->buckets_num, N); } + if (ext_store_) { + ext_store_->erase_async(*dev_mem_pool_, *host_mem_pool_, n, keys, stream); + } + CudaCheckError(); return; } @@ -1097,6 +1124,21 @@ class HashTable { return total_count; } + void link_external_storage( + std::shared_ptr& ext_store) { + MERLIN_CHECK( + ext_store->value_dim == dim(), + "Provided external storage value dimension is not incompatible!"); + + std::unique_lock lock(mutex_); + ext_store_ = ext_store; + } + + void unlink_external_storage() { + std::unique_lock lock(mutex_); + ext_store_.reset(); + } + private: inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; } @@ -1173,6 +1215,8 @@ class HashTable { int c_table_index_ = -1; std::unique_ptr dev_mem_pool_; std::unique_ptr host_mem_pool_; + + std::shared_ptr ext_store_; }; } // namespace merlin