From 56f65a5adcce11a7bc06eadf416afef77fdd8fb5 Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Fri, 27 Oct 2023 11:26:14 +0530 Subject: [PATCH] parse_label_file change for memory --- include/index.h | 4 +- src/index.cpp | 107 +++++++++++++++++++++++++++++++++++------------- 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/include/index.h b/include/index.h index 9183d9763..d10fb082b 100644 --- a/include/index.h +++ b/include/index.h @@ -247,7 +247,7 @@ template clas // determines navigating node of the graph by calculating medoid of datafopt uint32_t calculate_entry_point(); - void parse_label_file(const std::string &label_file, size_t &num_pts_labels); + void parse_label_file(std::basic_istream &label_file, size_t &num_pts); std::unordered_map load_label_map(const std::string &map_file); @@ -339,6 +339,8 @@ template clas IndexSearchContext &context); private: + void reset_stream_for_reading(std::basic_istream &infile); + // Distance functions Metric _dist_metric = diskann::L2; std::shared_ptr> _distance; diff --git a/src/index.cpp b/src/index.cpp index 4157edcef..e59fe4136 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -603,7 +603,12 @@ void Index::load(const char *filename, uint32_t num_threads, ui if (file_exists(labels_file)) { _label_map = load_label_map(labels_map_file); - parse_label_file(labels_file, label_num_pts); + std::ifstream infile(labels_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); + } + parse_label_file(infile, label_num_pts); assert(label_num_pts == data_file_num_pts); if (file_exists(labels_to_medoids)) { @@ -2018,49 +2023,88 @@ LabelT Index::get_converted_label(const std::string &raw_label) } template -void Index::parse_label_file(const std::string &label_file, size_t &num_points) +void Index::reset_stream_for_reading(std::basic_istream &infile) +{ + infile.clear(); + infile.seekg(0); +} + +template +void Index::parse_label_file(std::basic_istream &label_file, size_t &num_pts) { // Format of Label txt file: filters with comma separators + label_file.seekg(0, std::ios::end); + size_t file_size = label_file.tellg(); - std::ifstream infile(label_file); - if (infile.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); - } + std::string buffer(file_size, ' '); - std::string line, token; - uint32_t line_cnt = 0; + label_file.seekg(0, std::ios::beg); + label_file.read(&buffer[0], file_size); - while (std::getline(infile, line)) + uint32_t line_cnt = 0; + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + cur_pos = next_pos + 1; line_cnt++; } _pts_to_labels.resize(line_cnt, std::vector()); - infile.clear(); - infile.seekg(0, std::ios::beg); + cur_pos = 0; + next_pos = 0; line_cnt = 0; - - while (std::getline(infile, line)) + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - std::vector lbls(0); - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); + break; + } + std::vector lbls; + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t' || + label_str[label_str.length() - 1] == '\r') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); lbls.push_back(token_as_num); _labels.insert(token_as_num); - } - std::sort(lbls.begin(), lbls.end()); + // move to next label + lbl_pos = next_lbl_pos + 1; + } _pts_to_labels[line_cnt] = lbls; - line_cnt++; + line_cnt = line_cnt + 1; + // move to next line + cur_pos = next_pos + 1; } - num_points = (size_t)line_cnt; + + num_pts = line_cnt; + reset_stream_for_reading(label_file); diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; } @@ -2081,9 +2125,14 @@ void Index::build_filtered_index(const char *filename, const st _label_to_medoid_id.clear(); size_t num_points_labels = 0; - parse_label_file(label_file, - num_points_labels); // determines medoid for each label and identifies - // the points to label mapping + std::ifstream infile(label_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + _labels_file, -1); + } + + parse_label_file(infile, num_points_labels); // determines medoid for each label and identifies + // the points to label mapping std::unordered_map> label_to_points;