Skip to content

Commit

Permalink
fix in builder
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Jan 28, 2024
1 parent 86159e8 commit ce7db4d
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 132 deletions.
23 changes: 14 additions & 9 deletions apps/test_insert_deletes_consolidate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
const std::string &save_path, size_t points_to_delete_from_beginning,
size_t start_deletes_after, bool concurrent, const std::string &label_file,
const std::string &universal_label)
size_t start_deletes_after, bool concurrent, diskann::IndexFilterParams &filter_params)
{
size_t dim, aligned_dim;
size_t num_points;
diskann::get_bin_metadata(data_path, num_points, dim);
aligned_dim = ROUND_UP(dim, 8);
bool has_labels = label_file != "";
bool has_labels = filter_params.label_file != "";
using TagT = uint32_t;
using LabelT = uint32_t;

Expand Down Expand Up @@ -188,9 +187,9 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
auto index = index_factory.create_instance();

/* remove set_universal_label from here and set it through filter store only*/
if (universal_label != "")
if (filter_params.universal_label != "")
{
index->set_universal_labels(universal_label);
index->set_universal_label(filter_params.universal_label);
}

if (points_to_skip > num_points)
Expand Down Expand Up @@ -263,7 +262,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
points_to_delete_from_beginning, last_point_threshold);
if (has_labels)
{
auto parse_result = diskann::parse_raw_label_file(label_file);
auto parse_result = diskann::parse_raw_label_file(filter_params.label_file);
location_to_labels = std::get<0>(parse_result);
}

Expand Down Expand Up @@ -499,21 +498,27 @@ int main(int argc, char **argv)
.with_filter_list_size(Lf)
.build();

diskann::IndexFilterParams filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(label_file)
.with_save_path_prefix(index_path_prefix)
.build();

if (data_type == std::string("int8"))
build_incremental_index<int8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
points_to_delete_from_beginning, start_deletes_after, concurrent, filter_params);
else if (data_type == std::string("uint8"))
build_incremental_index<uint8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
points_to_delete_from_beginning, start_deletes_after, concurrent, filter_params);
else if (data_type == std::string("float"))
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
start_deletes_after, concurrent, label_file, universal_label);
start_deletes_after, concurrent, filter_params);
else
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
}
Expand Down
2 changes: 1 addition & 1 deletion apps/test_streaming_scenario.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con

if (universal_label != "")
{
index->set_universal_labels(universal_label);
index->set_universal_label(universal_label);
}

if (max_points_to_insert == 0)
Expand Down
2 changes: 1 addition & 1 deletion include/abstract_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ template <typename label_type> class AbstractFilterStore

// TODO: in future we may accept a set or vector of universal labels
// DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0;
DISKANN_DLLEXPORT virtual void set_universal_labels(const std::string &universal_labels) = 0;
DISKANN_DLLEXPORT virtual void set_universal_label(const std::string &universal_labels) = 0;
DISKANN_DLLEXPORT virtual std::pair<bool, label_type> get_universal_label() = 0;

// takes raw label file and then genrate internal mapping file and keep the info of mapping
Expand Down
2 changes: 1 addition & 1 deletion include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class AbstractIndex
template <typename tag_type, typename data_type> int get_vector_by_tag(tag_type &tag, data_type *vec);

// required for dynamic index (they dont use filter store / data store yet)
virtual void set_universal_labels(const std::string &raw_universal_labels) = 0;
virtual void set_universal_label(const std::string &raw_universal_labels) = 0;

private:
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
Expand Down
2 changes: 1 addition & 1 deletion include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
label_type get_numeric_label(const std::string &raw_label) override;

// takes raw universal labels and map them internally.
void set_universal_labels(const std::string &raw_universal_labels) override;
void set_universal_label(const std::string &raw_universal_labels) override;
std::pair<bool, label_type> get_universal_label() override;

// ideally takes raw label file and then genrate internal mapping file and keep the info of mapping
Expand Down
11 changes: 8 additions & 3 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,17 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
IndexFilterParams &filter_params);

