Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PQ Support for Dynamic Index #521

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ include_directories(${PROJECT_SOURCE_DIR}/include)

if(NOT PYBIND)
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
elseif(MSVS)
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS OFF)
endif()
# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
Expand Down
6 changes: 4 additions & 2 deletions apps/build_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace po = boost::program_options;

int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, codebook_path, universal_label, label_type;
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
float alpha;
bool use_pq_build, use_opq;
Expand Down Expand Up @@ -59,13 +59,14 @@ int main(int argc, char **argv)
program_options_utils::GRAPH_BUILD_ALPHA);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("codebook_path", po::value<std::string>(&codebook_path)->default_value(""),
program_options_utils::CODEBOOK_PATH);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
program_options_utils::LABEL_FILE);
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
program_options_utils::UNIVERSAL_LABEL);

optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
Expand Down Expand Up @@ -146,6 +147,7 @@ int main(int argc, char **argv)
.is_use_opq(use_opq)
.is_pq_dist_build(use_pq_build)
.with_num_pq_chunks(build_PQ_bytes)
.with_pq_codebook_path(codebook_path)
.build();

auto index_factory = diskann::IndexFactory(config);
Expand Down
38 changes: 29 additions & 9 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ namespace po = boost::program_options;

template <typename T, typename LabelT = uint32_t>
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const std::string &query_file, const std::string &truthset_file,
const std::string &codebook_file, const bool use_pq_build, const bool use_opq,
const uint32_t pq_num_chunks, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
Expand Down Expand Up @@ -82,12 +84,16 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
.is_dynamic_index(dynamic)
.is_enable_tags(tags)
.is_concurrent_consolidate(false)
.is_pq_dist_build(false)
.is_use_opq(false)
.with_num_pq_chunks(0)
.is_pq_dist_build(use_pq_build)
.is_use_opq(use_opq)
.with_num_pq_chunks(pq_num_chunks)
.with_num_frozen_pts(num_frozen_pts)
.with_pq_codebook_path(codebook_file)
.build();

std::cout << "******************** Attach Debugger ********************" << std::endl;
Sleep(60000);

auto index_factory = diskann::IndexFactory(config);
auto index = index_factory.create_instance();
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
Expand Down Expand Up @@ -278,10 +284,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
query_filters_file;
uint32_t num_threads, K;
query_filters_file, codebook_path;
uint32_t num_threads, K, build_PQ_bytes;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
bool print_all_recalls, dynamic, tags, show_qps_per_thread, use_pq_build, use_opq;
float fail_if_recall_below = 0.0f;

po::options_description desc{
Expand Down Expand Up @@ -331,6 +337,12 @@ int main(int argc, char **argv)
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("codebook_path", po::value<std::string>(&codebook_path)->default_value(""),
program_options_utils::CODEBOOK_PATH);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);

