diff --git a/include/merlin/external_storage.cuh b/include/merlin/external_storage.cuh new file mode 100644 index 000000000..2670887f0 --- /dev/null +++ b/include/merlin/external_storage.cuh @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include "merlin/memory_pool.cuh" + +namespace nv { +namespace merlin { + +template +class ExternalStorage { + public: + using size_type = size_t; + using key_type = Key; + using value_type = Value; + + using dev_mem_pool_type = MemoryPool>; + using host_mem_pool_type = MemoryPool>; + + const size_type value_dim; + + ExternalStorage() = delete; + + /** + * Constructs external storage object. + * + * @param value_dim The dimensionality of the values. In other words, each + * value stored is exactly `value_dim * sizeof(value_type)` bytes large. + */ + ExternalStorage(const size_type value_dim) : value_dim{value_dim} {} + + /** + * @brief Inserts key/value pairs into the external storage that are about to + * be evicted from the Merlin hashtable. If a key/value pair already exists, + * overwrites the current value. + * + * @param dev_mem_pool Memory pool for temporarily allocating device memory. + * @param host_mem_pool Memory pool for temporarily allocating host memory. + * @param hkvs_is_pure_hbm True if the Merlin hashtable store is currently + * operating in pure HBM mode, false otherwise. In pure HBM mode, all `values` + * pointers are GUARANTEED to point to device memory. + * @param n Number of key/value slots provided in other arguments. + * @param d_masked_keys Device pointer to an (n)-sized array of keys. + * Key-Value slots that should be ignored have the key set tO `EMPTY_KEY`. + * @param d_values Device pointer to an (n)-sized array containing pointers to + * respectively a memory location where the current values for a key are + * stored. Each pointer points to a vector of length `value_dim`. Pointers + * *can* be set to `nullptr` for slots where the corresponding key equated to + * the `EMPTY_KEY`. The memory locations can be device or host memory (see + * also `hkvs_is_pure_hbm`). + * @param stream Stream that MUST be used for queuing asynchronous CUDA + * operations. If only the input arguments or resources obtained from + * respectively `dev_mem_pool` and `host_mem_pool` are used for such + * operations, it is not necessary to synchronize the stream prior to + * returning from the function. + */ + virtual void insert_or_assign(dev_mem_pool_type& dev_mem_pool, + host_mem_pool_type& host_mem_pool, + bool hkvs_is_pure_hbm, size_type n, + const key_type* d_masked_keys, // (n) + const value_type* const* d_values, // (n) + cudaStream_t stream) = 0; + + /** + * @brief Attempts to find the supplied `d_keys` if the corresponding + * `d_founds`-flag is `false` and fills the stored into the supplied memory + * locations (i.e. in `d_values`). + * + * @param dev_mem_pool Memory pool for temporarily allocating device memory. + * @param host_mem_pool Memory pool for temporarily allocating host memory. + * @param n Number of key/value slots provided in other arguments. + * @param d_keys Device pointer to an (n)-sized array of keys. + * @param d_values Device pointer to an (n * value_dim)-sized array to store + * the retrieved `d_values`. For slots where the corresponding `d_founds`-flag + * is not `false`, the value may already have been assigned and, thus, MUST + * not be altered. + * @param d_founds Device pointer to an (n)-sized array which indicates + * whether the corresponding `d_values` slot is already filled or not. So, if + * and only if `d_founds` is still false, the implementation shall attempt to + * retrieve and fill in the value for the corresponding key. If a key/value + * was retrieved successfully from external storage, the implementation MUST + * also set `d_founds` to `true`. + * @param stream Stream that MUST be used for queuing asynchronous CUDA + * operations. If only the input arguments or resources obtained from + * respectively `dev_mem_pool` and `host_mem_pool` are used for such + * operations, it is not necessary to synchronize the stream prior to + * returning from the function. + */ + virtual void find(dev_mem_pool_type& dev_mem_pool, + host_mem_pool_type& host_mem_pool, size_type n, + const key_type* d_keys, // (n) + value_type* d_values, // (n * value_dim) + bool* d_founds, // (n) + cudaStream_t stream) = 0; +}; + +} // namespace merlin +} // namespace nv diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index cf71a816f..d5d78e593 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -25,6 +25,7 @@ #include #include #include "merlin/core_kernels.cuh" +#include "merlin/external_storage.cuh" #include "merlin/flexible_buffer.cuh" #include "merlin/memory_pool.cuh" #include "merlin/types.cuh" @@ -160,6 +161,8 @@ class HashTable { using DeviceMemoryPool = MemoryPool>; using HostMemoryPool = MemoryPool>; + using external_storage_type = ExternalStorage; + #if THRUST_VERSION >= 101600 static constexpr auto thrust_par = thrust::cuda::par_nosync; #else @@ -179,6 +182,8 @@ class HashTable { ~HashTable() { CUDA_CHECK(cudaDeviceSynchronize()); + unlink_external_storage(); + // Erase table. if (initialized_) { destroy_table(&table_); @@ -308,9 +313,12 @@ class HashTable { } } else { const size_t dev_ws_size = n * (sizeof(vector_type*) + sizeof(int)); - auto dev_ws = dev_mem_pool_->get_workspace<1>(dev_ws_size, stream); + auto dev_ws = dev_mem_pool_->get_workspace<1>( + dev_ws_size + (ext_store_ ? n * sizeof(key_type) : 0), stream); auto d_dst = dev_ws.get(0); auto d_src_offset = reinterpret_cast(d_dst + n); + auto d_evicted_keys = + ext_store_ ? reinterpret_cast(d_src_offset + n) : nullptr; CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream)); @@ -322,18 +330,26 @@ class HashTable { if (metas == nullptr) { upsert_kernel <<>>( - table_, keys, d_dst, table_->buckets, table_->buckets_size, - table_->bucket_max_size, table_->buckets_num, d_src_offset, - N); + table_, keys, + /* d_evicted_keys, */ d_dst, table_->buckets, + table_->buckets_size, table_->bucket_max_size, + table_->buckets_num, d_src_offset, N); } else { upsert_kernel <<>>( - table_, keys, d_dst, metas, table_->buckets, + table_, keys, + /* d_evicted_keys, */ d_dst, metas, table_->buckets, table_->buckets_size, table_->bucket_max_size, table_->buckets_num, d_src_offset, N); } } + if (ext_store_) { + ext_store_->insert_or_assign( + *dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm, n, + d_evicted_keys, reinterpret_cast(d_dst), stream); + } + { thrust::device_ptr d_dst_ptr( reinterpret_cast(d_dst)); @@ -575,6 +591,11 @@ class HashTable { } } + if (ext_store_) { + ext_store_->find(*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds, + stream); + } + CudaCheckError(); } @@ -1113,6 +1134,21 @@ class HashTable { return total_count; } + void link_external_storage( + std::shared_ptr& ext_store) { + MERLIN_CHECK( + ext_store->value_dim == DIM, + "Provided external storage value dimension is not incompatible!"); + + std::unique_lock lock(mutex_); + ext_store_ = ext_store; + } + + void unlink_external_storage() { + std::unique_lock lock(mutex_); + ext_store_.reset(); + } + private: inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; } @@ -1171,6 +1207,8 @@ class HashTable { std::unique_ptr dev_mem_pool_; std::unique_ptr host_mem_pool_; + + std::shared_ptr ext_store_; }; } // namespace merlin