From 50abf40b67c49835b272920d00f380168c5eac13 Mon Sep 17 00:00:00 2001 From: broccoliSpicy Date: Tue, 9 Apr 2024 17:59:14 -0400 Subject: [PATCH] add custom allocator --- c/lib.cpp | 4 ++- c/usearch.h | 5 +++ include/usearch/index_dense.hpp | 12 +++++-- include/usearch/index_plugins.hpp | 54 +++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/c/lib.cpp b/c/lib.cpp index cd09ac8a..dff3b684 100644 --- a/c/lib.cpp +++ b/c/lib.cpp @@ -19,7 +19,8 @@ using lantern_storage_t = lantern_external_storage_t; #else using lantern_storage_t = lantern_internal_storage_t; #endif -using index_dense_t = index_dense_gt; +using custom_allocator_t = custom_allocator_gt; +using index_dense_t = index_dense_gt; using add_result_t = typename index_dense_t::add_result_t; using search_result_t = typename index_dense_t::search_result_t; @@ -145,6 +146,7 @@ USEARCH_EXPORT usearch_index_t usearch_init(usearch_init_options_t* options, flo opts.num_centroids = options->num_centroids; opts.num_subvectors = options->num_subvectors; opts.scalar_bytes = bits_per_scalar(scalar_kind) / 8; + index_dense_t::dynamic_allocator_t(options->alloc_func, options->free_func); index_dense_t index = index_dense_t::make(metric, opts, options->num_threads, config, codebook); if (options->retriever != nullptr || options->retriever_mut != nullptr) { diff --git a/c/usearch.h b/c/usearch.h index 31cb3a75..af1e4be5 100644 --- a/c/usearch.h +++ b/c/usearch.h @@ -67,6 +67,9 @@ USEARCH_EXPORT typedef enum usearch_scalar_kind_t { usearch_scalar_b1_k, } usearch_scalar_kind_t; +USEARCH_EXPORT typedef void* (*usearch_alloc_func)(size_t); +USEARCH_EXPORT typedef void (*usearch_free_func)(void *); + USEARCH_EXPORT typedef struct usearch_init_options_t { /** * @brief The metric kind used for distance calculation between vectors. @@ -110,6 +113,8 @@ USEARCH_EXPORT typedef struct usearch_init_options_t { bool pq; size_t num_centroids; size_t num_subvectors; + usearch_alloc_func alloc_func; + usearch_free_func free_func; } usearch_init_options_t; USEARCH_EXPORT typedef struct { diff --git a/include/usearch/index_dense.hpp b/include/usearch/index_dense.hpp index 7d617fbf..80b75d95 100644 --- a/include/usearch/index_dense.hpp +++ b/include/usearch/index_dense.hpp @@ -18,7 +18,7 @@ namespace unum { namespace usearch { -template class index_dense_gt; +template class index_dense_gt; /** * @brief The "magic" sequence helps infer the type of the file. @@ -294,7 +294,8 @@ inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_map */ template > // + typename storage_at = storage_v2_at, + typename dynamic_allocator_at = std::allocator> // class index_dense_gt { public: using vector_key_t = key_at; @@ -313,6 +314,8 @@ class index_dense_gt { using serialization_config_t = index_dense_serialization_config_t; using storage_t = storage_at; + using dynamic_allocator_t = dynamic_allocator_at; + private: /// @brief Schema: input buffer, bytes in input buffer, output buffer. using cast_t = std::function; @@ -321,7 +324,10 @@ class index_dense_gt { storage_t, // distance_t, vector_key_t, compressed_slot_t, // dynamic_allocator_t>; - using index_allocator_t = aligned_allocator_gt; + + using dynamic_allocator_traits_t = std::allocator_traits; + using index_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using member_iterator_t = typename index_t::member_iterator_t; using member_citerator_t = typename index_t::member_citerator_t; diff --git a/include/usearch/index_plugins.hpp b/include/usearch/index_plugins.hpp index b7582360..5d546995 100644 --- a/include/usearch/index_plugins.hpp +++ b/include/usearch/index_plugins.hpp @@ -568,6 +568,60 @@ using executor_default_t = executor_stl_t; #endif +template +class custom_allocator_gt { +public: + // Standard allocator types + using value_type = element_at; + using size_type = std::size_t; + using pointer = element_at*; + using const_pointer = const element_at*; + using difference_type = std::ptrdiff_t; + + // Function pointer types + using alloc_func_ptr = void* (*)(size_t); + using free_func_ptr = void (*)(void*); + + // Rebind mechanism + template struct rebind { + using other = custom_allocator_gt; + }; + + // Constructor + custom_allocator_gt(alloc_func_ptr alloc_func = std::malloc, free_func_ptr free_func = std::free) + : alloc_func(alloc_func), free_func(free_func) {} + + // Allocate memory for n elements + pointer allocate(size_type n) const { + return static_cast(alloc_func(n * sizeof(value_type))); + } + + // Deallocate memory pointed to by p + void deallocate(pointer p, size_type n) const { + free_func(p); + } + + // Maximum size that may be allocated + size_type max_size() const noexcept { + return std::allocator_traits::max_size(*this); + } + + // Compare two allocators for equality (always true for stateless allocators) + template + bool operator==(const custom_allocator_gt&) const noexcept { + return true; + } + + template + bool operator!=(const custom_allocator_gt&) const noexcept { + return false; + } + +private: + alloc_func_ptr alloc_func; + free_func_ptr free_func; +}; + /** * @brief Uses OS-specific APIs for aligned memory allocations. */