diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1bb02c9bc..1a9acc285 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -163,7 +163,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, for (int64_t i = 0; i < (int64_t)query_num; i++) { auto qs = std::chrono::high_resolution_clock::now(); - if (filtered_search) + if (filtered_search && !tags) { std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; @@ -179,8 +179,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } else if (tags) { - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); + if (!filtered_search) + { + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res); + } + else + { + std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter); + } + for (int64_t r = 0; r < (int64_t)recall_at; r++) { query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; diff --git a/include/abstract_index.h b/include/abstract_index.h index 12feec663..059866f7c 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -62,7 +62,8 @@ class AbstractIndex // Initialize space for res_vectors before calling. template size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, - float *distances, std::vector &res_vectors); + float *distances, std::vector &res_vectors, bool use_filters = false, + const std::string filter_label = ""); // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" @@ -120,7 +121,8 @@ class AbstractIndex virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors) = 0; + float *distances, DataVector &res_vectors, bool use_filters = false, + const std::string filter_label = "") = 0; virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; virtual void _set_universal_label(const LabelType universal_label) = 0; }; diff --git a/include/index.h b/include/index.h index 199171020..b9bf4f384 100644 --- a/include/index.h +++ b/include/index.h @@ -136,7 +136,8 @@ template clas // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors); + float *distances, std::vector &res_vectors, bool use_filters = false, + const std::string filter_label = ""); // Filter support search template @@ -226,7 +227,8 @@ template clas virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors) override; + float *distances, DataVector &res_vectors, bool use_filters = false, + const std::string filter_label = "") override; virtual void _set_universal_label(const LabelType universal_label) override; diff --git a/include/natural_number_map.h b/include/natural_number_map.h index 820ac3fdf..e846882a8 100644 --- a/include/natural_number_map.h +++ b/include/natural_number_map.h @@ -26,9 +26,6 @@ template class natural_number_map { public: static_assert(std::is_trivial::value, "Key must be a trivial type"); - // Some of the class member prototypes are done with this assumption to - // minimize verbosity since it's the only use case. - static_assert(std::is_trivial::value, "Value must be a trivial type"); // Represents a reference to a element in the map. Used while iterating // over map entries. diff --git a/include/tag_uint128.h b/include/tag_uint128.h new file mode 100644 index 000000000..642de3159 --- /dev/null +++ b/include/tag_uint128.h @@ -0,0 +1,68 @@ +#pragma once +#include +#include + +namespace diskann +{ +#pragma pack(push, 1) + +struct tag_uint128 +{ + std::uint64_t _data1 = 0; + std::uint64_t _data2 = 0; + + bool operator==(const tag_uint128 &other) const + { + return _data1 == other._data1 && _data2 == other._data2; + } + + bool operator==(std::uint64_t other) const + { + return _data1 == other && _data2 == 0; + } + + tag_uint128 &operator=(const tag_uint128 &other) + { + _data1 = other._data1; + _data2 = other._data2; + + return *this; + } + + tag_uint128 &operator=(std::uint64_t other) + { + _data1 = other; + _data2 = 0; + + return *this; + } +}; + +#pragma pack(pop) +} // namespace diskann + +namespace std +{ +// Hash 128 input bits down to 64 bits of output. +// This is intended to be a reasonably good hash function. +inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high) +{ + // Murmur-inspired hashing. + const std::uint64_t kMul = 0x9ddfea08eb382d69ULL; + std::uint64_t a = (low ^ high) * kMul; + a ^= (a >> 47); + std::uint64_t b = (high ^ a) * kMul; + b ^= (b >> 47); + b *= kMul; + return b; +} + +template <> struct hash +{ + size_t operator()(const diskann::tag_uint128 &key) const noexcept + { + return Hash128to64(key._data1, key._data2); // map -0 to 0 + } +}; + +} // namespace std \ No newline at end of file diff --git a/include/utils.h b/include/utils.h index bb03d13f1..d3af5c3a9 100644 --- a/include/utils.h +++ b/include/utils.h @@ -27,6 +27,7 @@ typedef int FileHandle; #include "windows_customizations.h" #include "tsl/robin_set.h" #include "types.h" +#include "tag_uint128.h" #include #ifdef EXEC_ENV_OLS @@ -1007,6 +1008,17 @@ void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName); +inline std::string get_tag_string(std::uint64_t tag) +{ + return std::to_string(tag); +} + +inline std::string get_tag_string(const tag_uint128 &tag) +{ + std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1); + return str; +} + }; // namespace diskann struct PivotContainer diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index a7a5986cc..92665825f 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -24,12 +24,13 @@ std::pair AbstractIndex::search(const data_type *query, cons template size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, - float *distances, std::vector &res_vectors) + float *distances, std::vector &res_vectors, bool use_filters, + const std::string filter_label) { auto any_query = std::any(query); auto any_tags = std::any(tags); auto any_res_vectors = DataVector(res_vectors); - return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors); + return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label); } template @@ -162,61 +163,53 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search_w const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, int32_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, - int32_t *tags, float *distances, std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - int32_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, uint32_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors); + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - uint32_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, int64_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, - int64_t *tags, float *distances, std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - int64_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, uint64_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors); + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - uint64_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, size_t L, uint32_t *indices); diff --git a/src/index.cpp b/src/index.cpp index d906600d1..486d41e76 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -12,6 +12,7 @@ #include "tsl/robin_map.h" #include "tsl/robin_set.h" #include "windows_customizations.h" +#include "tag_uint128.h" #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -717,7 +718,7 @@ template int Index std::shared_lock lock(_tag_lock); if (_tag_to_location.find(tag) == _tag_to_location.end()) { - diskann::cout << "Tag " << tag << " does not exist" << std::endl; + diskann::cout << "Tag " << get_tag_string(tag) << " does not exist" << std::endl; return -1; } @@ -2093,24 +2094,7 @@ std::pair Index::search_with_filters(const { if (best_L_nodes[i].id < _max_points) { - // safe because Index uses uint32_t ids internally - // and IDType will be uint32_t or uint64_t - if (_enable_tags) - { - TagT tag; - if (_location_to_tag.try_get(best_L_nodes[i].id, tag)) - { - indices[pos] = (IdType)tag; - } - else - { - continue; - } - } - else - { - indices[pos] = (IdType)best_L_nodes[i].id; - } + indices[pos] = (IdType)best_L_nodes[i].id; if (distances != nullptr) { @@ -2137,12 +2121,13 @@ std::pair Index::search_with_filters(const template size_t Index::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, - const TagType &tags, float *distances, DataVector &res_vectors) + const TagType &tags, float *distances, DataVector &res_vectors, + bool use_filters, const std::string filter_label) { try { return this->search_with_tags(std::any_cast(query), K, L, std::any_cast(tags), distances, - res_vectors.get>()); + res_vectors.get>(), use_filters, filter_label); } catch (const std::bad_any_cast &e) { @@ -2156,7 +2141,8 @@ size_t Index::_search_with_tags(const DataType &query, const ui template size_t Index::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors) + float *distances, std::vector &res_vectors, bool use_filters, + const std::string filter_label) { if (K > (uint64_t)L) { @@ -2176,12 +2162,22 @@ size_t Index::search_with_tags(const T *query, const uint64_t K std::shared_lock ul(_update_lock); const std::vector init_ids = get_init_ids(); - const std::vector unused_filter_label; //_distance->preprocess_query(query, _data_store->get_dims(), // scratch->aligned_query()); _data_store->preprocess_query(query, scratch); - iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + if (!use_filters) + { + const std::vector unused_filter_label; + iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + } + else + { + std::vector filter_vec; + auto converted_label = this->get_converted_label(filter_label); + filter_vec.push_back(converted_label); + iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); + } NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); assert(best_L_nodes.size() <= L); @@ -2861,7 +2857,7 @@ int Index::insert_point(const T *point, const TagT tag, const s { assert(_has_built); - if (tag == static_cast(0)) + if (tag == 0) { throw diskann::ANNException("Do not insert point with tag 0. That is " "reserved for points hidden " @@ -2879,7 +2875,7 @@ int Index::insert_point(const T *point, const TagT tag, const s if (labels.empty()) { release_location(location); - std::cerr << "Error: Can't insert point with tag " + std::to_string(tag) + + std::cerr << "Error: Can't insert point with tag " + get_tag_string(tag) + " . there are no labels for the point." << std::endl; return -1; @@ -3047,7 +3043,7 @@ template int Index if (_tag_to_location.find(tag) == _tag_to_location.end()) { - diskann::cerr << "Delete tag not found " << tag << std::endl; + diskann::cerr << "Delete tag not found " << get_tag_string(tag) << std::endl; return -1; } assert(_tag_to_location[tag] < _max_points); @@ -3336,6 +3332,9 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; // Label with short int 2 byte template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; @@ -3349,6 +3348,9 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3477,4 +3479,5 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + } // namespace diskann diff --git a/src/natural_number_map.cpp b/src/natural_number_map.cpp index 9050831a2..a996dcf75 100644 --- a/src/natural_number_map.cpp +++ b/src/natural_number_map.cpp @@ -5,6 +5,7 @@ #include #include "natural_number_map.h" +#include "tag_uint128.h" namespace diskann { @@ -111,4 +112,5 @@ template class natural_number_map; template class natural_number_map; template class natural_number_map; template class natural_number_map; +template class natural_number_map; } // namespace diskann