Skip to content

Commit

Permalink
parse_label_file change for memory
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Oct 27, 2023
1 parent d1cb7e1 commit 56f65a5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 30 deletions.
4 changes: 3 additions & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// determines navigating node of the graph by calculating medoid of datafopt
uint32_t calculate_entry_point();

void parse_label_file(const std::string &label_file, size_t &num_pts_labels);
void parse_label_file(std::basic_istream<char> &label_file, size_t &num_pts);

std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);

Expand Down Expand Up @@ -339,6 +339,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
IndexSearchContext<LabelT> &context);

private:
void reset_stream_for_reading(std::basic_istream<char> &infile);

// Distance functions
Metric _dist_metric = diskann::L2;
std::shared_ptr<Distance<T>> _distance;
Expand Down
107 changes: 78 additions & 29 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,12 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
if (file_exists(labels_file))
{
_label_map = load_label_map(labels_map_file);
parse_label_file(labels_file, label_num_pts);
std::ifstream infile(labels_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1);
}
parse_label_file(infile, label_num_pts);
assert(label_num_pts == data_file_num_pts);
if (file_exists(labels_to_medoids))
{
Expand Down Expand Up @@ -2018,49 +2023,88 @@ LabelT Index<T, TagT, LabelT>::get_converted_label(const std::string &raw_label)
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points)
void Index<T, TagT, LabelT>::reset_stream_for_reading(std::basic_istream<char> &infile)
{
infile.clear();
infile.seekg(0);
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::parse_label_file(std::basic_istream<char> &label_file, size_t &num_pts)
{
// Format of Label txt file: filters with comma separators
label_file.seekg(0, std::ios::end);
size_t file_size = label_file.tellg();

std::ifstream infile(label_file);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}
std::string buffer(file_size, ' ');

std::string line, token;
uint32_t line_cnt = 0;
label_file.seekg(0, std::ios::beg);
label_file.read(&buffer[0], file_size);

while (std::getline(infile, line))
uint32_t line_cnt = 0;
std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}
cur_pos = next_pos + 1;
line_cnt++;
}
_pts_to_labels.resize(line_cnt, std::vector<LabelT>());

infile.clear();
infile.seekg(0, std::ios::beg);
cur_pos = 0;
next_pos = 0;
line_cnt = 0;

while (std::getline(infile, line))
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
std::vector<LabelT> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
break;
}
std::vector<LabelT> lbls;
size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = buffer.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
next_lbl_pos = next_pos;
}

if (next_lbl_pos > next_pos) // the last label in one line, just read to the end
{
next_lbl_pos = next_pos;
}

label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t' ||
label_str[label_str.length() - 1] == '\r') // '\t' won't exist in label file?
{
label_str.erase(label_str.length() - 1);
}

LabelT token_as_num = (LabelT)std::stoul(label_str);
lbls.push_back(token_as_num);
_labels.insert(token_as_num);
}

std::sort(lbls.begin(), lbls.end());
// move to next label
lbl_pos = next_lbl_pos + 1;
}
_pts_to_labels[line_cnt] = lbls;
line_cnt++;
line_cnt = line_cnt + 1;
// move to next line
cur_pos = next_pos + 1;
}
num_points = (size_t)line_cnt;

num_pts = line_cnt;
reset_stream_for_reading(label_file);
diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl;
}

Expand All @@ -2081,9 +2125,14 @@ void Index<T, TagT, LabelT>::build_filtered_index(const char *filename, const st
_label_to_medoid_id.clear();
size_t num_points_labels = 0;

parse_label_file(label_file,
num_points_labels); // determines medoid for each label and identifies
// the points to label mapping
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + _labels_file, -1);
}

parse_label_file(infile, num_points_labels); // determines medoid for each label and identifies
// the points to label mapping

std::unordered_map<LabelT, std::vector<uint32_t>> label_to_points;

Expand Down

0 comments on commit 56f65a5

Please sign in to comment.