Skip to content

Commit

Permalink
change raw_labels to populate_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Jan 29, 2024
1 parent 82f7182 commit e669521
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 81 deletions.
5 changes: 1 addition & 4 deletions include/abstract_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <typename label_type> class AbstractFilterStore
DISKANN_DLLEXPORT virtual std::pair<bool, label_type> get_universal_label() = 0;

// takes raw label file and then genrate internal mapping file and keep the info of mapping
DISKANN_DLLEXPORT virtual size_t load_raw_labels(const std::string &raw_labels_file,
DISKANN_DLLEXPORT virtual size_t populate_labels(const std::string &raw_labels_file,
const std::string &raw_universal_label) = 0;

DISKANN_DLLEXPORT virtual void save_labels(const std::string &save_path, const size_t total_points) = 0;
Expand All @@ -59,9 +59,6 @@ template <typename label_type> class AbstractFilterStore
private:
size_t _num_points;

// populates pts_to labels and _labels from given label file
virtual size_t parse_label_file(const std::string &label_file) = 0;

// mark Index as friend so it can access protected loads
template <typename T, typename TagT, typename LabelT> friend class Index;
};
Expand Down
6 changes: 2 additions & 4 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
std::pair<bool, label_type> get_universal_label() override;

// ideally takes raw label file and then genrate internal mapping file and keep the info of mapping
size_t load_raw_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) override;
size_t populate_labels(const std::string &raw_labels_file, const std::string &raw_universal_label) override;

void save_labels(const std::string &save_path, const size_t total_points) override;
// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
Expand All @@ -50,6 +50,7 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto

protected:
// This is for internal use and only loads already parsed file, used by index in during load().
// populates _loaction_to labels and _labels from given label file
size_t load_labels(const std::string &labels_file) override;
void load_label_map(const std::string &labels_map_file) override;
void load_universal_labels(const std::string &universal_labels_file) override;
Expand All @@ -69,9 +70,6 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
// 2. from _label_map and _mapped_universal_label, we can know what is raw universal label. Hence seems duplicate
// std::string _raw_universal_label;

// populates _loaction_to labels and _labels from given label file
size_t parse_label_file(const std::string &label_file);

bool detect_common_filters_by_set_intersection(uint32_t point_id, bool search_invocation,
const std::vector<label_type> &incoming_labels);
};
Expand Down
7 changes: 0 additions & 7 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
const IndexFilterParams &filter_params,
const std::vector<TagT> &tags = std::vector<TagT>());

// Filtered support streaming index
DISKANN_DLLEXPORT void build_filtered_index(const T *data, const size_t num_points_to_load,
const IndexFilterParams &filter_params,
const std::vector<TagT> &tags = std::vector<TagT>());

// DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
DISKANN_DLLEXPORT void set_universal_label(const std::string &raw_labels);

Expand Down Expand Up @@ -249,8 +244,6 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// determines navigating node of the graph by calculating medoid of datafopt
uint32_t calculate_entry_point();

void parse_label_file(const std::string &label_file, size_t &num_pts_labels);

// Returns the locations of start point and frozen points suitable for use
// with iterate_to_fixed_point.
std::vector<uint32_t> get_init_ids();
Expand Down
47 changes: 0 additions & 47 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,53 +176,6 @@ inline int delete_file(const std::string &fileName)
}
}