// Filtered Support
DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file,
const size_t num_points_to_load,
DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const size_t num_points_to_load,
const IndexFilterParams &filter_params,
const std::vector<TagT> &tags = std::vector<TagT>());

// Filtered support streaming index
DISKANN_DLLEXPORT void build_filtered_index(const T *data, const size_t num_points_to_load,
const IndexFilterParams &filter_params,
const std::vector<TagT> &tags = std::vector<TagT>());

// DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
DISKANN_DLLEXPORT void set_universal_labels(const std::string &raw_labels);
DISKANN_DLLEXPORT void set_universal_label(const std::string &raw_labels);

// Set starting point of an index before inserting any points incrementally.
// The data count should be equal to _num_frozen_pts * _aligned_dim.
Expand Down
64 changes: 0 additions & 64 deletions include/index_build_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,5 @@

namespace diskann
{
struct IndexFilterParams
{
public:
std::string save_path_prefix;
std::string label_file;
std::string tags_file;
std::string universal_label;
uint32_t filter_threshold = 0;

private:
IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file,
const std::string &universal_label, uint32_t filter_threshold)
: save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label),
filter_threshold(filter_threshold)
{
}

friend class IndexFilterParamsBuilder;
};
class IndexFilterParamsBuilder
{
public:
IndexFilterParamsBuilder() = default;

IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix)
{
if (save_path_prefix.empty() || save_path_prefix == "")
throw ANNException("Error: save_path_prefix can't be empty", -1);
this->_save_path_prefix = save_path_prefix;
return *this;
}

IndexFilterParamsBuilder &with_label_file(const std::string &label_file)
{
this->_label_file = label_file;
return *this;
}

IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label)
{
this->_universal_label = univeral_label;
return *this;
}

IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold)
{
this->_filter_threshold = filter_threshold;
return *this;
}

IndexFilterParams build()
{
return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold);
}

IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete;
IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete;

private:
std::string _save_path_prefix;
std::string _label_file;
std::string _tags_file;
std::string _universal_label;
uint32_t _filter_threshold = 0;
};
} // namespace diskann
65 changes: 65 additions & 0 deletions include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,69 @@ class IndexWriteParametersBuilder
uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE};
};

class IndexFilterParams
{
public:
std::string save_path_prefix;
std::string label_file;
std::string tags_file;
std::string universal_label;
uint32_t filter_threshold = 0;

private:
IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file,
const std::string &universal_label, uint32_t filter_threshold)
: save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label),
filter_threshold(filter_threshold)
{
}

friend class IndexFilterParamsBuilder;
};

class IndexFilterParamsBuilder
{
public:
IndexFilterParamsBuilder() = default;

IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix)
{
this->_save_path_prefix = save_path_prefix;
return *this;
}

IndexFilterParamsBuilder &with_label_file(const std::string &label_file)
{
this->_label_file = label_file;
return *this;
}

IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label)
{
this->_universal_label = univeral_label;
return *this;
}

IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold)
{
this->_filter_threshold = filter_threshold;
return *this;
}

IndexFilterParams build()
{
return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold);
}

IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete;
IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete;

private:
std::string _save_path_prefix;
std::string _label_file;
std::string _tags_file;
std::string _universal_label;
uint32_t _filter_threshold = 0;
};

} // namespace diskann
41 changes: 12 additions & 29 deletions python/src/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,6 @@ template void build_disk_index<uint8_t>(diskann::Metric, const std::string &, co
template void build_disk_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
double, double, uint32_t, uint32_t);

