diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 51ed7b128..49a504a07 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -122,6 +122,7 @@ template class PQFlashIndex uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); + void reset_stream_for_reading(std::basic_istream &infile); // sector # on disk where node_id is present with in the graph part DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index dc91a7393..a978cb57e 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -602,6 +602,13 @@ LabelT PQFlashIndex::get_converted_label(const std::string &filter_la throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } +template +void PQFlashIndex::reset_stream_for_reading(std::basic_istream &infile) +{ + infile.clear(); + infile.seekg(0); +} + template void PQFlashIndex::get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, uint32_t &num_total_labels) @@ -624,7 +631,7 @@ void PQFlashIndex::get_label_file_metadata(std::basic_istream & diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels << std::endl; - infile.seekg(0); + reset_stream_for_reading(infile); } template @@ -685,8 +692,8 @@ void PQFlashIndex::parse_label_file(std::basic_istream &infile, } line_cnt++; } - infile.seekg(0); num_points_labels = line_cnt; + reset_stream_for_reading(infile); } template void PQFlashIndex::set_universal_label(const LabelT &label)