diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h index ee9b15dc1..38cd21358 100644 --- a/include/abstract_filter_store.h +++ b/include/abstract_filter_store.h @@ -43,18 +43,24 @@ template class AbstractFilterStore DISKANN_DLLEXPORT virtual size_t populate_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) = 0; - DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0; + // save labels, labels_map and universal_label to files + DISKANN_DLLEXPORT virtual void save(const std::string &save_path, const size_t total_points) = 0; + + // load labels, labels_map and universal_label to filter store variables & returns total number of points + DISKANN_DLLEXPORT virtual size_t load(const std::string &load_path) = 0; + // For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of // raw labels to compute GT correctly. DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0; - DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0; - DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0; protected: // This is for internal use and only loads already parsed file DISKANN_DLLEXPORT virtual size_t load_labels(const std::string &labels_file) = 0; DISKANN_DLLEXPORT virtual void load_label_map(const std::string &labels_map_file) = 0; DISKANN_DLLEXPORT virtual void load_universal_labels(const std::string &universal_labels_file) = 0; + DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0; + DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0; + DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0; private: size_t _num_points; diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index cbeab45b8..134f3a0a9 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -36,12 +36,15 @@ template class InMemFilterStore : public AbstractFilterSto // ideally takes raw label file and then genrate internal mapping file and keep the info of mapping size_t populate_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) override; - void save_labels(const std::string &save_path, const size_t total_points) override; + // save labels, labels_map and universal_label to files + void save(const std::string &save_path, const size_t total_points) override; + + // load labels, labels_map and universal_label to filter store variables & returns total number of points + size_t load(const std::string &load_path) override; + // For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of // raw labels to compute GT correctly. void save_raw_labels(const std::string &save_path, const size_t total_points) override; - void save_label_map(const std::string &save_path) override; - void save_universal_label(const std::string &save_path) override; // The function is static so it remains the source of truth across the code. Returns label map DISKANN_DLLEXPORT static std::unordered_map convert_label_to_numeric( @@ -54,6 +57,9 @@ template class InMemFilterStore : public AbstractFilterSto size_t load_labels(const std::string &labels_file) override; void load_label_map(const std::string &labels_map_file) override; void load_universal_labels(const std::string &universal_labels_file) override; + void save_labels(const std::string &save_path, const size_t total_points) override; + void save_label_map(const std::string &save_path) override; + void save_universal_label(const std::string &save_path) override; private: size_t _num_points; diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp index e125f8adb..d02ef680e 100644 --- a/src/in_mem_filter_store.cpp +++ b/src/in_mem_filter_store.cpp @@ -159,6 +159,28 @@ void InMemFilterStore::load_universal_labels(const std::string &univ } } +// load labels, labels_map and universal_label to filter store variables & returns total number of points +template size_t InMemFilterStore::load(const std::string &load_path) +{ + const std::string labels_file = load_path + "_labels.txt"; + const std::string labels_map_file = load_path + "_labels_map.txt"; + const std::string universal_label_file = load_path + "_universal_label.txt"; + load_label_map(labels_map_file); + load_universal_labels(universal_label_file); + return load_labels(labels_file); +} + +template +void InMemFilterStore::save(const std::string &save_path, const size_t total_points) +{ + const std::string label_path = save_path + "_labels.txt"; + const std::string universal_label_path = save_path + "_universal_label.txt"; + const std::string label_map_path = save_path + "_labels_map.txt"; + save_label_map(label_map_path); + save_universal_label(universal_label_path); + save_labels(label_path, total_points); +} + template void InMemFilterStore::save_labels(const std::string &save_path, const size_t total_points) { diff --git a/src/index.cpp b/src/index.cpp index 5722683a6..030115f63 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -291,13 +291,12 @@ void Index::save(const char *filename, bool compact_before_save if (_filtered_index) { save_medoids(std::string(filename) + "_labels_to_medoids.txt"); - _filter_store->save_label_map(std::string(filename) + "_labels_map.txt"); - _filter_store->save_universal_label(std::string(filename) + "_universal_label.txt"); - _filter_store->save_labels(std::string(filename) + "_labels.txt", _nd + _num_frozen_pts); + _filter_store->save(std::string(filename), _nd + _num_frozen_pts); // if data was compacted we need a compacted version of corresponding raw labels to compute GT if (compact_before_save && _dynamic_index) { - _filter_store->load_label_map(std::string(filename) + "_labels_map.txt"); + // _label_map is already loaded. Feels this function is not required. + //_filter_store->load_label_map(std::string(filename) + "_labels_map.txt"); _filter_store->save_raw_labels(std::string(filename) + "_raw_labels.txt", _nd + _num_frozen_pts); } } @@ -486,8 +485,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui std::string mem_index_file(filename); std::string labels_file = mem_index_file + "_labels.txt"; - std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; - std::string labels_map_file = mem_index_file + "_labels_map.txt"; + std::string labels_to_medoids_file = mem_index_file + "_labels_to_medoids.txt"; if (!_save_as_one_file) { @@ -534,11 +532,9 @@ void Index::load(const char *filename, uint32_t num_threads, ui { _filter_store = std::make_unique>(_max_points + _num_frozen_pts); } - _filter_store->load_label_map(labels_map_file); - label_num_pts = _filter_store->load_labels(labels_file); + label_num_pts = _filter_store->load(mem_index_file); assert(label_num_pts == data_file_num_pts); - load_medoids(labels_to_medoids); - _filter_store->load_universal_labels(std::string(filename) + "_universal_label.txt"); + load_medoids(labels_to_medoids_file); } _nd = data_file_num_pts - _num_frozen_pts;