inline void convert_label_to_numeric(const std::string &inFileName, const std::string &outFileName,
const std::string &mapFileName, const std::string &unv_label)
{
std::unordered_map<std::string, uint32_t> string_int_map;
std::ofstream label_writer(outFileName);
std::ifstream label_reader(inFileName);
if (unv_label != "")
string_int_map[unv_label] = 0; // if universal label is provided map it to 0 always
std::string line, token;
while (std::getline(label_reader, line))
{
std::istringstream new_iss(line);
std::vector<uint32_t> lbls;
while (getline(new_iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
if (string_int_map.find(token) == string_int_map.end())
{
uint32_t nextId = (uint32_t)string_int_map.size() + 1;
string_int_map[token] = nextId; // nextId can never be 0
}
lbls.push_back(string_int_map[token]);
}
if (lbls.size() <= 0)
{
std::cout << "No label found";
exit(-1);
}
for (size_t j = 0; j < lbls.size(); j++)
{
if (j != lbls.size() - 1)
label_writer << lbls[j] << ",";
else
label_writer << lbls[j] << std::endl;
}
}
label_writer.close();

std::ofstream map_writer(mapFileName);
for (auto mp : string_int_map)
{
map_writer << mp.first << "\t" << mp.second << std::endl;
}
map_writer.close();
}

#ifdef EXEC_ENV_OLS
class AlignedFileReader;
#endif
Expand Down
17 changes: 7 additions & 10 deletions src/in_mem_filter_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ template <typename label_type> std::pair<bool, label_type> InMemFilterStore<labe

// ideally takes raw label file and then genrate internal mapping and keep the info of mapping
template <typename label_type>
size_t InMemFilterStore<label_type>::load_raw_labels(const std::string &raw_labels_file,
size_t InMemFilterStore<label_type>::populate_labels(const std::string &raw_labels_file,
const std::string &raw_universal_label)
{
std::string raw_label_file_path =
Expand All @@ -105,13 +105,7 @@ size_t InMemFilterStore<label_type>::load_raw_labels(const std::string &raw_labe
std::string mem_labels_int_map_file = raw_label_file_path + "_labels_map.txt";
_label_map = InMemFilterStore::convert_label_to_numeric(raw_labels_file, labels_file_to_use,
mem_labels_int_map_file, raw_universal_label);
return parse_label_file(labels_file_to_use);
}

template <typename label_type> size_t InMemFilterStore<label_type>::load_labels(const std::string &labels_file)
{
// parse the generated label file
return parse_label_file(labels_file);
return load_labels(labels_file_to_use);
}

template <typename label_type> void InMemFilterStore<label_type>::load_label_map(const std::string &labels_map_file)
Expand All @@ -137,7 +131,7 @@ template <typename label_type> void InMemFilterStore<label_type>::load_label_map
// TODO: throw exception from here and also make sure filtered_index is set appropriately for both build and
// search of index.
diskann::cout << "Warning: Can't load label map file please make sure it was generate, either by "
"filter_store->load_raw_labels() "
"filter_store->populate_labels() "
"then index->save() or convert_label_to_numeric() method in case of dynamic index"
<< std::endl;
}
Expand Down Expand Up @@ -269,7 +263,7 @@ template <typename label_type> label_type InMemFilterStore<label_type>::get_nume
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

template <typename label_type> size_t InMemFilterStore<label_type>::parse_label_file(const std::string &label_file)
template <typename label_type> size_t InMemFilterStore<label_type>::load_labels(const std::string &label_file)
{
// Format of Label txt file: filters with comma separators
// Format of Label txt file: filters with comma separators
Expand Down Expand Up @@ -354,6 +348,9 @@ std::unordered_map<std::string, label_type> InMemFilterStore<label_type>::conver
std::ofstream label_writer(outFileName);
std::ifstream label_reader(inFileName);
std::string line, token;
if (raw_universal_label != "")
string_int_map[raw_universal_label] = 0; // if universal label is provided map it to 0 always

while (std::getline(label_reader, line))
{
std::istringstream new_iss(line);
Expand Down
12 changes: 3 additions & 9 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1743,12 +1743,6 @@ void Index<T, TagT, LabelT>::build(const std::string &data_file, const size_t nu
std::cout << "Indexing time: " << diff.count() << "\n";
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points)
{
num_points = _filter_store->load_labels(label_file);
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::set_universal_label(const std::string &raw_label)
{
Expand All @@ -1760,7 +1754,7 @@ void Index<T, TagT, LabelT>::build_filtered_index(const char *filename, const si
const IndexFilterParams &filter_params, const std::vector<TagT> &tags)
{
_filtered_index = true;
size_t num_points_labels = _filter_store->load_raw_labels(filter_params.label_file, "");
size_t num_points_labels = _filter_store->populate_labels(filter_params.label_file, "");
if (filter_params.universal_label != "")
{
_filter_store->set_universal_label(filter_params.universal_label);
Expand All @@ -1769,7 +1763,7 @@ void Index<T, TagT, LabelT>::build_filtered_index(const char *filename, const si
this->build(filename, num_points_to_load, tags);
}

template <typename T, typename TagT, typename LabelT>
/*template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::build_filtered_index(const T *data, const size_t num_points_to_load,
const IndexFilterParams &filter_params, const std::vector<TagT> &tags)
{
Expand All @@ -1781,7 +1775,7 @@ void Index<T, TagT, LabelT>::build_filtered_index(const T *data, const size_t nu
}
calculate_best_medoids(num_points_to_load, 25);
this->build(data, num_points_to_load, tags);
}
}*/

template <typename T, typename TagT, typename LabelT>
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search(const DataType &query, const size_t K, const uint32_t L,
Expand Down

0 comments on commit e669521

Please sign in to comment.