template <typename T, typename TagT, typename LabelT>
std::string prepare_filtered_label_map(diskann::Index<T, TagT, LabelT> &index, const std::string &index_output_path,
const std::string &filter_labels_file, const std::string &universal_label)
{
std::string labels_file_to_use = index_output_path + "_label_numeric.txt";
std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt";
convert_label_to_numeric(filter_labels_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
if (!universal_label.empty())
{
index.set_universal_label(universal_label);
}
return labels_file_to_use;
}

template std::string prepare_filtered_label_map<float>(diskann::Index<float, uint32_t, uint32_t> &, const std::string &,
const std::string &, const std::string &);

template std::string prepare_filtered_label_map<int8_t>(diskann::Index<int8_t, uint32_t, uint32_t> &,
const std::string &, const std::string &, const std::string &);

template std::string prepare_filtered_label_map<uint8_t>(diskann::Index<uint8_t, uint32_t, uint32_t> &,
const std::string &, const std::string &, const std::string &);

template <typename T, typename TagT, typename LabelT>
void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path,
const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity,
Expand Down Expand Up @@ -95,9 +72,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
}
else
{
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags);
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(filter_labels_file)
.with_save_path_prefix(index_output_path.c_str())
.build();
index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params, tags);
}
}
else
Expand All @@ -108,9 +88,12 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
}
else
{
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num);
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(filter_labels_file)
.with_save_path_prefix(index_output_path.c_str())
.build();
index.build_filtered_index(vector_bin_path.c_str(), data_num, filter_params);
}
}

Expand Down
26 changes: 21 additions & 5 deletions src/disk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,14 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
if (universal_label != "")
{ // indicates no universal label
// LabelT unv_label_as_num = 0;
_index.set_universal_labels(universal_label);
_index.set_universal_label(universal_label);
}
_index.build_filtered_index(base_file.c_str(), label_file, base_num);
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(label_file)
.with_save_path_prefix(mem_index_path.c_str())
.build();
_index.build_filtered_index(base_file.c_str(), base_num, filter_params);
}
_index.save(mem_index_path.c_str());

Expand All @@ -685,8 +690,12 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
// file
std::remove(disk_labels_to_medoids_file.c_str());
std::string mem_labels_to_medoid_file = mem_index_path + "_labels_to_medoids.txt";
std::string mem_labels_file = mem_index_path + "_labels.txt";
std::string mem_labels_int_map_file = mem_index_path + "_labels_map.txt";
copy_file(mem_labels_to_medoid_file, disk_labels_to_medoids_file);
std::remove(mem_labels_to_medoid_file.c_str());
std::remove(mem_labels_file.c_str());
std::remove(mem_labels_int_map_file.c_str());
}

std::remove(medoids_file.c_str());
Expand Down Expand Up @@ -747,9 +756,14 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
if (universal_label != "")
{ // indicates no universal label
// LabelT unv_label_as_num = 0;
_index.set_universal_labels(universal_label);
_index.set_universal_label(universal_label);
}
_index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts);
auto filter_params = diskann::IndexFilterParamsBuilder()
.with_universal_label(universal_label)
.with_label_file(shard_labels_file)
.with_save_path_prefix(shard_index_file.c_str())
.build();
_index.build_filtered_index(shard_base_file.c_str(), shard_base_pts, filter_params);
}
_index.save(shard_index_file.c_str());
// copy universal label file from first shard to the final destination
Expand Down Expand Up @@ -791,11 +805,13 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
{
std::string shard_index_label_file = shard_index_file + "_labels.txt";
std::string shard_index_univ_label_file = shard_index_file + "_universal_label.txt";
std::string shard_index_label_map_file = shard_index_file + "_labels_to_medoids.txt";
std::string shard_index_label_medoid_file = shard_index_file + "_labels_to_medoids.txt";
std::string shard_index_label_map_file = shard_index_file + "_labels_map.txt";
std::remove(shard_labels_file.c_str());
std::remove(shard_index_label_file.c_str());
std::remove(shard_index_label_map_file.c_str());
std::remove(shard_index_univ_label_file.c_str());
std::remove(shard_index_label_medoid_file.c_str());
}
}
return 0;
Expand Down
2 changes: 1 addition & 1 deletion src/in_mem_filter_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void InMemFilterStore<label_type>::add_label_to_location(const location_t point_
}

template <typename label_type>
void InMemFilterStore<label_type>::set_universal_labels(const std::string &raw_universal_label)
void InMemFilterStore<label_type>::set_universal_label(const std::string &raw_universal_label)
{
if (raw_universal_label.empty())
{
Expand Down
Loading

0 comments on commit ce7db4d

Please sign in to comment.