// Output controls
po::options_description output_controls("Output controls");
Expand All @@ -352,6 +364,8 @@ int main(int argc, char **argv)
return 0;
}
po::notify(vm);
use_pq_build = (build_PQ_bytes > 0);
use_opq = vm["use_opq"].as<bool>();
}
catch (const std::exception &ex)
{
Expand Down Expand Up @@ -420,18 +434,21 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
metric, index_path_prefix, result_path, query_file, gt_file, codebook_path, use_pq_build, use_opq,
build_PQ_bytes, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
metric, index_path_prefix, result_path, query_file, gt_file, codebook_path, use_pq_build, use_opq,
build_PQ_bytes, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
Expand All @@ -446,18 +463,21 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
Expand Down
29 changes: 22 additions & 7 deletions apps/test_insert_deletes_consolidate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
const std::string &save_path, size_t points_to_delete_from_beginning,
size_t start_deletes_after, bool concurrent, const std::string &label_file,
const std::string &universal_label)
const std::string &universal_label, size_t num_pq_chunks, const std::string& pq_pivot_file)
{
size_t dim, aligned_dim;
size_t num_points;
Expand All @@ -161,7 +161,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
using LabelT = uint32_t;

size_t current_point_offset = points_to_skip;
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
size_t last_point_threshold = points_to_skip + max_points_to_insert;

bool enable_tags = true;
using TagT = uint32_t;
Expand All @@ -182,6 +182,9 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
.is_filtered(has_labels)
.with_num_frozen_pts(num_start_pts)
.is_concurrent_consolidate(concurrent)
.with_pq_codebook_path(pq_pivot_file)
.is_pq_dist_build(!pq_pivot_file.empty())
.with_num_pq_chunks(num_pq_chunks)
.build();

diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
Expand All @@ -206,6 +209,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
if (points_to_skip + max_points_to_insert > num_points)
{
max_points_to_insert = num_points - points_to_skip;
last_point_threshold = num_points;
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
<< " points since the data file has only that many" << std::endl;
}
Expand Down Expand Up @@ -377,11 +381,11 @@ int main(int argc, char **argv)
uint32_t num_threads, R, L, num_start_pts;
float alpha, start_point_norm;
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
points_to_delete_from_beginning, start_deletes_after;
points_to_delete_from_beginning, start_deletes_after, num_pq_chunks;
bool concurrent;

// label options
std::string label_file, label_type, universal_label;
std::string label_file, label_type, universal_label, pq_pivot_file;
std::uint32_t Lf, unique_labels_supported;

po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
Expand Down Expand Up @@ -449,6 +453,11 @@ int main(int argc, char **argv)
optional_configs.add_options()("unique_labels_supported",
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
"Number of unique labels supported by the dynamic index.");
optional_configs.add_options()("pq_pivot_file", po::value<std::string>(&pq_pivot_file)->default_value(""),
"The file stored pq pivot info.");
optional_configs.add_options()("num_pq_chunks", po::value<uint64_t>(&num_pq_chunks)->default_value(0),
"Number of PQ chunks to use.");


optional_configs.add_options()(
"num_start_points",
Expand Down Expand Up @@ -503,21 +512,27 @@ int main(int argc, char **argv)
.with_filter_list_size(Lf)
.build();


std::cout << "********** Attach Debugger Test insert delete consolidate **********" << std::endl;

if (data_type == std::string("int8"))
build_incremental_index<int8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
points_to_delete_from_beginning, start_deletes_after, concurrent,
label_file, universal_label, num_pq_chunks, pq_pivot_file);
else if (data_type == std::string("uint8"))
build_incremental_index<uint8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
points_to_delete_from_beginning, start_deletes_after, concurrent,
label_file, universal_label, num_pq_chunks, pq_pivot_file);
else if (data_type == std::string("float"))
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
start_deletes_after, concurrent, label_file, universal_label);
start_deletes_after, concurrent, label_file, universal_label,
num_pq_chunks, pq_pivot_file);
else
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
}
Expand Down
47 changes: 47 additions & 0 deletions include/fixed_chunk_pq_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once

#include "common_includes.h"

#ifdef EXEC_ENV_OLS
#include "content_buf.h"
#include "memory_mapped_files.h"
#endif

namespace diskann
{
class FixedChunkPQTable
{
public:
FixedChunkPQTable();
virtual ~FixedChunkPQTable();

#ifdef EXEC_ENV_OLS
void load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, size_t num_chunks);
#else
void load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks);
#endif

void preprocess_query(float *query_vec);
// assumes pre-processed query
void populate_chunk_distances(const float *query_vec, float *dist_vec);
float l2_distance(const float *query_vec, uint8_t *base_vec);
float inner_product(const float *query_vec, uint8_t *base_vec);
// assumes no rotation is involved
template <typename InputType = uint8_t, typename OutputType = float>
void inflate_vector(InputType *base_vec, OutputType *out_vec) const;

void populate_chunk_inner_products(const float *query_vec, float *dist_vec);

float *tables = nullptr; // pq_tables = float array of size [256 * ndims]
uint64_t ndims = 0; // ndims = true dimension of vectors
uint64_t n_chunks = 0;
bool use_rotation = false;
uint32_t *chunk_offsets = nullptr;
float *centroid = nullptr;
float *tables_tr = nullptr; // same as pq_tables, but col-major
float *rotmat_tr = nullptr;
};
} // namespace diskann
Loading
Loading