Skip to content

Commit

Permalink
combine load and save method
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Jan 30, 2024
1 parent e669521 commit df38242
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 16 deletions.
12 changes: 9 additions & 3 deletions include/abstract_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,24 @@ template <typename label_type> class AbstractFilterStore
DISKANN_DLLEXPORT virtual size_t populate_labels(const std::string &raw_labels_file,
const std::string &raw_universal_label) = 0;

DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0;
// save labels, labels_map and universal_label to files
DISKANN_DLLEXPORT virtual void save(const std::string &save_path, const size_t total_points) = 0;

// load labels, labels_map and universal_label to filter store variables & returns total number of points
DISKANN_DLLEXPORT virtual size_t load(const std::string &load_path) = 0;

// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
// raw labels to compute GT correctly.
DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0;
DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0;
DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0;

protected:
// This is for internal use and only loads already parsed file
DISKANN_DLLEXPORT virtual size_t load_labels(const std::string &labels_file) = 0;
DISKANN_DLLEXPORT virtual void load_label_map(const std::string &labels_map_file) = 0;
DISKANN_DLLEXPORT virtual void load_universal_labels(const std::string &universal_labels_file) = 0;
DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0;
DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0;
DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0;

private:
size_t _num_points;
Expand Down
12 changes: 9 additions & 3 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
// ideally takes raw label file and then genrate internal mapping file and keep the info of mapping
size_t populate_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) override;

void save_labels(const std::string &save_path, const size_t total_points) override;
// save labels, labels_map and universal_label to files
void save(const std::string &save_path, const size_t total_points) override;

// load labels, labels_map and universal_label to filter store variables & returns total number of points
size_t load(const std::string &load_path) override;

// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
// raw labels to compute GT correctly.
void save_raw_labels(const std::string &save_path, const size_t total_points) override;
void save_label_map(const std::string &save_path) override;
void save_universal_label(const std::string &save_path) override;

// The function is static so it remains the source of truth across the code. Returns label map
DISKANN_DLLEXPORT static std::unordered_map<std::string, label_type> convert_label_to_numeric(
Expand All @@ -54,6 +57,9 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
size_t load_labels(const std::string &labels_file) override;
void load_label_map(const std::string &labels_map_file) override;
void load_universal_labels(const std::string &universal_labels_file) override;
void save_labels(const std::string &save_path, const size_t total_points) override;
void save_label_map(const std::string &save_path) override;
void save_universal_label(const std::string &save_path) override;

private:
size_t _num_points;
Expand Down
22 changes: 22 additions & 0 deletions src/in_mem_filter_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,28 @@ void InMemFilterStore<label_type>::load_universal_labels(const std::string &univ
}
}

// load labels, labels_map and universal_label to filter store variables & returns total number of points
template <typename label_type> size_t InMemFilterStore<label_type>::load(const std::string &load_path)
{
const std::string labels_file = load_path + "_labels.txt";
const std::string labels_map_file = load_path + "_labels_map.txt";
const std::string universal_label_file = load_path + "_universal_label.txt";
load_label_map(labels_map_file);
load_universal_labels(universal_label_file);
return load_labels(labels_file);
}

template <typename label_type>
void InMemFilterStore<label_type>::save(const std::string &save_path, const size_t total_points)
{
const std::string label_path = save_path + "_labels.txt";
const std::string universal_label_path = save_path + "_universal_label.txt";
const std::string label_map_path = save_path + "_labels_map.txt";
save_label_map(label_map_path);
save_universal_label(universal_label_path);
save_labels(label_path, total_points);
}

template <typename label_type>
void InMemFilterStore<label_type>::save_labels(const std::string &save_path, const size_t total_points)
{
Expand Down
16 changes: 6 additions & 10 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,12 @@ void Index<T, TagT, LabelT>::save(const char *filename, bool compact_before_save
if (_filtered_index)
{
save_medoids(std::string(filename) + "_labels_to_medoids.txt");
_filter_store->save_label_map(std::string(filename) + "_labels_map.txt");
_filter_store->save_universal_label(std::string(filename) + "_universal_label.txt");
_filter_store->save_labels(std::string(filename) + "_labels.txt", _nd + _num_frozen_pts);
_filter_store->save(std::string(filename), _nd + _num_frozen_pts);
// if data was compacted we need a compacted version of corresponding raw labels to compute GT
if (compact_before_save && _dynamic_index)
{
_filter_store->load_label_map(std::string(filename) + "_labels_map.txt");
// _label_map is already loaded. Feels this function is not required.
//_filter_store->load_label_map(std::string(filename) + "_labels_map.txt");
_filter_store->save_raw_labels(std::string(filename) + "_raw_labels.txt", _nd + _num_frozen_pts);
}
}
Expand Down Expand Up @@ -486,8 +485,7 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui

std::string mem_index_file(filename);
std::string labels_file = mem_index_file + "_labels.txt";
std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt";
std::string labels_map_file = mem_index_file + "_labels_map.txt";
std::string labels_to_medoids_file = mem_index_file + "_labels_to_medoids.txt";

if (!_save_as_one_file)
{
Expand Down Expand Up @@ -534,11 +532,9 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
{
_filter_store = std::make_unique<InMemFilterStore<LabelT>>(_max_points + _num_frozen_pts);
}
_filter_store->load_label_map(labels_map_file);
label_num_pts = _filter_store->load_labels(labels_file);
label_num_pts = _filter_store->load(mem_index_file);
assert(label_num_pts == data_file_num_pts);
load_medoids(labels_to_medoids);
_filter_store->load_universal_labels(std::string(filename) + "_universal_label.txt");
load_medoids(labels_to_medoids_file);
}

_nd = data_file_num_pts - _num_frozen_pts;
Expand Down

0 comments on commit df38242

Please sign in to comment.