-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FilterStore: unifying filter specific logic #452
base: main
Are you sure you want to change the base?
Changes from 29 commits
c445d20
cf389f8
2c07ffc
497abca
b5dfd90
bc92303
77c5444
0fa86c2
1faa5b8
ff7955a
c86f085
275afc1
91eb6c0
a0cd607
9402f01
26aa806
0c73589
9bcedad
887e644
a9ab92f
119ee63
332de43
e471cf9
3992c97
b18ce98
8a5c700
eded185
392c0ec
8cfcd5f
f4b430b
21925ee
1cb4aae
7edc594
8ba9475
11f8be4
e978f98
ccb187c
9ab9445
3cfaf3e
4f1e81b
a7f6b44
6f0f8f5
68e1dbf
615247a
0b9118f
7462bd0
84174bd
86159e8
ce7db4d
82f7182
e669521
df38242
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#pragma once | ||
#include "common_includes.h" | ||
#include "utils.h" | ||
#include <any> | ||
|
||
namespace diskann | ||
{ | ||
|
||
enum class FilterMatchStrategy | ||
{ | ||
SET_INTERSECTION | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
// This class is responsible for filter actions in index, and should not be used outside. | ||
template <typename label_type> class AbstractFilterStore | ||
NeelamMahapatro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
public: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also didn't find methods to expand and shrink the filter store. |
||
DISKANN_DLLEXPORT AbstractFilterStore(const size_t num_points); | ||
virtual ~AbstractFilterStore() = default; | ||
|
||
// needs some internal lock + abstract implementation | ||
DISKANN_DLLEXPORT virtual bool detect_common_filters( | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
uint32_t point_id, bool search_invocation, const std::vector<label_type> &incoming_labels, | ||
const FilterMatchStrategy strategy = FilterMatchStrategy::SET_INTERSECTION) = 0; | ||
|
||
DISKANN_DLLEXPORT virtual const std::vector<label_type> &get_labels_by_location(const location_t point_id) = 0; | ||
DISKANN_DLLEXPORT virtual void set_labels_to_location(const location_t location, | ||
const std::vector<label_type> &labels) = 0; | ||
DISKANN_DLLEXPORT virtual void swap_labels(const location_t location_first, const location_t location_second) = 0; | ||
|
||
DISKANN_DLLEXPORT virtual const tsl::robin_set<label_type> &get_all_label_set() = 0; | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DISKANN_DLLEXPORT virtual void add_to_label_set(label_type &label) = 0; | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Throws: out of range exception | ||
DISKANN_DLLEXPORT virtual void add_label_to_location(const location_t point_id, label_type label) = 0; | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// returns internal mapping for given raw_label | ||
DISKANN_DLLEXPORT virtual label_type get_converted_label(const std::string &raw_label) = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, methods that are not modifying internal datastructures should be declared const. |
||
|
||
DISKANN_DLLEXPORT virtual void update_medoid_by_label(const label_type &label, const uint32_t new_medoid) = 0; | ||
DISKANN_DLLEXPORT virtual const uint32_t &get_medoid_by_label(const label_type &label) = 0; | ||
DISKANN_DLLEXPORT virtual const std::unordered_map<label_type, uint32_t> &get_labels_to_medoids() = 0; | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DISKANN_DLLEXPORT virtual bool label_has_medoid(const label_type &label) = 0; | ||
DISKANN_DLLEXPORT virtual void calculate_best_medoids(const size_t num_points_to_load, | ||
const uint32_t num_candidates) = 0; | ||
|
||
// TODO: in future we may accept a set or vector of universal labels | ||
// DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0; | ||
DISKANN_DLLEXPORT virtual void set_universal_labels(const std::vector<std::string> &universal_labels, | ||
bool dynamic_index = false) = 0; | ||
// DISKANN_DLLEXPORT virtual const label_type get_universal_label() const = 0; | ||
|
||
// takes raw label file and then genrate internal mapping file and keep the info of mapping | ||
DISKANN_DLLEXPORT virtual size_t load_raw_labels(const std::string &raw_labels_file) = 0; | ||
|
||
DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0; | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of | ||
// raw labels to compute GT correctly. | ||
DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0; | ||
DISKANN_DLLEXPORT virtual void save_medoids(const std::string &save_path) = 0; | ||
DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0; | ||
DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we also need a public load() which calls the protected load* methods? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, need to do this, and perhaps remove the "friend" relation between filter store and index class? |
||
protected: | ||
// This is for internal use and only loads already parsed file | ||
DISKANN_DLLEXPORT virtual size_t load_labels(const std::string &labels_file) = 0; | ||
DISKANN_DLLEXPORT virtual size_t load_medoids(const std::string &labels_to_medoid_file) = 0; | ||
DISKANN_DLLEXPORT virtual void load_label_map(const std::string &labels_map_file) = 0; | ||
DISKANN_DLLEXPORT virtual void load_universal_labels(const std::string &universal_labels_file) = 0; | ||
|
||
private: | ||
size_t _num_points; | ||
|
||
// populates pts_to labels and _labels from given label file | ||
virtual size_t parse_label_file(const std::string &label_file) = 0; | ||
|
||
// mark Index as friend so it can access protected loads | ||
template <typename T, typename TagT, typename LabelT> friend class Index; | ||
}; | ||
|
||
} // namespace diskann |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,7 +103,8 @@ class AbstractIndex | |
// memory should be allocated for vec before calling this function | ||
template <typename tag_type, typename data_type> int get_vector_by_tag(tag_type &tag, data_type *vec); | ||
|
||
template <typename label_type> void set_universal_label(const label_type universal_label); | ||
virtual void set_universal_labels(const std::vector<std::string> &raw_universal_labels, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if Index class should get involved in setting universal labels. Why not set it directly in the filter store when it is created? |
||
bool dynamic_index = false) = 0; | ||
|
||
private: | ||
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0; | ||
|
@@ -122,6 +123,5 @@ class AbstractIndex | |
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, | ||
float *distances, DataVector &res_vectors) = 0; | ||
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; | ||
virtual void _set_universal_label(const LabelType universal_label) = 0; | ||
}; | ||
} // namespace diskann |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#pragma once | ||
#include <abstract_filter_store.h> | ||
|
||
namespace diskann | ||
{ | ||
|
||
// This class is responsible for filter actions in index, and should not be used outside. | ||
template <typename label_type> class InMemFilterStore : public AbstractFilterStore<label_type> | ||
{ | ||
public: | ||
InMemFilterStore(const size_t num_points); | ||
~InMemFilterStore() = default; | ||
|
||
// needs some internal lock | ||
bool detect_common_filters(uint32_t point_id, bool search_invocation, | ||
const std::vector<label_type> &incoming_labels, | ||
const FilterMatchStrategy filter_match_strategy) override; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. define 2 detect_common_filters for base points and queries |
||
const std::vector<label_type> &get_labels_by_location(const location_t point_id) override; | ||
void set_labels_to_location(const location_t location, const std::vector<label_type> &labels); | ||
void swap_labels(const location_t location_first, const location_t location_second) override; | ||
const tsl::robin_set<label_type> &get_all_label_set() override; | ||
void add_to_label_set(label_type &label) override; | ||
// Throws: out of range exception | ||
void add_label_to_location(const location_t point_id, label_type label) override; | ||
// returns internal mapping for given raw_label | ||
label_type get_converted_label(const std::string &raw_label) override; | ||
|
||
void update_medoid_by_label(const label_type &label, const uint32_t new_medoid) override; | ||
const uint32_t &get_medoid_by_label(const label_type &label) override; | ||
const std::unordered_map<label_type, uint32_t> &get_labels_to_medoids() override; | ||
bool label_has_medoid(const label_type &label) override; | ||
void calculate_best_medoids(const size_t num_points_to_load, const uint32_t num_candidates) override; | ||
|
||
// takes raw universal labels and map them internally. | ||
void set_universal_labels(const std::vector<std::string> &raw_universal_labels, | ||
bool dyanmic_index = false) override; | ||
// const label_type get_universal_label() const; | ||
|
||
// ideally takes raw label file and then genrate internal mapping file and keep the info of mapping | ||
size_t load_raw_labels(const std::string &raw_labels_file) override; | ||
|
||
void save_labels(const std::string &save_path, const size_t total_points) override; | ||
// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of | ||
// raw labels to compute GT correctly. | ||
void save_raw_labels(const std::string &save_path, const size_t total_points) override; | ||
void save_medoids(const std::string &save_path) override; | ||
void save_label_map(const std::string &save_path) override; | ||
void save_universal_label(const std::string &save_path) override; | ||
|
||
// The function is static so it remains the source of truth across the code. Returns label map | ||
DISKANN_DLLEXPORT static std::unordered_map<std::string, label_type> convert_labels_string_to_int( | ||
const std::string &inFileName, const std::string &outFileName, const std::string &mapFileName, | ||
const std::set<std::string> &raw_universal_labels); | ||
|
||
protected: | ||
// This is for internal use and only loads already parsed file, used by index in during load(). | ||
size_t load_labels(const std::string &labels_file) override; | ||
size_t load_medoids(const std::string &labels_to_medoid_file) override; | ||
void load_label_map(const std::string &labels_map_file) override; | ||
void load_universal_labels(const std::string &universal_labels_file) override; | ||
|
||
private: | ||
size_t _num_points; | ||
std::vector<std::vector<label_type>> _location_to_labels; | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tsl::robin_set<label_type> _labels; | ||
std::unordered_map<std::string, label_type> _label_map; | ||
|
||
// medoids | ||
std::unordered_map<label_type, uint32_t> _label_to_medoid_id; | ||
std::unordered_map<uint32_t, uint32_t> _medoid_counts; // medoids only happen for filtered index | ||
rakri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// universal label | ||
bool _use_universal_label = false; | ||
tsl::robin_set<label_type> _mapped_universal_labels; | ||
std::set<std::string> _raw_universal_labels; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Label_type , string for single filter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, it does not hurt to keep this. the logic works fine. but if you think we just need single filter for now. feel free to change the logic. |
||
|
||
// populates _loaction_to labels and _labels from given label file | ||
size_t parse_label_file(const std::string &label_file); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename filter_utils to gt_filter_utils |
||
|
||
bool detect_common_filters_by_set_intersection(uint32_t point_id, bool search_invocation, | ||
const std::vector<label_type> &incoming_labels); | ||
}; | ||
|
||
} // namespace diskann |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
#include "in_mem_data_store.h" | ||
#include "in_mem_graph_store.h" | ||
#include "abstract_index.h" | ||
#include "in_mem_filter_store.h" | ||
|
||
#define OVERHEAD_FACTOR 1.1 | ||
#define EXPAND_IF_FULL 0 | ||
|
@@ -103,7 +104,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas | |
const size_t num_points_to_load, | ||
const std::vector<TagT> &tags = std::vector<TagT>()); | ||
|
||
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); | ||
// DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); | ||
DISKANN_DLLEXPORT void set_universal_labels(const std::vector<std::string> &raw_labels, bool dynamic_index = false); | ||
|
||
// Get converted integer label from string to int map (_label_map) | ||
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label); | ||
|
@@ -222,8 +224,6 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas | |
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, | ||
float *distances, DataVector &res_vectors) override; | ||
|
||
virtual void _set_universal_label(const LabelType universal_label) override; | ||
|
||
// No copy/assign. | ||
Index(const Index<T, TagT, LabelT> &) = delete; | ||
Index<T, TagT, LabelT> &operator=(const Index<T, TagT, LabelT> &) = delete; | ||
|
@@ -334,6 +334,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas | |
// Graph related data structures | ||
std::unique_ptr<AbstractGraphStore> _graph_store; | ||
|
||
// Filter Store | ||
std::unique_ptr<AbstractFilterStore<LabelT>> _filter_store; | ||
|
||
char *_opt_graph = nullptr; | ||
|
||
T *_data = nullptr; // coordinates of all base points | ||
|
@@ -369,18 +372,14 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas | |
// Filter Support | ||
|
||
bool _filtered_index = false; | ||
// Location to label is only updated during insert_point(), all other reads are protected by | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. insert_point should take string formatted labels (since exposed to user). conversion to label type should happen internally |
||
// default as a location can only be released at end of consolidate deletes | ||
std::vector<std::vector<LabelT>> _location_to_labels; | ||
tsl::robin_set<LabelT> _labels; | ||
std::string _labels_file; | ||
std::unordered_map<LabelT, uint32_t> _label_to_start_id; | ||
std::unordered_map<uint32_t, uint32_t> _medoid_counts; | ||
|
||
bool _use_universal_label = false; | ||
LabelT _universal_label = 0; | ||
/* std::vector<std::vector<LabelT>> _pts_to_labels; | ||
tsl::robin_set<LabelT> _labels; | ||
std::unordered_map<LabelT, uint32_t> _label_to_medoid_id; | ||
std::unordered_map<uint32_t, uint32_t> _medoid_counts; | ||
bool _use_universal_label = false; | ||
LabelT _universal_label = 0; | ||
std::unordered_map<std::string, LabelT> _label_map;*/ | ||
uint32_t _filterIndexingQueueSize; | ||
std::unordered_map<std::string, LabelT> _label_map; | ||
|
||
// Indexing parameters | ||
uint32_t _indexingQueueSize; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be the responsibility of filter store to convert labels from strings to int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this here because, the end user does not know if we have an internal mapping and how its done, similarly Index does not know that string labels is a thing, it just expects number labels. So filter store seems like an ideal place to do this conversion. That way it can also keep the mapping logic and also the map itself.