Skip to content

Commit

Permalink
sync to latest
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanhaoji2 committed Mar 5, 2024
2 parents 52039c2 + 9bb0cf0 commit 7acac92
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 57 deletions.
2 changes: 1 addition & 1 deletion include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ template <typename data_t> class AbstractDataStore
public:
AbstractDataStore(const location_t capacity, const size_t dim);

// virtual ~AbstractDataStore() = default;
virtual ~AbstractDataStore() = default;

// Return number of points returned
virtual location_t load(const std::string &filename) = 0;
Expand Down
2 changes: 2 additions & 0 deletions include/abstract_graph_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class AbstractGraphStore
{
}

virtual ~AbstractGraphStore() = default;

virtual int load(const std::string &index_path_prefix) = 0;
virtual int store(const std::string &index_path_prefix) = 0;

Expand Down
2 changes: 2 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Get converted integer label from string to int map (_label_map)
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);

DISKANN_DLLEXPORT bool is_label_valid(const std::string& raw_label);

// Set starting point of an index before inserting any points incrementally.
// The data count should be equal to _num_frozen_pts * _aligned_dim.
DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count);
Expand Down
6 changes: 5 additions & 1 deletion include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);

DISKANN_DLLEXPORT bool is_label_valid(const std::string& filter_label);

DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search,
const uint64_t max_l_search, std::vector<uint64_t> &indices,
std::vector<float> &distances, const uint64_t min_beam_width,
Expand All @@ -107,11 +109,13 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);

size_t search_string_range(const std::string& str, char ch, size_t start, size_t end);
// index info
// nhood of node `i` is in sector: [i / nnodes_per_sector]
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
Expand Down
11 changes: 11 additions & 0 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,17 @@ LabelT Index<T, TagT, LabelT>::get_converted_label(const std::string &raw_label)
}
}

template <typename T, typename TagT, typename LabelT>
bool Index<T, TagT, LabelT>::is_label_valid(const std::string& raw_label)
{
if (_label_map.find(raw_label) != _label_map.end())
{
return true;
}

return false;
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points)
{
Expand Down
160 changes: 105 additions & 55 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,29 +579,59 @@ LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_la
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::string map_file, uint32_t &num_pts,
bool PQFlashIndex<T, LabelT>::is_label_valid(const std::string& filter_label)
{
if (_label_map.find(filter_label) != _label_map.end())
{
return true;
}

return false;
}

// test commit
template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels)
{
std::ifstream infile(map_file);
std::string line, token;
num_pts = 0;
num_total_labels = 0;

while (std::getline(infile, line))
size_t file_size = fileContent.length();

std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
while (getline(iss, token, ','))
next_pos = fileContent.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
break;
}

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = search_string_range(fileContent, ',', lbl_pos, next_pos);
if (next_lbl_pos == std::string::npos) // the last label
{
next_lbl_pos = next_pos;
}

num_total_labels++;

lbl_pos = next_lbl_pos + 1;
}

cur_pos = next_pos + 1;

num_pts++;
}

diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
<< std::endl;
infile.close();
}

template <typename T, typename LabelT>
Expand All @@ -624,77 +654,98 @@ inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, uint32_t
template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points_labels)
{
std::ifstream infile(label_file);
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}
infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

std::string buffer(file_size, ' ');

infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);
infile.close();

std::string line, token;
uint32_t line_cnt = 0;

uint32_t num_pts_in_label_file;
uint32_t num_total_labels;
get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels);
get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels);

_pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
_pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels];
uint32_t counter = 0;

while (std::getline(infile, line))
std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
std::vector<uint32_t> lbls(0);
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}

_pts_to_label_offsets[line_cnt] = counter;
uint32_t &num_lbls_in_cur_pt = _pts_to_labels[counter];
num_lbls_in_cur_pt = 0;
counter++;
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
if (_labels.find(token_as_num) == _labels.end())
next_lbl_pos = search_string_range(buffer, ',', lbl_pos, next_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
_filter_list.emplace_back(token_as_num);
next_lbl_pos = next_pos;
}
int32_t filter_num = get_filter_number(token_as_num);
if (filter_num == -1)

if (next_lbl_pos > next_pos) // the last label in one line
{
next_lbl_pos = next_pos;
}

label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t')
{
label_str.erase(label_str.length() - 1);
}

LabelT token_as_num = (LabelT)std::stoul(label_str);
if (_labels.find(token_as_num) == _labels.end())
{
diskann::cout << "Error!! " << std::endl;
exit(-1);
_filter_list.emplace_back(token_as_num);
}
_pts_to_labels[counter++] = filter_num;

_pts_to_labels[counter++] = token_as_num;
num_lbls_in_cur_pt++;
_labels.insert(token_as_num);

lbl_pos = next_lbl_pos + 1;
}

cur_pos = next_pos + 1;

if (num_lbls_in_cur_pt == 0)
{
diskann::cout << "No label found for point " << line_cnt << std::endl;
exit(-1);
}

line_cnt++;
}
infile.close();

num_points_labels = line_cnt;
}

template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label)
{
int32_t temp_filter_num = get_filter_number(label);
if (temp_filter_num == -1)
{
diskann::cout << "Error, could not find universal label." << std::endl;
}
else
{
_use_universal_label = true;
_universal_filter_num = (uint32_t)temp_filter_num;
}
_use_universal_label = true;
_universal_filter_num = (uint32_t)label;
}

#ifdef EXEC_ENV_OLS
Expand Down Expand Up @@ -1150,22 +1201,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
const uint32_t io_limit, const bool use_reorder_data,
QueryStats *stats)
{
int32_t filter_num = 0;
if (use_filter)
{
filter_num = get_filter_number(filter_label);
if (filter_num < 0)
{
if (!_use_universal_label)
{
return;
}
else
{
filter_num = _universal_filter_num;
}
}
}
int32_t filter_num = filter_label;

if (beam_width > MAX_N_SECTOR_READS)
throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, __LINE__);
Expand Down Expand Up @@ -1637,6 +1673,20 @@ template <typename T, typename LabelT> diskann::Metric PQFlashIndex<T, LabelT>::
return this->metric;
}

template <typename T, typename LabelT>
size_t PQFlashIndex<T, LabelT>::search_string_range(const std::string& str, char ch, size_t start, size_t end)
{
for (; start != end; start++)
{
if (str[start] == ch)
{
return start;
}
}

return std::string::npos;
}

#ifdef EXEC_ENV_OLS
template <typename T, typename LabelT> char *PQFlashIndex<T, LabelT>::getHeaderBytes()
{
Expand Down

0 comments on commit 7acac92

Please sign in to comment.