From 149cd40e0a56ae4be46a3fd508d5f603c1db4dbd Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Mon, 4 Nov 2024 09:41:28 +0000 Subject: [PATCH 1/7] Fixing Linux compile issues --- include/in_mem_filter_store.h | 8 ++++++-- include/index_config.h | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index ec3362b32..f1bedba62 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -1,10 +1,14 @@ #pragma once - +#include +#include +#include +#include "logger.h" +#include "ann_exception.h" #include "abstract_filter_store.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" #include "windows_customizations.h" -#include + namespace diskann { template diff --git a/include/index_config.h b/include/index_config.h index 73c7133d1..a7100edfa 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -1,5 +1,6 @@ #pragma once +#include #include "ann_exception.h" #include "common_includes.h" #include "logger.h" From 83d8ebb03059635cec7af35bc6cb27f486fcc894 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Mon, 4 Nov 2024 10:30:20 +0000 Subject: [PATCH 2/7] Fixing formatting issues --- include/in_mem_filter_store.h | 4 ++++ include/index_config.h | 3 +++ 2 files changed, 7 insertions(+) diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index f1bedba62..7080ef665 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -1,4 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + #pragma once + #include #include #include diff --git a/include/index_config.h b/include/index_config.h index a7100edfa..dde4ec51d 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + #pragma once #include From 06a12112d378292072b15a8dd51dde2c51dac34b Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 12 Nov 2024 20:53:22 -0800 Subject: [PATCH 3/7] clang-formatted --- apps/build_disk_index.cpp | 349 ++-- apps/build_memory_index.cpp | 286 +-- apps/build_stitched_index.cpp | 736 ++++---- apps/range_search_disk_index.cpp | 636 +++---- apps/search_disk_index.cpp | 862 ++++----- apps/search_memory_index.cpp | 850 +++++---- apps/test_insert_deletes_consolidate.cpp | 1007 ++++++----- apps/test_streaming_scenario.cpp | 959 +++++----- apps/utils/bin_to_fvecs.cpp | 106 +- apps/utils/bin_to_tsv.cpp | 105 +- apps/utils/calculate_recall.cpp | 73 +- apps/utils/compute_groundtruth.cpp | 972 +++++----- .../utils/compute_groundtruth_for_filters.cpp | 1575 ++++++++--------- apps/utils/count_bfs_levels.cpp | 103 +- apps/utils/create_disk_layout.cpp | 60 +- apps/utils/float_bin_to_int8.cpp | 107 +- apps/utils/fvecs_to_bin.cpp | 143 +- apps/utils/fvecs_to_bvecs.cpp | 88 +- apps/utils/gen_random_slice.cpp | 61 +- apps/utils/generate_pq.cpp | 109 +- apps/utils/generate_synthetic_labels.cpp | 330 ++-- apps/utils/int8_to_float.cpp | 30 +- apps/utils/int8_to_float_scale.cpp | 107 +- apps/utils/ivecs_to_bin.cpp | 91 +- apps/utils/merge_shards.cpp | 47 +- apps/utils/partition_data.cpp | 57 +- apps/utils/partition_with_ram_budget.cpp | 57 +- apps/utils/rand_data_gen.cpp | 401 ++--- apps/utils/simulate_aggregate_recall.cpp | 123 +- apps/utils/stats_label_data.cpp | 229 +-- apps/utils/tsv_to_bin.cpp | 193 +- apps/utils/uint32_to_uint8.cpp | 30 +- apps/utils/uint8_to_float.cpp | 30 +- apps/utils/vector_analysis.cpp | 243 ++- include/in_mem_filter_store.h | 11 +- include/index_config.h | 2 +- 36 files changed, 5613 insertions(+), 5555 deletions(-) diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..475c9165b 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -1,191 +1,210 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include +#include -#include "utils.h" #include "disk_utils.h" -#include "math_utils.h" #include "index.h" +#include "math_utils.h" #include "partition.h" #include "program_options_utils.hpp" +#include "utils.h" namespace po = boost::program_options; -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, - label_type; - uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; - float B, M; - bool append_reorder_data = false; - bool use_opq = false; +int main(int argc, char **argv) { + std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, + label_file, universal_label, label_type; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + float B, M; + bool append_reorder_data = false; + bool use_opq = false; - po::options_description desc{ - program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); + po::options_description desc{program_options_utils::make_program_description( + "build_disk_index", "Build a disk-based index.")}; + try { + desc.add_options()("help,h", "Print information on arguments"); - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - required_configs.add_options()("search_DRAM_budget,B", po::value(&B)->required(), - "DRAM budget in GB for searching the index to set the " - "compressed level for data while search happens"); - required_configs.add_options()("build_DRAM_budget,M", po::value(&M)->required(), - "DRAM budget in GB for building the index"); + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()( + "search_DRAM_budget,B", po::value(&B)->required(), + "DRAM budget in GB for searching the index to set the " + "compressed level for data while search happens"); + required_configs.add_options()("build_DRAM_budget,M", + po::value(&M)->required(), + "DRAM budget in GB for building the index"); - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("QD", po::value(&QD)->default_value(0), - " Quantized Dimension for compression"); - optional_configs.add_options()("codebook_prefix", po::value(&codebook_prefix)->default_value(""), - "Path prefix for pre-trained codebook"); - optional_configs.add_options()("PQ_disk_bytes", po::value(&disk_PQ)->default_value(0), - "Number of bytes to which vectors should be compressed " - "on SSD; 0 for no compression"); - optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false), - "Include full precision data in the index. Use only in " - "conjuction with compressed data on SSD."); - optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ)->default_value(0), - program_options_utils::BUIlD_GRAPH_PQ_BYTES); - optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), - program_options_utils::USE_OPQ); - optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), - program_options_utils::FILTERED_LBUILD); - optional_configs.add_options()("filter_threshold,F", po::value(&filter_threshold)->default_value(0), - "Threshold to break up the existing nodes to generate new graph " - "internally where each node has a maximum F labels."); - optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", + po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()( + "Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("QD", + po::value(&QD)->default_value(0), + " Quantized Dimension for compression"); + optional_configs.add_options()( + "codebook_prefix", + po::value(&codebook_prefix)->default_value(""), + "Path prefix for pre-trained codebook"); + optional_configs.add_options()( + "PQ_disk_bytes", po::value(&disk_PQ)->default_value(0), + "Number of bytes to which vectors should be compressed " + "on SSD; 0 for no compression"); + optional_configs.add_options()( + "append_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD."); + optional_configs.add_options()( + "build_PQ_bytes", po::value(&build_PQ)->default_value(0), + program_options_utils::BUIlD_GRAPH_PQ_BYTES); + optional_configs.add_options()("use_opq", + po::bool_switch()->default_value(false), + program_options_utils::USE_OPQ); + optional_configs.add_options()( + "label_file", po::value(&label_file)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + optional_configs.add_options()("FilteredLbuild", + po::value(&Lf)->default_value(0), + program_options_utils::FILTERED_LBUILD); + optional_configs.add_options()( + "filter_threshold,F", + po::value(&filter_threshold)->default_value(0), + "Threshold to break up the existing nodes to generate new graph " + "internally where each node has a maximum F labels."); + optional_configs.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - if (vm["append_reorder_data"].as()) - append_reorder_data = true; - if (vm["use_opq"].as()) - use_opq = true; - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } + po::notify(vm); + if (vm["append_reorder_data"].as()) + append_reorder_data = true; + if (vm["use_opq"].as()) + use_opq = true; + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } - bool use_filters = (label_file != "") ? true : false; - diskann::Metric metric; - if (dist_fn == std::string("l2")) - metric = diskann::Metric::L2; - else if (dist_fn == std::string("mips")) - metric = diskann::Metric::INNER_PRODUCT; - else if (dist_fn == std::string("cosine")) - metric = diskann::Metric::COSINE; - else - { - std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl; - return -1; - } + bool use_filters = (label_file != "") ? true : false; + diskann::Metric metric; + if (dist_fn == std::string("l2")) + metric = diskann::Metric::L2; + else if (dist_fn == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else if (dist_fn == std::string("cosine")) + metric = diskann::Metric::COSINE; + else { + std::cout << "Error. Only l2 and mips distance functions are supported" + << std::endl; + return -1; + } - if (append_reorder_data) - { - if (disk_PQ == 0) - { - std::cout << "Error: It is not necessary to append data for reordering " - "when vectors are not compressed on disk." - << std::endl; - return -1; - } - if (data_type != std::string("float")) - { - std::cout << "Error: Appending data for reordering currently only " - "supported for float data type." - << std::endl; - return -1; - } + if (append_reorder_data) { + if (disk_PQ == 0) { + std::cout << "Error: It is not necessary to append data for reordering " + "when vectors are not compressed on disk." + << std::endl; + return -1; } + if (data_type != std::string("float")) { + std::cout << "Error: Appending data for reordering currently only " + "supported for float data type." + << std::endl; + return -1; + } + } - std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " + - std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " + - std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + - std::string(std::to_string(append_reorder_data)) + " " + - std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); + std::string params = std::string(std::to_string(R)) + " " + + std::string(std::to_string(L)) + " " + + std::string(std::to_string(B)) + " " + + std::string(std::to_string(M)) + " " + + std::string(std::to_string(num_threads)) + " " + + std::string(std::to_string(disk_PQ)) + " " + + std::string(std::to_string(append_reorder_data)) + " " + + std::string(std::to_string(build_PQ)) + " " + + std::string(std::to_string(QD)); - try - { - if (label_file != "" && label_type == "ushort") - { - if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); - else if (data_type == std::string("float")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); - else - { - diskann::cerr << "Error. Unsupported data type" << std::endl; - return -1; - } - } - else - { - if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("float")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else - { - diskann::cerr << "Error. Unsupported data type" << std::endl; - return -1; - } - } - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; + try { + if (label_file != "" && label_type == "ushort") { + if (data_type == std::string("int8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } + } else { + if (data_type == std::string("int8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else { + diskann::cerr << "Error. Unsupported data type" << std::endl; return -1; + } } + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; + return -1; + } } diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 544e42dee..0efd75281 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include +#include +#include #include "index.h" -#include "utils.h" #include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include @@ -16,149 +16,155 @@ #include #endif -#include "memory_mapper.h" #include "ann_exception.h" #include "index_factory.h" +#include "memory_mapper.h" namespace po = boost::program_options; -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - uint32_t num_threads, R, L, Lf, build_PQ_bytes; - float alpha; - bool use_pq_build, use_opq; - - po::options_description desc{ - program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ_bytes)->default_value(0), - program_options_utils::BUIlD_GRAPH_PQ_BYTES); - optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), - program_options_utils::USE_OPQ); - optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - - optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), - program_options_utils::FILTERED_LBUILD); - optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - use_pq_build = (build_PQ_bytes > 0); - use_opq = vm["use_opq"].as(); +int main(int argc, char **argv) { + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, + universal_label, label_type; + uint32_t num_threads, R, L, Lf, build_PQ_bytes; + float alpha; + bool use_pq_build, use_opq; + + po::options_description desc{program_options_utils::make_program_description( + "build_memory_index", "Build a memory-based DiskANN index.")}; + try { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", + po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()( + "Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()( + "alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()( + "build_PQ_bytes", + po::value(&build_PQ_bytes)->default_value(0), + program_options_utils::BUIlD_GRAPH_PQ_BYTES); + optional_configs.add_options()("use_opq", + po::bool_switch()->default_value(false), + program_options_utils::USE_OPQ); + optional_configs.add_options()( + "label_file", po::value(&label_file)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + + optional_configs.add_options()("FilteredLbuild", + po::value(&Lf)->default_value(0), + program_options_utils::FILTERED_LBUILD); + optional_configs.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("mips")) - { - metric = diskann::Metric::INNER_PRODUCT; - } - else if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; - } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else - { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." + po::notify(vm); + use_pq_build = (build_PQ_bytes > 0); + use_opq = vm["use_opq"].as(); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("mips")) { + metric = diskann::Metric::INNER_PRODUCT; + } else if (dist_fn == std::string("l2")) { + metric = diskann::Metric::L2; + } else if (dist_fn == std::string("cosine")) { + metric = diskann::Metric::COSINE; + } else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." + << std::endl; + return -1; + } + + try { + diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L + << " alpha: " << alpha << " #threads: " << num_threads << std::endl; - return -1; - } - try - { - diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha - << " #threads: " << num_threads << std::endl; - - size_t data_num, data_dim; - diskann::get_bin_metadata(data_path, data_num, data_dim); - - auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) - .with_filter_list_size(Lf) - .with_alpha(alpha) - .with_saturate_graph(false) - .with_num_threads(num_threads) - .build(); - - auto filter_params = diskann::IndexFilterParamsBuilder() - .with_universal_label(universal_label) - .with_label_file(label_file) - .with_save_path_prefix(index_path_prefix) - .build(); - auto config = diskann::IndexConfigBuilder() - .with_metric(metric) - .with_dimension(data_dim) - .with_max_points(data_num) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .with_data_type(data_type) - .with_label_type(label_type) - .is_dynamic_index(false) - .with_index_write_params(index_build_params) - .is_enable_tags(false) - .is_use_opq(use_opq) - .is_pq_dist_build(use_pq_build) - .with_num_pq_chunks(build_PQ_bytes) - .build(); - - auto index_factory = diskann::IndexFactory(config); - auto index = index_factory.create_instance(); - index->build(data_path, data_num, filter_params); - index->save(index_path_prefix.c_str()); - index.reset(); - return 0; - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; - return -1; - } + size_t data_num, data_dim; + diskann::get_bin_metadata(data_path, data_num, data_dim); + + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(index_path_prefix) + .build(); + auto config = + diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(data_dim) + .with_max_points(data_num) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .with_data_type(data_type) + .with_label_type(label_type) + .is_dynamic_index(false) + .with_index_write_params(index_build_params) + .is_enable_tags(false) + .is_use_opq(use_opq) + .is_pq_dist_build(use_pq_build) + .with_num_pq_chunks(build_PQ_bytes) + .build(); + + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->build(data_path, data_num, filter_params); + index->save(index_path_prefix.c_str()); + index.reset(); + return 0; + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; + return -1; + } } diff --git a/apps/build_stitched_index.cpp b/apps/build_stitched_index.cpp index 60e38c1be..9b09a062b 100644 --- a/apps/build_stitched_index.cpp +++ b/apps/build_stitched_index.cpp @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include "filter_utils.h" #include #include #include #include +#include #include #include #include -#include "filter_utils.h" -#include #ifndef _WINDOWS #include #endif @@ -17,33 +17,32 @@ #include "index.h" #include "memory_mapper.h" #include "parameters.h" -#include "utils.h" #include "program_options_utils.hpp" +#include "utils.h" namespace po = boost::program_options; -typedef std::tuple>, uint64_t> stitch_indices_return_values; +typedef std::tuple>, uint64_t> + stitch_indices_return_values; /* * Inline function to display progress bar. */ -inline void print_progress(double percentage) -{ - int val = (int)(percentage * 100); - int lpad = (int)(percentage * PBWIDTH); - int rpad = PBWIDTH - lpad; - printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); - fflush(stdout); +inline void print_progress(double percentage) { + int val = (int)(percentage * 100); + int lpad = (int)(percentage * PBWIDTH); + int rpad = PBWIDTH - lpad; + printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); + fflush(stdout); } /* * Inline function to generate a random integer in a range. */ -inline size_t random(size_t range_from, size_t range_to) -{ - std::random_device rand_dev; - std::mt19937 generator(rand_dev()); - std::uniform_int_distribution distr(range_from, range_to); - return distr(generator); +inline size_t random(size_t range_from, size_t range_to) { + std::random_device rand_dev; + std::mt19937 generator(rand_dev()); + std::uniform_int_distribution distr(range_from, range_to); + return distr(generator); } /* @@ -51,61 +50,70 @@ inline size_t random(size_t range_from, size_t range_to) * * Arguments are merely the inputs from the command line. */ -void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix, - path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L, - uint32_t &stitched_R, float &alpha) -{ - po::options_description desc{ - program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("index_path_prefix", - po::value(&final_index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("data_path", po::value(&input_data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()("label_file", po::value(&label_data_path)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - optional_configs.add_options()("stitched_R", po::value(&stitched_R)->default_value(100), - "Degree to prune final graph down to"); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - exit(0); - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - throw; +void handle_args(int argc, char **argv, std::string &data_type, + path &input_data_path, path &final_index_path_prefix, + path &label_data_path, std::string &universal_label, + uint32_t &num_threads, uint32_t &R, uint32_t &L, + uint32_t &stitched_R, float &alpha) { + po::options_description desc{program_options_utils::make_program_description( + "build_stitched_index", "Build a stitched DiskANN index.")}; + try { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&final_index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "data_path", po::value(&input_data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", + po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()( + "Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()( + "alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()( + "label_file", + po::value(&label_data_path)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + optional_configs.add_options()( + "stitched_R", po::value(&stitched_R)->default_value(100), + "Degree to prune final graph down to"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + exit(0); } + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + throw; + } } /* @@ -116,98 +124,110 @@ void handle_args(int argc, char **argv, std::string &data_type, path &input_data * 3. data (redundant for static indices) * 4. labels (redundant for static indices) */ -void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size, +void save_full_index(path final_index_path_prefix, path input_data_path, + uint64_t final_index_size, std::vector> stitched_graph, - tsl::robin_map entry_points, std::string universal_label, - path label_data_path) -{ - // aux. file 1 - auto saving_index_timer = std::chrono::high_resolution_clock::now(); - std::ifstream original_label_data_stream; - original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - original_label_data_stream.open(label_data_path, std::ios::binary); - std::ofstream new_label_data_stream; - new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary); - new_label_data_stream << original_label_data_stream.rdbuf(); - original_label_data_stream.close(); - new_label_data_stream.close(); - - // aux. file 2 - std::ifstream original_input_data_stream; - original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - original_input_data_stream.open(input_data_path, std::ios::binary); - std::ofstream new_input_data_stream; - new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary); - new_input_data_stream << original_input_data_stream.rdbuf(); - original_input_data_stream.close(); - new_input_data_stream.close(); - - // aux. file 3 - std::ofstream labels_to_medoids_writer; - labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); - labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt"); - for (auto iter : entry_points) - labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; - labels_to_medoids_writer.close(); - - // aux. file 4 (only if we're using a universal label) - if (universal_label != "") - { - std::ofstream universal_label_writer; - universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); - universal_label_writer.open(final_index_path_prefix + "_universal_label.txt"); - universal_label_writer << universal_label << std::endl; - universal_label_writer.close(); - } - - // main index - uint64_t index_num_frozen_points = 0, index_num_edges = 0; - uint32_t index_max_observed_degree = 0, index_entry_point = 0; - const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); - for (auto &point_neighbors : stitched_graph) - { - index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size()); + tsl::robin_map entry_points, + std::string universal_label, path label_data_path) { + // aux. file 1 + auto saving_index_timer = std::chrono::high_resolution_clock::now(); + std::ifstream original_label_data_stream; + original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_label_data_stream.open(label_data_path, std::ios::binary); + std::ofstream new_label_data_stream; + new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_label_data_stream.open(final_index_path_prefix + "_labels.txt", + std::ios::binary); + new_label_data_stream << original_label_data_stream.rdbuf(); + original_label_data_stream.close(); + new_label_data_stream.close(); + + // aux. file 2 + std::ifstream original_input_data_stream; + original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_input_data_stream.open(input_data_path, std::ios::binary); + std::ofstream new_input_data_stream; + new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_input_data_stream.open(final_index_path_prefix + ".data", + std::ios::binary); + new_input_data_stream << original_input_data_stream.rdbuf(); + original_input_data_stream.close(); + new_input_data_stream.close(); + + // aux. file 3 + std::ofstream labels_to_medoids_writer; + labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); + labels_to_medoids_writer.open(final_index_path_prefix + + "_labels_to_medoids.txt"); + for (auto iter : entry_points) + labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; + labels_to_medoids_writer.close(); + + // aux. file 4 (only if we're using a universal label) + if (universal_label != "") { + std::ofstream universal_label_writer; + universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); + universal_label_writer.open(final_index_path_prefix + + "_universal_label.txt"); + universal_label_writer << universal_label << std::endl; + universal_label_writer.close(); + } + + // main index + uint64_t index_num_frozen_points = 0, index_num_edges = 0; + uint32_t index_max_observed_degree = 0, index_entry_point = 0; + const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + for (auto &point_neighbors : stitched_graph) { + index_max_observed_degree = + std::max(index_max_observed_degree, (uint32_t)point_neighbors.size()); + } + + std::ofstream stitched_graph_writer; + stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); + stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); + + stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t)); + stitched_graph_writer.write((char *)&index_max_observed_degree, + sizeof(uint32_t)); + stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t)); + stitched_graph_writer.write((char *)&index_num_frozen_points, + sizeof(uint64_t)); + + size_t bytes_written = METADATA; + for (uint32_t node_point = 0; node_point < stitched_graph.size(); + node_point++) { + uint32_t current_node_num_neighbors = + (uint32_t)stitched_graph[node_point].size(); + std::vector current_node_neighbors = stitched_graph[node_point]; + stitched_graph_writer.write((char *)¤t_node_num_neighbors, + sizeof(uint32_t)); + bytes_written += sizeof(uint32_t); + for (const auto ¤t_node_neighbor : current_node_neighbors) { + stitched_graph_writer.write((char *)¤t_node_neighbor, + sizeof(uint32_t)); + bytes_written += sizeof(uint32_t); } + index_num_edges += current_node_num_neighbors; + } - std::ofstream stitched_graph_writer; - stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); - stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); - - stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t)); - stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t)); - stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t)); - stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t)); - - size_t bytes_written = METADATA; - for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++) - { - uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size(); - std::vector current_node_neighbors = stitched_graph[node_point]; - stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t)); - bytes_written += sizeof(uint32_t); - for (const auto ¤t_node_neighbor : current_node_neighbors) - { - stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t)); - bytes_written += sizeof(uint32_t); - } - index_num_edges += current_node_num_neighbors; - } - - if (bytes_written != final_index_size) - { - std::cerr << "Error: written bytes does not match allocated space" << std::endl; - throw; - } - - stitched_graph_writer.close(); - - std::chrono::duration saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer; - std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl; - std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size())) + if (bytes_written != final_index_size) { + std::cerr << "Error: written bytes does not match allocated space" << std::endl; - std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl; + throw; + } + + stitched_graph_writer.close(); + + std::chrono::duration saving_index_time = + std::chrono::high_resolution_clock::now() - saving_index_timer; + std::cout << "Stitched graph written in " << saving_index_time.count() + << " seconds" << std::endl; + std::cout << "Stitched graph average degree: " + << ((float)index_num_edges) / ((float)(stitched_graph.size())) + << std::endl; + std::cout << "Stitched graph max degree: " << index_max_observed_degree + << std::endl + << std::endl; } /* @@ -218,52 +238,55 @@ void save_full_index(path final_index_path_prefix, path input_data_path, uint64_ */ template stitch_indices_return_values stitch_label_indices( - path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels, + path final_index_path_prefix, uint32_t total_number_of_points, + label_set all_labels, tsl::robin_map labels_to_number_of_points, tsl::robin_map &label_entry_points, - tsl::robin_map> label_id_to_orig_id_map) -{ - size_t final_index_size = 0; - std::vector> stitched_graph(total_number_of_points); - - auto stitching_index_timer = std::chrono::high_resolution_clock::now(); - for (const auto &lbl : all_labels) - { - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - std::vector> curr_label_index; - uint64_t curr_label_index_size; - uint32_t curr_label_entry_point; - - std::tie(curr_label_index, curr_label_index_size) = - diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]); - curr_label_entry_point = (uint32_t)random(0, curr_label_index.size()); - label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point]; - - for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++) - { - uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point]; - for (auto &node_neighbor : curr_label_index[node_point]) - { - uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor]; - std::vector curr_point_neighbors = stitched_graph[original_point_id]; - if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) == - curr_point_neighbors.end()) - { - stitched_graph[original_point_id].push_back(original_neighbor_id); - final_index_size += sizeof(uint32_t); - } - } + tsl::robin_map> + label_id_to_orig_id_map) { + size_t final_index_size = 0; + std::vector> stitched_graph(total_number_of_points); + + auto stitching_index_timer = std::chrono::high_resolution_clock::now(); + for (const auto &lbl : all_labels) { + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + std::vector> curr_label_index; + uint64_t curr_label_index_size; + uint32_t curr_label_entry_point; + + std::tie(curr_label_index, curr_label_index_size) = + diskann::load_label_index(curr_label_index_path, + labels_to_number_of_points[lbl]); + curr_label_entry_point = (uint32_t)random(0, curr_label_index.size()); + label_entry_points[lbl] = + label_id_to_orig_id_map[lbl][curr_label_entry_point]; + + for (uint32_t node_point = 0; node_point < curr_label_index.size(); + node_point++) { + uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point]; + for (auto &node_neighbor : curr_label_index[node_point]) { + uint32_t original_neighbor_id = + label_id_to_orig_id_map[lbl][node_neighbor]; + std::vector curr_point_neighbors = + stitched_graph[original_point_id]; + if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), + original_neighbor_id) == curr_point_neighbors.end()) { + stitched_graph[original_point_id].push_back(original_neighbor_id); + final_index_size += sizeof(uint32_t); } + } } + } - const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); - final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA); + const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA); - std::chrono::duration stitching_index_time = - std::chrono::high_resolution_clock::now() - stitching_index_timer; - std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl; + std::chrono::duration stitching_index_time = + std::chrono::high_resolution_clock::now() - stitching_index_timer; + std::cout << "stitched graph generated in memory in " + << stitching_index_time.count() << " seconds" << std::endl; - return std::make_tuple(stitched_graph, final_index_size); + return std::make_tuple(stitched_graph, final_index_size); } /* @@ -274,33 +297,39 @@ stitch_indices_return_values stitch_label_indices( * and pruned graph. */ template -void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path, - std::vector> stitched_graph, uint32_t stitched_R, - tsl::robin_map label_entry_points, std::string universal_label, - path label_data_path, uint32_t num_threads) -{ - size_t dimension, number_of_label_points; - auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); - auto std_cout_buffer = std::cout.rdbuf(nullptr); - auto pruning_index_timer = std::chrono::high_resolution_clock::now(); - - diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); - - diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false, - false, false, 0, false); - - // not searching this index, set search_l to 0 - index.load(full_index_path_prefix.c_str(), num_threads, 1); - - std::cout << "parsing labels" << std::endl; - - index.prune_all_neighbors(stitched_R, 750, 1.2); - index.save((final_index_path_prefix).c_str()); - - diskann::cout.rdbuf(diskann_cout_buffer); - std::cout.rdbuf(std_cout_buffer); - std::chrono::duration pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer; - std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl; +void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, + path input_data_path, + std::vector> stitched_graph, + uint32_t stitched_R, + tsl::robin_map label_entry_points, + std::string universal_label, path label_data_path, + uint32_t num_threads) { + size_t dimension, number_of_label_points; + auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); + auto std_cout_buffer = std::cout.rdbuf(nullptr); + auto pruning_index_timer = std::chrono::high_resolution_clock::now(); + + diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); + + diskann::Index index(diskann::Metric::L2, dimension, + number_of_label_points, nullptr, nullptr, 0, false, + false, false, false, 0, false); + + // not searching this index, set search_l to 0 + index.load(full_index_path_prefix.c_str(), num_threads, 1); + + std::cout << "parsing labels" << std::endl; + + index.prune_all_neighbors(stitched_R, 750, 1.2); + index.save((final_index_path_prefix).c_str()); + + diskann::cout.rdbuf(diskann_cout_buffer); + std::cout.rdbuf(std_cout_buffer); + std::chrono::duration pruning_index_time = + std::chrono::high_resolution_clock::now() - pruning_index_timer; + std::cout << "pruning performed in " << pruning_index_time.count() + << " seconds\n" + << std::endl; } /* @@ -311,131 +340,160 @@ void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, p * 2. the separate diskANN indices built for each label * 3. the '.data' file created while generating the indices */ -void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels) -{ - for (const auto &lbl : all_labels) - { - path curr_label_input_data_path(input_data_path + "_" + lbl); - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - path curr_label_index_path_data(curr_label_index_path + ".data"); - - if (std::remove(curr_label_index_path.c_str()) != 0) - throw; - if (std::remove(curr_label_input_data_path.c_str()) != 0) - throw; - if (std::remove(curr_label_index_path_data.c_str()) != 0) - throw; - } +void clean_up_artifacts(path input_data_path, path final_index_path_prefix, + label_set all_labels) { + for (const auto &lbl : all_labels) { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + path curr_label_index_path_data(curr_label_index_path + ".data"); + + if (std::remove(curr_label_index_path.c_str()) != 0) + throw; + if (std::remove(curr_label_input_data_path.c_str()) != 0) + throw; + if (std::remove(curr_label_index_path_data.c_str()) != 0) + throw; + } } -int main(int argc, char **argv) -{ - // 1. handle cmdline inputs - std::string data_type; - path input_data_path, final_index_path_prefix, label_data_path; - std::string universal_label; - uint32_t num_threads, R, L, stitched_R; - float alpha; +int main(int argc, char **argv) { + // 1. handle cmdline inputs + std::string data_type; + path input_data_path, final_index_path_prefix, label_data_path; + std::string universal_label; + uint32_t num_threads, R, L, stitched_R; + float alpha; - auto index_timer = std::chrono::high_resolution_clock::now(); - handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label, - num_threads, R, L, stitched_R, alpha); + auto index_timer = std::chrono::high_resolution_clock::now(); + handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, + label_data_path, universal_label, num_threads, R, L, stitched_R, + alpha); - path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; - path labels_map_file = final_index_path_prefix + "_labels_map.txt"; + path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; + path labels_map_file = final_index_path_prefix + "_labels_map.txt"; - convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label); + convert_labels_string_to_int(label_data_path, labels_file_to_use, + labels_map_file, universal_label); - // 2. parse label file and create necessary data structures - std::vector point_ids_to_labels; - tsl::robin_map labels_to_number_of_points; - label_set all_labels; + // 2. parse label file and create necessary data structures + std::vector point_ids_to_labels; + tsl::robin_map labels_to_number_of_points; + label_set all_labels; - std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = - diskann::parse_label_file(labels_file_to_use, universal_label); + std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = + diskann::parse_label_file(labels_file_to_use, universal_label); - // 3. for each label, make a separate data file - tsl::robin_map> label_id_to_orig_id_map; - uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size(); + // 3. for each label, make a separate data file + tsl::robin_map> label_id_to_orig_id_map; + uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size(); #ifndef _WINDOWS - if (data_type == "uint8") - label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "int8") - label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "float") - label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else - throw; + if (data_type == "uint8") + label_id_to_orig_id_map = + diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = + diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = + diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else + throw; #else - if (data_type == "uint8") - label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "int8") - label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "float") - label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else - throw; + if (data_type == "uint8") + label_id_to_orig_id_map = + diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = + diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = + diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else + throw; #endif - // 4. for each created data file, create a vanilla diskANN index - if (data_type == "uint8") - diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, - num_threads); - else if (data_type == "int8") - diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, - num_threads); - else if (data_type == "float") - diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, - num_threads); - else - throw; - - // 5. "stitch" the indices together - std::vector> stitched_graph; - tsl::robin_map label_entry_points; - uint64_t stitched_graph_size; - - if (data_type == "uint8") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); - else if (data_type == "int8") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); - else if (data_type == "float") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); - else - throw; - path full_index_path_prefix = final_index_path_prefix + "_full"; - // 5a. save the stitched graph to disk - save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points, - universal_label, labels_file_to_use); - - // 6. run a prune on the stitched index, and save to disk - if (data_type == "uint8") - prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, - stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); - else if (data_type == "int8") - prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, - stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); - else if (data_type == "float") - prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, - stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); - else - throw; - - std::chrono::duration index_time = std::chrono::high_resolution_clock::now() - index_timer; - std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl; - - clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); + // 4. for each created data file, create a vanilla diskANN index + if (data_type == "uint8") + diskann::generate_label_indices( + input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else if (data_type == "int8") + diskann::generate_label_indices(input_data_path, + final_index_path_prefix, all_labels, + R, L, alpha, num_threads); + else if (data_type == "float") + diskann::generate_label_indices(input_data_path, + final_index_path_prefix, all_labels, + R, L, alpha, num_threads); + else + throw; + + // 5. "stitch" the indices together + std::vector> stitched_graph; + tsl::robin_map label_entry_points; + uint64_t stitched_graph_size; + + if (data_type == "uint8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices( + final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, + label_id_to_orig_id_map); + else if (data_type == "int8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices( + final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, + label_id_to_orig_id_map); + else if (data_type == "float") + std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices( + final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, + label_id_to_orig_id_map); + else + throw; + path full_index_path_prefix = final_index_path_prefix + "_full"; + // 5a. save the stitched graph to disk + save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, + stitched_graph, label_entry_points, universal_label, + labels_file_to_use); + + // 6. run a prune on the stitched index, and save to disk + if (data_type == "uint8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, + input_data_path, stitched_graph, stitched_R, + label_entry_points, universal_label, + labels_file_to_use, num_threads); + else if (data_type == "int8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, + input_data_path, stitched_graph, stitched_R, + label_entry_points, universal_label, + labels_file_to_use, num_threads); + else if (data_type == "float") + prune_and_save(final_index_path_prefix, full_index_path_prefix, + input_data_path, stitched_graph, stitched_R, + label_entry_points, universal_label, + labels_file_to_use, num_threads); + else + throw; + + std::chrono::duration index_time = + std::chrono::high_resolution_clock::now() - index_timer; + std::cout << "pruned/stitched graph generated in " << index_time.count() + << " seconds" << std::endl; + + clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); } diff --git a/apps/range_search_disk_index.cpp b/apps/range_search_disk_index.cpp index 31675724b..bfdfafa03 100644 --- a/apps/range_search_disk_index.cpp +++ b/apps/range_search_disk_index.cpp @@ -2,26 +2,26 @@ // Licensed under the MIT license. #include +#include #include #include #include #include -#include -#include "index.h" #include "disk_utils.h" +#include "index.h" #include "math_utils.h" #include "memory_mapper.h" -#include "pq_flash_index.h" #include "partition.h" -#include "timer.h" +#include "pq_flash_index.h" #include "program_options_utils.hpp" +#include "timer.h" #ifndef _WINDOWS +#include "linux_aligned_file_reader.h" #include #include #include -#include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" @@ -34,346 +34,348 @@ namespace po = boost::program_options; #define WARMUP false -void print_stats(std::string category, std::vector percentiles, std::vector results) -{ - diskann::cout << std::setw(20) << category << ": " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) - { - diskann::cout << std::setw(8) << percentiles[s] << "%"; - } - diskann::cout << std::endl; - diskann::cout << std::setw(22) << " " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) - { - diskann::cout << std::setw(9) << results[s]; - } - diskann::cout << std::endl; +void print_stats(std::string category, std::vector percentiles, + std::vector results) { + diskann::cout << std::setw(20) << category << ": " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) { + diskann::cout << std::setw(8) << percentiles[s] << "%"; + } + diskann::cout << std::endl; + diskann::cout << std::setw(22) << " " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) { + diskann::cout << std::setw(9) << results[s]; + } + diskann::cout << std::endl; } template -int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file, - std::string >_file, const uint32_t num_threads, const float search_range, - const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector &Lvec) -{ - std::string pq_prefix = index_path_prefix + "_pq"; - std::string disk_index_file = index_path_prefix + "_disk.index"; - std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; - - diskann::cout << "Search parameters: #threads: " << num_threads << ", "; - if (beamwidth <= 0) - diskann::cout << "beamwidth to be optimized for each L value" << std::endl; - else - diskann::cout << " beamwidth: " << beamwidth << std::endl; - - // load query bin - T *query = nullptr; - std::vector> groundtruth_ids; - size_t query_num, query_dim, query_aligned_dim, gt_num; - diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); - - bool calc_recall_flag = false; - if (gt_file != std::string("null") && file_exists(gt_file)) - { - diskann::load_range_truthset(gt_file, groundtruth_ids, - gt_num); // use for range search type of truthset - // diskann::prune_truthset_for_range(gt_file, search_range, - // groundtruth_ids, gt_num); // use for traditional truthset - if (gt_num != query_num) - { - diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; - return -1; - } - calc_recall_flag = true; +int search_disk_index(diskann::Metric &metric, + const std::string &index_path_prefix, + const std::string &query_file, std::string >_file, + const uint32_t num_threads, const float search_range, + const uint32_t beamwidth, + const uint32_t num_nodes_to_cache, + const std::vector &Lvec) { + std::string pq_prefix = index_path_prefix + "_pq"; + std::string disk_index_file = index_path_prefix + "_disk.index"; + std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; + + diskann::cout << "Search parameters: #threads: " << num_threads << ", "; + if (beamwidth <= 0) + diskann::cout << "beamwidth to be optimized for each L value" << std::endl; + else + diskann::cout << " beamwidth: " << beamwidth << std::endl; + + // load query bin + T *query = nullptr; + std::vector> groundtruth_ids; + size_t query_num, query_dim, query_aligned_dim, gt_num; + diskann::load_aligned_bin(query_file, query, query_num, query_dim, + query_aligned_dim); + + bool calc_recall_flag = false; + if (gt_file != std::string("null") && file_exists(gt_file)) { + diskann::load_range_truthset( + gt_file, groundtruth_ids, + gt_num); // use for range search type of truthset + // diskann::prune_truthset_for_range(gt_file, search_range, + // groundtruth_ids, gt_num); // use for traditional truthset + if (gt_num != query_num) { + diskann::cout + << "Error. Mismatch in number of queries and ground truth data" + << std::endl; + return -1; } + calc_recall_flag = true; + } - std::shared_ptr reader = nullptr; + std::shared_ptr reader = nullptr; #ifdef _WINDOWS #ifndef USE_BING_INFRA - reader.reset(new WindowsAlignedFileReader()); + reader.reset(new WindowsAlignedFileReader()); #else - reader.reset(new diskann::BingAlignedFileReader()); + reader.reset(new diskann::BingAlignedFileReader()); #endif #else - reader.reset(new LinuxAlignedFileReader()); + reader.reset(new LinuxAlignedFileReader()); #endif - std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader, metric)); - - int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); - - if (res != 0) - { - return res; - } - // cache bfs levels - std::vector node_list; - diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl; - _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); - // _pFlashIndex->generate_cache_list_from_sample_queries( - // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, - // node_list); - _pFlashIndex->load_cache_list(node_list); - node_list.clear(); - node_list.shrink_to_fit(); - - omp_set_num_threads(num_threads); - - uint64_t warmup_L = 20; - uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; - T *warmup = nullptr; - - if (WARMUP) - { - if (file_exists(warmup_query_file)) - { - diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); - } - else - { - warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); - warmup_dim = query_dim; - warmup_aligned_dim = query_aligned_dim; - diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); - std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(-128, 127); - for (uint32_t i = 0; i < warmup_num; i++) - { - for (uint32_t d = 0; d < warmup_dim; d++) - { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); - } - } - } - diskann::cout << "Warming up index... " << std::flush; - std::vector warmup_result_ids_64(warmup_num, 0); - std::vector warmup_result_dists(warmup_num, 0); - -#pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)warmup_num; i++) - { - _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, - warmup_result_ids_64.data() + (i * 1), - warmup_result_dists.data() + (i * 1), 4); + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); + + int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); + + if (res != 0) { + return res; + } + // cache bfs levels + std::vector node_list; + diskann::cout << "Caching " << num_nodes_to_cache + << " BFS nodes around medoid(s)" << std::endl; + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // _pFlashIndex->generate_cache_list_from_sample_queries( + // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, + // node_list); + _pFlashIndex->load_cache_list(node_list); + node_list.clear(); + node_list.shrink_to_fit(); + + omp_set_num_threads(num_threads); + + uint64_t warmup_L = 20; + uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; + T *warmup = nullptr; + + if (WARMUP) { + if (file_exists(warmup_query_file)) { + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, + warmup_dim, warmup_aligned_dim); + } else { + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); + warmup_dim = query_dim; + warmup_aligned_dim = query_aligned_dim; + diskann::alloc_aligned(((void **)&warmup), + warmup_num * warmup_aligned_dim * sizeof(T), + 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) { + for (uint32_t d = 0; d < warmup_dim; d++) { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); } - diskann::cout << "..done" << std::endl; + } } + diskann::cout << "Warming up index... " << std::flush; + std::vector warmup_result_ids_64(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); - diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - diskann::cout.precision(2); - - std::string recall_string = "Recall@rng=" + std::to_string(search_range); - diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16) - << "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16) - << "CPU (s)"; - if (calc_recall_flag) - { - diskann::cout << std::setw(16) << recall_string << std::endl; +#pragma omp parallel for schedule(dynamic, 1) + for (int64_t i = 0; i < (int64_t)warmup_num; i++) { + _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, + warmup_L, + warmup_result_ids_64.data() + (i * 1), + warmup_result_dists.data() + (i * 1), 4); } - else - diskann::cout << std::endl; - diskann::cout << "===============================================================" - "===========================================" - << std::endl; + diskann::cout << "..done" << std::endl; + } + + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + diskann::cout.precision(2); + + std::string recall_string = "Recall@rng=" + std::to_string(search_range); + diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" + << std::setw(16) << "QPS" << std::setw(16) << "Mean Latency" + << std::setw(16) << "99.9 Latency" << std::setw(16) + << "Mean IOs" << std::setw(16) << "CPU (s)"; + if (calc_recall_flag) { + diskann::cout << std::setw(16) << recall_string << std::endl; + } else + diskann::cout << std::endl; + diskann::cout + << "===============================================================" + "===========================================" + << std::endl; - std::vector>> query_result_ids(Lvec.size()); + std::vector>> query_result_ids(Lvec.size()); - uint32_t optimized_beamwidth = 2; - uint32_t max_list_size = 10000; + uint32_t optimized_beamwidth = 2; + uint32_t max_list_size = 10000; - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) - { - uint32_t L = Lvec[test_id]; + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { + uint32_t L = Lvec[test_id]; - if (beamwidth <= 0) - { - optimized_beamwidth = - optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); - } - else - optimized_beamwidth = beamwidth; + if (beamwidth <= 0) { + optimized_beamwidth = + optimize_beamwidth(_pFlashIndex, warmup, warmup_num, + warmup_aligned_dim, L, optimized_beamwidth); + } else + optimized_beamwidth = beamwidth; - query_result_ids[test_id].clear(); - query_result_ids[test_id].resize(query_num); + query_result_ids[test_id].clear(); + query_result_ids[test_id].resize(query_num); - diskann::QueryStats *stats = new diskann::QueryStats[query_num]; + diskann::QueryStats *stats = new diskann::QueryStats[query_num]; - auto s = std::chrono::high_resolution_clock::now(); + auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) - { - std::vector indices; - std::vector distances; - uint32_t res_count = - _pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices, - distances, optimized_beamwidth, stats + i); - query_result_ids[test_id][i].reserve(res_count); - query_result_ids[test_id][i].resize(res_count); - for (uint32_t idx = 0; idx < res_count; idx++) - query_result_ids[test_id][i][idx] = (uint32_t)indices[idx]; - } - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - auto qps = (1.0 * query_num) / (1.0 * diff.count()); - - auto mean_latency = diskann::get_mean_stats( - stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto latency_999 = diskann::get_percentile_stats( - stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto mean_ios = diskann::get_mean_stats(stats, query_num, - [](const diskann::QueryStats &stats) { return stats.n_ios; }); - - double mean_cpuus = diskann::get_mean_stats( - stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - - double recall = 0; - double ratio_of_sums = 0; - if (calc_recall_flag) - { - recall = - diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]); - - uint32_t total_true_positive = 0; - uint32_t total_positive = 0; - for (uint32_t i = 0; i < query_num; i++) - { - total_true_positive += (uint32_t)query_result_ids[test_id][i].size(); - total_positive += (uint32_t)groundtruth_ids[i].size(); - } - - ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive); - } - - diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps - << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios - << std::setw(16) << mean_cpuus; - if (calc_recall_flag) - { - diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl; - } - else - diskann::cout << std::endl; - } - - diskann::cout << "Done searching. " << std::endl; - - diskann::aligned_free(query); - if (warmup != nullptr) - diskann::aligned_free(warmup); - return 0; -} - -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file; - uint32_t num_threads, W, num_nodes_to_cache; - std::vector Lvec; - float range; - - po::options_description desc{program_options_utils::make_program_description( - "range_search_disk_index", "Searches disk DiskANN indexes using ranges")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("query_file", po::value(&query_file)->required(), - program_options_utils::QUERY_FILE_DESCRIPTION); - required_configs.add_options()("search_list,L", - po::value>(&Lvec)->multitoken()->required(), - program_options_utils::SEARCH_LIST_DESCRIPTION); - required_configs.add_options()("range_threshold,K", po::value(&range)->required(), - "Number of neighbors to be returned"); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), - program_options_utils::NUMBER_OF_NODES_TO_CACHE); - optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), - program_options_utils::BEAMWIDTH); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); + for (int64_t i = 0; i < (int64_t)query_num; i++) { + std::vector indices; + std::vector distances; + uint32_t res_count = _pFlashIndex->range_search( + query + (i * query_aligned_dim), search_range, L, max_list_size, + indices, distances, optimized_beamwidth, stats + i); + query_result_ids[test_id][i].reserve(res_count); + query_result_ids[test_id][i].resize(res_count); + for (uint32_t idx = 0; idx < res_count; idx++) + query_result_ids[test_id][i][idx] = (uint32_t)indices[idx]; } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + auto qps = (1.0 * query_num) / (1.0 * diff.count()); + + auto mean_latency = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto latency_999 = diskann::get_percentile_stats( + stats, query_num, 0.999, + [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto mean_ios = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats &stats) { return stats.n_ios; }); + + double mean_cpuus = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats &stats) { return stats.cpu_us; }); + + double recall = 0; + double ratio_of_sums = 0; + if (calc_recall_flag) { + recall = diskann::calculate_range_search_recall( + (uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]); + + uint32_t total_true_positive = 0; + uint32_t total_positive = 0; + for (uint32_t i = 0; i < query_num; i++) { + total_true_positive += (uint32_t)query_result_ids[test_id][i].size(); + total_positive += (uint32_t)groundtruth_ids[i].size(); + } + + ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive); } - diskann::Metric metric; - if (dist_fn == std::string("mips")) - { - metric = diskann::Metric::INNER_PRODUCT; - } - else if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; - } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else - { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." - << std::endl; - return -1; - } - - if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) - { - std::cout << "Currently support only floating point data for Inner Product." << std::endl; - return -1; - } + diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth + << std::setw(16) << qps << std::setw(16) << mean_latency + << std::setw(16) << latency_999 << std::setw(16) << mean_ios + << std::setw(16) << mean_cpuus; + if (calc_recall_flag) { + diskann::cout << std::setw(16) << recall << "," << ratio_of_sums + << std::endl; + } else + diskann::cout << std::endl; + } + + diskann::cout << "Done searching. " << std::endl; + + diskann::aligned_free(query); + if (warmup != nullptr) + diskann::aligned_free(warmup); + return 0; +} - try - { - if (data_type == std::string("float")) - return search_disk_index(metric, index_path_prefix, query_file, gt_file, num_threads, range, W, - num_nodes_to_cache, Lvec); - else if (data_type == std::string("int8")) - return search_disk_index(metric, index_path_prefix, query_file, gt_file, num_threads, range, W, - num_nodes_to_cache, Lvec); - else if (data_type == std::string("uint8")) - return search_disk_index(metric, index_path_prefix, query_file, gt_file, num_threads, range, W, - num_nodes_to_cache, Lvec); - else - { - std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; - return -1; - } +int main(int argc, char **argv) { + std::string data_type, dist_fn, index_path_prefix, result_path_prefix, + query_file, gt_file; + uint32_t num_threads, W, num_nodes_to_cache; + std::vector Lvec; + float range; + + po::options_description desc{program_options_utils::make_program_description( + "range_search_disk_index", "Searches disk DiskANN indexes using ranges")}; + try { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()( + "search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + required_configs.add_options()("range_threshold,K", + po::value(&range)->required(), + "Number of neighbors to be returned"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()( + "gt_file", + po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()( + "num_nodes_to_cache", + po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()("beamwidth,W", + po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index search failed." << std::endl; - return -1; + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("mips")) { + metric = diskann::Metric::INNER_PRODUCT; + } else if (dist_fn == std::string("l2")) { + metric = diskann::Metric::L2; + } else if (dist_fn == std::string("cosine")) { + metric = diskann::Metric::COSINE; + } else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." + << std::endl; + return -1; + } + + if ((data_type != std::string("float")) && + (metric == diskann::Metric::INNER_PRODUCT)) { + std::cout << "Currently support only floating point data for Inner Product." + << std::endl; + return -1; + } + + try { + if (data_type == std::string("float")) + return search_disk_index(metric, index_path_prefix, query_file, + gt_file, num_threads, range, W, + num_nodes_to_cache, Lvec); + else if (data_type == std::string("int8")) + return search_disk_index(metric, index_path_prefix, query_file, + gt_file, num_threads, range, W, + num_nodes_to_cache, Lvec); + else if (data_type == std::string("uint8")) + return search_disk_index(metric, index_path_prefix, query_file, + gt_file, num_threads, range, W, + num_nodes_to_cache, Lvec); + else { + std::cerr << "Unsupported data type. Use float or int8 or uint8" + << std::endl; + return -1; } + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index search failed." << std::endl; + return -1; + } } diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..307fa8e7a 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -4,21 +4,21 @@ #include "common_includes.h" #include -#include "index.h" #include "disk_utils.h" +#include "index.h" #include "math_utils.h" #include "memory_mapper.h" #include "partition.h" -#include "pq_flash_index.h" -#include "timer.h" #include "percentile_stats.h" +#include "pq_flash_index.h" #include "program_options_utils.hpp" +#include "timer.h" #ifndef _WINDOWS +#include "linux_aligned_file_reader.h" #include #include #include -#include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" @@ -31,466 +31,476 @@ namespace po = boost::program_options; -void print_stats(std::string category, std::vector percentiles, std::vector results) -{ - diskann::cout << std::setw(20) << category << ": " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) - { - diskann::cout << std::setw(8) << percentiles[s] << "%"; - } - diskann::cout << std::endl; - diskann::cout << std::setw(22) << " " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) - { - diskann::cout << std::setw(9) << results[s]; - } - diskann::cout << std::endl; +void print_stats(std::string category, std::vector percentiles, + std::vector results) { + diskann::cout << std::setw(20) << category << ": " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) { + diskann::cout << std::setw(8) << percentiles[s] << "%"; + } + diskann::cout << std::endl; + diskann::cout << std::setw(22) << " " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) { + diskann::cout << std::setw(9) << results[s]; + } + diskann::cout << std::endl; } template -int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, - const std::string &result_output_prefix, const std::string &query_file, std::string >_file, - const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, - const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, - const std::vector &Lvec, const float fail_if_recall_below, - const std::vector &query_filters, const bool use_reorder_data = false) -{ - diskann::cout << "Search parameters: #threads: " << num_threads << ", "; - if (beamwidth <= 0) - diskann::cout << "beamwidth to be optimized for each L value" << std::flush; - else - diskann::cout << " beamwidth: " << beamwidth << std::flush; - if (search_io_limit == std::numeric_limits::max()) - diskann::cout << "." << std::endl; - else - diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; - - std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; - - // load query bin - T *query = nullptr; - uint32_t *gt_ids = nullptr; - float *gt_dists = nullptr; - size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; - diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); - - bool filtered_search = false; - if (!query_filters.empty()) - { - filtered_search = true; - if (query_filters.size() != 1 && query_filters.size() != query_num) - { - std::cout << "Error. Mismatch in number of queries and size of query " - "filters file" - << std::endl; - return -1; // To return -1 or some other error handling? - } +int search_disk_index( + diskann::Metric &metric, const std::string &index_path_prefix, + const std::string &result_output_prefix, const std::string &query_file, + std::string >_file, const uint32_t num_threads, const uint32_t recall_at, + const uint32_t beamwidth, const uint32_t num_nodes_to_cache, + const uint32_t search_io_limit, const std::vector &Lvec, + const float fail_if_recall_below, + const std::vector &query_filters, + const bool use_reorder_data = false) { + diskann::cout << "Search parameters: #threads: " << num_threads << ", "; + if (beamwidth <= 0) + diskann::cout << "beamwidth to be optimized for each L value" << std::flush; + else + diskann::cout << " beamwidth: " << beamwidth << std::flush; + if (search_io_limit == std::numeric_limits::max()) + diskann::cout << "." << std::endl; + else + diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; + + std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; + + // load query bin + T *query = nullptr; + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; + diskann::load_aligned_bin(query_file, query, query_num, query_dim, + query_aligned_dim); + + bool filtered_search = false; + if (!query_filters.empty()) { + filtered_search = true; + if (query_filters.size() != 1 && query_filters.size() != query_num) { + std::cout << "Error. Mismatch in number of queries and size of query " + "filters file" + << std::endl; + return -1; // To return -1 or some other error handling? } - - bool calc_recall_flag = false; - if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file)) - { - diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim); - if (gt_num != query_num) - { - diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; - } - calc_recall_flag = true; + } + + bool calc_recall_flag = false; + if (gt_file != std::string("null") && gt_file != std::string("NULL") && + file_exists(gt_file)) { + diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim); + if (gt_num != query_num) { + diskann::cout + << "Error. Mismatch in number of queries and ground truth data" + << std::endl; } + calc_recall_flag = true; + } - std::shared_ptr reader = nullptr; + std::shared_ptr reader = nullptr; #ifdef _WINDOWS #ifndef USE_BING_INFRA - reader.reset(new WindowsAlignedFileReader()); + reader.reset(new WindowsAlignedFileReader()); #else - reader.reset(new diskann::BingAlignedFileReader()); + reader.reset(new diskann::BingAlignedFileReader()); #endif #else - reader.reset(new LinuxAlignedFileReader()); + reader.reset(new LinuxAlignedFileReader()); #endif - std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader, metric)); - - int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); - - if (res != 0) - { - return res; - } - - std::vector node_list; - diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl; - _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); - // if (num_nodes_to_cache > 0) - // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache, - // num_threads, node_list); - _pFlashIndex->load_cache_list(node_list); - node_list.clear(); - node_list.shrink_to_fit(); - - omp_set_num_threads(num_threads); - - uint64_t warmup_L = 20; - uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; - T *warmup = nullptr; - - if (WARMUP) - { - if (file_exists(warmup_query_file)) - { - diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); - } - else - { - warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); - warmup_dim = query_dim; - warmup_aligned_dim = query_aligned_dim; - diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); - std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(-128, 127); - for (uint32_t i = 0; i < warmup_num; i++) - { - for (uint32_t d = 0; d < warmup_dim; d++) - { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); - } - } - } - diskann::cout << "Warming up index... " << std::flush; - std::vector warmup_result_ids_64(warmup_num, 0); - std::vector warmup_result_dists(warmup_num, 0); - -#pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)warmup_num; i++) - { - _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, - warmup_result_ids_64.data() + (i * 1), - warmup_result_dists.data() + (i * 1), 4); + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); + + int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); + + if (res != 0) { + return res; + } + + std::vector node_list; + diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" + << std::endl; + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // if (num_nodes_to_cache > 0) + // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, + // 15, 6, num_nodes_to_cache, num_threads, node_list); + _pFlashIndex->load_cache_list(node_list); + node_list.clear(); + node_list.shrink_to_fit(); + + omp_set_num_threads(num_threads); + + uint64_t warmup_L = 20; + uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; + T *warmup = nullptr; + + if (WARMUP) { + if (file_exists(warmup_query_file)) { + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, + warmup_dim, warmup_aligned_dim); + } else { + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); + warmup_dim = query_dim; + warmup_aligned_dim = query_aligned_dim; + diskann::alloc_aligned(((void **)&warmup), + warmup_num * warmup_aligned_dim * sizeof(T), + 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) { + for (uint32_t d = 0; d < warmup_dim; d++) { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); } - diskann::cout << "..done" << std::endl; + } } + diskann::cout << "Warming up index... " << std::flush; + std::vector warmup_result_ids_64(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); - diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - diskann::cout.precision(2); - - std::string recall_string = "Recall@" + std::to_string(recall_at); - diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16) - << "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16) - << "CPU (s)"; - if (calc_recall_flag) - { - diskann::cout << std::setw(16) << recall_string << std::endl; +#pragma omp parallel for schedule(dynamic, 1) + for (int64_t i = 0; i < (int64_t)warmup_num; i++) { + _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, + warmup_L, + warmup_result_ids_64.data() + (i * 1), + warmup_result_dists.data() + (i * 1), 4); } - else - diskann::cout << std::endl; - diskann::cout << "===============================================================" - "=======================================================" - << std::endl; + diskann::cout << "..done" << std::endl; + } + + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + diskann::cout.precision(2); + + std::string recall_string = "Recall@" + std::to_string(recall_at); + diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" + << std::setw(16) << "QPS" << std::setw(16) << "Mean Latency" + << std::setw(16) << "99.9 Latency" << std::setw(16) + << "Mean IOs" << std::setw(16) << "CPU (s)"; + if (calc_recall_flag) { + diskann::cout << std::setw(16) << recall_string << std::endl; + } else + diskann::cout << std::endl; + diskann::cout + << "===============================================================" + "=======================================================" + << std::endl; - std::vector> query_result_ids(Lvec.size()); - std::vector> query_result_dists(Lvec.size()); + std::vector> query_result_ids(Lvec.size()); + std::vector> query_result_dists(Lvec.size()); - uint32_t optimized_beamwidth = 2; + uint32_t optimized_beamwidth = 2; - double best_recall = 0.0; + double best_recall = 0.0; - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) - { - uint32_t L = Lvec[test_id]; + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { + uint32_t L = Lvec[test_id]; - if (L < recall_at) - { - diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; - continue; - } + if (L < recall_at) { + diskann::cout << "Ignoring search with L:" << L + << " since it's smaller than K:" << recall_at << std::endl; + continue; + } - if (beamwidth <= 0) - { - diskann::cout << "Tuning beamwidth.." << std::endl; - optimized_beamwidth = - optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); - } - else - optimized_beamwidth = beamwidth; + if (beamwidth <= 0) { + diskann::cout << "Tuning beamwidth.." << std::endl; + optimized_beamwidth = + optimize_beamwidth(_pFlashIndex, warmup, warmup_num, + warmup_aligned_dim, L, optimized_beamwidth); + } else + optimized_beamwidth = beamwidth; - query_result_ids[test_id].resize(recall_at * query_num); - query_result_dists[test_id].resize(recall_at * query_num); + query_result_ids[test_id].resize(recall_at * query_num); + query_result_dists[test_id].resize(recall_at * query_num); - auto stats = new diskann::QueryStats[query_num]; + auto stats = new diskann::QueryStats[query_num]; - std::vector query_result_ids_64(recall_at * query_num); - auto s = std::chrono::high_resolution_clock::now(); + std::vector query_result_ids_64(recall_at * query_num); + auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) - { - if (!filtered_search) - { - _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, - query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, use_reorder_data, stats + i); - } - else - { - LabelT label_for_search; - if (query_filters.size() == 1) - { // one label for all queries - label_for_search = _pFlashIndex->get_converted_label(query_filters[0]); - } - else - { // one label for each query - label_for_search = _pFlashIndex->get_converted_label(query_filters[i]); - } - _pFlashIndex->cached_beam_search( - query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, - use_reorder_data, stats + i); - } - } - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - double qps = (1.0 * query_num) / (1.0 * diff.count()); - - diskann::convert_types(query_result_ids_64.data(), query_result_ids[test_id].data(), - query_num, recall_at); - - auto mean_latency = diskann::get_mean_stats( - stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto latency_999 = diskann::get_percentile_stats( - stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto mean_ios = diskann::get_mean_stats(stats, query_num, - [](const diskann::QueryStats &stats) { return stats.n_ios; }); - - auto mean_cpuus = diskann::get_mean_stats(stats, query_num, - [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - - double recall = 0; - if (calc_recall_flag) - { - recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, recall_at); - best_recall = std::max(recall, best_recall); + for (int64_t i = 0; i < (int64_t)query_num; i++) { + if (!filtered_search) { + _pFlashIndex->cached_beam_search( + query + (i * query_aligned_dim), recall_at, L, + query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), + optimized_beamwidth, use_reorder_data, stats + i); + } else { + LabelT label_for_search; + if (query_filters.size() == 1) { // one label for all queries + label_for_search = + _pFlashIndex->get_converted_label(query_filters[0]); + } else { // one label for each query + label_for_search = + _pFlashIndex->get_converted_label(query_filters[i]); } - - diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps - << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios - << std::setw(16) << mean_cpuus; - if (calc_recall_flag) - { - diskann::cout << std::setw(16) << recall << std::endl; - } - else - diskann::cout << std::endl; - delete[] stats; + _pFlashIndex->cached_beam_search( + query + (i * query_aligned_dim), recall_at, L, + query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), + optimized_beamwidth, true, label_for_search, use_reorder_data, + stats + i); + } } - - diskann::cout << "Done searching. Now saving results " << std::endl; - uint64_t test_id = 0; - for (auto L : Lvec) - { - if (L < recall_at) - continue; - - std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; - diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); - - cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; - diskann::save_bin(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at); + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + double qps = (1.0 * query_num) / (1.0 * diff.count()); + + diskann::convert_types(query_result_ids_64.data(), + query_result_ids[test_id].data(), + query_num, recall_at); + + auto mean_latency = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto latency_999 = diskann::get_percentile_stats( + stats, query_num, 0.999, + [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto mean_ios = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats &stats) { return stats.n_ios; }); + + auto mean_cpuus = diskann::get_mean_stats( + stats, query_num, + [](const diskann::QueryStats &stats) { return stats.cpu_us; }); + + double recall = 0; + if (calc_recall_flag) { + recall = diskann::calculate_recall( + (uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, + query_result_ids[test_id].data(), recall_at, recall_at); + best_recall = std::max(recall, best_recall); } - diskann::aligned_free(query); - if (warmup != nullptr) - diskann::aligned_free(warmup); - return best_recall >= fail_if_recall_below ? 0 : -1; + diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth + << std::setw(16) << qps << std::setw(16) << mean_latency + << std::setw(16) << latency_999 << std::setw(16) << mean_ios + << std::setw(16) << mean_cpuus; + if (calc_recall_flag) { + diskann::cout << std::setw(16) << recall << std::endl; + } else + diskann::cout << std::endl; + delete[] stats; + } + + diskann::cout << "Done searching. Now saving results " << std::endl; + uint64_t test_id = 0; + for (auto L : Lvec) { + if (L < recall_at) + continue; + + std::string cur_result_path = + result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; + diskann::save_bin(cur_result_path, + query_result_ids[test_id].data(), query_num, + recall_at); + + cur_result_path = + result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; + diskann::save_bin(cur_result_path, + query_result_dists[test_id++].data(), query_num, + recall_at); + } + + diskann::aligned_free(query); + if (warmup != nullptr) + diskann::aligned_free(warmup); + return best_recall >= fail_if_recall_below ? 0 : -1; } -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, - label_type, query_filters_file; - uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; - std::vector Lvec; - bool use_reorder_data = false; - float fail_if_recall_below = 0.0f; - - po::options_description desc{ - program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("result_path", po::value(&result_path_prefix)->required(), - program_options_utils::RESULT_PATH_DESCRIPTION); - required_configs.add_options()("query_file", po::value(&query_file)->required(), - program_options_utils::QUERY_FILE_DESCRIPTION); - required_configs.add_options()("recall_at,K", po::value(&K)->required(), - program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); - required_configs.add_options()("search_list,L", - po::value>(&Lvec)->multitoken()->required(), - program_options_utils::SEARCH_LIST_DESCRIPTION); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), - program_options_utils::BEAMWIDTH); - optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), - program_options_utils::NUMBER_OF_NODES_TO_CACHE); - optional_configs.add_options()( - "search_io_limit", - po::value(&search_io_limit)->default_value(std::numeric_limits::max()), - "Max #IOs for search. Default value: uint32::max()"); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false), - "Include full precision data in the index. Use only in " - "conjuction with compressed data on SSD. Default value: false"); - optional_configs.add_options()("filter_label", - po::value(&filter_label)->default_value(std::string("")), - program_options_utils::FILTER_LABEL_DESCRIPTION); - optional_configs.add_options()("query_filters_file", - po::value(&query_filters_file)->default_value(std::string("")), - program_options_utils::FILTERS_FILE_DESCRIPTION); - optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - optional_configs.add_options()("fail_if_recall_below", - po::value(&fail_if_recall_below)->default_value(0.0f), - program_options_utils::FAIL_IF_RECALL_BELOW); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - if (vm["use_reorder_data"].as()) - use_reorder_data = true; - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("mips")) - { - metric = diskann::Metric::INNER_PRODUCT; - } - else if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; +int main(int argc, char **argv) { + std::string data_type, dist_fn, index_path_prefix, result_path_prefix, + query_file, gt_file, filter_label, label_type, query_filters_file; + uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + std::vector Lvec; + bool use_reorder_data = false; + float fail_if_recall_below = 0.0f; + + po::options_description desc{program_options_utils::make_program_description( + "search_disk_index", "Searches on-disk DiskANN indexes")}; + try { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "result_path", po::value(&result_path_prefix)->required(), + program_options_utils::RESULT_PATH_DESCRIPTION); + required_configs.add_options()( + "query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()( + "recall_at,K", po::value(&K)->required(), + program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); + required_configs.add_options()( + "search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "gt_file", + po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("beamwidth,W", + po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + optional_configs.add_options()( + "num_nodes_to_cache", + po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()( + "search_io_limit", + po::value(&search_io_limit) + ->default_value(std::numeric_limits::max()), + "Max #IOs for search. Default value: uint32::max()"); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()( + "use_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD. Default value: false"); + optional_configs.add_options()( + "filter_label", + po::value(&filter_label)->default_value(std::string("")), + program_options_utils::FILTER_LABEL_DESCRIPTION); + optional_configs.add_options()( + "query_filters_file", + po::value(&query_filters_file) + ->default_value(std::string("")), + program_options_utils::FILTERS_FILE_DESCRIPTION); + optional_configs.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()( + "fail_if_recall_below", + po::value(&fail_if_recall_below)->default_value(0.0f), + program_options_utils::FAIL_IF_RECALL_BELOW); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else - { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." + po::notify(vm); + if (vm["use_reorder_data"].as()) + use_reorder_data = true; + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("mips")) { + metric = diskann::Metric::INNER_PRODUCT; + } else if (dist_fn == std::string("l2")) { + metric = diskann::Metric::L2; + } else if (dist_fn == std::string("cosine")) { + metric = diskann::Metric::COSINE; + } else { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." + << std::endl; + return -1; + } + + if ((data_type != std::string("float")) && + (metric == diskann::Metric::INNER_PRODUCT)) { + std::cout << "Currently support only floating point data for Inner Product." + << std::endl; + return -1; + } + + if (use_reorder_data && data_type != std::string("float")) { + std::cout << "Error: Reorder data for reordering currently only " + "supported for float data type." + << std::endl; + return -1; + } + + if (filter_label != "" && query_filters_file != "") { + std::cerr + << "Only one of filter_label and query_filters_file should be provided" + << std::endl; + return -1; + } + + std::vector query_filters; + if (filter_label != "") { + query_filters.push_back(filter_label); + } else if (query_filters_file != "") { + query_filters = read_file_to_vector_of_strings(query_filters_file); + } + + try { + if (!query_filters.empty() && label_type == "ushort") { + if (data_type == std::string("float")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("int8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("uint8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else { + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; return -1; - } - - if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) - { - std::cout << "Currently support only floating point data for Inner Product." << std::endl; - return -1; - } - - if (use_reorder_data && data_type != std::string("float")) - { - std::cout << "Error: Reorder data for reordering currently only " - "supported for float data type." + } + } else { + if (data_type == std::string("float")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("int8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("uint8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else { + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; return -1; + } } - - if (filter_label != "" && query_filters_file != "") - { - std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; - return -1; - } - - std::vector query_filters; - if (filter_label != "") - { - query_filters.push_back(filter_label); - } - else if (query_filters_file != "") - { - query_filters = read_file_to_vector_of_strings(query_filters_file); - } - - try - { - if (!query_filters.empty() && label_type == "ushort") - { - if (data_type == std::string("float")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("int8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("uint8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); - else - { - std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; - return -1; - } - } - else - { - if (data_type == std::string("float")) - return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("int8")) - return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("uint8")) - return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else - { - std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; - return -1; - } - } - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index search failed." << std::endl; - return -1; - } + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index search failed." << std::endl; + return -1; + } } \ No newline at end of file diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..64c6ac26f 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include -#include #include #include #include #include -#include #ifndef _WINDOWS #include @@ -18,460 +18,454 @@ #endif #include "index.h" +#include "index_factory.h" #include "memory_mapper.h" -#include "utils.h" #include "program_options_utils.hpp" -#include "index_factory.h" +#include "utils.h" namespace po = boost::program_options; template -int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix, - const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, - const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, - const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, const float fail_if_recall_below) -{ - using TagT = uint32_t; - // Load the query file - T *query = nullptr; - uint32_t *gt_ids = nullptr; - float *gt_dists = nullptr; - size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; - diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); - - bool calc_recall_flag = false; - if (truthset_file != std::string("null") && file_exists(truthset_file)) - { - diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim); - if (gt_num != query_num) - { - std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; - } - calc_recall_flag = true; - } - else - { - diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl; - } - - bool filtered_search = false; - if (!query_filters.empty()) - { - filtered_search = true; - if (query_filters.size() != 1 && query_filters.size() != query_num) - { - std::cout << "Error. Mismatch in number of queries and size of query " - "filters file" - << std::endl; - return -1; // To return -1 or some other error handling? - } - } - - const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); - - auto config = diskann::IndexConfigBuilder() - .with_metric(metric) - .with_dimension(query_dim) - .with_max_points(0) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .with_data_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_tag_type(diskann_type_to_name()) - .is_dynamic_index(dynamic) - .is_enable_tags(tags) - .is_concurrent_consolidate(false) - .is_pq_dist_build(false) - .is_use_opq(false) - .with_num_pq_chunks(0) - .with_num_frozen_pts(num_frozen_pts) - .build(); - - auto index_factory = diskann::IndexFactory(config); - auto index = index_factory.create_instance(); - index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); - std::cout << "Index loaded" << std::endl; - - if (metric == diskann::FAST_L2) - index->optimize_index_layout(); - - std::cout << "Using " << num_threads << " threads to search" << std::endl; - std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - std::cout.precision(2); - const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS"; - uint32_t table_width = 0; - if (tags) - { - std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)" - << std::setw(15) << "99.9 Latency"; - table_width += 4 + 12 + 20 + 15; +int search_memory_index(diskann::Metric &metric, const std::string &index_path, + const std::string &result_path_prefix, + const std::string &query_file, + const std::string &truthset_file, + const uint32_t num_threads, const uint32_t recall_at, + const bool print_all_recalls, + const std::vector &Lvec, const bool dynamic, + const bool tags, const bool show_qps_per_thread, + const std::vector &query_filters, + const float fail_if_recall_below) { + using TagT = uint32_t; + // Load the query file + T *query = nullptr; + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; + diskann::load_aligned_bin(query_file, query, query_num, query_dim, + query_aligned_dim); + + bool calc_recall_flag = false; + if (truthset_file != std::string("null") && file_exists(truthset_file)) { + diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim); + if (gt_num != query_num) { + std::cout << "Error. Mismatch in number of queries and ground truth data" + << std::endl; } - else - { - std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps" - << std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency"; - table_width += 4 + 12 + 18 + 20 + 15; + calc_recall_flag = true; + } else { + diskann::cout << " Truthset file " << truthset_file + << " not found. Not computing recall." << std::endl; + } + + bool filtered_search = false; + if (!query_filters.empty()) { + filtered_search = true; + if (query_filters.size() != 1 && query_filters.size() != query_num) { + std::cout << "Error. Mismatch in number of queries and size of query " + "filters file" + << std::endl; + return -1; // To return -1 or some other error handling? } - uint32_t recalls_to_print = 0; - const uint32_t first_recall = print_all_recalls ? 1 : recall_at; - if (calc_recall_flag) - { - for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) - { - std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall)); - } - recalls_to_print = recall_at + 1 - first_recall; - table_width += recalls_to_print * 12; + } + + const size_t num_frozen_pts = + diskann::get_graph_num_frozen_points(index_path); + + auto config = + diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(query_dim) + .with_max_points(0) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .with_data_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_tag_type(diskann_type_to_name()) + .is_dynamic_index(dynamic) + .is_enable_tags(tags) + .is_concurrent_consolidate(false) + .is_pq_dist_build(false) + .is_use_opq(false) + .with_num_pq_chunks(0) + .with_num_frozen_pts(num_frozen_pts) + .build(); + + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->load(index_path.c_str(), num_threads, + *(std::max_element(Lvec.begin(), Lvec.end()))); + std::cout << "Index loaded" << std::endl; + + if (metric == diskann::FAST_L2) + index->optimize_index_layout(); + + std::cout << "Using " << num_threads << " threads to search" << std::endl; + std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + std::cout.precision(2); + const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS"; + uint32_t table_width = 0; + if (tags) { + std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title + << std::setw(20) << "Mean Latency (mus)" << std::setw(15) + << "99.9 Latency"; + table_width += 4 + 12 + 20 + 15; + } else { + std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title + << std::setw(18) << "Avg dist cmps" << std::setw(20) + << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency"; + table_width += 4 + 12 + 18 + 20 + 15; + } + uint32_t recalls_to_print = 0; + const uint32_t first_recall = print_all_recalls ? 1 : recall_at; + if (calc_recall_flag) { + for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; + curr_recall++) { + std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall)); } - std::cout << std::endl; - std::cout << std::string(table_width, '=') << std::endl; - - std::vector> query_result_ids(Lvec.size()); - std::vector> query_result_dists(Lvec.size()); - std::vector latency_stats(query_num, 0); - std::vector cmp_stats; - if (not tags || filtered_search) - { - cmp_stats = std::vector(query_num, 0); + recalls_to_print = recall_at + 1 - first_recall; + table_width += recalls_to_print * 12; + } + std::cout << std::endl; + std::cout << std::string(table_width, '=') << std::endl; + + std::vector> query_result_ids(Lvec.size()); + std::vector> query_result_dists(Lvec.size()); + std::vector latency_stats(query_num, 0); + std::vector cmp_stats; + if (not tags || filtered_search) { + cmp_stats = std::vector(query_num, 0); + } + + std::vector query_result_tags; + if (tags) { + query_result_tags.resize(recall_at * query_num); + } + + double best_recall = 0.0; + + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { + uint32_t L = Lvec[test_id]; + if (L < recall_at) { + diskann::cout << "Ignoring search with L:" << L + << " since it's smaller than K:" << recall_at << std::endl; + continue; } - std::vector query_result_tags; - if (tags) - { - query_result_tags.resize(recall_at * query_num); - } - - double best_recall = 0.0; - - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) - { - uint32_t L = Lvec[test_id]; - if (L < recall_at) - { - diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; - continue; - } - - query_result_ids[test_id].resize(recall_at * query_num); - query_result_dists[test_id].resize(recall_at * query_num); - std::vector res = std::vector(); + query_result_ids[test_id].resize(recall_at * query_num); + query_result_dists[test_id].resize(recall_at * query_num); + std::vector res = std::vector(); - auto s = std::chrono::high_resolution_clock::now(); - omp_set_num_threads(num_threads); + auto s = std::chrono::high_resolution_clock::now(); + omp_set_num_threads(num_threads); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) - { - auto qs = std::chrono::high_resolution_clock::now(); - if (filtered_search && !tags) - { - std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, - query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at); - cmp_stats[i] = retval.second; - } - else if (metric == diskann::FAST_L2) - { - index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at); - } - else if (tags) - { - if (!filtered_search) - { - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); - } - else - { - std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter); - } - - for (int64_t r = 0; r < (int64_t)recall_at; r++) - { - query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; - } - } - else - { - cmp_stats[i] = index - ->search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) - .second; - } - auto qe = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = qe - qs; - latency_stats[i] = (float)(diff.count() * 1000000); - } - std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; - - double displayed_qps = query_num / diff.count(); - - if (show_qps_per_thread) - displayed_qps /= num_threads; - - std::vector recalls; - if (calc_recall_flag) - { - recalls.reserve(recalls_to_print); - for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) - { - recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, curr_recall)); - } + for (int64_t i = 0; i < (int64_t)query_num; i++) { + auto qs = std::chrono::high_resolution_clock::now(); + if (filtered_search && !tags) { + std::string raw_filter = + query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + auto retval = index->search_with_filters( + query + i * query_aligned_dim, raw_filter, recall_at, L, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at); + cmp_stats[i] = retval.second; + } else if (metric == diskann::FAST_L2) { + index->search_with_optimized_layout( + query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at); + } else if (tags) { + if (!filtered_search) { + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, + nullptr, res); + } else { + std::string raw_filter = + query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, + nullptr, res, true, raw_filter); } - std::sort(latency_stats.begin(), latency_stats.end()); - double mean_latency = - std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast(query_num); - - float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num; - - if (tags && !filtered_search) - { - std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency - << std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)]; + for (int64_t r = 0; r < (int64_t)recall_at; r++) { + query_result_ids[test_id][recall_at * i + r] = + query_result_tags[recall_at * i + r]; } - else - { - std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps - << std::setw(20) << (float)mean_latency << std::setw(15) - << (float)latency_stats[(uint64_t)(0.999 * query_num)]; - } - for (double recall : recalls) - { - std::cout << std::setw(12) << recall; - best_recall = std::max(recall, best_recall); - } - std::cout << std::endl; + } else { + cmp_stats[i] = + index + ->search(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at) + .second; + } + auto qe = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = qe - qs; + latency_stats[i] = (float)(diff.count() * 1000000); } - - std::cout << "Done searching. Now saving results " << std::endl; - uint64_t test_id = 0; - for (auto L : Lvec) - { - if (L < recall_at) - { - diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; - continue; - } - std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L); - - std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin"; - diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); - - cur_result_path = cur_result_path_prefix + "_dists_float.bin"; - diskann::save_bin(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at); - - test_id++; - } - - diskann::aligned_free(query); - return best_recall >= fail_if_recall_below ? 0 : -1; -} - -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, - query_filters_file; - uint32_t num_threads, K; - std::vector Lvec; - bool print_all_recalls, dynamic, tags, show_qps_per_thread; - float fail_if_recall_below = 0.0f; - - po::options_description desc{ - program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")}; - try - { - desc.add_options()("help,h", "Print this information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("result_path", po::value(&result_path)->required(), - program_options_utils::RESULT_PATH_DESCRIPTION); - required_configs.add_options()("query_file", po::value(&query_file)->required(), - program_options_utils::QUERY_FILE_DESCRIPTION); - required_configs.add_options()("recall_at,K", po::value(&K)->required(), - program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); - required_configs.add_options()("search_list,L", - po::value>(&Lvec)->multitoken()->required(), - program_options_utils::SEARCH_LIST_DESCRIPTION); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("filter_label", - po::value(&filter_label)->default_value(std::string("")), - program_options_utils::FILTER_LABEL_DESCRIPTION); - optional_configs.add_options()("query_filters_file", - po::value(&query_filters_file)->default_value(std::string("")), - program_options_utils::FILTERS_FILE_DESCRIPTION); - optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "dynamic", po::value(&dynamic)->default_value(false), - "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); - optional_configs.add_options()("tags", po::value(&tags)->default_value(false), - "Whether to search with external identifiers (tags). Default false."); - optional_configs.add_options()("fail_if_recall_below", - po::value(&fail_if_recall_below)->default_value(0.0f), - program_options_utils::FAIL_IF_RECALL_BELOW); - - // Output controls - po::options_description output_controls("Output controls"); - output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls), - "Print recalls at all positions, from 1 up to specified " - "recall_at value"); - output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread), - "Print overall QPS divided by the number of threads in " - "the output table"); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs).add(output_controls); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; + std::chrono::duration diff = + std::chrono::high_resolution_clock::now() - s; + + double displayed_qps = query_num / diff.count(); + + if (show_qps_per_thread) + displayed_qps /= num_threads; + + std::vector recalls; + if (calc_recall_flag) { + recalls.reserve(recalls_to_print); + for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; + curr_recall++) { + recalls.push_back(diskann::calculate_recall( + (uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, + query_result_ids[test_id].data(), recall_at, curr_recall)); + } } - diskann::Metric metric; - if ((dist_fn == std::string("mips")) && (data_type == std::string("float"))) - { - metric = diskann::Metric::INNER_PRODUCT; + std::sort(latency_stats.begin(), latency_stats.end()); + double mean_latency = + std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / + static_cast(query_num); + + float avg_cmps = + (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / + (float)query_num; + + if (tags && !filtered_search) { + std::cout << std::setw(4) << L << std::setw(12) << displayed_qps + << std::setw(20) << (float)mean_latency << std::setw(15) + << (float)latency_stats[(uint64_t)(0.999 * query_num)]; + } else { + std::cout << std::setw(4) << L << std::setw(12) << displayed_qps + << std::setw(18) << avg_cmps << std::setw(20) + << (float)mean_latency << std::setw(15) + << (float)latency_stats[(uint64_t)(0.999 * query_num)]; } - else if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; + for (double recall : recalls) { + std::cout << std::setw(12) << recall; + best_recall = std::max(recall, best_recall); } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float"))) - { - metric = diskann::Metric::FAST_L2; - } - else - { - std::cout << "Unsupported distance function. Currently only l2/ cosine are " - "supported in general, and mips/fast_l2 only for floating " - "point data." - << std::endl; - return -1; + std::cout << std::endl; + } + + std::cout << "Done searching. Now saving results " << std::endl; + uint64_t test_id = 0; + for (auto L : Lvec) { + if (L < recall_at) { + diskann::cout << "Ignoring search with L:" << L + << " since it's smaller than K:" << recall_at << std::endl; + continue; } + std::string cur_result_path_prefix = + result_path_prefix + "_" + std::to_string(L); - if (dynamic && not tags) - { - std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl; - return -1; - } + std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin"; + diskann::save_bin(cur_result_path, + query_result_ids[test_id].data(), query_num, + recall_at); - if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0) - { - std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl; - return -1; - } + cur_result_path = cur_result_path_prefix + "_dists_float.bin"; + diskann::save_bin(cur_result_path, + query_result_dists[test_id].data(), query_num, + recall_at); - if (filter_label != "" && query_filters_file != "") - { - std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; - return -1; - } + test_id++; + } - std::vector query_filters; - if (filter_label != "") - { - query_filters.push_back(filter_label); - } - else if (query_filters_file != "") - { - query_filters = read_file_to_vector_of_strings(query_filters_file); - } + diskann::aligned_free(query); + return best_recall >= fail_if_recall_below ? 0 : -1; +} - try - { - if (!query_filters.empty() && label_type == "ushort") - { - if (data_type == std::string("int8")) - { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); - } - else if (data_type == std::string("uint8")) - { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); - } - else if (data_type == std::string("float")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } - else - { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; - return -1; - } - } - else - { - if (data_type == std::string("int8")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } - else if (data_type == std::string("uint8")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } - else if (data_type == std::string("float")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } - else - { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; - return -1; - } - } +int main(int argc, char **argv) { + std::string data_type, dist_fn, index_path_prefix, result_path, query_file, + gt_file, filter_label, label_type, query_filters_file; + uint32_t num_threads, K; + std::vector Lvec; + bool print_all_recalls, dynamic, tags, show_qps_per_thread; + float fail_if_recall_below = 0.0f; + + po::options_description desc{program_options_utils::make_program_description( + "search_memory_index", "Searches in-memory DiskANN indexes")}; + try { + desc.add_options()("help,h", "Print this information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "result_path", po::value(&result_path)->required(), + program_options_utils::RESULT_PATH_DESCRIPTION); + required_configs.add_options()( + "query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()( + "recall_at,K", po::value(&K)->required(), + program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); + required_configs.add_options()( + "search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "filter_label", + po::value(&filter_label)->default_value(std::string("")), + program_options_utils::FILTER_LABEL_DESCRIPTION); + optional_configs.add_options()( + "query_filters_file", + po::value(&query_filters_file) + ->default_value(std::string("")), + program_options_utils::FILTERS_FILE_DESCRIPTION); + optional_configs.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()( + "gt_file", + po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()( + "dynamic", po::value(&dynamic)->default_value(false), + "Whether the index is dynamic. Dynamic indices must have associated " + "tags. Default false."); + optional_configs.add_options()( + "tags", po::value(&tags)->default_value(false), + "Whether to search with external identifiers (tags). Default false."); + optional_configs.add_options()( + "fail_if_recall_below", + po::value(&fail_if_recall_below)->default_value(0.0f), + program_options_utils::FAIL_IF_RECALL_BELOW); + + // Output controls + po::options_description output_controls("Output controls"); + output_controls.add_options()( + "print_all_recalls", po::bool_switch(&print_all_recalls), + "Print recalls at all positions, from 1 up to specified " + "recall_at value"); + output_controls.add_options()( + "print_qps_per_thread", po::bool_switch(&show_qps_per_thread), + "Print overall QPS divided by the number of threads in " + "the output table"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs).add(output_controls); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index search failed." << std::endl; + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if ((dist_fn == std::string("mips")) && (data_type == std::string("float"))) { + metric = diskann::Metric::INNER_PRODUCT; + } else if (dist_fn == std::string("l2")) { + metric = diskann::Metric::L2; + } else if (dist_fn == std::string("cosine")) { + metric = diskann::Metric::COSINE; + } else if ((dist_fn == std::string("fast_l2")) && + (data_type == std::string("float"))) { + metric = diskann::Metric::FAST_L2; + } else { + std::cout << "Unsupported distance function. Currently only l2/ cosine are " + "supported in general, and mips/fast_l2 only for floating " + "point data." + << std::endl; + return -1; + } + + if (dynamic && not tags) { + std::cerr + << "Tags must be enabled while searching dynamically built indices" + << std::endl; + return -1; + } + + if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0) { + std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" + << std::endl; + return -1; + } + + if (filter_label != "" && query_filters_file != "") { + std::cerr + << "Only one of filter_label and query_filters_file should be provided" + << std::endl; + return -1; + } + + std::vector query_filters; + if (filter_label != "") { + query_filters.push_back(filter_label); + } else if (query_filters_file != "") { + query_filters = read_file_to_vector_of_strings(query_filters_file); + } + + try { + if (!query_filters.empty() && label_type == "ushort") { + if (data_type == std::string("int8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } else if (data_type == std::string("uint8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } else if (data_type == std::string("float")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } else { + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + return -1; + } + } else { + if (data_type == std::string("int8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } else if (data_type == std::string("uint8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } else if (data_type == std::string("float")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } else { + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; return -1; + } } + } catch (std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index search failed." << std::endl; + return -1; + } } diff --git a/apps/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp index 97aed1864..047a677c9 100644 --- a/apps/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -1,19 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include #include #include #include #include -#include -#include -#include "utils.h" #include "filter_utils.h" -#include "program_options_utils.hpp" #include "index_factory.h" +#include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include @@ -28,509 +28,558 @@ namespace po = boost::program_options; // load_aligned_bin modified to read pieces of the file, but using ifstream // instead of cached_ifstream. template -inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read) -{ - diskann::Timer timer; - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(bin_file, std::ios::binary | std::ios::ate); - size_t actual_file_size = reader.tellg(); - reader.seekg(0, std::ios::beg); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - size_t npts = (uint32_t)npts_i32; - size_t dim = (uint32_t)dim_i32; - - size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); - if (actual_file_size != expected_actual_file_size) - { - std::stringstream stream; - stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is " - << expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of = " << sizeof(T) - << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (offset_points + points_to_read > npts) - { - std::stringstream stream; - stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read - << " points, but have only " << npts << " points" << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - - reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); - - const size_t rounded_dim = ROUND_UP(dim, 8); - - for (size_t i = 0; i < points_to_read; i++) - { - reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); - memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); - } - reader.close(); - - const double elapsedSeconds = timer.elapsed() / 1000000.0; - std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl; +inline void load_aligned_bin_part(const std::string &bin_file, T *data, + size_t offset_points, size_t points_to_read) { + diskann::Timer timer; + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(bin_file, std::ios::binary | std::ios::ate); + size_t actual_file_size = reader.tellg(); + reader.seekg(0, std::ios::beg); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + size_t npts = (uint32_t)npts_i32; + size_t dim = (uint32_t)dim_i32; + + size_t expected_actual_file_size = + npts * dim * sizeof(T) + 2 * sizeof(uint32_t); + if (actual_file_size != expected_actual_file_size) { + std::stringstream stream; + stream << "Error. File size mismatch. Actual size is " << actual_file_size + << " while expected size is " << expected_actual_file_size + << " npts = " << npts << " dim = " << dim + << " size of = " << sizeof(T) << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + if (offset_points + points_to_read > npts) { + std::stringstream stream; + stream << "Error. Not enough points in file. Requested " << offset_points + << " offset and " << points_to_read << " points, but have only " + << npts << " points" << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); + + const size_t rounded_dim = ROUND_UP(dim, 8); + + for (size_t i = 0; i < points_to_read; i++) { + reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); + memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); + } + reader.close(); + + const double elapsedSeconds = timer.elapsed() / 1000000.0; + std::cout << "Read " << points_to_read << " points using non-cached reads in " + << elapsedSeconds << std::endl; } -std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted, - size_t last_point_threshold) -{ - std::string final_path = save_path; - if (points_to_skip > 0) - { - final_path += "skip" + std::to_string(points_to_skip) + "-"; - } - - final_path += "del" + std::to_string(points_deleted) + "-"; - final_path += std::to_string(last_point_threshold); - return final_path; +std::string get_save_filename(const std::string &save_path, + size_t points_to_skip, size_t points_deleted, + size_t last_point_threshold) { + std::string final_path = save_path; + if (points_to_skip > 0) { + final_path += "skip" + std::to_string(points_to_skip) + "-"; + } + + final_path += "del" + std::to_string(points_deleted) + "-"; + final_path += std::to_string(last_point_threshold); + return final_path; } template -void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data, - size_t aligned_dim, std::vector> &location_to_labels) -{ - diskann::Timer insert_timer; +void insert_till_next_checkpoint( + diskann::AbstractIndex &index, size_t start, size_t end, + int32_t thread_count, T *data, size_t aligned_dim, + std::vector> &location_to_labels) { + diskann::Timer insert_timer; #pragma omp parallel for num_threads(thread_count) schedule(dynamic) - for (int64_t j = start; j < (int64_t)end; j++) - { - if (!location_to_labels.empty()) - { - index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j), - location_to_labels[j - start]); - } - else - { - index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j)); - } + for (int64_t j = start; j < (int64_t)end; j++) { + if (!location_to_labels.empty()) { + index.insert_point(&data[(j - start) * aligned_dim], + 1 + static_cast(j), + location_to_labels[j - start]); + } else { + index.insert_point(&data[(j - start) * aligned_dim], + 1 + static_cast(j)); } - const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; - std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds - << " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n "; + } + const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; + std::cout << "Insertion time " << elapsedSeconds << " seconds (" + << (end - start) / elapsedSeconds << " points/second overall, " + << (end - start) / elapsedSeconds / thread_count + << " per thread)\n "; } template -void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, - size_t points_to_skip, size_t points_to_delete_from_beginning) -{ - try - { - std::cout << std::endl - << "Lazy deleting points " << points_to_skip << " to " - << points_to_skip + points_to_delete_from_beginning << "... "; - for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i) - index.lazy_delete(static_cast(i + 1)); // Since tags are data location + 1 - std::cout << "done." << std::endl; - - auto report = index.consolidate_deletes(delete_params); - std::cout << "#active points: " << report._active_points << std::endl - << "max points: " << report._max_points << std::endl - << "empty slots: " << report._empty_slots << std::endl - << "deletes processed: " << report._slots_released << std::endl - << "latest delete size: " << report._delete_set_size << std::endl - << "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, " - << points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)" - << std::endl; - } - catch (std::system_error &e) - { - std::cout << "Exception caught in deletion thread: " << e.what() << std::endl; - } +void delete_from_beginning(diskann::AbstractIndex &index, + diskann::IndexWriteParameters &delete_params, + size_t points_to_skip, + size_t points_to_delete_from_beginning) { + try { + std::cout << std::endl + << "Lazy deleting points " << points_to_skip << " to " + << points_to_skip + points_to_delete_from_beginning << "... "; + for (size_t i = points_to_skip; + i < points_to_skip + points_to_delete_from_beginning; ++i) + index.lazy_delete( + static_cast(i + 1)); // Since tags are data location + 1 + std::cout << "done." << std::endl; + + auto report = index.consolidate_deletes(delete_params); + std::cout << "#active points: " << report._active_points << std::endl + << "max points: " << report._max_points << std::endl + << "empty slots: " << report._empty_slots << std::endl + << "deletes processed: " << report._slots_released << std::endl + << "latest delete size: " << report._delete_set_size << std::endl + << "rate: (" << points_to_delete_from_beginning / report._time + << " points/second overall, " + << points_to_delete_from_beginning / report._time / + delete_params.num_threads + << " per thread)" << std::endl; + } catch (std::system_error &e) { + std::cout << "Exception caught in deletion thread: " << e.what() + << std::endl; + } } template -void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip, - size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm, - uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot, - const std::string &save_path, size_t points_to_delete_from_beginning, - size_t start_deletes_after, bool concurrent, const std::string &label_file, - const std::string &universal_label) -{ - size_t dim, aligned_dim; - size_t num_points; - diskann::get_bin_metadata(data_path, num_points, dim); - aligned_dim = ROUND_UP(dim, 8); - bool has_labels = label_file != ""; - using TagT = uint32_t; - using LabelT = uint32_t; - - size_t current_point_offset = points_to_skip; - const size_t last_point_threshold = points_to_skip + max_points_to_insert; - - bool enable_tags = true; - using TagT = uint32_t; - auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads); - diskann::IndexConfig index_config = diskann::IndexConfigBuilder() - .with_metric(diskann::L2) - .with_dimension(dim) - .with_max_points(max_points_to_insert) - .is_dynamic_index(true) - .with_index_write_params(params) - .with_index_search_params(index_search_params) - .with_data_type(diskann_type_to_name()) - .with_tag_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .is_enable_tags(enable_tags) - .is_filtered(has_labels) - .with_num_frozen_pts(num_start_pts) - .is_concurrent_consolidate(concurrent) - .build(); - - diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); - auto index = index_factory.create_instance(); - - if (universal_label != "") - { - LabelT u_label = 0; - index->set_universal_label(u_label); - } - - if (points_to_skip > num_points) - { - throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (max_points_to_insert == 0) - { - max_points_to_insert = num_points; - } - - if (points_to_skip + max_points_to_insert > num_points) - { - max_points_to_insert = num_points - points_to_skip; - std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert - << " points since the data file has only that many" << std::endl; +void build_incremental_index( + const std::string &data_path, diskann::IndexWriteParameters ¶ms, + size_t points_to_skip, size_t max_points_to_insert, + size_t beginning_index_size, float start_point_norm, uint32_t num_start_pts, + size_t points_per_checkpoint, size_t checkpoints_per_snapshot, + const std::string &save_path, size_t points_to_delete_from_beginning, + size_t start_deletes_after, bool concurrent, const std::string &label_file, + const std::string &universal_label) { + size_t dim, aligned_dim; + size_t num_points; + diskann::get_bin_metadata(data_path, num_points, dim); + aligned_dim = ROUND_UP(dim, 8); + bool has_labels = label_file != ""; + using TagT = uint32_t; + using LabelT = uint32_t; + + size_t current_point_offset = points_to_skip; + const size_t last_point_threshold = points_to_skip + max_points_to_insert; + + bool enable_tags = true; + using TagT = uint32_t; + auto index_search_params = + diskann::IndexSearchParams(params.search_list_size, params.num_threads); + diskann::IndexConfig index_config = + diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(max_points_to_insert) + .is_dynamic_index(true) + .with_index_write_params(params) + .with_index_search_params(index_search_params) + .with_data_type(diskann_type_to_name()) + .with_tag_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .is_enable_tags(enable_tags) + .is_filtered(has_labels) + .with_num_frozen_pts(num_start_pts) + .is_concurrent_consolidate(concurrent) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + + if (universal_label != "") { + LabelT u_label = 0; + index->set_universal_label(u_label); + } + + if (points_to_skip > num_points) { + throw diskann::ANNException("Asked to skip more points than in data file", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (max_points_to_insert == 0) { + max_points_to_insert = num_points; + } + + if (points_to_skip + max_points_to_insert > num_points) { + max_points_to_insert = num_points - points_to_skip; + std::cerr << "WARNING: Reducing max_points_to_insert to " + << max_points_to_insert + << " points since the data file has only that many" << std::endl; + } + + if (beginning_index_size > max_points_to_insert) { + beginning_index_size = max_points_to_insert; + std::cerr << "WARNING: Reducing beginning index size to " + << beginning_index_size + << " points since the data file has only that many" << std::endl; + } + if (checkpoints_per_snapshot > 0 && + beginning_index_size > points_per_checkpoint) { + beginning_index_size = points_per_checkpoint; + std::cerr << "WARNING: Reducing beginning index size to " + << beginning_index_size << std::endl; + } + + T *data = nullptr; + diskann::alloc_aligned((void **)&data, + std::max(points_per_checkpoint, beginning_index_size) * + aligned_dim * sizeof(T), + 8 * sizeof(T)); + + std::vector tags(beginning_index_size); + std::iota(tags.begin(), tags.end(), + 1 + static_cast(current_point_offset)); + + load_aligned_bin_part(data_path, data, current_point_offset, + beginning_index_size); + std::cout << "load aligned bin succeeded" << std::endl; + diskann::Timer timer; + + if (beginning_index_size > 0) { + index->build(data, beginning_index_size, tags); + } else { + index->set_start_points_at_random(static_cast(start_point_norm)); + } + + const double elapsedSeconds = timer.elapsed() / 1000000.0; + std::cout << "Initial non-incremental index build time for " + << beginning_index_size << " points took " << elapsedSeconds + << " seconds (" << beginning_index_size / elapsedSeconds + << " points/second)\n "; + + current_point_offset += beginning_index_size; + + if (points_to_delete_from_beginning > max_points_to_insert) { + points_to_delete_from_beginning = + static_cast(max_points_to_insert); + std::cerr << "WARNING: Reducing points to delete from beginning to " + << points_to_delete_from_beginning + << " points since the data file has only that many" << std::endl; + } + + std::vector> location_to_labels; + if (concurrent) { + // handle labels + const auto save_path_inc = get_save_filename( + save_path + ".after-concurrent-delete-", points_to_skip, + points_to_delete_from_beginning, last_point_threshold); + std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; + if (has_labels) { + convert_labels_string_to_int(label_file, labels_file_to_use, + mem_labels_int_map_file, universal_label); + auto parse_result = + diskann::parse_formatted_label_file(labels_file_to_use); + location_to_labels = std::get<0>(parse_result); } - if (beginning_index_size > max_points_to_insert) - { - beginning_index_size = max_points_to_insert; - std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size - << " points since the data file has only that many" << std::endl; - } - if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint) - { - beginning_index_size = points_per_checkpoint; - std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl; - } - - T *data = nullptr; - diskann::alloc_aligned( - (void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T)); - - std::vector tags(beginning_index_size); - std::iota(tags.begin(), tags.end(), 1 + static_cast(current_point_offset)); + int32_t sub_threads = (params.num_threads + 1) / 2; + bool delete_launched = false; + std::future delete_task; - load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size); - std::cout << "load aligned bin succeeded" << std::endl; diskann::Timer timer; - if (beginning_index_size > 0) - { - index->build(data, beginning_index_size, tags); + for (size_t start = current_point_offset; start < last_point_threshold; + start += points_per_checkpoint, + current_point_offset += points_per_checkpoint) { + const size_t end = + std::min(start + points_per_checkpoint, last_point_threshold); + std::cout << std::endl + << "Inserting from " << start << " to " << end << std::endl; + + auto insert_task = std::async(std::launch::async, [&]() { + load_aligned_bin_part(data_path, data, start, end - start); + insert_till_next_checkpoint( + *index, start, end, sub_threads, data, aligned_dim, + location_to_labels); + }); + insert_task.wait(); + + if (!delete_launched && end >= start_deletes_after && + end >= points_to_skip + points_to_delete_from_beginning) { + delete_launched = true; + diskann::IndexWriteParameters delete_params = + diskann::IndexWriteParametersBuilder(params) + .with_num_threads(sub_threads) + .build(); + + delete_task = std::async(std::launch::async, [&]() { + delete_from_beginning(*index, delete_params, points_to_skip, + points_to_delete_from_beginning); + }); + } } - else - { - index->set_start_points_at_random(static_cast(start_point_norm)); + delete_task.wait(); + + std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; + index->save(save_path_inc.c_str(), true); + } else { + const auto save_path_inc = get_save_filename( + save_path + ".after-delete-", points_to_skip, + points_to_delete_from_beginning, last_point_threshold); + std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; + if (has_labels) { + convert_labels_string_to_int(label_file, labels_file_to_use, + mem_labels_int_map_file, universal_label); + auto parse_result = + diskann::parse_formatted_label_file(labels_file_to_use); + location_to_labels = std::get<0>(parse_result); } - const double elapsedSeconds = timer.elapsed() / 1000000.0; - std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took " - << elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n "; - - current_point_offset += beginning_index_size; - - if (points_to_delete_from_beginning > max_points_to_insert) - { - points_to_delete_from_beginning = static_cast(max_points_to_insert); - std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning - << " points since the data file has only that many" << std::endl; + size_t last_snapshot_points_threshold = 0; + size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot; + + for (size_t start = current_point_offset; start < last_point_threshold; + start += points_per_checkpoint, + current_point_offset += points_per_checkpoint) { + const size_t end = + std::min(start + points_per_checkpoint, last_point_threshold); + std::cout << std::endl + << "Inserting from " << start << " to " << end << std::endl; + + load_aligned_bin_part(data_path, data, start, end - start); + insert_till_next_checkpoint( + *index, start, end, (int32_t)params.num_threads, data, aligned_dim, + location_to_labels); + + if (checkpoints_per_snapshot > 0 && + --num_checkpoints_till_snapshot == 0) { + diskann::Timer save_timer; + + const auto save_path_inc = + get_save_filename(save_path + ".inc-", points_to_skip, + points_to_delete_from_beginning, end); + index->save(save_path_inc.c_str(), false); + const double elapsedSeconds = save_timer.elapsed() / 1000000.0; + const size_t points_saved = end - points_to_skip; + + std::cout << "Saved " << points_saved << " points in " << elapsedSeconds + << " seconds (" << points_saved / elapsedSeconds + << " points/second)\n"; + + num_checkpoints_till_snapshot = checkpoints_per_snapshot; + last_snapshot_points_threshold = end; + } + + std::cout << "Number of points in the index post insertion " << end + << std::endl; } - std::vector> location_to_labels; - if (concurrent) - { - // handle labels - const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip, - points_to_delete_from_beginning, last_point_threshold); - std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; - if (has_labels) - { - convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); - auto parse_result = diskann::parse_formatted_label_file(labels_file_to_use); - location_to_labels = std::get<0>(parse_result); - } - - int32_t sub_threads = (params.num_threads + 1) / 2; - bool delete_launched = false; - std::future delete_task; - - diskann::Timer timer; - - for (size_t start = current_point_offset; start < last_point_threshold; - start += points_per_checkpoint, current_point_offset += points_per_checkpoint) - { - const size_t end = std::min(start + points_per_checkpoint, last_point_threshold); - std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; - - auto insert_task = std::async(std::launch::async, [&]() { - load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint(*index, start, end, sub_threads, data, aligned_dim, - location_to_labels); - }); - insert_task.wait(); - - if (!delete_launched && end >= start_deletes_after && - end >= points_to_skip + points_to_delete_from_beginning) - { - delete_launched = true; - diskann::IndexWriteParameters delete_params = - diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build(); - - delete_task = std::async(std::launch::async, [&]() { - delete_from_beginning(*index, delete_params, points_to_skip, - points_to_delete_from_beginning); - }); - } - } - delete_task.wait(); - - std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; - index->save(save_path_inc.c_str(), true); - } - else - { - const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip, - points_to_delete_from_beginning, last_point_threshold); - std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; - if (has_labels) - { - convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); - auto parse_result = diskann::parse_formatted_label_file(labels_file_to_use); - location_to_labels = std::get<0>(parse_result); - } - - size_t last_snapshot_points_threshold = 0; - size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot; - - for (size_t start = current_point_offset; start < last_point_threshold; - start += points_per_checkpoint, current_point_offset += points_per_checkpoint) - { - const size_t end = std::min(start + points_per_checkpoint, last_point_threshold); - std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; - - load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint(*index, start, end, (int32_t)params.num_threads, data, - aligned_dim, location_to_labels); - - if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0) - { - diskann::Timer save_timer; - - const auto save_path_inc = - get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end); - index->save(save_path_inc.c_str(), false); - const double elapsedSeconds = save_timer.elapsed() / 1000000.0; - const size_t points_saved = end - points_to_skip; - - std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds (" - << points_saved / elapsedSeconds << " points/second)\n"; - - num_checkpoints_till_snapshot = checkpoints_per_snapshot; - last_snapshot_points_threshold = end; - } - - std::cout << "Number of points in the index post insertion " << end << std::endl; - } - - if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold) - { - const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip, - points_to_delete_from_beginning, last_point_threshold); - // index.save(save_path_inc.c_str(), false); - } - - if (points_to_delete_from_beginning > 0) - { - delete_from_beginning(*index, params, points_to_skip, points_to_delete_from_beginning); - } - - index->save(save_path_inc.c_str(), true); + if (checkpoints_per_snapshot > 0 && + last_snapshot_points_threshold != last_point_threshold) { + const auto save_path_inc = get_save_filename( + save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, + last_point_threshold); + // index.save(save_path_inc.c_str(), false); } - diskann::aligned_free(data); -} - -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, data_path, index_path_prefix; - uint32_t num_threads, R, L, num_start_pts; - float alpha, start_point_norm; - size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot, - points_to_delete_from_beginning, start_deletes_after; - bool concurrent; - - // label options - std::string label_file, label_type, universal_label; - std::uint32_t Lf, unique_labels_supported; - - po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate", - "Test insert deletes & consolidate")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - required_configs.add_options()("points_to_skip", po::value(&points_to_skip)->required(), - "Skip these first set of points from file"); - required_configs.add_options()("beginning_index_size", po::value(&beginning_index_size)->required(), - "Batch build will be called on these set of points"); - required_configs.add_options()("points_per_checkpoint", po::value(&points_per_checkpoint)->required(), - "Insertions are done in batches of points_per_checkpoint"); - required_configs.add_options()("checkpoints_per_snapshot", - po::value(&checkpoints_per_snapshot)->required(), - "Save the index to disk every few checkpoints"); - required_configs.add_options()("points_to_delete_from_beginning", - po::value(&points_to_delete_from_beginning)->required(), ""); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()("max_points_to_insert", - po::value(&max_points_to_insert)->default_value(0), - "These number of points from the file are inserted after " - "points_to_skip"); - optional_configs.add_options()("do_concurrent", po::value(&concurrent)->default_value(false), ""); - optional_configs.add_options()("start_deletes_after", - po::value(&start_deletes_after)->default_value(0), ""); - optional_configs.add_options()("start_point_norm", po::value(&start_point_norm)->default_value(0), - "Set the start point to a random point on a sphere of this radius"); - - // optional params for filters - optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), - "Input label file in txt format for Filtered Index search. " - "The file should contain comma separated filters for each node " - "with each line corresponding to a graph node"); - optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with labels_file"); - optional_configs.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), - "Build complexity for filtered points, higher value " - "results in better graphs"); - optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - optional_configs.add_options()("unique_labels_supported", - po::value(&unique_labels_supported)->default_value(0), - "Number of unique labels supported by the dynamic index."); - - optional_configs.add_options()( - "num_start_points", - po::value(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), - "Set the number of random start (frozen) points to use when " - "inserting and searching"); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - if (beginning_index_size == 0) - if (start_point_norm == 0) - { - std::cout << "When beginning_index_size is 0, use a start " - "point with " - "appropriate norm" - << std::endl; - return -1; - } - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; + if (points_to_delete_from_beginning > 0) { + delete_from_beginning(*index, params, points_to_skip, + points_to_delete_from_beginning); } - bool has_labels = false; - if (!label_file.empty() || label_file != "") - { - has_labels = true; - } + index->save(save_path_inc.c_str(), true); + } - if (num_start_pts < unique_labels_supported) - { - num_start_pts = unique_labels_supported; - } + diskann::aligned_free(data); +} - try - { - diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(500) - .with_alpha(alpha) - .with_num_threads(num_threads) - .with_filter_list_size(Lf) - .build(); - - if (data_type == std::string("int8")) - build_incremental_index( - data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, - num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label); - else if (data_type == std::string("uint8")) - build_incremental_index( - data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, - num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label); - else if (data_type == std::string("float")) - build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, - beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, - checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, - start_deletes_after, concurrent, label_file, universal_label); - else - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; +int main(int argc, char **argv) { + std::string data_type, dist_fn, data_path, index_path_prefix; + uint32_t num_threads, R, L, num_start_pts; + float alpha, start_point_norm; + size_t points_to_skip, max_points_to_insert, beginning_index_size, + points_per_checkpoint, checkpoints_per_snapshot, + points_to_delete_from_beginning, start_deletes_after; + bool concurrent; + + // label options + std::string label_file, label_type, universal_label; + std::uint32_t Lf, unique_labels_supported; + + po::options_description desc{program_options_utils::make_program_description( + "test_insert_deletes_consolidate", "Test insert deletes & consolidate")}; + try { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()( + "points_to_skip", po::value(&points_to_skip)->required(), + "Skip these first set of points from file"); + required_configs.add_options()( + "beginning_index_size", + po::value(&beginning_index_size)->required(), + "Batch build will be called on these set of points"); + required_configs.add_options()( + "points_per_checkpoint", + po::value(&points_per_checkpoint)->required(), + "Insertions are done in batches of points_per_checkpoint"); + required_configs.add_options()( + "checkpoints_per_snapshot", + po::value(&checkpoints_per_snapshot)->required(), + "Save the index to disk every few checkpoints"); + required_configs.add_options()( + "points_to_delete_from_beginning", + po::value(&points_to_delete_from_beginning)->required(), ""); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", + po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()( + "Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()( + "alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()( + "max_points_to_insert", + po::value(&max_points_to_insert)->default_value(0), + "These number of points from the file are inserted after " + "points_to_skip"); + optional_configs.add_options()( + "do_concurrent", po::value(&concurrent)->default_value(false), + ""); + optional_configs.add_options()( + "start_deletes_after", + po::value(&start_deletes_after)->default_value(0), ""); + optional_configs.add_options()( + "start_point_norm", + po::value(&start_point_norm)->default_value(0), + "Set the start point to a random point on a sphere of this radius"); + + // optional params for filters + optional_configs.add_options()( + "label_file", po::value(&label_file)->default_value(""), + "Input label file in txt format for Filtered Index search. " + "The file should contain comma separated filters for each node " + "with each line corresponding to a graph node"); + optional_configs.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with labels_file"); + optional_configs.add_options()( + "FilteredLbuild,Lf", po::value(&Lf)->default_value(0), + "Build complexity for filtered points, higher value " + "results in better graphs"); + optional_configs.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); + optional_configs.add_options()( + "unique_labels_supported", + po::value(&unique_labels_supported)->default_value(0), + "Number of unique labels supported by the dynamic index."); + + optional_configs.add_options()( + "num_start_points", + po::value(&num_start_pts) + ->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), + "Set the number of random start (frozen) points to use when " + "inserting and searching"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (const std::exception &e) - { - std::cerr << "Caught exception: " << e.what() << std::endl; - exit(-1); - } - catch (...) - { - std::cerr << "Caught unknown exception" << std::endl; - exit(-1); - } - - return 0; + po::notify(vm); + if (beginning_index_size == 0) + if (start_point_norm == 0) { + std::cout << "When beginning_index_size is 0, use a start " + "point with " + "appropriate norm" + << std::endl; + return -1; + } + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + bool has_labels = false; + if (!label_file.empty() || label_file != "") { + has_labels = true; + } + + if (num_start_pts < unique_labels_supported) { + num_start_pts = unique_labels_supported; + } + + try { + diskann::IndexWriteParameters params = + diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(500) + .with_alpha(alpha) + .with_num_threads(num_threads) + .with_filter_list_size(Lf) + .build(); + + if (data_type == std::string("int8")) + build_incremental_index( + data_path, params, points_to_skip, max_points_to_insert, + beginning_index_size, start_point_norm, num_start_pts, + points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, + points_to_delete_from_beginning, start_deletes_after, concurrent, + label_file, universal_label); + else if (data_type == std::string("uint8")) + build_incremental_index( + data_path, params, points_to_skip, max_points_to_insert, + beginning_index_size, start_point_norm, num_start_pts, + points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, + points_to_delete_from_beginning, start_deletes_after, concurrent, + label_file, universal_label); + else if (data_type == std::string("float")) + build_incremental_index( + data_path, params, points_to_skip, max_points_to_insert, + beginning_index_size, start_point_norm, num_start_pts, + points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, + points_to_delete_from_beginning, start_deletes_after, concurrent, + label_file, universal_label); + else + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + } catch (const std::exception &e) { + std::cerr << "Caught exception: " << e.what() << std::endl; + exit(-1); + } catch (...) { + std::cerr << "Caught unknown exception" << std::endl; + exit(-1); + } + + return 0; } diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index 5a43a69f3..b51db73fd 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -1,20 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include +#include #include +#include #include #include #include #include #include -#include -#include -#include -#include -#include "utils.h" #include "filter_utils.h" #include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include @@ -29,495 +29,516 @@ namespace po = boost::program_options; // load_aligned_bin modified to read pieces of the file, but using ifstream // instead of cached_ifstream. template -inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read) -{ - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(bin_file, std::ios::binary | std::ios::ate); - size_t actual_file_size = reader.tellg(); - reader.seekg(0, std::ios::beg); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - size_t npts = (uint32_t)npts_i32; - size_t dim = (uint32_t)dim_i32; - - size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); - if (actual_file_size != expected_actual_file_size) - { - std::stringstream stream; - stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is " - << expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of = " << sizeof(T) - << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (offset_points + points_to_read > npts) - { - std::stringstream stream; - stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read - << " points, but have only " << npts << " points" << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - - reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); - - const size_t rounded_dim = ROUND_UP(dim, 8); - - for (size_t i = 0; i < points_to_read; i++) - { - reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); - memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); - } - reader.close(); +inline void load_aligned_bin_part(const std::string &bin_file, T *data, + size_t offset_points, size_t points_to_read) { + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(bin_file, std::ios::binary | std::ios::ate); + size_t actual_file_size = reader.tellg(); + reader.seekg(0, std::ios::beg); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + size_t npts = (uint32_t)npts_i32; + size_t dim = (uint32_t)dim_i32; + + size_t expected_actual_file_size = + npts * dim * sizeof(T) + 2 * sizeof(uint32_t); + if (actual_file_size != expected_actual_file_size) { + std::stringstream stream; + stream << "Error. File size mismatch. Actual size is " << actual_file_size + << " while expected size is " << expected_actual_file_size + << " npts = " << npts << " dim = " << dim + << " size of = " << sizeof(T) << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + if (offset_points + points_to_read > npts) { + std::stringstream stream; + stream << "Error. Not enough points in file. Requested " << offset_points + << " offset and " << points_to_read << " points, but have only " + << npts << " points" << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); + + const size_t rounded_dim = ROUND_UP(dim, 8); + + for (size_t i = 0; i < points_to_read; i++) { + reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); + memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); + } + reader.close(); } -std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval, - size_t max_points_to_insert) -{ - std::string final_path = save_path; - final_path += "act" + std::to_string(active_window) + "-"; - final_path += "cons" + std::to_string(consolidate_interval) + "-"; - final_path += "max" + std::to_string(max_points_to_insert); - return final_path; +std::string get_save_filename(const std::string &save_path, + size_t active_window, size_t consolidate_interval, + size_t max_points_to_insert) { + std::string final_path = save_path; + final_path += "act" + std::to_string(active_window) + "-"; + final_path += "cons" + std::to_string(consolidate_interval) + "-"; + final_path += "max" + std::to_string(max_points_to_insert); + return final_path; } template -void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data, - size_t aligned_dim, std::vector> &pts_to_labels) -{ - try - { - diskann::Timer insert_timer; - std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; - - size_t num_failed = 0; +void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, + size_t insert_threads, T *data, size_t aligned_dim, + std::vector> &pts_to_labels) { + try { + diskann::Timer insert_timer; + std::cout << std::endl + << "Inserting from " << start << " to " << end << std::endl; + + size_t num_failed = 0; #pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed) - for (int64_t j = start; j < (int64_t)end; j++) - { - int insert_result = -1; - if (pts_to_labels.size() > 0) - { - insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j), - pts_to_labels[j - start]); - } - else - { - insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j)); - } - - if (insert_result != 0) - { - std::cerr << "Insert failed " << j << std::endl; - num_failed++; - } - } - const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; - std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds - << " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)" - << std::endl; - if (num_failed > 0) - std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl; - } - catch (std::system_error &e) - { - std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl; - exit(-1); + for (int64_t j = start; j < (int64_t)end; j++) { + int insert_result = -1; + if (pts_to_labels.size() > 0) { + insert_result = index.insert_point(&data[(j - start) * aligned_dim], + 1 + static_cast(j), + pts_to_labels[j - start]); + } else { + insert_result = index.insert_point(&data[(j - start) * aligned_dim], + 1 + static_cast(j)); + } + + if (insert_result != 0) { + std::cerr << "Insert failed " << j << std::endl; + num_failed++; + } } + const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; + std::cout << "Insertion time " << elapsedSeconds << " seconds (" + << (end - start) / elapsedSeconds << " points/second overall, " + << (end - start) / elapsedSeconds / insert_threads + << " per thread)" << std::endl; + if (num_failed > 0) + std::cout << num_failed << " of " << end - start << "inserts failed" + << std::endl; + } catch (std::system_error &e) { + std::cout << "Exiting after catching exception in insertion task: " + << e.what() << std::endl; + exit(-1); + } } template -void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start, - size_t end) -{ - try - { - std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... "; - for (size_t i = start; i < end; ++i) - index.lazy_delete(static_cast(1 + i)); - std::cout << "lazy delete done." << std::endl; - - auto report = index.consolidate_deletes(delete_params); - while (report._status != diskann::consolidation_report::status_code::SUCCESS) - { - int wait_time = 5; - if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL) - { - diskann::cerr << "Unable to acquire consolidate delete lock after " - << "deleting points " << start << " to " << end << ". Will retry in " << wait_time - << "seconds." << std::endl; - } - else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR) - { - diskann::cerr << "Inconsistent counts in data structure. " - << "Will retry in " << wait_time << "seconds." << std::endl; - } - else - { - std::cerr << "Exiting after unknown error in consolidate delete" << std::endl; - exit(-1); - } - std::this_thread::sleep_for(std::chrono::seconds(wait_time)); - report = index.consolidate_deletes(delete_params); - } - auto points_processed = report._active_points + report._slots_released; - auto deletion_rate = points_processed / report._time; - std::cout << "#active points: " << report._active_points << std::endl - << "max points: " << report._max_points << std::endl - << "empty slots: " << report._empty_slots << std::endl - << "deletes processed: " << report._slots_released << std::endl - << "latest delete size: " << report._delete_set_size << std::endl - << "Deletion rate: " << deletion_rate << "/sec " - << "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl; - } - catch (std::system_error &e) - { - std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl; +void delete_and_consolidate(diskann::AbstractIndex &index, + diskann::IndexWriteParameters &delete_params, + size_t start, size_t end) { + try { + std::cout << std::endl + << "Lazy deleting points " << start << " to " << end << "... "; + for (size_t i = start; i < end; ++i) + index.lazy_delete(static_cast(1 + i)); + std::cout << "lazy delete done." << std::endl; + + auto report = index.consolidate_deletes(delete_params); + while (report._status != + diskann::consolidation_report::status_code::SUCCESS) { + int wait_time = 5; + if (report._status == + diskann::consolidation_report::status_code::LOCK_FAIL) { + diskann::cerr << "Unable to acquire consolidate delete lock after " + << "deleting points " << start << " to " << end + << ". Will retry in " << wait_time << "seconds." + << std::endl; + } else if (report._status == diskann::consolidation_report::status_code:: + INCONSISTENT_COUNT_ERROR) { + diskann::cerr << "Inconsistent counts in data structure. " + << "Will retry in " << wait_time << "seconds." + << std::endl; + } else { + std::cerr << "Exiting after unknown error in consolidate delete" + << std::endl; exit(-1); + } + std::this_thread::sleep_for(std::chrono::seconds(wait_time)); + report = index.consolidate_deletes(delete_params); } + auto points_processed = report._active_points + report._slots_released; + auto deletion_rate = points_processed / report._time; + std::cout << "#active points: " << report._active_points << std::endl + << "max points: " << report._max_points << std::endl + << "empty slots: " << report._empty_slots << std::endl + << "deletes processed: " << report._slots_released << std::endl + << "latest delete size: " << report._delete_set_size << std::endl + << "Deletion rate: " << deletion_rate << "/sec " + << "Deletion rate: " << deletion_rate / delete_params.num_threads + << "/thread/sec " << std::endl; + } catch (std::system_error &e) { + std::cerr << "Exiting after catching exception in deletion task: " + << e.what() << std::endl; + exit(-1); + } } template -void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha, - const uint32_t insert_threads, const uint32_t consolidate_threads, - size_t max_points_to_insert, size_t active_window, size_t consolidate_interval, - const float start_point_norm, uint32_t num_start_pts, const std::string &save_path, - const std::string &label_file, const std::string &universal_label, const uint32_t Lf) -{ - const uint32_t C = 500; - const bool saturate_graph = false; - bool has_labels = label_file != ""; - - diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(C) - .with_alpha(alpha) - .with_saturate_graph(saturate_graph) - .with_num_threads(insert_threads) - .with_filter_list_size(Lf) - .build(); - - auto index_search_params = diskann::IndexSearchParams(L, insert_threads); - diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(C) - .with_alpha(alpha) - .with_saturate_graph(saturate_graph) - .with_num_threads(consolidate_threads) - .with_filter_list_size(Lf) - .build(); - - size_t dim, aligned_dim; - size_t num_points; - - std::vector> pts_to_labels; - - const auto save_path_inc = - get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert); - std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; - if (has_labels) - { - convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); - auto parse_result = diskann::parse_formatted_label_file(labels_file_to_use); - pts_to_labels = std::get<0>(parse_result); - } - - diskann::get_bin_metadata(data_path, num_points, dim); - diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims" - << std::endl; - aligned_dim = ROUND_UP(dim, 8); - auto index_config = diskann::IndexConfigBuilder() - .with_metric(diskann::L2) - .with_dimension(dim) - .with_max_points(active_window + 4 * consolidate_interval) - .is_dynamic_index(true) - .is_enable_tags(true) - .is_use_opq(false) - .is_filtered(has_labels) - .with_num_pq_chunks(0) - .is_pq_dist_build(false) - .with_num_frozen_pts(num_start_pts) - .with_tag_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_data_type(diskann_type_to_name()) - .with_index_write_params(params) - .with_index_search_params(index_search_params) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .build(); - - diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); - auto index = index_factory.create_instance(); - - if (universal_label != "") - { - LabelT u_label = 0; - index->set_universal_label(u_label); - } - - if (max_points_to_insert == 0) - { - max_points_to_insert = num_points; - } - - if (num_points < max_points_to_insert) - throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) + - ") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")", - -1, __FUNCSIG__, __FILE__, __LINE__); - - if (max_points_to_insert < active_window + consolidate_interval) - throw diskann::ANNException("ERROR: max_points_to_insert < " - "active_window + consolidate_interval", - -1, __FUNCSIG__, __FILE__, __LINE__); - - if (consolidate_interval < max_points_to_insert / 1000) - throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__); - - index->set_start_points_at_random(static_cast(start_point_norm)); - - T *data = nullptr; - diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T), - 8 * sizeof(T)); - - std::vector tags(max_points_to_insert); - std::iota(tags.begin(), tags.end(), static_cast(0)); - - diskann::Timer timer; - - std::vector> delete_tasks; - +void build_incremental_index( + const std::string &data_path, const uint32_t L, const uint32_t R, + const float alpha, const uint32_t insert_threads, + const uint32_t consolidate_threads, size_t max_points_to_insert, + size_t active_window, size_t consolidate_interval, + const float start_point_norm, uint32_t num_start_pts, + const std::string &save_path, const std::string &label_file, + const std::string &universal_label, const uint32_t Lf) { + const uint32_t C = 500; + const bool saturate_graph = false; + bool has_labels = label_file != ""; + + diskann::IndexWriteParameters params = + diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(C) + .with_alpha(alpha) + .with_saturate_graph(saturate_graph) + .with_num_threads(insert_threads) + .with_filter_list_size(Lf) + .build(); + + auto index_search_params = diskann::IndexSearchParams(L, insert_threads); + diskann::IndexWriteParameters delete_params = + diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(C) + .with_alpha(alpha) + .with_saturate_graph(saturate_graph) + .with_num_threads(consolidate_threads) + .with_filter_list_size(Lf) + .build(); + + size_t dim, aligned_dim; + size_t num_points; + + std::vector> pts_to_labels; + + const auto save_path_inc = + get_save_filename(save_path + ".after-streaming-", active_window, + consolidate_interval, max_points_to_insert); + std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; + if (has_labels) { + convert_labels_string_to_int(label_file, labels_file_to_use, + mem_labels_int_map_file, universal_label); + auto parse_result = + diskann::parse_formatted_label_file(labels_file_to_use); + pts_to_labels = std::get<0>(parse_result); + } + + diskann::get_bin_metadata(data_path, num_points, dim); + diskann::cout << "metadata: file " << data_path << " has " << num_points + << " points in " << dim << " dims" << std::endl; + aligned_dim = ROUND_UP(dim, 8); + auto index_config = + diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(active_window + 4 * consolidate_interval) + .is_dynamic_index(true) + .is_enable_tags(true) + .is_use_opq(false) + .is_filtered(has_labels) + .with_num_pq_chunks(0) + .is_pq_dist_build(false) + .with_num_frozen_pts(num_start_pts) + .with_tag_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_data_type(diskann_type_to_name()) + .with_index_write_params(params) + .with_index_search_params(index_search_params) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + + if (universal_label != "") { + LabelT u_label = 0; + index->set_universal_label(u_label); + } + + if (max_points_to_insert == 0) { + max_points_to_insert = num_points; + } + + if (num_points < max_points_to_insert) + throw diskann::ANNException(std::string("num_points(") + + std::to_string(num_points) + + ") < max_points_to_insert(" + + std::to_string(max_points_to_insert) + ")", + -1, __FUNCSIG__, __FILE__, __LINE__); + + if (max_points_to_insert < active_window + consolidate_interval) + throw diskann::ANNException("ERROR: max_points_to_insert < " + "active_window + consolidate_interval", + -1, __FUNCSIG__, __FILE__, __LINE__); + + if (consolidate_interval < max_points_to_insert / 1000) + throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, + __FUNCSIG__, __FILE__, __LINE__); + + index->set_start_points_at_random(static_cast(start_point_norm)); + + T *data = nullptr; + diskann::alloc_aligned((void **)&data, + std::max(consolidate_interval, active_window) * + aligned_dim * sizeof(T), + 8 * sizeof(T)); + + std::vector tags(max_points_to_insert); + std::iota(tags.begin(), tags.end(), static_cast(0)); + + diskann::Timer timer; + + std::vector> delete_tasks; + + auto insert_task = std::async(std::launch::async, [&]() { + load_aligned_bin_part(data_path, data, 0, active_window); + insert_next_batch(*index, (size_t)0, active_window, + params.num_threads, data, aligned_dim, + pts_to_labels); + }); + insert_task.wait(); + + for (size_t start = active_window; + start + consolidate_interval <= max_points_to_insert; + start += consolidate_interval) { + auto end = std::min(start + consolidate_interval, max_points_to_insert); auto insert_task = std::async(std::launch::async, [&]() { - load_aligned_bin_part(data_path, data, 0, active_window); - insert_next_batch(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim, - pts_to_labels); + load_aligned_bin_part(data_path, data, start, end - start); + insert_next_batch(*index, start, end, params.num_threads, + data, aligned_dim, pts_to_labels); }); insert_task.wait(); - for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert; - start += consolidate_interval) - { - auto end = std::min(start + consolidate_interval, max_points_to_insert); - auto insert_task = std::async(std::launch::async, [&]() { - load_aligned_bin_part(data_path, data, start, end - start); - insert_next_batch(*index, start, end, params.num_threads, data, aligned_dim, - pts_to_labels); - }); - insert_task.wait(); - - if (delete_tasks.size() > 0) - delete_tasks[delete_tasks.size() - 1].wait(); - if (start >= active_window + consolidate_interval) - { - auto start_del = start - active_window - consolidate_interval; - auto end_del = start - active_window; - - delete_tasks.emplace_back(std::async(std::launch::async, [&]() { - delete_and_consolidate(*index, delete_params, (size_t)start_del, (size_t)end_del); - })); - } - } if (delete_tasks.size() > 0) - delete_tasks[delete_tasks.size() - 1].wait(); - - std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; - - index->save(save_path_inc.c_str(), true); - - diskann::aligned_free(data); -} - -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported; - float alpha, start_point_norm; - size_t max_points_to_insert, active_window, consolidate_interval; - - po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario", - "Test insert deletes & consolidate")}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()("data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - required_configs.add_options()("active_window", po::value(&active_window)->required(), - "Program maintains an index over an active window of " - "this size that slides through the data"); - required_configs.add_options()("consolidate_interval", po::value(&consolidate_interval)->required(), - "The program simultaneously adds this number of points to the " - "right of " - "the window while deleting the same number from the left"); - required_configs.add_options()("start_point_norm", po::value(&start_point_norm)->required(), - "Set the start point to a random point on a sphere of this radius"); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()("insert_threads", - po::value(&insert_threads)->default_value(omp_get_num_procs() / 2), - "Number of threads used for inserting into the index (defaults to " - "omp_get_num_procs()/2)"); - optional_configs.add_options()( - "consolidate_threads", po::value(&consolidate_threads)->default_value(omp_get_num_procs() / 2), - "Number of threads used for consolidating deletes to " - "the index (defaults to omp_get_num_procs()/2)"); - optional_configs.add_options()("max_points_to_insert", - po::value(&max_points_to_insert)->default_value(0), - "The number of points from the file that the program streams " - "over "); - optional_configs.add_options()( - "num_start_points", - po::value(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), - "Set the number of random start (frozen) points to use when " - "inserting and searching"); - - optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), - "Input label file in txt format for Filtered Index search. " - "The file should contain comma separated filters for each node " - "with each line corresponding to a graph node"); - optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with labels_file"); - optional_configs.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), - "Build complexity for filtered points, higher value " - "results in better graphs"); - optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - optional_configs.add_options()("unique_labels_supported", - po::value(&unique_labels_supported)->default_value(0), - "Number of unique labels supported by the dynamic index."); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - // Validate arguments - if (start_point_norm == 0) - { - std::cout << "When beginning_index_size is 0, use a start point with " - "appropriate norm" - << std::endl; - return -1; + delete_tasks[delete_tasks.size() - 1].wait(); + if (start >= active_window + consolidate_interval) { + auto start_del = start - active_window - consolidate_interval; + auto end_del = start - active_window; + + delete_tasks.emplace_back(std::async(std::launch::async, [&]() { + delete_and_consolidate( + *index, delete_params, (size_t)start_del, (size_t)end_del); + })); } + } + if (delete_tasks.size() > 0) + delete_tasks[delete_tasks.size() - 1].wait(); - if (label_type != std::string("ushort") && label_type != std::string("uint")) - { - std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl; - return -1; - } - - if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float")) - { - std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl; - return -1; - } + std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; - // TODO: Are additional distance functions supported? - if (dist_fn != std::string("l2") && dist_fn != std::string("mips")) - { - std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl; - return -1; - } + index->save(save_path_inc.c_str(), true); - if (num_start_pts < unique_labels_supported) - { - num_start_pts = unique_labels_supported; - } + diskann::aligned_free(data); +} - try - { - if (data_type == std::string("uint8")) - { - if (label_type == std::string("ushort")) - { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, - consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - else if (label_type == std::string("uint")) - { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, - consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - } - else if (data_type == std::string("int8")) - { - if (label_type == std::string("ushort")) - { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, - consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - else if (label_type == std::string("uint")) - { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, - consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - } - else if (data_type == std::string("float")) - { - if (label_type == std::string("ushort")) - { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, - consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - else if (label_type == std::string("uint")) - { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, - consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - } +int main(int argc, char **argv) { + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, + universal_label, label_type; + uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, + unique_labels_supported; + float alpha, start_point_norm; + size_t max_points_to_insert, active_window, consolidate_interval; + + po::options_description desc{program_options_utils::make_program_description( + "test_streaming_scenario", "Test insert deletes & consolidate")}; + try { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()( + "data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()( + "dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()( + "index_path_prefix", + po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()( + "data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()( + "active_window", po::value(&active_window)->required(), + "Program maintains an index over an active window of " + "this size that slides through the data"); + required_configs.add_options()( + "consolidate_interval", + po::value(&consolidate_interval)->required(), + "The program simultaneously adds this number of points to the " + "right of " + "the window while deleting the same number from the left"); + required_configs.add_options()( + "start_point_norm", po::value(&start_point_norm)->required(), + "Set the start point to a random point on a sphere of this radius"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("max_degree,R", + po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()( + "Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()( + "alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()( + "insert_threads", + po::value(&insert_threads) + ->default_value(omp_get_num_procs() / 2), + "Number of threads used for inserting into the index (defaults to " + "omp_get_num_procs()/2)"); + optional_configs.add_options()( + "consolidate_threads", + po::value(&consolidate_threads) + ->default_value(omp_get_num_procs() / 2), + "Number of threads used for consolidating deletes to " + "the index (defaults to omp_get_num_procs()/2)"); + optional_configs.add_options()( + "max_points_to_insert", + po::value(&max_points_to_insert)->default_value(0), + "The number of points from the file that the program streams " + "over "); + optional_configs.add_options()( + "num_start_points", + po::value(&num_start_pts) + ->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), + "Set the number of random start (frozen) points to use when " + "inserting and searching"); + + optional_configs.add_options()( + "label_file", po::value(&label_file)->default_value(""), + "Input label file in txt format for Filtered Index search. " + "The file should contain comma separated filters for each node " + "with each line corresponding to a graph node"); + optional_configs.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with labels_file"); + optional_configs.add_options()( + "FilteredLbuild,Lf", po::value(&Lf)->default_value(0), + "Build complexity for filtered points, higher value " + "results in better graphs"); + optional_configs.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); + optional_configs.add_options()( + "unique_labels_supported", + po::value(&unique_labels_supported)->default_value(0), + "Number of unique labels supported by the dynamic index."); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (const std::exception &e) - { - std::cerr << "Caught exception: " << e.what() << std::endl; - exit(-1); + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + // Validate arguments + if (start_point_norm == 0) { + std::cout << "When beginning_index_size is 0, use a start point with " + "appropriate norm" + << std::endl; + return -1; + } + + if (label_type != std::string("ushort") && + label_type != std::string("uint")) { + std::cerr << "Invalid label type. Supported types are uint and ushort" + << std::endl; + return -1; + } + + if (data_type != std::string("int8") && data_type != std::string("uint8") && + data_type != std::string("float")) { + std::cerr << "Invalid data type. Supported types are int8, uint8 and float" + << std::endl; + return -1; + } + + // TODO: Are additional distance functions supported? + if (dist_fn != std::string("l2") && dist_fn != std::string("mips")) { + std::cerr + << "Invalid distance function. Supported functions are l2 and mips" + << std::endl; + return -1; + } + + if (num_start_pts < unique_labels_supported) { + num_start_pts = unique_labels_supported; + } + + try { + if (data_type == std::string("uint8")) { + if (label_type == std::string("ushort")) { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, + max_points_to_insert, active_window, consolidate_interval, + start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } else if (label_type == std::string("uint")) { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, + max_points_to_insert, active_window, consolidate_interval, + start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + } else if (data_type == std::string("int8")) { + if (label_type == std::string("ushort")) { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, + max_points_to_insert, active_window, consolidate_interval, + start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } else if (label_type == std::string("uint")) { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, + max_points_to_insert, active_window, consolidate_interval, + start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + } else if (data_type == std::string("float")) { + if (label_type == std::string("ushort")) { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, + max_points_to_insert, active_window, consolidate_interval, + start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } else if (label_type == std::string("uint")) { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, + max_points_to_insert, active_window, consolidate_interval, + start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } } - catch (...) - { - std::cerr << "Caught unknown exception" << std::endl; - exit(-1); - } - - return 0; + } catch (const std::exception &e) { + std::cerr << "Caught exception: " << e.what() << std::endl; + exit(-1); + } catch (...) { + std::cerr << "Caught unknown exception" << std::endl; + exit(-1); + } + + return 0; } diff --git a/apps/utils/bin_to_fvecs.cpp b/apps/utils/bin_to_fvecs.cpp index e9a6a8ecc..a9b86686c 100644 --- a/apps/utils/bin_to_fvecs.cpp +++ b/apps/utils/bin_to_fvecs.cpp @@ -1,63 +1,61 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "util.h" +#include -void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts, - uint64_t ndims) -{ - writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned))); +void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, + float *write_buf, uint64_t npts, uint64_t ndims) { + writr.write((char *)read_buf, + npts * (ndims * sizeof(float) + sizeof(unsigned))); #pragma omp parallel for - for (uint64_t i = 0; i < npts; i++) - { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float)); - } - readr.read((char *)write_buf, npts * ndims * sizeof(float)); + for (uint64_t i = 0; i < npts; i++) { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, + ndims * sizeof(float)); + } + readr.read((char *)write_buf, npts * ndims * sizeof(float)); } -int main(int argc, char **argv) -{ - if (argc != 3) - { - std::cout << argv[0] << " input_bin output_fvecs" << std::endl; - exit(-1); - } - std::ifstream readr(argv[1], std::ios::binary); - int npts_s32; - int ndims_s32; - readr.read((char *)&npts_s32, sizeof(int32_t)); - readr.read((char *)&ndims_s32, sizeof(int32_t)); - size_t npts = npts_s32; - size_t ndims = ndims_s32; - uint32_t ndims_u32 = (uint32_t)ndims_s32; - // uint64_t fsize = writr.tellg(); - readr.seekg(0, std::ios::beg); - - unsigned ndims_u32; - writr.write((char *)&ndims_u32, sizeof(unsigned)); - writr.seekg(0, std::ios::beg); - uint64_t ndims = (uint64_t)ndims_u32; - uint64_t npts = fsize / ((ndims + 1) * sizeof(float)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - - uint64_t blk_size = 131072; - uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - - std::ofstream writr(argv[2], std::ios::binary); - float *read_buf = new float[npts * (ndims + 1)]; - float *write_buf = new float[npts * ndims]; - for (uint64_t i = 0; i < nblks; i++) - { - uint64_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } - - delete[] read_buf; - delete[] write_buf; - - writr.close(); - readr.close(); +int main(int argc, char **argv) { + if (argc != 3) { + std::cout << argv[0] << " input_bin output_fvecs" << std::endl; + exit(-1); + } + std::ifstream readr(argv[1], std::ios::binary); + int npts_s32; + int ndims_s32; + readr.read((char *)&npts_s32, sizeof(int32_t)); + readr.read((char *)&ndims_s32, sizeof(int32_t)); + size_t npts = npts_s32; + size_t ndims = ndims_s32; + uint32_t ndims_u32 = (uint32_t)ndims_s32; + // uint64_t fsize = writr.tellg(); + readr.seekg(0, std::ios::beg); + + unsigned ndims_u32; + writr.write((char *)&ndims_u32, sizeof(unsigned)); + writr.seekg(0, std::ios::beg); + uint64_t ndims = (uint64_t)ndims_u32; + uint64_t npts = fsize / ((ndims + 1) * sizeof(float)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; + + uint64_t blk_size = 131072; + uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + + std::ofstream writr(argv[2], std::ios::binary); + float *read_buf = new float[npts * (ndims + 1)]; + float *write_buf = new float[npts * ndims]; + for (uint64_t i = 0; i < nblks; i++) { + uint64_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } + + delete[] read_buf; + delete[] write_buf; + + writr.close(); + readr.close(); } diff --git a/apps/utils/bin_to_tsv.cpp b/apps/utils/bin_to_tsv.cpp index 7851bef6d..62ed77e55 100644 --- a/apps/utils/bin_to_tsv.cpp +++ b/apps/utils/bin_to_tsv.cpp @@ -1,69 +1,68 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include template -void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims) -{ - reader.read((char *)read_buf, npts * ndims * sizeof(float)); +void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, + size_t npts, size_t ndims) { + reader.read((char *)read_buf, npts * ndims * sizeof(float)); - for (size_t i = 0; i < npts; i++) - { - for (size_t d = 0; d < ndims; d++) - { - writer << read_buf[d + i * ndims]; - if (d < ndims - 1) - writer << "\t"; - else - writer << "\n"; - } + for (size_t i = 0; i < npts; i++) { + for (size_t d = 0; d < ndims; d++) { + writer << read_buf[d + i * ndims]; + if (d < ndims - 1) + writer << "\t"; + else + writer << "\n"; } + } } -int main(int argc, char **argv) -{ - if (argc != 4) - { - std::cout << argv[0] << " input_bin output_tsv" << std::endl; - exit(-1); - } - std::string type_string(argv[1]); - if ((type_string != std::string("float")) && (type_string != std::string("int8")) && - (type_string != std::string("uin8"))) - { - std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl; - } +int main(int argc, char **argv) { + if (argc != 4) { + std::cout << argv[0] << " input_bin output_tsv" + << std::endl; + exit(-1); + } + std::string type_string(argv[1]); + if ((type_string != std::string("float")) && + (type_string != std::string("int8")) && + (type_string != std::string("uin8"))) { + std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl; + } - std::ifstream reader(argv[2], std::ios::binary); - uint32_t npts_u32; - uint32_t ndims_u32; - reader.read((char *)&npts_u32, sizeof(uint32_t)); - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - size_t npts = npts_u32; - size_t ndims = ndims_u32; - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + std::ifstream reader(argv[2], std::ios::binary); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + size_t npts = npts_u32; + size_t ndims = ndims_u32; + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::ofstream writer(argv[3]); - char *read_buf = new char[blk_size * ndims * 4]; - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (type_string == std::string("float")) - block_convert(writer, reader, (float *)read_buf, cblk_size, ndims); - else if (type_string == std::string("int8")) - block_convert(writer, reader, (int8_t *)read_buf, cblk_size, ndims); - else if (type_string == std::string("uint8")) - block_convert(writer, reader, (uint8_t *)read_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } + std::ofstream writer(argv[3]); + char *read_buf = new char[blk_size * ndims * 4]; + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (type_string == std::string("float")) + block_convert(writer, reader, (float *)read_buf, cblk_size, ndims); + else if (type_string == std::string("int8")) + block_convert(writer, reader, (int8_t *)read_buf, cblk_size, + ndims); + else if (type_string == std::string("uint8")) + block_convert(writer, reader, (uint8_t *)read_buf, cblk_size, + ndims); + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; + delete[] read_buf; - writer.close(); - reader.close(); + writer.close(); + reader.close(); } diff --git a/apps/utils/calculate_recall.cpp b/apps/utils/calculate_recall.cpp index dc76252cc..d329ebba4 100644 --- a/apps/utils/calculate_recall.cpp +++ b/apps/utils/calculate_recall.cpp @@ -9,47 +9,46 @@ #include #include -#include "utils.h" #include "disk_utils.h" +#include "utils.h" -int main(int argc, char **argv) -{ - if (argc != 4) - { - std::cout << argv[0] << " " << std::endl; - return -1; - } - uint32_t *gold_std = NULL; - float *gs_dist = nullptr; - uint32_t *our_results = NULL; - float *or_dist = nullptr; - size_t points_num, points_num_gs, points_num_or; - size_t dim_gs; - size_t dim_or; - diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs); - diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or); +int main(int argc, char **argv) { + if (argc != 4) { + std::cout << argv[0] << " " + << std::endl; + return -1; + } + uint32_t *gold_std = NULL; + float *gs_dist = nullptr; + uint32_t *our_results = NULL; + float *or_dist = nullptr; + size_t points_num, points_num_gs, points_num_or; + size_t dim_gs; + size_t dim_or; + diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs); + diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or); - if (points_num_gs != points_num_or) - { - std::cout << "Error. Number of queries mismatch in ground truth and " - "our results" - << std::endl; - return -1; - } - points_num = points_num_gs; + if (points_num_gs != points_num_or) { + std::cout << "Error. Number of queries mismatch in ground truth and " + "our results" + << std::endl; + return -1; + } + points_num = points_num_gs; - uint32_t recall_at = std::atoi(argv[3]); + uint32_t recall_at = std::atoi(argv[3]); - if ((dim_or < recall_at) || (recall_at > dim_gs)) - { - std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall " - << recall_at << std::endl; - return -1; - } - std::cout << "Calculating recall@" << recall_at << std::endl; - double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs, - our_results, (uint32_t)dim_or, (uint32_t)recall_at); + if ((dim_or < recall_at) || (recall_at > dim_gs)) { + std::cout << "ground truth has size " << dim_gs << "; our set has " + << dim_or << " points. Asking for recall " << recall_at + << std::endl; + return -1; + } + std::cout << "Calculating recall@" << recall_at << std::endl; + double recall_val = diskann::calculate_recall( + (uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs, our_results, + (uint32_t)dim_or, (uint32_t)recall_at); - // double avg_recall = (recall*1.0)/(points_num*1.0); - std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n"; + // double avg_recall = (recall*1.0)/(points_num*1.0); + std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n"; } diff --git a/apps/utils/compute_groundtruth.cpp b/apps/utils/compute_groundtruth.cpp index da32fd7c6..2b644a8ec 100644 --- a/apps/utils/compute_groundtruth.cpp +++ b/apps/utils/compute_groundtruth.cpp @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include #include +#include +#include +#include -#include #include +#include #include #include -#include -#include #include -#include -#include +#include #include -#include -#include +#include +#include +#include #include #include +#include +#include #ifdef _WINDOWS #include @@ -40,535 +40,539 @@ typedef std::string path; namespace po = boost::program_options; -template T div_round_up(const T numerator, const T denominator) -{ - return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); +template T div_round_up(const T numerator, const T denominator) { + return (numerator % denominator == 0) ? (numerator / denominator) + : 1 + (numerator / denominator); } using pairIF = std::pair; -struct cmpmaxstruct -{ - bool operator()(const pairIF &l, const pairIF &r) - { - return l.second < r.second; - }; +struct cmpmaxstruct { + bool operator()(const pairIF &l, const pairIF &r) { + return l.second < r.second; + }; }; -using maxPQIFCS = std::priority_queue, cmpmaxstruct>; +using maxPQIFCS = + std::priority_queue, cmpmaxstruct>; -template T *aligned_malloc(const size_t n, const size_t alignment) -{ +template T *aligned_malloc(const size_t n, const size_t alignment) { #ifdef _WINDOWS - return (T *)_aligned_malloc(sizeof(T) * n, alignment); + return (T *)_aligned_malloc(sizeof(T) * n, alignment); #else - return static_cast(aligned_alloc(alignment, sizeof(T) * n)); + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); #endif } -inline bool custom_dist(const std::pair &a, const std::pair &b) -{ - return a.second < b.second; +inline bool custom_dist(const std::pair &a, + const std::pair &b) { + return a.second < b.second; } -void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) -{ - assert(points_l2sq != NULL); +void compute_l2sq(float *const points_l2sq, const float *const matrix, + const int64_t num_points, const uint64_t dim) { + assert(points_l2sq != NULL); #pragma omp parallel for schedule(static, 65536) - for (int64_t d = 0; d < num_points; ++d) - points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, - matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = + cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); } -void distsq_to_points(const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, - const float *const points_l2sq, // points in Col major - size_t nqueries, const float *const queries, - const float *const queries_l2sq, // queries in Col major - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void distsq_to_points( + const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) - { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, - (float)0.0, dist_matrix, npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, - ones_vec, nqueries, (float)1.0, dist_matrix, npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, - queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, + (float)-2.0, points, dim, queries, dim, (float)0.0, dist_matrix, + npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, + (float)1.0, points_l2sq, npoints, ones_vec, nqueries, (float)1.0, + dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, + (float)1.0, ones_vec, npoints, queries_l2sq, nqueries, (float)1.0, + dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; } -void inner_prod_to_points(const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, size_t nqueries, const float *const queries, - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void inner_prod_to_points( + const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, + const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) - { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, - (float)0.0, dist_matrix, npoints); - - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, + (float)-1.0, points, dim, queries, dim, (float)0.0, dist_matrix, + npoints); + + if (ones_vec_alloc) + delete[] ones_vec; } -void exact_knn(const size_t dim, const size_t k, - size_t *const closest_points, // k * num_queries preallocated, col - // major, queries columns - float *const dist_closest_points, // k * num_queries - // preallocated, Dist to - // corresponding closes_points - size_t npoints, - float *points_in, // points in Col major - size_t nqueries, float *queries_in, - diskann::Metric metric = diskann::Metric::L2) // queries in Col major +void exact_knn( + const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major { - float *points_l2sq = new float[npoints]; - float *queries_l2sq = new float[nqueries]; - compute_l2sq(points_l2sq, points_in, npoints, dim); - compute_l2sq(queries_l2sq, queries_in, nqueries, dim); - - float *points = points_in; - float *queries = queries_in; - - if (metric == diskann::Metric::COSINE) - { // we convert cosine distance as - // normalized L2 distnace - points = new float[npoints * dim]; - queries = new float[nqueries * dim]; + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)npoints; i++) - { - float norm = std::sqrt(points_l2sq[i]); - if (norm == 0) - { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) - { - points[i * dim + j] = points_in[i * dim + j] / norm; - } - } + for (int64_t i = 0; i < (int64_t)npoints; i++) { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)nqueries; i++) - { - float norm = std::sqrt(queries_l2sq[i]); - if (norm == 0) - { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) - { - queries[i * dim + j] = queries_in[i * dim + j] / norm; - } - } - // recalculate norms after normalizing, they should all be one. - compute_l2sq(points_l2sq, points, npoints, dim); - compute_l2sq(queries_l2sq, queries, nqueries, dim); + for (int64_t i = 0; i < (int64_t)nqueries; i++) { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } } - - std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " - << dim << " dimensions using"; - if (metric == diskann::Metric::INNER_PRODUCT) - std::cout << " MIPS "; - else if (metric == diskann::Metric::COSINE) - std::cout << " Cosine "; - else - std::cout << " L2 "; - std::cout << "distance fn. " << std::endl; - - size_t q_batch_size = (1 << 9); - float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; - - for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) - { - int64_t q_b = b * q_batch_size; - int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; - - if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) - { - distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, - queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); - } - else - { - inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, - queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); - } - std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); + } + + std::cout << "Going to compute " << k << " NNs for " << nqueries + << " queries over " << npoints << " points in " << dim + << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) { + int64_t q_b = b * q_batch_size; + int64_t q_e = + ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, + q_e - q_b, queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, + queries_l2sq + q_b); + } else { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" + << std::endl; #pragma omp parallel for schedule(dynamic, 16) - for (long long q = q_b; q < q_e; q++) - { - maxPQIFCS point_dist; - for (size_t p = 0; p < k; p++) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - for (size_t p = k; p < npoints; p++) - { - if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - if (point_dist.size() > k) - point_dist.pop(); - } - for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) - { - closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; - dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; - point_dist.pop(); - } - assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, - dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); - } - std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; + for (long long q = q_b; q < q_e; q++) { + maxPQIFCS point_dist; + for (size_t p = 0; p < k; p++) + point_dist.emplace(p, + dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * + (ptrdiff_t)npoints]); + for (size_t p = k; p < npoints; p++) { + if (point_dist.top().second > + dist_matrix[(ptrdiff_t)p + + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + point_dist.emplace( + p, dist_matrix[(ptrdiff_t)p + + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + if (point_dist.size() > k) + point_dist.pop(); + } + for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) { + closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = + point_dist.top().first; + dist_closest_points[(ptrdiff_t)(k - 1 - l) + + (ptrdiff_t)q * (ptrdiff_t)k] = + point_dist.top().second; + point_dist.pop(); + } + assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, + dist_closest_points + + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e + << ")" << std::endl; + } - delete[] dist_matrix; + delete[] dist_matrix; - delete[] points_l2sq; - delete[] queries_l2sq; + delete[] points_l2sq; + delete[] queries_l2sq; - if (metric == diskann::Metric::COSINE) - { - delete[] points; - delete[] queries; - } + if (metric == diskann::Metric::COSINE) { + delete[] points; + delete[] queries; + } } -template inline int get_num_parts(const char *filename) -{ - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; - reader.close(); - uint32_t num_parts = - (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; - std::cout << "Number of parts: " << num_parts << std::endl; - return num_parts; +template inline int get_num_parts(const char *filename) { + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + uint32_t num_parts = (npts_i32 % PARTSIZE) == 0 + ? npts_i32 / PARTSIZE + : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; } template -inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num) -{ - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts = end_id - start_id; - ndims = (uint64_t)ndims_i32; - std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B" - << std::endl; - - reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); - T *data_T = new T[npts * ndims]; - reader.read((char *)data_T, sizeof(T) * npts * ndims); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - data = aligned_malloc(npts * ndims, ALIGNMENT); +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, + size_t &ndims, int part_num) { + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) << "B" << std::endl; + + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), + std::ios::beg); + T *data_T = new T[npts * ndims]; + reader.read((char *)data_T, sizeof(T) * npts * ndims); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts * ndims, ALIGNMENT); #pragma omp parallel for schedule(dynamic, 32768) - for (int64_t i = 0; i < (int64_t)npts; i++) - { - for (int64_t j = 0; j < (int64_t)ndims; j++) - { - float cur_val_float = (float)data_T[i * ndims + j]; - std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float)); - } + for (int64_t i = 0; i < (int64_t)npts; i++) { + for (int64_t j = 0; j < (int64_t)ndims; j++) { + float cur_val_float = (float)data_T[i * ndims + j]; + std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, + sizeof(float)); } - delete[] data_T; - std::cout << "Finished converting part data to float." << std::endl; + } + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; } -template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) -{ - std::ofstream writer; - writer.exceptions(std::ios::failbit | std::ios::badbit); - writer.open(filename, std::ios::binary | std::ios::out); - std::cout << "Writing bin: " << filename << "\n"; - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "bin: #pts = " << npts << ", #dims = " << ndims - << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(T)); - writer.close(); - std::cout << "Finished writing bin" << std::endl; +template +inline void save_bin(const std::string filename, T *data, size_t npts, + size_t ndims) { + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" + << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; } -inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, - size_t ndims) -{ - std::ofstream writer(filename, std::ios::binary | std::ios::out); - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " - "npts*dim dist-matrix) with npts = " - << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) - << "B" << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(uint32_t)); - writer.write((char *)distances, npts * ndims * sizeof(float)); - writer.close(); - std::cout << "Finished writing truthset" << std::endl; +inline void save_groundtruth_as_one_file(const std::string filename, + int32_t *data, float *distances, + size_t npts, size_t ndims) { + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " + << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) << "B" + << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; } template -std::vector>> processUnfilteredParts(const std::string &base_file, - size_t &nqueries, size_t &npoints, - size_t &dim, size_t &k, float *query_data, - const diskann::Metric &metric, - std::vector &location_to_tag) -{ - float *base_data = nullptr; - int num_parts = get_num_parts(base_file.c_str()); - std::vector>> res(nqueries); - for (int p = 0; p < num_parts; p++) - { - size_t start_id = p * PARTSIZE; - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, - metric); - - for (size_t i = 0; i < nqueries; i++) - { - for (size_t j = 0; j < part_k; j++) - { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - - res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); - } - } - - delete[] closest_points_part; - delete[] dist_closest_points_part; - - diskann::aligned_free(base_data); +std::vector>> +processUnfilteredParts(const std::string &base_file, size_t &nqueries, + size_t &npoints, size_t &dim, size_t &k, + float *query_data, const diskann::Metric &metric, + std::vector &location_to_tag) { + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, + npoints, base_data, nqueries, query_data, metric); + + for (size_t i = 0; i < nqueries; i++) { + for (size_t j = 0; j < part_k; j++) { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair( + (uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } } - return res; -}; -template -int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k, - const diskann::Metric &metric, const std::string &tags_file = std::string("")) -{ - size_t npoints, nqueries, dim; - - float *query_data; - - load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); - if (nqueries > PARTSIZE) - std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE - << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; - - // load tags - const bool tags_enabled = tags_file.empty() ? false : true; - std::vector location_to_tag = diskann::loadTags(tags_file, base_file); - - int *closest_points = new int[nqueries * k]; - float *dist_closest_points = new float[nqueries * k]; - - std::vector>> results = - processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); - - for (size_t i = 0; i < nqueries; i++) - { - std::vector> &cur_res = results[i]; - std::sort(cur_res.begin(), cur_res.end(), custom_dist); - size_t j = 0; - for (auto iter : cur_res) - { - if (j == k) - break; - if (tags_enabled) - { - std::uint32_t index_with_tag = location_to_tag[iter.first]; - closest_points[i * k + j] = (int32_t)index_with_tag; - } - else - { - closest_points[i * k + j] = (int32_t)iter.first; - } - - if (metric == diskann::Metric::INNER_PRODUCT) - dist_closest_points[i * k + j] = -iter.second; - else - dist_closest_points[i * k + j] = iter.second; - - ++j; - } - if (j < k) - std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; - } - - save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); - delete[] closest_points; - delete[] dist_closest_points; - diskann::aligned_free(query_data); - - return 0; -} + delete[] closest_points_part; + delete[] dist_closest_points_part; -void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) -{ - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; - size_t actual_file_size = reader.get_file_size(); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (uint32_t)npts_i32; - dim = (uint32_t)dim_i32; - - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; - - int truthset_type = -1; // 1 means truthset has ids and distances, 2 means - // only ids, -1 is error - size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_with_dists) - truthset_type = 1; - - size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_just_ids) - truthset_type = 2; - - if (truthset_type == -1) - { - std::stringstream stream; - stream << "Error. File size mismatch. File should have bin format, with " - "npts followed by ngt followed by npts*ngt ids and optionally " - "followed by npts*ngt distance values; actual size: " - << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " - << expected_file_size_just_ids; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } + diskann::aligned_free(base_data); + } + return res; +}; - ids = new uint32_t[npts * dim]; - reader.read((char *)ids, npts * dim * sizeof(uint32_t)); +template +int aux_main(const std::string &base_file, const std::string &query_file, + const std::string >_file, size_t k, + const diskann::Metric &metric, + const std::string &tags_file = std::string("")) { + size_t npoints, nqueries, dim; + + float *query_data; + + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries + << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." + << std::endl; - if (truthset_type == 1) - { - dists = new float[npts * dim]; - reader.read((char *)dists, npts * dim * sizeof(float)); + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = + diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector>> results = + processUnfilteredParts(base_file, nqueries, npoints, dim, k, + query_data, metric, location_to_tag); + + for (size_t i = 0; i < nqueries; i++) { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) { + if (j == k) + break; + if (tags_enabled) { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } else { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; } + if (j < k) + std::cout << "WARNING: found less than k GT entries for query " << i + << std::endl; + } + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, + nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; } -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; - uint64_t K; - - try - { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), - "distance function "); - desc.add_options()("base_file", po::value(&base_file)->required(), - "File containing the base vectors in binary format"); - desc.add_options()("query_file", po::value(&query_file)->required(), - "File containing the query vectors in binary format"); - desc.add_options()("gt_file", po::value(>_file)->required(), - "File name for the writing ground truth in binary " - "format, please don' append .bin at end if " - "no filter_label or filter_label_file is provided it " - "will save the file with '.bin' at end." - "else it will save the file as filename_label.bin"); - desc.add_options()("K", po::value(&K)->required(), - "Number of ground truth nearest neighbors to compute"); - desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), - "File containing the tags in binary format"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) - { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; - } - else if (dist_fn == std::string("mips")) - { - metric = diskann::Metric::INNER_PRODUCT; - } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else - { - std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; - return -1; - } +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, + size_t &npts, size_t &dim) { + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." + << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " + << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = + 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = + npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size + << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } +} - try - { - if (data_type == std::string("float")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); - if (data_type == std::string("int8")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); - if (data_type == std::string("uint8")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Compute GT failed." << std::endl; - return -1; +int main(int argc, char **argv) { + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; + uint64_t K; + + try { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), + "distance function "); + desc.add_options()("base_file", + po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", + po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()( + "tags_file", + po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && + data_type != std::string("uint8")) { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." + << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) { + metric = diskann::Metric::L2; + } else if (dist_fn == std::string("mips")) { + metric = diskann::Metric::INNER_PRODUCT; + } else if (dist_fn == std::string("cosine")) { + metric = diskann::Metric::COSINE; + } else { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." + << std::endl; + return -1; + } + + try { + if (data_type == std::string("float")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } } diff --git a/apps/utils/compute_groundtruth_for_filters.cpp b/apps/utils/compute_groundtruth_for_filters.cpp index 52e586475..c6cfe476b 100644 --- a/apps/utils/compute_groundtruth_for_filters.cpp +++ b/apps/utils/compute_groundtruth_for_filters.cpp @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include #include +#include +#include +#include -#include #include +#include #include #include -#include -#include #include -#include -#include +#include #include -#include -#include +#include +#include +#include #include #include +#include +#include #ifdef _WINDOWS #include @@ -41,879 +41,876 @@ typedef std::string path; namespace po = boost::program_options; -template T div_round_up(const T numerator, const T denominator) -{ - return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); +template T div_round_up(const T numerator, const T denominator) { + return (numerator % denominator == 0) ? (numerator / denominator) + : 1 + (numerator / denominator); } using pairIF = std::pair; -struct cmpmaxstruct -{ - bool operator()(const pairIF &l, const pairIF &r) - { - return l.second < r.second; - }; +struct cmpmaxstruct { + bool operator()(const pairIF &l, const pairIF &r) { + return l.second < r.second; + }; }; -using maxPQIFCS = std::priority_queue, cmpmaxstruct>; +using maxPQIFCS = + std::priority_queue, cmpmaxstruct>; -template T *aligned_malloc(const size_t n, const size_t alignment) -{ +template T *aligned_malloc(const size_t n, const size_t alignment) { #ifdef _WINDOWS - return (T *)_aligned_malloc(sizeof(T) * n, alignment); + return (T *)_aligned_malloc(sizeof(T) * n, alignment); #else - return static_cast(aligned_alloc(alignment, sizeof(T) * n)); + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); #endif } -inline bool custom_dist(const std::pair &a, const std::pair &b) -{ - return a.second < b.second; +inline bool custom_dist(const std::pair &a, + const std::pair &b) { + return a.second < b.second; } -void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) -{ - assert(points_l2sq != NULL); +void compute_l2sq(float *const points_l2sq, const float *const matrix, + const int64_t num_points, const uint64_t dim) { + assert(points_l2sq != NULL); #pragma omp parallel for schedule(static, 65536) - for (int64_t d = 0; d < num_points; ++d) - points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, - matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = + cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); } -void distsq_to_points(const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, - const float *const points_l2sq, // points in Col major - size_t nqueries, const float *const queries, - const float *const queries_l2sq, // queries in Col major - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void distsq_to_points( + const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) - { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, - (float)0.0, dist_matrix, npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, - ones_vec, nqueries, (float)1.0, dist_matrix, npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, - queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, + (float)-2.0, points, dim, queries, dim, (float)0.0, dist_matrix, + npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, + (float)1.0, points_l2sq, npoints, ones_vec, nqueries, (float)1.0, + dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, + (float)1.0, ones_vec, npoints, queries_l2sq, nqueries, (float)1.0, + dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; } -void inner_prod_to_points(const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, size_t nqueries, const float *const queries, - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void inner_prod_to_points( + const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, + const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) - { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, - (float)0.0, dist_matrix, npoints); - - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, + (float)-1.0, points, dim, queries, dim, (float)0.0, dist_matrix, + npoints); + + if (ones_vec_alloc) + delete[] ones_vec; } -void exact_knn(const size_t dim, const size_t k, - size_t *const closest_points, // k * num_queries preallocated, col - // major, queries columns - float *const dist_closest_points, // k * num_queries - // preallocated, Dist to - // corresponding closes_points - size_t npoints, - float *points_in, // points in Col major - size_t nqueries, float *queries_in, - diskann::Metric metric = diskann::Metric::L2) // queries in Col major +void exact_knn( + const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major { - float *points_l2sq = new float[npoints]; - float *queries_l2sq = new float[nqueries]; - compute_l2sq(points_l2sq, points_in, npoints, dim); - compute_l2sq(queries_l2sq, queries_in, nqueries, dim); - - float *points = points_in; - float *queries = queries_in; - - if (metric == diskann::Metric::COSINE) - { // we convert cosine distance as - // normalized L2 distnace - points = new float[npoints * dim]; - queries = new float[nqueries * dim]; + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)npoints; i++) - { - float norm = std::sqrt(points_l2sq[i]); - if (norm == 0) - { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) - { - points[i * dim + j] = points_in[i * dim + j] / norm; - } - } + for (int64_t i = 0; i < (int64_t)npoints; i++) { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)nqueries; i++) - { - float norm = std::sqrt(queries_l2sq[i]); - if (norm == 0) - { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) - { - queries[i * dim + j] = queries_in[i * dim + j] / norm; - } - } - // recalculate norms after normalizing, they should all be one. - compute_l2sq(points_l2sq, points, npoints, dim); - compute_l2sq(queries_l2sq, queries, nqueries, dim); + for (int64_t i = 0; i < (int64_t)nqueries; i++) { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } } - - std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " - << dim << " dimensions using"; - if (metric == diskann::Metric::INNER_PRODUCT) - std::cout << " MIPS "; - else if (metric == diskann::Metric::COSINE) - std::cout << " Cosine "; - else - std::cout << " L2 "; - std::cout << "distance fn. " << std::endl; - - size_t q_batch_size = (1 << 9); - float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; - - for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) - { - int64_t q_b = b * q_batch_size; - int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; - - if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) - { - distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, - queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); - } - else - { - inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, - queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); - } - std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); + } + + std::cout << "Going to compute " << k << " NNs for " << nqueries + << " queries over " << npoints << " points in " << dim + << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) { + int64_t q_b = b * q_batch_size; + int64_t q_e = + ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, + q_e - q_b, queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, + queries_l2sq + q_b); + } else { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" + << std::endl; #pragma omp parallel for schedule(dynamic, 16) - for (long long q = q_b; q < q_e; q++) - { - maxPQIFCS point_dist; - for (size_t p = 0; p < k; p++) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - for (size_t p = k; p < npoints; p++) - { - if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - if (point_dist.size() > k) - point_dist.pop(); - } - for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) - { - closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; - dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; - point_dist.pop(); - } - assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, - dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); - } - std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; + for (long long q = q_b; q < q_e; q++) { + maxPQIFCS point_dist; + for (size_t p = 0; p < k; p++) + point_dist.emplace(p, + dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * + (ptrdiff_t)npoints]); + for (size_t p = k; p < npoints; p++) { + if (point_dist.top().second > + dist_matrix[(ptrdiff_t)p + + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + point_dist.emplace( + p, dist_matrix[(ptrdiff_t)p + + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + if (point_dist.size() > k) + point_dist.pop(); + } + for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) { + closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = + point_dist.top().first; + dist_closest_points[(ptrdiff_t)(k - 1 - l) + + (ptrdiff_t)q * (ptrdiff_t)k] = + point_dist.top().second; + point_dist.pop(); + } + assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, + dist_closest_points + + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e + << ")" << std::endl; + } - delete[] dist_matrix; + delete[] dist_matrix; - delete[] points_l2sq; - delete[] queries_l2sq; + delete[] points_l2sq; + delete[] queries_l2sq; - if (metric == diskann::Metric::COSINE) - { - delete[] points; - delete[] queries; - } + if (metric == diskann::Metric::COSINE) { + delete[] points; + delete[] queries; + } } -template inline int get_num_parts(const char *filename) -{ - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; - reader.close(); - int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; - std::cout << "Number of parts: " << num_parts << std::endl; - return num_parts; +template inline int get_num_parts(const char *filename) { + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + int num_parts = (npts_i32 % PARTSIZE) == 0 + ? npts_i32 / PARTSIZE + : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; } template -inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num) -{ - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts_u64 = end_id - start_id; - ndims_u64 = (uint64_t)ndims_i32; - std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64 - << ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl; - - reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); - T *data_T = new T[npts_u64 * ndims_u64]; - reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - data = aligned_malloc(npts_u64 * ndims_u64, ALIGNMENT); +inline void load_bin_as_float(const char *filename, float *&data, + size_t &npts_u64, size_t &ndims_u64, + int part_num) { + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts_u64 = end_id - start_id; + ndims_u64 = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64 + << ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" + << std::endl; + + reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), + std::ios::beg); + T *data_T = new T[npts_u64 * ndims_u64]; + reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts_u64 * ndims_u64, ALIGNMENT); #pragma omp parallel for schedule(dynamic, 32768) - for (int64_t i = 0; i < (int64_t)npts_u64; i++) - { - for (int64_t j = 0; j < (int64_t)ndims_u64; j++) - { - float cur_val_float = (float)data_T[i * ndims_u64 + j]; - std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float)); - } + for (int64_t i = 0; i < (int64_t)npts_u64; i++) { + for (int64_t j = 0; j < (int64_t)ndims_u64; j++) { + float cur_val_float = (float)data_T[i * ndims_u64 + j]; + std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, + sizeof(float)); } - delete[] data_T; - std::cout << "Finished converting part data to float." << std::endl; + } + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; } template -inline std::vector load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, - int part_num, const char *label_file, - const std::string &filter_label, - const std::string &universal_label, size_t &npoints_filt, - std::vector> &pts_to_labels) -{ - std::ifstream reader(filename, std::ios::binary); - if (reader.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); - } - - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - std::vector rev_map; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts = end_id - start_id; - ndims = (uint32_t)ndims_i32; - uint64_t nptsuint64_t = (uint64_t)npts; - uint64_t ndimsuint64_t = (uint64_t)ndims; - npoints_filt = 0; - std::cout << "#pts in part = " << npts << ", #dims = " << ndims - << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl; - std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; - reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); - - T *data_T = new T[nptsuint64_t * ndimsuint64_t]; - reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - - data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); - - for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++) - { - if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) != - pts_to_labels[start_id + i].end() || - std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) != - pts_to_labels[start_id + i].end()) - { - rev_map.push_back(start_id + i); - for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++) - { - float cur_val_float = (float)data_T[i * ndimsuint64_t + j]; - std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float)); - } - npoints_filt++; - } +inline std::vector load_filtered_bin_as_float( + const char *filename, float *&data, size_t &npts, size_t &ndims, + int part_num, const char *label_file, const std::string &filter_label, + const std::string &universal_label, size_t &npoints_filt, + std::vector> &pts_to_labels) { + std::ifstream reader(filename, std::ios::binary); + if (reader.fail()) { + throw diskann::ANNException(std::string("Failed to open file ") + filename, + -1); + } + + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + std::vector rev_map; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint32_t)ndims_i32; + uint64_t nptsuint64_t = (uint64_t)npts; + uint64_t ndimsuint64_t = (uint64_t)ndims; + npoints_filt = 0; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims + << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" + << std::endl; + std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), + std::ios::beg); + + T *data_T = new T[nptsuint64_t * ndimsuint64_t]; + reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + + data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); + + for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++) { + if (std::find(pts_to_labels[start_id + i].begin(), + pts_to_labels[start_id + i].end(), + filter_label) != pts_to_labels[start_id + i].end() || + std::find(pts_to_labels[start_id + i].begin(), + pts_to_labels[start_id + i].end(), + universal_label) != pts_to_labels[start_id + i].end()) { + rev_map.push_back(start_id + i); + for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++) { + float cur_val_float = (float)data_T[i * ndimsuint64_t + j]; + std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), + (char *)&cur_val_float, sizeof(float)); + } + npoints_filt++; } - delete[] data_T; - std::cout << "Finished converting part data to float.. identified " << npoints_filt - << " points matching the filter." << std::endl; - return rev_map; + } + delete[] data_T; + std::cout << "Finished converting part data to float.. identified " + << npoints_filt << " points matching the filter." << std::endl; + return rev_map; } -template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) -{ - std::ofstream writer; - writer.exceptions(std::ios::failbit | std::ios::badbit); - writer.open(filename, std::ios::binary | std::ios::out); - std::cout << "Writing bin: " << filename << "\n"; - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "bin: #pts = " << npts << ", #dims = " << ndims - << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(T)); - writer.close(); - std::cout << "Finished writing bin" << std::endl; +template +inline void save_bin(const std::string filename, T *data, size_t npts, + size_t ndims) { + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" + << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; } -inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, - size_t ndims) -{ - std::ofstream writer(filename, std::ios::binary | std::ios::out); - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " - "npts*dim dist-matrix) with npts = " - << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) - << "B" << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(uint32_t)); - writer.write((char *)distances, npts * ndims * sizeof(float)); - writer.close(); - std::cout << "Finished writing truthset" << std::endl; +inline void save_groundtruth_as_one_file(const std::string filename, + int32_t *data, float *distances, + size_t npts, size_t ndims) { + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " + << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) << "B" + << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; } -inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file, - std::vector> &pts_to_labels) -{ - std::ifstream infile(map_file); - std::string line, token; - std::set labels; - infile.clear(); - infile.seekg(0, std::ios::beg); - while (std::getline(infile, line)) - { - std::istringstream iss(line); - std::vector lbls(0); - - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) - { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - lbls.push_back(token); - labels.insert(token); - } - std::sort(lbls.begin(), lbls.end()); - pts_to_labels.push_back(lbls); +inline void parse_label_file_into_vec( + size_t &line_cnt, const std::string &map_file, + std::vector> &pts_to_labels) { + std::ifstream infile(map_file); + std::string line, token; + std::set labels; + infile.clear(); + infile.seekg(0, std::ios::beg); + while (std::getline(infile, line)) { + std::istringstream iss(line); + std::vector lbls(0); + + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + lbls.push_back(token); + labels.insert(token); } - std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for " - << pts_to_labels.size() << " points" << std::endl; + std::sort(lbls.begin(), lbls.end()); + pts_to_labels.push_back(lbls); + } + std::cout << "Identified " << labels.size() + << " distinct label(s), and populated labels for " + << pts_to_labels.size() << " points" << std::endl; } template -std::vector>> processUnfilteredParts(const std::string &base_file, - size_t &nqueries, size_t &npoints, - size_t &dim, size_t &k, float *query_data, - const diskann::Metric &metric, - std::vector &location_to_tag) -{ - float *base_data = nullptr; - int num_parts = get_num_parts(base_file.c_str()); - std::vector>> res(nqueries); - for (int p = 0; p < num_parts; p++) - { - size_t start_id = p * PARTSIZE; - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, - metric); - - for (size_t i = 0; i < nqueries; i++) - { - for (uint64_t j = 0; j < part_k; j++) - { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - - res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); - } - } +std::vector>> +processUnfilteredParts(const std::string &base_file, size_t &nqueries, + size_t &npoints, size_t &dim, size_t &k, + float *query_data, const diskann::Metric &metric, + std::vector &location_to_tag) { + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, + npoints, base_data, nqueries, query_data, metric); + + for (size_t i = 0; i < nqueries; i++) { + for (uint64_t j = 0; j < part_k; j++) { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair( + (uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } + } - delete[] closest_points_part; - delete[] dist_closest_points_part; + delete[] closest_points_part; + delete[] dist_closest_points_part; - diskann::aligned_free(base_data); - } - return res; + diskann::aligned_free(base_data); + } + return res; }; template std::vector>> processFilteredParts( - const std::string &base_file, const std::string &label_file, const std::string &filter_label, - const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data, - const diskann::Metric &metric, std::vector &location_to_tag) -{ - size_t npoints_filt = 0; - float *base_data = nullptr; - std::vector>> res(nqueries); - int num_parts = get_num_parts(base_file.c_str()); - - std::vector> pts_to_labels; + const std::string &base_file, const std::string &label_file, + const std::string &filter_label, const std::string &universal_label, + size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, + float *query_data, const diskann::Metric &metric, + std::vector &location_to_tag) { + size_t npoints_filt = 0; + float *base_data = nullptr; + std::vector>> res(nqueries); + int num_parts = get_num_parts(base_file.c_str()); + + std::vector> pts_to_labels; + if (filter_label != "") + parse_label_file_into_vec(npoints, label_file, pts_to_labels); + + for (int p = 0; p < num_parts; p++) { + size_t start_id = p * PARTSIZE; + std::vector rev_map; if (filter_label != "") - parse_label_file_into_vec(npoints, label_file, pts_to_labels); - - for (int p = 0; p < num_parts; p++) - { - size_t start_id = p * PARTSIZE; - std::vector rev_map; - if (filter_label != "") - rev_map = load_filtered_bin_as_float(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), - filter_label, universal_label, npoints_filt, pts_to_labels); - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints_filt ? k : npoints_filt; - if (npoints_filt > 0) - { - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries, - query_data, metric); - } - - for (size_t i = 0; i < nqueries; i++) - { - for (uint64_t j = 0; j < part_k; j++) - { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - - res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]), - dist_closest_points_part[i * part_k + j])); - } - } - - delete[] closest_points_part; - delete[] dist_closest_points_part; - - diskann::aligned_free(base_data); + rev_map = load_filtered_bin_as_float( + base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), + filter_label, universal_label, npoints_filt, pts_to_labels); + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints_filt ? k : npoints_filt; + if (npoints_filt > 0) { + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, + npoints_filt, base_data, nqueries, query_data, metric); } - return res; -}; -template -int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file, - const std::string >_file, size_t k, const std::string &universal_label, const diskann::Metric &metric, - const std::string &filter_label, const std::string &tags_file = std::string("")) -{ - size_t npoints, nqueries, dim; + for (size_t i = 0; i < nqueries; i++) { + for (uint64_t j = 0; j < part_k; j++) { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; - float *query_data = nullptr; - - load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); - if (nqueries > PARTSIZE) - std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE - << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; - - // load tags - const bool tags_enabled = tags_file.empty() ? false : true; - std::vector location_to_tag = diskann::loadTags(tags_file, base_file); - - int *closest_points = new int[nqueries * k]; - float *dist_closest_points = new float[nqueries * k]; - - std::vector>> results; - if (filter_label == "") - { - results = processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); - } - else - { - results = processFilteredParts(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim, - k, query_data, metric, location_to_tag); + res[i].push_back(std::make_pair( + (uint32_t)(rev_map[closest_points_part[i * part_k + j]]), + dist_closest_points_part[i * part_k + j])); + } } - for (size_t i = 0; i < nqueries; i++) - { - std::vector> &cur_res = results[i]; - std::sort(cur_res.begin(), cur_res.end(), custom_dist); - size_t j = 0; - for (auto iter : cur_res) - { - if (j == k) - break; - if (tags_enabled) - { - std::uint32_t index_with_tag = location_to_tag[iter.first]; - closest_points[i * k + j] = (int32_t)index_with_tag; - } - else - { - closest_points[i * k + j] = (int32_t)iter.first; - } - - if (metric == diskann::Metric::INNER_PRODUCT) - dist_closest_points[i * k + j] = -iter.second; - else - dist_closest_points[i * k + j] = iter.second; - - ++j; - } - if (j < k) - std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; - } + delete[] closest_points_part; + delete[] dist_closest_points_part; - save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); - delete[] closest_points; - delete[] dist_closest_points; - diskann::aligned_free(query_data); - - return 0; -} + diskann::aligned_free(base_data); + } + return res; +}; -void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) -{ - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; - size_t actual_file_size = reader.get_file_size(); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (uint32_t)npts_i32; - dim = (uint32_t)dim_i32; - - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; - - int truthset_type = -1; // 1 means truthset has ids and distances, 2 means - // only ids, -1 is error - size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_with_dists) - truthset_type = 1; - - size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_just_ids) - truthset_type = 2; - - if (truthset_type == -1) - { - std::stringstream stream; - stream << "Error. File size mismatch. File should have bin format, with " - "npts followed by ngt followed by npts*ngt ids and optionally " - "followed by npts*ngt distance values; actual size: " - << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " - << expected_file_size_just_ids; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); +template +int aux_main(const std::string &base_file, const std::string &label_file, + const std::string &query_file, const std::string >_file, + size_t k, const std::string &universal_label, + const diskann::Metric &metric, const std::string &filter_label, + const std::string &tags_file = std::string("")) { + size_t npoints, nqueries, dim; + + float *query_data = nullptr; + + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries + << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." + << std::endl; + + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = + diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector>> results; + if (filter_label == "") { + results = processUnfilteredParts(base_file, nqueries, npoints, dim, k, + query_data, metric, location_to_tag); + } else { + results = processFilteredParts(base_file, label_file, filter_label, + universal_label, nqueries, npoints, dim, + k, query_data, metric, location_to_tag); + } + + for (size_t i = 0; i < nqueries; i++) { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) { + if (j == k) + break; + if (tags_enabled) { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } else { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; } + if (j < k) + std::cout << "WARNING: found less than k GT entries for query " << i + << std::endl; + } + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, + nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; +} - ids = new uint32_t[npts * dim]; - reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - - if (truthset_type == 1) - { - dists = new float[npts * dim]; - reader.read((char *)dists, npts * dim * sizeof(float)); - } +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, + size_t &npts, size_t &dim) { + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." + << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " + << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = + 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = + npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size + << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } } -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label, - universal_label, filter_label_file; - uint64_t K; - - try - { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); - desc.add_options()("base_file", po::value(&base_file)->required(), - "File containing the base vectors in binary format"); - desc.add_options()("query_file", po::value(&query_file)->required(), - "File containing the query vectors in binary format"); - desc.add_options()("label_file", po::value(&label_file)->default_value(""), - "Input labels file in txt format if present"); - desc.add_options()("filter_label", po::value(&filter_label)->default_value(""), - "Input filter label if doing filtered groundtruth"); - desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with label_file"); - desc.add_options()("gt_file", po::value(>_file)->required(), - "File name for the writing ground truth in binary " - "format, please don' append .bin at end if " - "no filter_label or filter_label_file is provided it " - "will save the file with '.bin' at end." - "else it will save the file as filename_label.bin"); - desc.add_options()("K", po::value(&K)->required(), - "Number of ground truth nearest neighbors to compute"); - desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), - "File containing the tags in binary format"); - desc.add_options()("filter_label_file", - po::value(&filter_label_file)->default_value(std::string("")), - "Filter file for Queries for Filtered Search "); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); +int main(int argc, char **argv) { + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, + label_file, filter_label, universal_label, filter_label_file; + uint64_t K; + + try { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), + "distance function "); + desc.add_options()("base_file", + po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", + po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("label_file", + po::value(&label_file)->default_value(""), + "Input labels file in txt format if present"); + desc.add_options()("filter_label", + po::value(&filter_label)->default_value(""), + "Input filter label if doing filtered groundtruth"); + desc.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with label_file"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()( + "tags_file", + po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + desc.add_options()("filter_label_file", + po::value(&filter_label_file) + ->default_value(std::string("")), + "Filter file for Queries for Filtered Search "); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && + data_type != std::string("uint8")) { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." + << std::endl; + return -1; + } + + if (filter_label != "" && filter_label_file != "") { + std::cerr + << "Only one of filter_label and query_filters_file should be provided" + << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) { + metric = diskann::Metric::L2; + } else if (dist_fn == std::string("mips")) { + metric = diskann::Metric::INNER_PRODUCT; + } else if (dist_fn == std::string("cosine")) { + metric = diskann::Metric::COSINE; + } else { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." + << std::endl; + return -1; + } + + std::vector filter_labels; + if (filter_label != "") { + filter_labels.push_back(filter_label); + } else if (filter_label_file != "") { + filter_labels = read_file_to_vector_of_strings(filter_label_file, false); + } + + // only when there is no filter label or 1 filter label for all queries + if (filter_labels.size() == 1) { + try { + if (data_type == std::string("float")) + aux_main(base_file, label_file, query_file, gt_file, K, + universal_label, metric, filter_labels[0], tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, label_file, query_file, gt_file, K, + universal_label, metric, filter_labels[0], tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, label_file, query_file, gt_file, K, + universal_label, metric, filter_labels[0], tags_file); + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; } - - if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) - { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; - return -1; + } else { // Each query has its own filter label + // Split up data and query bins into label specific ones + tsl::robin_map labels_to_number_of_points; + tsl::robin_map labels_to_number_of_queries; + + label_set all_labels; + for (size_t i = 0; i < filter_labels.size(); i++) { + std::string label = filter_labels[i]; + all_labels.insert(label); + + if (labels_to_number_of_queries.find(label) == + labels_to_number_of_queries.end()) { + labels_to_number_of_queries[label] = 0; + } + labels_to_number_of_queries[label] += 1; } - if (filter_label != "" && filter_label_file != "") - { - std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; - } - else if (dist_fn == std::string("mips")) - { - metric = diskann::Metric::INNER_PRODUCT; - } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else - { - std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; - return -1; + size_t npoints; + std::vector> point_to_labels; + parse_label_file_into_vec(npoints, label_file, point_to_labels); + std::vector point_ids_to_labels(point_to_labels.size()); + std::vector query_ids_to_labels(filter_labels.size()); + + for (size_t i = 0; i < point_to_labels.size(); i++) { + for (size_t j = 0; j < point_to_labels[i].size(); j++) { + std::string label = point_to_labels[i][j]; + if (all_labels.find(label) != all_labels.end()) { + point_ids_to_labels[i].insert(point_to_labels[i][j]); + if (labels_to_number_of_points.find(label) == + labels_to_number_of_points.end()) { + labels_to_number_of_points[label] = 0; + } + labels_to_number_of_points[label] += 1; + } + } } - std::vector filter_labels; - if (filter_label != "") - { - filter_labels.push_back(filter_label); - } - else if (filter_label_file != "") - { - filter_labels = read_file_to_vector_of_strings(filter_label_file, false); + for (size_t i = 0; i < filter_labels.size(); i++) { + query_ids_to_labels[i].insert(filter_labels[i]); } - // only when there is no filter label or 1 filter label for all queries - if (filter_labels.size() == 1) - { - try - { - if (data_type == std::string("float")) - aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, - filter_labels[0], tags_file); - if (data_type == std::string("int8")) - aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, - filter_labels[0], tags_file); - if (data_type == std::string("uint8")) - aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, - filter_labels[0], tags_file); - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Compute GT failed." << std::endl; - return -1; - } + tsl::robin_map> label_id_to_orig_id; + tsl::robin_map> + label_query_id_to_orig_id; + + if (data_type == std::string("float")) { + label_id_to_orig_id = + diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, + all_labels); + + label_query_id_to_orig_id = + diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } else if (data_type == std::string("int8")) { + label_id_to_orig_id = + diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, + all_labels); + + label_query_id_to_orig_id = + diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } else if (data_type == std::string("uint8")) { + label_id_to_orig_id = + diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, + all_labels); + + label_query_id_to_orig_id = + diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } else { + diskann::cerr << "Invalid data type" << std::endl; + return -1; } - else - { // Each query has its own filter label - // Split up data and query bins into label specific ones - tsl::robin_map labels_to_number_of_points; - tsl::robin_map labels_to_number_of_queries; - - label_set all_labels; - for (size_t i = 0; i < filter_labels.size(); i++) - { - std::string label = filter_labels[i]; - all_labels.insert(label); - - if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end()) - { - labels_to_number_of_queries[label] = 0; - } - labels_to_number_of_queries[label] += 1; - } - - size_t npoints; - std::vector> point_to_labels; - parse_label_file_into_vec(npoints, label_file, point_to_labels); - std::vector point_ids_to_labels(point_to_labels.size()); - std::vector query_ids_to_labels(filter_labels.size()); - - for (size_t i = 0; i < point_to_labels.size(); i++) - { - for (size_t j = 0; j < point_to_labels[i].size(); j++) - { - std::string label = point_to_labels[i][j]; - if (all_labels.find(label) != all_labels.end()) - { - point_ids_to_labels[i].insert(point_to_labels[i][j]); - if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end()) - { - labels_to_number_of_points[label] = 0; - } - labels_to_number_of_points[label] += 1; - } - } - } - for (size_t i = 0; i < filter_labels.size(); i++) - { - query_ids_to_labels[i].insert(filter_labels[i]); - } - - tsl::robin_map> label_id_to_orig_id; - tsl::robin_map> label_query_id_to_orig_id; + // Generate label specific ground truths + try { + for (const auto &label : all_labels) { + std::string filtered_base_file = base_file + "_" + label; + std::string filtered_query_file = query_file + "_" + label; + std::string filtered_gt_file = gt_file + "_" + label; if (data_type == std::string("float")) - { - label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( - base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); - - label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( - query_file, labels_to_number_of_queries, query_ids_to_labels, - all_labels); // query_filters acts like query_ids_to_labels - } - else if (data_type == std::string("int8")) - { - label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( - base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); - - label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( - query_file, labels_to_number_of_queries, query_ids_to_labels, - all_labels); // query_filters acts like query_ids_to_labels - } - else if (data_type == std::string("uint8")) - { - label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( - base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); - - label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( - query_file, labels_to_number_of_queries, query_ids_to_labels, - all_labels); // query_filters acts like query_ids_to_labels - } - else - { - diskann::cerr << "Invalid data type" << std::endl; - return -1; - } + aux_main(filtered_base_file, "", filtered_query_file, + filtered_gt_file, K, "", metric, ""); + if (data_type == std::string("int8")) + aux_main(filtered_base_file, "", filtered_query_file, + filtered_gt_file, K, "", metric, ""); + if (data_type == std::string("uint8")) + aux_main(filtered_base_file, "", filtered_query_file, + filtered_gt_file, K, "", metric, ""); + } + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } - // Generate label specific ground truths - - try - { - for (const auto &label : all_labels) - { - std::string filtered_base_file = base_file + "_" + label; - std::string filtered_query_file = query_file + "_" + label; - std::string filtered_gt_file = gt_file + "_" + label; - if (data_type == std::string("float")) - aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); - if (data_type == std::string("int8")) - aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); - if (data_type == std::string("uint8")) - aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); - } - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Compute GT failed." << std::endl; - return -1; - } + // Combine the label specific ground truths to produce a single GT file - // Combine the label specific ground truths to produce a single GT file + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t gt_num, gt_dim; - uint32_t *gt_ids = nullptr; - float *gt_dists = nullptr; - size_t gt_num, gt_dim; + std::vector> final_gt_ids; + std::vector> final_gt_dists; - std::vector> final_gt_ids; - std::vector> final_gt_dists; + uint32_t query_num = 0; + for (const auto &lbl : all_labels) { + query_num += labels_to_number_of_queries[lbl]; + } - uint32_t query_num = 0; - for (const auto &lbl : all_labels) - { - query_num += labels_to_number_of_queries[lbl]; - } + for (uint32_t i = 0; i < query_num; i++) { + final_gt_ids.push_back(std::vector(K)); + final_gt_dists.push_back(std::vector(K)); + } - for (uint32_t i = 0; i < query_num; i++) - { - final_gt_ids.push_back(std::vector(K)); - final_gt_dists.push_back(std::vector(K)); - } + for (const auto &lbl : all_labels) { + std::string filtered_gt_file = gt_file + "_" + lbl; + load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim); - for (const auto &lbl : all_labels) - { - std::string filtered_gt_file = gt_file + "_" + lbl; - load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim); - - for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++) - { - uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i]; - for (uint64_t j = 0; j < K; j++) - { - final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]]; - final_gt_dists[orig_query_id][j] = gt_dists[i * K + j]; - } - } + for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++) { + uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i]; + for (uint64_t j = 0; j < K; j++) { + final_gt_ids[orig_query_id][j] = + label_id_to_orig_id[lbl][gt_ids[i * K + j]]; + final_gt_dists[orig_query_id][j] = gt_dists[i * K + j]; } + } + } - int32_t *closest_points = new int32_t[query_num * K]; - float *dist_closest_points = new float[query_num * K]; + int32_t *closest_points = new int32_t[query_num * K]; + float *dist_closest_points = new float[query_num * K]; - for (uint32_t i = 0; i < query_num; i++) - { - for (uint32_t j = 0; j < K; j++) - { - closest_points[i * K + j] = final_gt_ids[i][j]; - dist_closest_points[i * K + j] = final_gt_dists[i][j]; - } - } + for (uint32_t i = 0; i < query_num; i++) { + for (uint32_t j = 0; j < K; j++) { + closest_points[i * K + j] = final_gt_ids[i][j]; + dist_closest_points[i * K + j] = final_gt_dists[i][j]; + } + } - save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K); + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, + query_num, K); - // cleanup artifacts - std::cout << "Cleaning up artifacts..." << std::endl; - tsl::robin_set paths_to_clean{gt_file, base_file, query_file}; - clean_up_artifacts(paths_to_clean, all_labels); - } + // cleanup artifacts + std::cout << "Cleaning up artifacts..." << std::endl; + tsl::robin_set paths_to_clean{gt_file, base_file, query_file}; + clean_up_artifacts(paths_to_clean, all_labels); + } } diff --git a/apps/utils/count_bfs_levels.cpp b/apps/utils/count_bfs_levels.cpp index 6dd2d6233..2f1a1db92 100644 --- a/apps/utils/count_bfs_levels.cpp +++ b/apps/utils/count_bfs_levels.cpp @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include -#include #include #include #include #include -#include #ifndef _WINDOWS #include @@ -17,66 +17,63 @@ #include #endif -#include "utils.h" #include "index.h" #include "memory_mapper.h" +#include "utils.h" namespace po = boost::program_options; -template void bfs_count(const std::string &index_path, uint32_t data_dims) -{ - using TagT = uint32_t; - using LabelT = uint32_t; - diskann::Index index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false, - false, 0, false); - std::cout << "Index class instantiated" << std::endl; - index.load(index_path.c_str(), 1, 100); - std::cout << "Index loaded" << std::endl; - index.count_nodes_at_bfs_levels(); +template +void bfs_count(const std::string &index_path, uint32_t data_dims) { + using TagT = uint32_t; + using LabelT = uint32_t; + diskann::Index index(diskann::Metric::L2, data_dims, 0, + nullptr, nullptr, 0, false, false, + false, false, 0, false); + std::cout << "Index class instantiated" << std::endl; + index.load(index_path.c_str(), 1, 100); + std::cout << "Index loaded" << std::endl; + index.count_nodes_at_bfs_levels(); } -int main(int argc, char **argv) -{ - std::string data_type, index_path_prefix; - uint32_t data_dims; +int main(int argc, char **argv) { + std::string data_type, index_path_prefix; + uint32_t data_dims; - po::options_description desc{"Arguments"}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix to the index"); - desc.add_options()("data_dims", po::value(&data_dims)->required(), "Dimensionality of the data"); + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("index_path_prefix", + po::value(&index_path_prefix)->required(), + "Path prefix to the index"); + desc.add_options()("data_dims", po::value(&data_dims)->required(), + "Dimensionality of the data"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } - try - { - if (data_type == std::string("int8")) - bfs_count(index_path_prefix, data_dims); - else if (data_type == std::string("uint8")) - bfs_count(index_path_prefix, data_dims); - if (data_type == std::string("float")) - bfs_count(index_path_prefix, data_dims); - } - catch (std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index BFS failed." << std::endl; - return -1; - } + try { + if (data_type == std::string("int8")) + bfs_count(index_path_prefix, data_dims); + else if (data_type == std::string("uint8")) + bfs_count(index_path_prefix, data_dims); + if (data_type == std::string("float")) + bfs_count(index_path_prefix, data_dims); + } catch (std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index BFS failed." << std::endl; + return -1; + } } diff --git a/apps/utils/create_disk_layout.cpp b/apps/utils/create_disk_layout.cpp index f494c1227..7c8eca1b0 100644 --- a/apps/utils/create_disk_layout.cpp +++ b/apps/utils/create_disk_layout.cpp @@ -8,41 +8,37 @@ #include #include -#include "utils.h" -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" +#include "utils.h" -template int create_disk_layout(char **argv) -{ - std::string base_file(argv[2]); - std::string vamana_file(argv[3]); - std::string output_file(argv[4]); - diskann::create_disk_layout(base_file, vamana_file, output_file); - return 0; +template int create_disk_layout(char **argv) { + std::string base_file(argv[2]); + std::string vamana_file(argv[3]); + std::string output_file(argv[4]); + diskann::create_disk_layout(base_file, vamana_file, output_file); + return 0; } -int main(int argc, char **argv) -{ - if (argc != 5) - { - std::cout << argv[0] - << " data_type data_bin " - "vamana_index_file output_diskann_index_file" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 5) { + std::cout << argv[0] + << " data_type data_bin " + "vamana_index_file output_diskann_index_file" + << std::endl; + exit(-1); + } - int ret_val = -1; - if (std::string(argv[1]) == std::string("float")) - ret_val = create_disk_layout(argv); - else if (std::string(argv[1]) == std::string("int8")) - ret_val = create_disk_layout(argv); - else if (std::string(argv[1]) == std::string("uint8")) - ret_val = create_disk_layout(argv); - else - { - std::cout << "unsupported type. use int8/uint8/float " << std::endl; - ret_val = -2; - } - return ret_val; + int ret_val = -1; + if (std::string(argv[1]) == std::string("float")) + ret_val = create_disk_layout(argv); + else if (std::string(argv[1]) == std::string("int8")) + ret_val = create_disk_layout(argv); + else if (std::string(argv[1]) == std::string("uint8")) + ret_val = create_disk_layout(argv); + else { + std::cout << "unsupported type. use int8/uint8/float " << std::endl; + ret_val = -2; + } + return ret_val; } diff --git a/apps/utils/float_bin_to_int8.cpp b/apps/utils/float_bin_to_int8.cpp index 1982005af..d3776b641 100644 --- a/apps/utils/float_bin_to_int8.cpp +++ b/apps/utils/float_bin_to_int8.cpp @@ -1,63 +1,62 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts, - size_t ndims, float bias, float scale) -{ - reader.read((char *)read_buf, npts * ndims * sizeof(float)); - - for (size_t i = 0; i < npts; i++) - { - for (size_t d = 0; d < ndims; d++) - { - write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale)); - } - } - writer.write((char *)write_buf, npts * ndims); -} - -int main(int argc, char **argv) -{ - if (argc != 5) - { - std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl; - exit(-1); - } +void block_convert(std::ofstream &writer, int8_t *write_buf, + std::ifstream &reader, float *read_buf, size_t npts, + size_t ndims, float bias, float scale) { + reader.read((char *)read_buf, npts * ndims * sizeof(float)); - std::ifstream reader(argv[1], std::ios::binary); - uint32_t npts_u32; - uint32_t ndims_u32; - reader.read((char *)&npts_u32, sizeof(uint32_t)); - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - size_t npts = npts_u32; - size_t ndims = ndims_u32; - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - - std::ofstream writer(argv[2], std::ios::binary); - auto read_buf = new float[blk_size * ndims]; - auto write_buf = new int8_t[blk_size * ndims]; - float bias = (float)atof(argv[3]); - float scale = (float)atof(argv[4]); - - writer.write((char *)(&npts_u32), sizeof(uint32_t)); - writer.write((char *)(&ndims_u32), sizeof(uint32_t)); - - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale); - std::cout << "Block #" << i << " written" << std::endl; + for (size_t i = 0; i < npts; i++) { + for (size_t d = 0; d < ndims; d++) { + write_buf[d + i * ndims] = + (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale)); } + } + writer.write((char *)write_buf, npts * ndims); +} - delete[] read_buf; - delete[] write_buf; - - writer.close(); - reader.close(); +int main(int argc, char **argv) { + if (argc != 5) { + std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" + << std::endl; + exit(-1); + } + + std::ifstream reader(argv[1], std::ios::binary); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + size_t npts = npts_u32; + size_t ndims = ndims_u32; + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + + std::ofstream writer(argv[2], std::ios::binary); + auto read_buf = new float[blk_size * ndims]; + auto write_buf = new int8_t[blk_size * ndims]; + float bias = (float)atof(argv[3]); + float scale = (float)atof(argv[4]); + + writer.write((char *)(&npts_u32), sizeof(uint32_t)); + writer.write((char *)(&ndims_u32), sizeof(uint32_t)); + + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, + scale); + std::cout << "Block #" << i << " written" << std::endl; + } + + delete[] read_buf; + delete[] write_buf; + + writer.close(); + reader.close(); } diff --git a/apps/utils/fvecs_to_bin.cpp b/apps/utils/fvecs_to_bin.cpp index 873ad3b0c..02dbacf54 100644 --- a/apps/utils/fvecs_to_bin.cpp +++ b/apps/utils/fvecs_to_bin.cpp @@ -1,95 +1,92 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include // Convert float types -void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts, - size_t ndims) -{ - reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) - { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float)); - } - writer.write((char *)write_buf, npts * ndims * sizeof(float)); +void block_convert_float(std::ifstream &reader, std::ofstream &writer, + float *read_buf, float *write_buf, size_t npts, + size_t ndims) { + reader.read((char *)read_buf, + npts * (ndims * sizeof(float) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, + ndims * sizeof(float)); + } + writer.write((char *)write_buf, npts * ndims * sizeof(float)); } // Convert byte types -void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf, - size_t npts, size_t ndims) -{ - reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) - { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t), - ndims * sizeof(uint8_t)); - } - writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t)); +void block_convert_byte(std::ifstream &reader, std::ofstream &writer, + uint8_t *read_buf, uint8_t *write_buf, size_t npts, + size_t ndims) { + reader.read((char *)read_buf, + npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { + memcpy(write_buf + i * ndims, + (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t), + ndims * sizeof(uint8_t)); + } + writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t)); } -int main(int argc, char **argv) -{ - if (argc != 4) - { - std::cout << argv[0] << " input_vecs output_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 4) { + std::cout << argv[0] << " input_vecs output_bin" + << std::endl; + exit(-1); + } - int datasize = sizeof(float); + int datasize = sizeof(float); - if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0) - { - datasize = sizeof(uint8_t); - } - else if (strcmp(argv[1], "float") != 0) - { - std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl; - exit(-1); - } + if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0) { + datasize = sizeof(uint8_t); + } else if (strcmp(argv[1], "float") != 0) { + std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl; + exit(-1); + } - std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); - size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); + std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); + size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); - uint32_t ndims_u32; - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - reader.seekg(0, std::ios::beg); - size_t ndims = (size_t)ndims_u32; - size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + reader.seekg(0, std::ios::beg); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[3], std::ios::binary); - int32_t npts_s32 = (int32_t)npts; - int32_t ndims_s32 = (int32_t)ndims; - writer.write((char *)&npts_s32, sizeof(int32_t)); - writer.write((char *)&ndims_s32, sizeof(int32_t)); + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[3], std::ios::binary); + int32_t npts_s32 = (int32_t)npts; + int32_t ndims_s32 = (int32_t)ndims; + writer.write((char *)&npts_s32, sizeof(int32_t)); + writer.write((char *)&ndims_s32, sizeof(int32_t)); - size_t chunknpts = std::min(npts, blk_size); - uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))]; - uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize]; + size_t chunknpts = std::min(npts, blk_size); + uint8_t *read_buf = + new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))]; + uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize]; - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (datasize == sizeof(float)) - { - block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims); - } - else - { - block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims); - } - std::cout << "Block #" << i << " written" << std::endl; + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (datasize == sizeof(float)) { + block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, + cblk_size, ndims); + } else { + block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims); } + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; - delete[] write_buf; + delete[] read_buf; + delete[] write_buf; - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/fvecs_to_bvecs.cpp b/apps/utils/fvecs_to_bvecs.cpp index f9c2aa71b..a5cc09449 100644 --- a/apps/utils/fvecs_to_bvecs.cpp +++ b/apps/utils/fvecs_to_bvecs.cpp @@ -1,56 +1,56 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts, - size_t ndims) -{ - reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) - { - memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t)); - for (size_t d = 0; d < ndims; d++) - write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d]; - } - writer.write((char *)write_buf, npts * (ndims * 1 + 4)); +void block_convert(std::ifstream &reader, std::ofstream &writer, + float *read_buf, uint8_t *write_buf, size_t npts, + size_t ndims) { + reader.read((char *)read_buf, + npts * (ndims * sizeof(float) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { + memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), + sizeof(uint32_t)); + for (size_t d = 0; d < ndims; d++) + write_buf[i * (ndims + 4) + 4 + d] = + (uint8_t)read_buf[i * (ndims + 1) + 1 + d]; + } + writer.write((char *)write_buf, npts * (ndims * 1 + 4)); } -int main(int argc, char **argv) -{ - if (argc != 3) - { - std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl; - exit(-1); - } - std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); - size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); +int main(int argc, char **argv) { + if (argc != 3) { + std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl; + exit(-1); + } + std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); + size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); - uint32_t ndims_u32; - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - reader.seekg(0, std::ios::beg); - size_t ndims = (size_t)ndims_u32; - size_t npts = fsize / ((ndims + 1) * sizeof(float)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + reader.seekg(0, std::ios::beg); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims + 1) * sizeof(float)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[2], std::ios::binary); - auto read_buf = new float[npts * (ndims + 1)]; - auto write_buf = new uint8_t[npts * (ndims + 4)]; - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[2], std::ios::binary); + auto read_buf = new float[npts * (ndims + 1)]; + auto write_buf = new uint8_t[npts * (ndims + 4)]; + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; - delete[] write_buf; + delete[] read_buf; + delete[] write_buf; - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/gen_random_slice.cpp b/apps/utils/gen_random_slice.cpp index a4cd96e0a..29307937d 100644 --- a/apps/utils/gen_random_slice.cpp +++ b/apps/utils/gen_random_slice.cpp @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include "partition.h" +#include "utils.h" #include #include #include @@ -10,49 +11,39 @@ #include #include #include +#include #include #include -#include "partition.h" -#include "utils.h" #include #include #include #include -template int aux_main(char **argv) -{ - std::string base_file(argv[2]); - std::string output_prefix(argv[3]); - float sampling_rate = (float)(std::atof(argv[4])); - gen_random_slice(base_file, output_prefix, sampling_rate); - return 0; +template int aux_main(char **argv) { + std::string base_file(argv[2]); + std::string output_prefix(argv[3]); + float sampling_rate = (float)(std::atof(argv[4])); + gen_random_slice(base_file, output_prefix, sampling_rate); + return 0; } -int main(int argc, char **argv) -{ - if (argc != 5) - { - std::cout << argv[0] - << " data_type [float/int8/uint8] base_bin_file " - "sample_output_prefix sampling_probability" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 5) { + std::cout << argv[0] + << " data_type [float/int8/uint8] base_bin_file " + "sample_output_prefix sampling_probability" + << std::endl; + exit(-1); + } - if (std::string(argv[1]) == std::string("float")) - { - aux_main(argv); - } - else if (std::string(argv[1]) == std::string("int8")) - { - aux_main(argv); - } - else if (std::string(argv[1]) == std::string("uint8")) - { - aux_main(argv); - } - else - std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; - return 0; + if (std::string(argv[1]) == std::string("float")) { + aux_main(argv); + } else if (std::string(argv[1]) == std::string("int8")) { + aux_main(argv); + } else if (std::string(argv[1]) == std::string("uint8")) { + aux_main(argv); + } else + std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; + return 0; } diff --git a/apps/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp index a881b1104..390bf4355 100644 --- a/apps/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -2,69 +2,72 @@ // Licensed under the MIT license. #include "math_utils.h" -#include "pq.h" #include "partition.h" +#include "pq.h" #define KMEANS_ITERS_FOR_PQ 15 template -bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers, - const size_t num_pq_chunks, const float sampling_rate, const bool opq) -{ - std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; - std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin"; +bool generate_pq(const std::string &data_path, + const std::string &index_prefix_path, + const size_t num_pq_centers, const size_t num_pq_chunks, + const float sampling_rate, const bool opq) { + std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; + std::string pq_compressed_vectors_path = + index_prefix_path + "_pq_compressed.bin"; - // generates random sample and sets it to train_data and updates train_size - size_t train_size, train_dim; - float *train_data; - gen_random_slice(data_path, sampling_rate, train_data, train_size, train_dim); - std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl; + // generates random sample and sets it to train_data and updates train_size + size_t train_size, train_dim; + float *train_data; + gen_random_slice(data_path, sampling_rate, train_data, train_size, + train_dim); + std::cout << "For computing pivots, loaded sample data of size " << train_size + << std::endl; - if (opq) - { - diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, - (uint32_t)num_pq_chunks, pq_pivots_path, true); - } - else - { - diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, - (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); - } - diskann::generate_pq_data_from_pivots(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, - pq_pivots_path, pq_compressed_vectors_path, true); + if (opq) { + diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, + (uint32_t)num_pq_centers, + (uint32_t)num_pq_chunks, pq_pivots_path, true); + } else { + diskann::generate_pq_pivots( + train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, + (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); + } + diskann::generate_pq_data_from_pivots( + data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, + pq_pivots_path, pq_compressed_vectors_path, true); - delete[] train_data; + delete[] train_data; - return 0; + return 0; } -int main(int argc, char **argv) -{ - if (argc != 7) - { - std::cout << "Usage: \n" - << argv[0] - << " " - " " - " " - << std::endl; - } - else - { - const std::string data_path(argv[2]); - const std::string index_prefix_path(argv[3]); - const size_t num_pq_centers = 256; - const size_t num_pq_chunks = (size_t)atoi(argv[4]); - const float sampling_rate = (float)atof(argv[5]); - const bool opq = atoi(argv[6]) == 0 ? false : true; +int main(int argc, char **argv) { + if (argc != 7) { + std::cout << "Usage: \n" + << argv[0] + << " " + " " + " " + << std::endl; + } else { + const std::string data_path(argv[2]); + const std::string index_prefix_path(argv[3]); + const size_t num_pq_centers = 256; + const size_t num_pq_chunks = (size_t)atoi(argv[4]); + const float sampling_rate = (float)atof(argv[5]); + const bool opq = atoi(argv[6]) == 0 ? false : true; - if (std::string(argv[1]) == std::string("float")) - generate_pq(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq); - else if (std::string(argv[1]) == std::string("int8")) - generate_pq(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq); - else if (std::string(argv[1]) == std::string("uint8")) - generate_pq(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq); - else - std::cout << "Error. wrong file type" << std::endl; - } + if (std::string(argv[1]) == std::string("float")) + generate_pq(data_path, index_prefix_path, num_pq_centers, + num_pq_chunks, sampling_rate, opq); + else if (std::string(argv[1]) == std::string("int8")) + generate_pq(data_path, index_prefix_path, num_pq_centers, + num_pq_chunks, sampling_rate, opq); + else if (std::string(argv[1]) == std::string("uint8")) + generate_pq(data_path, index_prefix_path, num_pq_centers, + num_pq_chunks, sampling_rate, opq); + else + std::cout << "Error. wrong file type" << std::endl; + } } diff --git a/apps/utils/generate_synthetic_labels.cpp b/apps/utils/generate_synthetic_labels.cpp index 6741760cb..4b11df0a4 100644 --- a/apps/utils/generate_synthetic_labels.cpp +++ b/apps/utils/generate_synthetic_labels.cpp @@ -1,204 +1,176 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include +#include "utils.h" #include -#include #include -#include "utils.h" +#include +#include +#include namespace po = boost::program_options; -class ZipfDistribution -{ - public: - ZipfDistribution(uint64_t num_points, uint32_t num_labels) - : num_labels(num_labels), num_points(num_points), - uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)) - { +class ZipfDistribution { +public: + ZipfDistribution(uint64_t num_points, uint32_t num_labels) + : num_labels(num_labels), num_points(num_points), + uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)) {} + + std::unordered_map createDistributionMap() { + std::unordered_map map; + uint32_t primary_label_freq = + (uint32_t)ceil(num_points * distribution_factor); + for (uint32_t i{1}; i < num_labels + 1; i++) { + map[i] = (uint32_t)ceil(primary_label_freq / i); } - - std::unordered_map createDistributionMap() - { - std::unordered_map map; - uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor); - for (uint32_t i{1}; i < num_labels + 1; i++) - { - map[i] = (uint32_t)ceil(primary_label_freq / i); + return map; + } + + int writeDistribution(std::ofstream &outfile) { + auto distribution_map = createDistributionMap(); + for (uint32_t i{0}; i < num_points; i++) { + bool label_written = false; + for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); + it++) { + auto label_selection_probability = std::bernoulli_distribution( + distribution_factor / (double)it->first); + if (label_selection_probability(rand_engine) && + distribution_map[it->first] > 0) { + if (label_written) { + outfile << ','; + } + outfile << it->first; + label_written = true; + // remove label from map if we have used all labels + distribution_map[it->first] -= 1; } - return map; + } + if (!label_written) { + outfile << 0; + } + if (i < num_points - 1) { + outfile << '\n'; + } } + return 0; + } - int writeDistribution(std::ofstream &outfile) - { - auto distribution_map = createDistributionMap(); - for (uint32_t i{0}; i < num_points; i++) - { - bool label_written = false; - for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++) - { - auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first); - if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0) - { - if (label_written) - { - outfile << ','; - } - outfile << it->first; - label_written = true; - // remove label from map if we have used all labels - distribution_map[it->first] -= 1; - } - } - if (!label_written) - { - outfile << 0; - } - if (i < num_points - 1) - { - outfile << '\n'; - } - } - return 0; + int writeDistribution(std::string filename) { + std::ofstream outfile(filename); + if (!outfile.is_open()) { + std::cerr << "Error: could not open output file " << filename << '\n'; + return -1; } - - int writeDistribution(std::string filename) - { - std::ofstream outfile(filename); - if (!outfile.is_open()) - { - std::cerr << "Error: could not open output file " << filename << '\n'; - return -1; - } - writeDistribution(outfile); - outfile.close(); - } - - private: - const uint32_t num_labels; - const uint64_t num_points; - const double distribution_factor = 0.7; - std::knuth_b rand_engine; - const std::uniform_real_distribution uniform_zero_to_one; + writeDistribution(outfile); + outfile.close(); + } + +private: + const uint32_t num_labels; + const uint64_t num_points; + const double distribution_factor = 0.7; + std::knuth_b rand_engine; + const std::uniform_real_distribution uniform_zero_to_one; }; -int main(int argc, char **argv) -{ - std::string output_file, distribution_type; - uint32_t num_labels; - uint64_t num_points; - - try - { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("output_file,O", po::value(&output_file)->required(), - "Filename for saving the label file"); - desc.add_options()("num_points,N", po::value(&num_points)->required(), "Number of points in dataset"); - desc.add_options()("num_labels,L", po::value(&num_labels)->required(), - "Number of unique labels, up to 5000"); - desc.add_options()("distribution_type,DT", po::value(&distribution_type)->default_value("random"), - "Distribution function for labels defaults " - "to random"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - if (num_labels > 5000) - { - std::cerr << "Error: num_labels must be 5000 or less" << '\n'; - return -1; +int main(int argc, char **argv) { + std::string output_file, distribution_type; + uint32_t num_labels; + uint64_t num_points; + + try { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("output_file,O", + po::value(&output_file)->required(), + "Filename for saving the label file"); + desc.add_options()("num_points,N", + po::value(&num_points)->required(), + "Number of points in dataset"); + desc.add_options()("num_labels,L", + po::value(&num_labels)->required(), + "Number of unique labels, up to 5000"); + desc.add_options()( + "distribution_type,DT", + po::value(&distribution_type)->default_value("random"), + "Distribution function for labels defaults " + "to random"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - - if (num_points <= 0) - { - std::cerr << "Error: num_points must be greater than 0" << '\n'; - return -1; + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (num_labels > 5000) { + std::cerr << "Error: num_labels must be 5000 or less" << '\n'; + return -1; + } + + if (num_points <= 0) { + std::cerr << "Error: num_points must be greater than 0" << '\n'; + return -1; + } + + std::cout << "Generating synthetic labels for " << num_points + << " points with " << num_labels << " unique labels" << '\n'; + + try { + std::ofstream outfile(output_file); + if (!outfile.is_open()) { + std::cerr << "Error: could not open output file " << output_file << '\n'; + return -1; } - std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels" - << '\n'; - - try - { - std::ofstream outfile(output_file); - if (!outfile.is_open()) - { - std::cerr << "Error: could not open output file " << output_file << '\n'; - return -1; - } - - if (distribution_type == "zipf") - { - ZipfDistribution zipf(num_points, num_labels); - zipf.writeDistribution(outfile); - } - else if (distribution_type == "random") - { - for (size_t i = 0; i < num_points; i++) - { - bool label_written = false; - for (size_t j = 1; j <= num_labels; j++) - { - // 50% chance to assign each label - if (rand() > (RAND_MAX / 2)) - { - if (label_written) - { - outfile << ','; - } - outfile << j; - label_written = true; - } - } - if (!label_written) - { - outfile << 0; - } - if (i < num_points - 1) - { - outfile << '\n'; - } + if (distribution_type == "zipf") { + ZipfDistribution zipf(num_points, num_labels); + zipf.writeDistribution(outfile); + } else if (distribution_type == "random") { + for (size_t i = 0; i < num_points; i++) { + bool label_written = false; + for (size_t j = 1; j <= num_labels; j++) { + // 50% chance to assign each label + if (rand() > (RAND_MAX / 2)) { + if (label_written) { + outfile << ','; } + outfile << j; + label_written = true; + } } - else if (distribution_type == "one_per_point") - { - std::random_device rd; // obtain a random number from hardware - std::mt19937 gen(rd()); // seed the generator - std::uniform_int_distribution<> distr(0, num_labels); // define the range - - for (size_t i = 0; i < num_points; i++) - { - outfile << distr(gen); - if (i != num_points - 1) - outfile << '\n'; - } + if (!label_written) { + outfile << 0; } - if (outfile.is_open()) - { - outfile.close(); + if (i < num_points - 1) { + outfile << '\n'; } - - std::cout << "Labels written to " << output_file << '\n'; + } + } else if (distribution_type == "one_per_point") { + std::random_device rd; // obtain a random number from hardware + std::mt19937 gen(rd()); // seed the generator + std::uniform_int_distribution<> distr(0, num_labels); // define the range + + for (size_t i = 0; i < num_points; i++) { + outfile << distr(gen); + if (i != num_points - 1) + outfile << '\n'; + } } - catch (const std::exception &ex) - { - std::cerr << "Label generation failed: " << ex.what() << '\n'; - return -1; + if (outfile.is_open()) { + outfile.close(); } - return 0; + std::cout << "Labels written to " << output_file << '\n'; + } catch (const std::exception &ex) { + std::cerr << "Label generation failed: " << ex.what() << '\n'; + return -1; + } + + return 0; } \ No newline at end of file diff --git a/apps/utils/int8_to_float.cpp b/apps/utils/int8_to_float.cpp index dcdfddc0d..d8b1e6f5a 100644 --- a/apps/utils/int8_to_float.cpp +++ b/apps/utils/int8_to_float.cpp @@ -1,23 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -int main(int argc, char **argv) -{ - if (argc != 3) - { - std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 3) { + std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl; + exit(-1); + } - int8_t *input; - size_t npts, nd; - diskann::load_bin(argv[1], input, npts, nd); - float *output = new float[npts * nd]; - diskann::convert_types(input, output, npts, nd); - diskann::save_bin(argv[2], output, npts, nd); - delete[] output; - delete[] input; + int8_t *input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + float *output = new float[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; } diff --git a/apps/utils/int8_to_float_scale.cpp b/apps/utils/int8_to_float_scale.cpp index 19fbc6c43..ff1b62aa6 100644 --- a/apps/utils/int8_to_float_scale.cpp +++ b/apps/utils/int8_to_float_scale.cpp @@ -1,63 +1,62 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts, - size_t ndims, float bias, float scale) -{ - reader.read((char *)read_buf, npts * ndims * sizeof(int8_t)); - - for (size_t i = 0; i < npts; i++) - { - for (size_t d = 0; d < ndims; d++) - { - write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale); - } - } - writer.write((char *)write_buf, npts * ndims * sizeof(float)); -} - -int main(int argc, char **argv) -{ - if (argc != 5) - { - std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl; - exit(-1); - } +void block_convert(std::ofstream &writer, float *write_buf, + std::ifstream &reader, int8_t *read_buf, size_t npts, + size_t ndims, float bias, float scale) { + reader.read((char *)read_buf, npts * ndims * sizeof(int8_t)); - std::ifstream reader(argv[1], std::ios::binary); - uint32_t npts_u32; - uint32_t ndims_u32; - reader.read((char *)&npts_u32, sizeof(uint32_t)); - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - size_t npts = npts_u32; - size_t ndims = ndims_u32; - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - - std::ofstream writer(argv[2], std::ios::binary); - auto read_buf = new int8_t[blk_size * ndims]; - auto write_buf = new float[blk_size * ndims]; - float bias = (float)atof(argv[3]); - float scale = (float)atof(argv[4]); - - writer.write((char *)(&npts_u32), sizeof(uint32_t)); - writer.write((char *)(&ndims_u32), sizeof(uint32_t)); - - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale); - std::cout << "Block #" << i << " written" << std::endl; + for (size_t i = 0; i < npts; i++) { + for (size_t d = 0; d < ndims; d++) { + write_buf[d + i * ndims] = + (((float)read_buf[d + i * ndims] - bias) * scale); } + } + writer.write((char *)write_buf, npts * ndims * sizeof(float)); +} - delete[] read_buf; - delete[] write_buf; - - writer.close(); - reader.close(); +int main(int argc, char **argv) { + if (argc != 5) { + std::cout << "Usage: " << argv[0] + << " input-int8.bin output-float.bin bias scale" << std::endl; + exit(-1); + } + + std::ifstream reader(argv[1], std::ios::binary); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + size_t npts = npts_u32; + size_t ndims = ndims_u32; + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + + std::ofstream writer(argv[2], std::ios::binary); + auto read_buf = new int8_t[blk_size * ndims]; + auto write_buf = new float[blk_size * ndims]; + float bias = (float)atof(argv[3]); + float scale = (float)atof(argv[4]); + + writer.write((char *)(&npts_u32), sizeof(uint32_t)); + writer.write((char *)(&ndims_u32), sizeof(uint32_t)); + + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, + scale); + std::cout << "Block #" << i << " written" << std::endl; + } + + delete[] read_buf; + delete[] write_buf; + + writer.close(); + reader.close(); } diff --git a/apps/utils/ivecs_to_bin.cpp b/apps/utils/ivecs_to_bin.cpp index ea8a4a3d2..f439d1d73 100644 --- a/apps/utils/ivecs_to_bin.cpp +++ b/apps/utils/ivecs_to_bin.cpp @@ -1,58 +1,57 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts, - size_t ndims) -{ - reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) - { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t)); - } - writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t)); +void block_convert(std::ifstream &reader, std::ofstream &writer, + uint32_t *read_buf, uint32_t *write_buf, size_t npts, + size_t ndims) { + reader.read((char *)read_buf, + npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, + ndims * sizeof(uint32_t)); + } + writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t)); } -int main(int argc, char **argv) -{ - if (argc != 3) - { - std::cout << argv[0] << " input_ivecs output_bin" << std::endl; - exit(-1); - } - std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); - size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); +int main(int argc, char **argv) { + if (argc != 3) { + std::cout << argv[0] << " input_ivecs output_bin" << std::endl; + exit(-1); + } + std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); + size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); - uint32_t ndims_u32; - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - reader.seekg(0, std::ios::beg); - size_t ndims = (size_t)ndims_u32; - size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + reader.seekg(0, std::ios::beg); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims + << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[2], std::ios::binary); - int npts_s32 = (int)npts; - int ndims_s32 = (int)ndims; - writer.write((char *)&npts_s32, sizeof(int)); - writer.write((char *)&ndims_s32, sizeof(int)); - uint32_t *read_buf = new uint32_t[npts * (ndims + 1)]; - uint32_t *write_buf = new uint32_t[npts * ndims]; - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[2], std::ios::binary); + int npts_s32 = (int)npts; + int ndims_s32 = (int)ndims; + writer.write((char *)&npts_s32, sizeof(int)); + writer.write((char *)&ndims_s32, sizeof(int)); + uint32_t *read_buf = new uint32_t[npts * (ndims + 1)]; + uint32_t *write_buf = new uint32_t[npts * ndims]; + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; - delete[] write_buf; + delete[] read_buf; + delete[] write_buf; - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/merge_shards.cpp b/apps/utils/merge_shards.cpp index 106c15eef..cf5c6a00b 100644 --- a/apps/utils/merge_shards.cpp +++ b/apps/utils/merge_shards.cpp @@ -10,33 +10,32 @@ #include #include -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" #include "utils.h" -int main(int argc, char **argv) -{ - if (argc != 9) - { - std::cout << argv[0] - << " vamana_index_prefix[1] vamana_index_suffix[2] " - "idmaps_prefix[3] " - "idmaps_suffix[4] n_shards[5] max_degree[6] " - "output_vamana_path[7] " - "output_medoids_path[8]" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 9) { + std::cout << argv[0] + << " vamana_index_prefix[1] vamana_index_suffix[2] " + "idmaps_prefix[3] " + "idmaps_suffix[4] n_shards[5] max_degree[6] " + "output_vamana_path[7] " + "output_medoids_path[8]" + << std::endl; + exit(-1); + } - std::string vamana_prefix(argv[1]); - std::string vamana_suffix(argv[2]); - std::string idmaps_prefix(argv[3]); - std::string idmaps_suffix(argv[4]); - uint64_t nshards = (uint64_t)std::atoi(argv[5]); - uint32_t max_degree = (uint64_t)std::atoi(argv[6]); - std::string output_index(argv[7]); - std::string output_medoids(argv[8]); + std::string vamana_prefix(argv[1]); + std::string vamana_suffix(argv[2]); + std::string idmaps_prefix(argv[3]); + std::string idmaps_suffix(argv[4]); + uint64_t nshards = (uint64_t)std::atoi(argv[5]); + uint32_t max_degree = (uint64_t)std::atoi(argv[6]); + std::string output_index(argv[7]); + std::string output_medoids(argv[8]); - return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree, - output_index, output_medoids); + return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, + idmaps_suffix, nshards, max_degree, output_index, + output_medoids); } diff --git a/apps/utils/partition_data.cpp b/apps/utils/partition_data.cpp index 2520f3f4a..72eb7af90 100644 --- a/apps/utils/partition_data.cpp +++ b/apps/utils/partition_data.cpp @@ -1,39 +1,40 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include "cached_io.h" #include "partition.h" +#include +#include // DEPRECATED: NEED TO REPROGRAM -int main(int argc, char **argv) -{ - if (argc != 7) - { - std::cout << "Usage:\n" - << argv[0] - << " datatype " - " " - " " - << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 7) { + std::cout << "Usage:\n" + << argv[0] + << " datatype " + " " + " " + << std::endl; + exit(-1); + } - const std::string data_path(argv[2]); - const std::string prefix_path(argv[3]); - const float sampling_rate = (float)atof(argv[4]); - const size_t num_partitions = (size_t)std::atoi(argv[5]); - const size_t max_reps = 15; - const size_t k_index = (size_t)std::atoi(argv[6]); + const std::string data_path(argv[2]); + const std::string prefix_path(argv[3]); + const float sampling_rate = (float)atof(argv[4]); + const size_t num_partitions = (size_t)std::atoi(argv[5]); + const size_t max_reps = 15; + const size_t k_index = (size_t)std::atoi(argv[6]); - if (std::string(argv[1]) == std::string("float")) - partition(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index); - else if (std::string(argv[1]) == std::string("int8")) - partition(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index); - else if (std::string(argv[1]) == std::string("uint8")) - partition(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index); - else - std::cout << "unsupported data format. use float/int8/uint8" << std::endl; + if (std::string(argv[1]) == std::string("float")) + partition(data_path, sampling_rate, num_partitions, max_reps, + prefix_path, k_index); + else if (std::string(argv[1]) == std::string("int8")) + partition(data_path, sampling_rate, num_partitions, max_reps, + prefix_path, k_index); + else if (std::string(argv[1]) == std::string("uint8")) + partition(data_path, sampling_rate, num_partitions, max_reps, + prefix_path, k_index); + else + std::cout << "unsupported data format. use float/int8/uint8" << std::endl; } diff --git a/apps/utils/partition_with_ram_budget.cpp b/apps/utils/partition_with_ram_budget.cpp index 937b68d2c..9c5535def 100644 --- a/apps/utils/partition_with_ram_budget.cpp +++ b/apps/utils/partition_with_ram_budget.cpp @@ -1,39 +1,40 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include "cached_io.h" #include "partition.h" +#include +#include // DEPRECATED: NEED TO REPROGRAM -int main(int argc, char **argv) -{ - if (argc != 8) - { - std::cout << "Usage:\n" - << argv[0] - << " datatype " - " " - " " - << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 8) { + std::cout << "Usage:\n" + << argv[0] + << " datatype " + " " + " " + << std::endl; + exit(-1); + } - const std::string data_path(argv[2]); - const std::string prefix_path(argv[3]); - const float sampling_rate = (float)atof(argv[4]); - const double ram_budget = (double)std::atof(argv[5]); - const size_t graph_degree = (size_t)std::atoi(argv[6]); - const size_t k_index = (size_t)std::atoi(argv[7]); + const std::string data_path(argv[2]); + const std::string prefix_path(argv[3]); + const float sampling_rate = (float)atof(argv[4]); + const double ram_budget = (double)std::atof(argv[5]); + const size_t graph_degree = (size_t)std::atoi(argv[6]); + const size_t k_index = (size_t)std::atoi(argv[7]); - if (std::string(argv[1]) == std::string("float")) - partition_with_ram_budget(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index); - else if (std::string(argv[1]) == std::string("int8")) - partition_with_ram_budget(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index); - else if (std::string(argv[1]) == std::string("uint8")) - partition_with_ram_budget(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index); - else - std::cout << "unsupported data format. use float/int8/uint8" << std::endl; + if (std::string(argv[1]) == std::string("float")) + partition_with_ram_budget(data_path, sampling_rate, ram_budget, + graph_degree, prefix_path, k_index); + else if (std::string(argv[1]) == std::string("int8")) + partition_with_ram_budget(data_path, sampling_rate, ram_budget, + graph_degree, prefix_path, k_index); + else if (std::string(argv[1]) == std::string("uint8")) + partition_with_ram_budget(data_path, sampling_rate, ram_budget, + graph_degree, prefix_path, k_index); + else + std::cout << "unsupported data format. use float/int8/uint8" << std::endl; } diff --git a/apps/utils/rand_data_gen.cpp b/apps/utils/rand_data_gen.cpp index e89ede800..25577d242 100644 --- a/apps/utils/rand_data_gen.cpp +++ b/apps/utils/rand_data_gen.cpp @@ -1,237 +1,224 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include +#include #include +#include #include -#include -#include #include "utils.h" namespace po = boost::program_options; -int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm, - float rand_scale) -{ - auto vec = new float[ndims]; - - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution<> normal_rand{0, 1}; - std::uniform_real_distribution<> unif_dis(1.0, rand_scale); - - for (size_t i = 0; i < npts; i++) - { - float sum = 0; - float scale = 1.0f; - if (rand_scale > 1.0f) - scale = (float)unif_dis(gen); - for (size_t d = 0; d < ndims; ++d) - vec[d] = scale * (float)normal_rand(gen); - if (normalization) - { - for (size_t d = 0; d < ndims; ++d) - sum += vec[d] * vec[d]; - for (size_t d = 0; d < ndims; ++d) - vec[d] = vec[d] * norm / std::sqrt(sum); - } - - writer.write((char *)vec, ndims * sizeof(float)); +int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, + bool normalization, float norm, float rand_scale) { + auto vec = new float[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + std::uniform_real_distribution<> unif_dis(1.0, rand_scale); + + for (size_t i = 0; i < npts; i++) { + float sum = 0; + float scale = 1.0f; + if (rand_scale > 1.0f) + scale = (float)unif_dis(gen); + for (size_t d = 0; d < ndims; ++d) + vec[d] = scale * (float)normal_rand(gen); + if (normalization) { + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); } - delete[] vec; - return 0; -} - -int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm) -{ - auto vec = new float[ndims]; - auto vec_T = new int8_t[ndims]; - - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution<> normal_rand{0, 1}; - - for (size_t i = 0; i < npts; i++) - { - float sum = 0; - for (size_t d = 0; d < ndims; ++d) - vec[d] = (float)normal_rand(gen); - for (size_t d = 0; d < ndims; ++d) - sum += vec[d] * vec[d]; - for (size_t d = 0; d < ndims; ++d) - vec[d] = vec[d] * norm / std::sqrt(sum); - - for (size_t d = 0; d < ndims; ++d) - { - vec_T[d] = (int8_t)std::round(vec[d]); - } - - writer.write((char *)vec_T, ndims * sizeof(int8_t)); - } + writer.write((char *)vec, ndims * sizeof(float)); + } - delete[] vec; - delete[] vec_T; - return 0; + delete[] vec; + return 0; } -int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm) -{ - auto vec = new float[ndims]; - auto vec_T = new int8_t[ndims]; - - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution<> normal_rand{0, 1}; - - for (size_t i = 0; i < npts; i++) - { - float sum = 0; - for (size_t d = 0; d < ndims; ++d) - vec[d] = (float)normal_rand(gen); - for (size_t d = 0; d < ndims; ++d) - sum += vec[d] * vec[d]; - for (size_t d = 0; d < ndims; ++d) - vec[d] = vec[d] * norm / std::sqrt(sum); - - for (size_t d = 0; d < ndims; ++d) - { - vec_T[d] = 128 + (int8_t)std::round(vec[d]); - } - - writer.write((char *)vec_T, ndims * sizeof(uint8_t)); +int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, + float norm) { + auto vec = new float[ndims]; + auto vec_T = new int8_t[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + + for (size_t i = 0; i < npts; i++) { + float sum = 0; + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); + + for (size_t d = 0; d < ndims; ++d) { + vec_T[d] = (int8_t)std::round(vec[d]); } - delete[] vec; - delete[] vec_T; - return 0; + writer.write((char *)vec_T, ndims * sizeof(int8_t)); + } + + delete[] vec; + delete[] vec_T; + return 0; } -int main(int argc, char **argv) -{ - std::string data_type, output_file; - size_t ndims, npts; - float norm, rand_scaling; - bool normalization = false; - try - { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("output_file", po::value(&output_file)->required(), - "File name for saving the random vectors"); - desc.add_options()("ndims,D", po::value(&ndims)->required(), "Dimensoinality of the vector"); - desc.add_options()("npts,N", po::value(&npts)->required(), "Number of vectors"); - desc.add_options()("norm", po::value(&norm)->default_value(-1.0f), - "Norm of the vectors (if not specified, vectors are not normalized)"); - desc.add_options()("rand_scaling", po::value(&rand_scaling)->default_value(1.0f), - "Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from " - "[1, rand_scale]. Only applicable for floating point data"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; +int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, + float norm) { + auto vec = new float[ndims]; + auto vec_T = new int8_t[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + + for (size_t i = 0; i < npts; i++) { + float sum = 0; + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); + + for (size_t d = 0; d < ndims; ++d) { + vec_T[d] = 128 + (int8_t)std::round(vec[d]); } - if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) - { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; - return -1; - } + writer.write((char *)vec_T, ndims * sizeof(uint8_t)); + } - if (norm > 0.0) - { - normalization = true; - } + delete[] vec; + delete[] vec_T; + return 0; +} - if (rand_scaling < 1.0) - { - std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl; - return -1; +int main(int argc, char **argv) { + std::string data_type, output_file; + size_t ndims, npts; + float norm, rand_scaling; + bool normalization = false; + try { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("output_file", + po::value(&output_file)->required(), + "File name for saving the random vectors"); + desc.add_options()("ndims,D", po::value(&ndims)->required(), + "Dimensoinality of the vector"); + desc.add_options()("npts,N", po::value(&npts)->required(), + "Number of vectors"); + desc.add_options()( + "norm", po::value(&norm)->default_value(-1.0f), + "Norm of the vectors (if not specified, vectors are not normalized)"); + desc.add_options()( + "rand_scaling", po::value(&rand_scaling)->default_value(1.0f), + "Each vector will be scaled (if not explicitly normalized) by a factor " + "randomly chosen from " + "[1, rand_scale]. Only applicable for floating point data"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - - if ((rand_scaling > 1.0) && (normalization == true)) - { - std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl; - return -1; + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && + data_type != std::string("uint8")) { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." + << std::endl; + return -1; + } + + if (norm > 0.0) { + normalization = true; + } + + if (rand_scaling < 1.0) { + std::cout << "We will only scale the vector norms randomly in [1, value], " + "so value must be >= 1." + << std::endl; + return -1; + } + + if ((rand_scaling > 1.0) && (normalization == true)) { + std::cout << "Data cannot be normalized and randomly scaled at same time. " + "Use one or the other." + << std::endl; + return -1; + } + + if (data_type == std::string("int8") || data_type == std::string("uint8")) { + if (norm > 127) { + std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be " + "greater " + "than 127" + << std::endl; + return -1; } - - if (data_type == std::string("int8") || data_type == std::string("uint8")) - { - if (norm > 127) - { - std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be " - "greater " - "than 127" - << std::endl; - return -1; - } - if (rand_scaling > 1.0) - { - std::cout << "Data scaling only supported for floating point data." << std::endl; - return -1; - } + if (rand_scaling > 1.0) { + std::cout << "Data scaling only supported for floating point data." + << std::endl; + return -1; } - - try - { - std::ofstream writer; - writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); - writer.open(output_file, std::ios::binary); - auto npts_u32 = (uint32_t)npts; - auto ndims_u32 = (uint32_t)ndims; - writer.write((char *)&npts_u32, sizeof(uint32_t)); - writer.write((char *)&ndims_u32, sizeof(uint32_t)); - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - - int ret = 0; - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (data_type == std::string("float")) - { - ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling); - } - else if (data_type == std::string("int8")) - { - ret = block_write_int8(writer, ndims, cblk_size, norm); - } - else if (data_type == std::string("uint8")) - { - ret = block_write_uint8(writer, ndims, cblk_size, norm); - } - if (ret == 0) - std::cout << "Block #" << i << " written" << std::endl; - else - { - writer.close(); - std::cout << "failed to write" << std::endl; - return -1; - } - } + } + + try { + std::ofstream writer; + writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); + writer.open(output_file, std::ios::binary); + auto npts_u32 = (uint32_t)npts; + auto ndims_u32 = (uint32_t)ndims; + writer.write((char *)&npts_u32, sizeof(uint32_t)); + writer.write((char *)&ndims_u32, sizeof(uint32_t)); + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + + int ret = 0; + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (data_type == std::string("float")) { + ret = block_write_float(writer, ndims, cblk_size, normalization, norm, + rand_scaling); + } else if (data_type == std::string("int8")) { + ret = block_write_int8(writer, ndims, cblk_size, norm); + } else if (data_type == std::string("uint8")) { + ret = block_write_uint8(writer, ndims, cblk_size, norm); + } + if (ret == 0) + std::cout << "Block #" << i << " written" << std::endl; + else { writer.close(); - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; + std::cout << "failed to write" << std::endl; return -1; + } } - - return 0; + writer.close(); + } catch (const std::exception &e) { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; + return -1; + } + + return 0; } diff --git a/apps/utils/simulate_aggregate_recall.cpp b/apps/utils/simulate_aggregate_recall.cpp index 73c4ea0f7..b934c2bea 100644 --- a/apps/utils/simulate_aggregate_recall.cpp +++ b/apps/utils/simulate_aggregate_recall.cpp @@ -1,85 +1,78 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include #include +#include #include -#include -inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count, - const std::vector &recalls) -{ - float found = 0; - for (uint32_t i = 0; i < npart; ++i) - { - size_t max_found = std::min(count[i], k); - found += recalls[max_found - 1] * max_found; - } - return found / (float)k_aggr; +inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, + const uint32_t npart, uint32_t *count, + const std::vector &recalls) { + float found = 0; + for (uint32_t i = 0; i < npart; ++i) { + size_t max_found = std::min(count[i], k); + found += recalls[max_found - 1] * max_found; + } + return found / (float)k_aggr; } -void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, const uint32_t nsim, - const std::vector &recalls) -{ - std::random_device r; - std::default_random_engine randeng(r()); - std::uniform_int_distribution uniform_dist(0, npart - 1); +void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, + const uint32_t nsim, const std::vector &recalls) { + std::random_device r; + std::default_random_engine randeng(r()); + std::uniform_int_distribution uniform_dist(0, npart - 1); - uint32_t *count = new uint32_t[npart]; - double aggr_recall = 0; + uint32_t *count = new uint32_t[npart]; + double aggr_recall = 0; - for (uint32_t i = 0; i < nsim; ++i) - { - for (uint32_t p = 0; p < npart; ++p) - { - count[p] = 0; - } - for (uint32_t t = 0; t < k_aggr; ++t) - { - count[uniform_dist(randeng)]++; - } - aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls); + for (uint32_t i = 0; i < nsim; ++i) { + for (uint32_t p = 0; p < npart; ++p) { + count[p] = 0; + } + for (uint32_t t = 0; t < k_aggr; ++t) { + count[uniform_dist(randeng)]++; } + aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls); + } - std::cout << "Aggregate recall is " << aggr_recall / (double)nsim << std::endl; - delete[] count; + std::cout << "Aggregate recall is " << aggr_recall / (double)nsim + << std::endl; + delete[] count; } -int main(int argc, char **argv) -{ - if (argc < 6) - { - std::cout << argv[0] << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc < 6) { + std::cout << argv[0] + << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" + << std::endl; + exit(-1); + } - const uint32_t k_aggr = atoi(argv[1]); - const uint32_t k = atoi(argv[2]); - const uint32_t npart = atoi(argv[3]); - const uint32_t nsim = atoi(argv[4]); + const uint32_t k_aggr = atoi(argv[1]); + const uint32_t k = atoi(argv[2]); + const uint32_t npart = atoi(argv[3]); + const uint32_t nsim = atoi(argv[4]); - std::vector recalls; - for (int ctr = 5; ctr < argc; ctr++) - { - recalls.push_back((float)atof(argv[ctr])); - } + std::vector recalls; + for (int ctr = 5; ctr < argc; ctr++) { + recalls.push_back((float)atof(argv[ctr])); + } - if (recalls.size() != k) - { - std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" << std::endl; - } - if (k_aggr > npart * k) - { - std::cerr << "k_aggr must be <= k * npart" << std::endl; - exit(-1); - } - if (nsim <= npart * k_aggr) - { - std::cerr << "Choose nsim > npart*k_aggr" << std::endl; - exit(-1); - } + if (recalls.size() != k) { + std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" + << std::endl; + } + if (k_aggr > npart * k) { + std::cerr << "k_aggr must be <= k * npart" << std::endl; + exit(-1); + } + if (nsim <= npart * k_aggr) { + std::cerr << "Choose nsim > npart*k_aggr" << std::endl; + exit(-1); + } - simulate(k_aggr, k, npart, nsim, recalls); + simulate(k_aggr, k, npart, nsim, recalls); - return 0; + return 0; } diff --git a/apps/utils/stats_label_data.cpp b/apps/utils/stats_label_data.cpp index 3342672ff..4de42f7a0 100644 --- a/apps/utils/stats_label_data.cpp +++ b/apps/utils/stats_label_data.cpp @@ -1,147 +1,150 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include +#include #include +#include #include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "utils.h" #ifndef _WINDOWS #include -#include #include #include +#include #else #include #endif namespace po = boost::program_options; -void stats_analysis(const std::string labels_file, std::string univeral_label, uint32_t density = 10) -{ - std::string token, line; - std::ifstream labels_stream(labels_file); - std::unordered_map label_counts; - std::string label_with_max_points; - uint32_t max_points = 0; - long long sum = 0; - long long point_cnt = 0; - float avg_labels_per_pt, mean_label_size; +void stats_analysis(const std::string labels_file, std::string univeral_label, + uint32_t density = 10) { + std::string token, line; + std::ifstream labels_stream(labels_file); + std::unordered_map label_counts; + std::string label_with_max_points; + uint32_t max_points = 0; + long long sum = 0; + long long point_cnt = 0; + float avg_labels_per_pt, mean_label_size; - std::vector labels_per_point; - uint32_t dense_pts = 0; - if (labels_stream.is_open()) - { - while (getline(labels_stream, line)) - { - point_cnt++; - std::stringstream iss(line); - uint32_t lbl_cnt = 0; - while (getline(iss, token, ',')) - { - lbl_cnt++; - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - if (label_counts.find(token) == label_counts.end()) - label_counts[token] = 0; - label_counts[token]++; - } - if (lbl_cnt >= density) - { - dense_pts++; - } - labels_per_point.emplace_back(lbl_cnt); - } + std::vector labels_per_point; + uint32_t dense_pts = 0; + if (labels_stream.is_open()) { + while (getline(labels_stream, line)) { + point_cnt++; + std::stringstream iss(line); + uint32_t lbl_cnt = 0; + while (getline(iss, token, ',')) { + lbl_cnt++; + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + if (label_counts.find(token) == label_counts.end()) + label_counts[token] = 0; + label_counts[token]++; + } + if (lbl_cnt >= density) { + dense_pts++; + } + labels_per_point.emplace_back(lbl_cnt); } + } - std::cout << "fraction of dense points with >= " << density - << " labels = " << (float)dense_pts / (float)labels_per_point.size() << std::endl; - std::sort(labels_per_point.begin(), labels_per_point.end()); + std::cout << "fraction of dense points with >= " << density + << " labels = " << (float)dense_pts / (float)labels_per_point.size() + << std::endl; + std::sort(labels_per_point.begin(), labels_per_point.end()); - std::vector> label_count_vec; + std::vector> label_count_vec; - for (auto it = label_counts.begin(); it != label_counts.end(); it++) - { - auto &lbl = *it; - label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second)); - if (lbl.second > max_points) - { - max_points = lbl.second; - label_with_max_points = lbl.first; - } - sum += lbl.second; + for (auto it = label_counts.begin(); it != label_counts.end(); it++) { + auto &lbl = *it; + label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second)); + if (lbl.second > max_points) { + max_points = lbl.second; + label_with_max_points = lbl.first; } + sum += lbl.second; + } - sort(label_count_vec.begin(), label_count_vec.end(), - [](const std::pair &lhs, const std::pair &rhs) { - return lhs.second < rhs.second; - }); + sort(label_count_vec.begin(), label_count_vec.end(), + [](const std::pair &lhs, + const std::pair &rhs) { + return lhs.second < rhs.second; + }); - for (float p = 0; p < 1; p += 0.05) - { - std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(size_t)(p * label_count_vec.size())].first - << " with count=" << label_count_vec[(size_t)(p * label_count_vec.size())].second << std::endl; - } + for (float p = 0; p < 1; p += 0.05) { + std::cout << "Percentile " << (100 * p) << "\t" + << label_count_vec[(size_t)(p * label_count_vec.size())].first + << " with count=" + << label_count_vec[(size_t)(p * label_count_vec.size())].second + << std::endl; + } - std::cout << "Most common label " - << "\t" << label_count_vec[label_count_vec.size() - 1].first - << " with count=" << label_count_vec[label_count_vec.size() - 1].second << std::endl; - if (label_count_vec.size() > 1) - std::cout << "Second common label " - << "\t" << label_count_vec[label_count_vec.size() - 2].first - << " with count=" << label_count_vec[label_count_vec.size() - 2].second << std::endl; - if (label_count_vec.size() > 2) - std::cout << "Third common label " - << "\t" << label_count_vec[label_count_vec.size() - 3].first - << " with count=" << label_count_vec[label_count_vec.size() - 3].second << std::endl; - avg_labels_per_pt = sum / (float)point_cnt; - mean_label_size = sum / (float)label_counts.size(); - std::cout << "Total number of points = " << point_cnt << ", number of labels = " << label_counts.size() + std::cout << "Most common label " + << "\t" << label_count_vec[label_count_vec.size() - 1].first + << " with count=" + << label_count_vec[label_count_vec.size() - 1].second << std::endl; + if (label_count_vec.size() > 1) + std::cout << "Second common label " + << "\t" << label_count_vec[label_count_vec.size() - 2].first + << " with count=" + << label_count_vec[label_count_vec.size() - 2].second + << std::endl; + if (label_count_vec.size() > 2) + std::cout << "Third common label " + << "\t" << label_count_vec[label_count_vec.size() - 3].first + << " with count=" + << label_count_vec[label_count_vec.size() - 3].second << std::endl; - std::cout << "Average number of labels per point = " << avg_labels_per_pt << std::endl; - std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl; - std::cout << "Most popular label is " << label_with_max_points << " with " << max_points << " pts" << std::endl; + avg_labels_per_pt = sum / (float)point_cnt; + mean_label_size = sum / (float)label_counts.size(); + std::cout << "Total number of points = " << point_cnt + << ", number of labels = " << label_counts.size() << std::endl; + std::cout << "Average number of labels per point = " << avg_labels_per_pt + << std::endl; + std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl; + std::cout << "Most popular label is " << label_with_max_points << " with " + << max_points << " pts" << std::endl; } -int main(int argc, char **argv) -{ - std::string labels_file, universal_label; - uint32_t density; +int main(int argc, char **argv) { + std::string labels_file, universal_label; + uint32_t density; - po::options_description desc{"Arguments"}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("labels_file", po::value(&labels_file)->required(), - "path to labels data file."); - desc.add_options()("universal_label", po::value(&universal_label)->required(), - "Universal label used in labels file."); - desc.add_options()("density", po::value(&density)->default_value(1), - "Number of labels each point in labels file, defaults to 1"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception &e) - { - std::cerr << e.what() << '\n'; - return -1; + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("labels_file", + po::value(&labels_file)->required(), + "path to labels data file."); + desc.add_options()("universal_label", + po::value(&universal_label)->required(), + "Universal label used in labels file."); + desc.add_options()( + "density", po::value(&density)->default_value(1), + "Number of labels each point in labels file, defaults to 1"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - stats_analysis(labels_file, universal_label, density); + po::notify(vm); + } catch (const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } + stats_analysis(labels_file, universal_label, density); } diff --git a/apps/utils/tsv_to_bin.cpp b/apps/utils/tsv_to_bin.cpp index c590a8f73..2cd00ae38 100644 --- a/apps/utils/tsv_to_bin.cpp +++ b/apps/utils/tsv_to_bin.cpp @@ -1,121 +1,108 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) -{ - auto read_buf = new float[npts * (ndims + 1)]; - - auto cursor = read_buf; - float val; - - for (size_t i = 0; i < npts; i++) - { - for (size_t d = 0; d < ndims; ++d) - { - reader >> val; - *cursor = val; - cursor++; - } - } - writer.write((char *)read_buf, npts * ndims * sizeof(float)); - delete[] read_buf; -} +void block_convert_float(std::ifstream &reader, std::ofstream &writer, + size_t npts, size_t ndims) { + auto read_buf = new float[npts * (ndims + 1)]; + + auto cursor = read_buf; + float val; -void block_convert_int8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) -{ - auto read_buf = new int8_t[npts * (ndims + 1)]; - - auto cursor = read_buf; - int val; - - for (size_t i = 0; i < npts; i++) - { - for (size_t d = 0; d < ndims; ++d) - { - reader >> val; - *cursor = (int8_t)val; - cursor++; - } + for (size_t i = 0; i < npts; i++) { + for (size_t d = 0; d < ndims; ++d) { + reader >> val; + *cursor = val; + cursor++; } - writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); - delete[] read_buf; + } + writer.write((char *)read_buf, npts * ndims * sizeof(float)); + delete[] read_buf; } -void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) -{ - auto read_buf = new uint8_t[npts * (ndims + 1)]; - - auto cursor = read_buf; - int val; - - for (size_t i = 0; i < npts; i++) - { - for (size_t d = 0; d < ndims; ++d) - { - reader >> val; - *cursor = (uint8_t)val; - cursor++; - } +void block_convert_int8(std::ifstream &reader, std::ofstream &writer, + size_t npts, size_t ndims) { + auto read_buf = new int8_t[npts * (ndims + 1)]; + + auto cursor = read_buf; + int val; + + for (size_t i = 0; i < npts; i++) { + for (size_t d = 0; d < ndims; ++d) { + reader >> val; + *cursor = (int8_t)val; + cursor++; } - writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); - delete[] read_buf; + } + writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); + delete[] read_buf; } -int main(int argc, char **argv) -{ - if (argc != 6) - { - std::cout << argv[0] - << " input_filename.tsv output_filename.bin " - "dim num_pts>" - << std::endl; - exit(-1); - } +void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, + size_t npts, size_t ndims) { + auto read_buf = new uint8_t[npts * (ndims + 1)]; - if (std::string(argv[1]) != std::string("float") && std::string(argv[1]) != std::string("int8") && - std::string(argv[1]) != std::string("uint8")) - { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + auto cursor = read_buf; + int val; + + for (size_t i = 0; i < npts; i++) { + for (size_t d = 0; d < ndims; ++d) { + reader >> val; + *cursor = (uint8_t)val; + cursor++; } + } + writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); + delete[] read_buf; +} - size_t ndims = atoi(argv[4]); - size_t npts = atoi(argv[5]); - - std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); - // size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); - reader.seekg(0, std::ios::beg); - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[3], std::ios::binary); - auto npts_u32 = (uint32_t)npts; - auto ndims_u32 = (uint32_t)ndims; - writer.write((char *)&npts_u32, sizeof(uint32_t)); - writer.write((char *)&ndims_u32, sizeof(uint32_t)); - - for (size_t i = 0; i < nblks; i++) - { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (std::string(argv[1]) == std::string("float")) - { - block_convert_float(reader, writer, cblk_size, ndims); - } - else if (std::string(argv[1]) == std::string("int8")) - { - block_convert_int8(reader, writer, cblk_size, ndims); - } - else if (std::string(argv[1]) == std::string("uint8")) - { - block_convert_uint8(reader, writer, cblk_size, ndims); - } - std::cout << "Block #" << i << " written" << std::endl; +int main(int argc, char **argv) { + if (argc != 6) { + std::cout << argv[0] + << " input_filename.tsv output_filename.bin " + "dim num_pts>" + << std::endl; + exit(-1); + } + + if (std::string(argv[1]) != std::string("float") && + std::string(argv[1]) != std::string("int8") && + std::string(argv[1]) != std::string("uint8")) { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." + << std::endl; + } + + size_t ndims = atoi(argv[4]); + size_t npts = atoi(argv[5]); + + std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); + // size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); + reader.seekg(0, std::ios::beg); + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[3], std::ios::binary); + auto npts_u32 = (uint32_t)npts; + auto ndims_u32 = (uint32_t)ndims; + writer.write((char *)&npts_u32, sizeof(uint32_t)); + writer.write((char *)&ndims_u32, sizeof(uint32_t)); + + for (size_t i = 0; i < nblks; i++) { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (std::string(argv[1]) == std::string("float")) { + block_convert_float(reader, writer, cblk_size, ndims); + } else if (std::string(argv[1]) == std::string("int8")) { + block_convert_int8(reader, writer, cblk_size, ndims); + } else if (std::string(argv[1]) == std::string("uint8")) { + block_convert_uint8(reader, writer, cblk_size, ndims); } + std::cout << "Block #" << i << " written" << std::endl; + } - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/uint32_to_uint8.cpp b/apps/utils/uint32_to_uint8.cpp index 87b6fb8ed..3868780e6 100644 --- a/apps/utils/uint32_to_uint8.cpp +++ b/apps/utils/uint32_to_uint8.cpp @@ -1,23 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -int main(int argc, char **argv) -{ - if (argc != 3) - { - std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 3) { + std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl; + exit(-1); + } - uint32_t *input; - size_t npts, nd; - diskann::load_bin(argv[1], input, npts, nd); - uint8_t *output = new uint8_t[npts * nd]; - diskann::convert_types(input, output, npts, nd); - diskann::save_bin(argv[2], output, npts, nd); - delete[] output; - delete[] input; + uint32_t *input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + uint8_t *output = new uint8_t[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; } diff --git a/apps/utils/uint8_to_float.cpp b/apps/utils/uint8_to_float.cpp index 6415b7c92..779226f90 100644 --- a/apps/utils/uint8_to_float.cpp +++ b/apps/utils/uint8_to_float.cpp @@ -1,23 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include -int main(int argc, char **argv) -{ - if (argc != 3) - { - std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc != 3) { + std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl; + exit(-1); + } - uint8_t *input; - size_t npts, nd; - diskann::load_bin(argv[1], input, npts, nd); - float *output = new float[npts * nd]; - diskann::convert_types(input, output, npts, nd); - diskann::save_bin(argv[2], output, npts, nd); - delete[] output; - delete[] input; + uint8_t *input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + float *output = new float[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; } diff --git a/apps/utils/vector_analysis.cpp b/apps/utils/vector_analysis.cpp index 009df6d05..9dde684f1 100644 --- a/apps/utils/vector_analysis.cpp +++ b/apps/utils/vector_analysis.cpp @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include #include #include #include #include +#include #include #include #include +#include #include #include -#include #include #include #include @@ -20,144 +20,129 @@ #include "partition.h" #include "utils.h" -template int analyze_norm(std::string base_file) -{ - std::cout << "Analyzing data norms" << std::endl; - T *data; - size_t npts, ndims; - diskann::load_bin(base_file, data, npts, ndims); - std::vector norms(npts, 0); +template int analyze_norm(std::string base_file) { + std::cout << "Analyzing data norms" << std::endl; + T *data; + size_t npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + std::vector norms(npts, 0); #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)npts; i++) - { - for (size_t d = 0; d < ndims; d++) - norms[i] += data[i * ndims + d] * data[i * ndims + d]; - norms[i] = std::sqrt(norms[i]); - } - std::sort(norms.begin(), norms.end()); - for (int p = 0; p < 100; p += 5) - std::cout << "percentile " << p << ": " << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl; - std::cout << "percentile 100" - << ": " << norms[npts - 1] << std::endl; - delete[] data; - return 0; + for (int64_t i = 0; i < (int64_t)npts; i++) { + for (size_t d = 0; d < ndims; d++) + norms[i] += data[i * ndims + d] * data[i * ndims + d]; + norms[i] = std::sqrt(norms[i]); + } + std::sort(norms.begin(), norms.end()); + for (int p = 0; p < 100; p += 5) + std::cout << "percentile " << p << ": " + << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl; + std::cout << "percentile 100" + << ": " << norms[npts - 1] << std::endl; + delete[] data; + return 0; } -template int normalize_base(std::string base_file, std::string out_file) -{ - std::cout << "Normalizing base" << std::endl; - T *data; - size_t npts, ndims; - diskann::load_bin(base_file, data, npts, ndims); - // std::vector norms(npts, 0); +template +int normalize_base(std::string base_file, std::string out_file) { + std::cout << "Normalizing base" << std::endl; + T *data; + size_t npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + // std::vector norms(npts, 0); #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)npts; i++) - { - float pt_norm = 0; - for (size_t d = 0; d < ndims; d++) - pt_norm += data[i * ndims + d] * data[i * ndims + d]; - pt_norm = std::sqrt(pt_norm); - for (size_t d = 0; d < ndims; d++) - data[i * ndims + d] = static_cast(data[i * ndims + d] / pt_norm); - } - diskann::save_bin(out_file, data, npts, ndims); - delete[] data; - return 0; + for (int64_t i = 0; i < (int64_t)npts; i++) { + float pt_norm = 0; + for (size_t d = 0; d < ndims; d++) + pt_norm += data[i * ndims + d] * data[i * ndims + d]; + pt_norm = std::sqrt(pt_norm); + for (size_t d = 0; d < ndims; d++) + data[i * ndims + d] = static_cast(data[i * ndims + d] / pt_norm); + } + diskann::save_bin(out_file, data, npts, ndims); + delete[] data; + return 0; } -template int augment_base(std::string base_file, std::string out_file, bool prep_base = true) -{ - std::cout << "Analyzing data norms" << std::endl; - T *data; - size_t npts, ndims; - diskann::load_bin(base_file, data, npts, ndims); - std::vector norms(npts, 0); - float max_norm = 0; +template +int augment_base(std::string base_file, std::string out_file, + bool prep_base = true) { + std::cout << "Analyzing data norms" << std::endl; + T *data; + size_t npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + std::vector norms(npts, 0); + float max_norm = 0; #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)npts; i++) - { - for (size_t d = 0; d < ndims; d++) - norms[i] += data[i * ndims + d] * data[i * ndims + d]; - max_norm = norms[i] > max_norm ? norms[i] : max_norm; + for (int64_t i = 0; i < (int64_t)npts; i++) { + for (size_t d = 0; d < ndims; d++) + norms[i] += data[i * ndims + d] * data[i * ndims + d]; + max_norm = norms[i] > max_norm ? norms[i] : max_norm; + } + // std::sort(norms.begin(), norms.end()); + max_norm = std::sqrt(max_norm); + std::cout << "Max norm: " << max_norm << std::endl; + T *new_data; + size_t newdims = ndims + 1; + new_data = new T[npts * newdims]; + for (size_t i = 0; i < npts; i++) { + if (prep_base) { + for (size_t j = 0; j < ndims; j++) { + new_data[i * newdims + j] = + static_cast(data[i * ndims + j] / max_norm); + } + float diff = 1 - (norms[i] / (max_norm * max_norm)); + diff = diff <= 0 ? 0 : std::sqrt(diff); + new_data[i * newdims + ndims] = static_cast(diff); + if (diff <= 0) { + std::cout << i << " has large max norm, investigate if needed. diff = " + << diff << std::endl; + } + } else { + for (size_t j = 0; j < ndims; j++) { + new_data[i * newdims + j] = + static_cast(data[i * ndims + j] / std::sqrt(norms[i])); + } + new_data[i * newdims + ndims] = 0; } - // std::sort(norms.begin(), norms.end()); - max_norm = std::sqrt(max_norm); - std::cout << "Max norm: " << max_norm << std::endl; - T *new_data; - size_t newdims = ndims + 1; - new_data = new T[npts * newdims]; - for (size_t i = 0; i < npts; i++) - { - if (prep_base) - { - for (size_t j = 0; j < ndims; j++) - { - new_data[i * newdims + j] = static_cast(data[i * ndims + j] / max_norm); - } - float diff = 1 - (norms[i] / (max_norm * max_norm)); - diff = diff <= 0 ? 0 : std::sqrt(diff); - new_data[i * newdims + ndims] = static_cast(diff); - if (diff <= 0) - { - std::cout << i << " has large max norm, investigate if needed. diff = " << diff << std::endl; - } - } - else - { - for (size_t j = 0; j < ndims; j++) - { - new_data[i * newdims + j] = static_cast(data[i * ndims + j] / std::sqrt(norms[i])); - } - new_data[i * newdims + ndims] = 0; - } - } - diskann::save_bin(out_file, new_data, npts, newdims); - delete[] new_data; - delete[] data; - return 0; + } + diskann::save_bin(out_file, new_data, npts, newdims); + delete[] new_data; + delete[] data; + return 0; } -template int aux_main(char **argv) -{ - std::string base_file(argv[2]); - uint32_t option = atoi(argv[3]); - if (option == 1) - analyze_norm(base_file); - else if (option == 2) - augment_base(base_file, std::string(argv[4]), true); - else if (option == 3) - augment_base(base_file, std::string(argv[4]), false); - else if (option == 4) - normalize_base(base_file, std::string(argv[4])); - return 0; +template int aux_main(char **argv) { + std::string base_file(argv[2]); + uint32_t option = atoi(argv[3]); + if (option == 1) + analyze_norm(base_file); + else if (option == 2) + augment_base(base_file, std::string(argv[4]), true); + else if (option == 3) + augment_base(base_file, std::string(argv[4]), false); + else if (option == 4) + normalize_base(base_file, std::string(argv[4])); + return 0; } -int main(int argc, char **argv) -{ - if (argc < 4) - { - std::cout << argv[0] - << " data_type [float/int8/uint8] base_bin_file " - "[option: 1-norm analysis, 2-prep_base_for_mip, " - "3-prep_query_for_mip, 4-normalize-vecs] [out_file for " - "options 2/3/4]" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) { + if (argc < 4) { + std::cout << argv[0] + << " data_type [float/int8/uint8] base_bin_file " + "[option: 1-norm analysis, 2-prep_base_for_mip, " + "3-prep_query_for_mip, 4-normalize-vecs] [out_file for " + "options 2/3/4]" + << std::endl; + exit(-1); + } - if (std::string(argv[1]) == std::string("float")) - { - aux_main(argv); - } - else if (std::string(argv[1]) == std::string("int8")) - { - aux_main(argv); - } - else if (std::string(argv[1]) == std::string("uint8")) - { - aux_main(argv); - } - else - std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; - return 0; + if (std::string(argv[1]) == std::string("float")) { + aux_main(argv); + } else if (std::string(argv[1]) == std::string("int8")) { + aux_main(argv); + } else if (std::string(argv[1]) == std::string("uint8")) { + aux_main(argv); + } else + std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; + return 0; } diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index 7080ef665..3dcb1029e 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -3,16 +3,15 @@ #pragma once -#include -#include -#include -#include "logger.h" -#include "ann_exception.h" #include "abstract_filter_store.h" +#include "ann_exception.h" +#include "logger.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" #include "windows_customizations.h" - +#include +#include +#include namespace diskann { template diff --git a/include/index_config.h b/include/index_config.h index dde4ec51d..351d9aeba 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -3,11 +3,11 @@ #pragma once -#include #include "ann_exception.h" #include "common_includes.h" #include "logger.h" #include "parameters.h" +#include namespace diskann { enum class DataStoreStrategy { MEMORY }; From 1c43eeffa2eba986ab47fa8fe5f9ebe4d3ed0551 Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 12 Nov 2024 21:13:38 -0800 Subject: [PATCH 4/7] fixed compile error in linux --- include/cached_io.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/cached_io.h b/include/cached_io.h index c62441e07..921283356 100644 --- a/include/cached_io.h +++ b/include/cached_io.h @@ -2,6 +2,7 @@ // Licensed under the MIT license. #pragma once +#include #include #include #include From 151a46c246ef0939acd49920ce270727f4aa8969 Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 12 Nov 2024 23:00:25 -0800 Subject: [PATCH 5/7] clang-format added back formatting file --- apps/build_disk_index.cpp | 343 +- apps/build_memory_index.cpp | 278 +- apps/build_stitched_index.cpp | 730 +- apps/range_search_disk_index.cpp | 626 +- apps/search_disk_index.cpp | 854 ++- apps/search_memory_index.cpp | 842 +-- apps/test_insert_deletes_consolidate.cpp | 999 ++- apps/test_streaming_scenario.cpp | 949 ++- apps/utils/bin_to_fvecs.cpp | 104 +- apps/utils/bin_to_tsv.cpp | 103 +- apps/utils/calculate_recall.cpp | 71 +- apps/utils/compute_groundtruth.cpp | 952 ++- .../utils/compute_groundtruth_for_filters.cpp | 1555 ++--- apps/utils/count_bfs_levels.cpp | 97 +- apps/utils/create_disk_layout.cpp | 56 +- apps/utils/float_bin_to_int8.cpp | 105 +- apps/utils/fvecs_to_bin.cpp | 141 +- apps/utils/fvecs_to_bvecs.cpp | 86 +- apps/utils/gen_random_slice.cpp | 55 +- apps/utils/generate_pq.cpp | 107 +- apps/utils/generate_synthetic_labels.cpp | 322 +- apps/utils/int8_to_float.cpp | 28 +- apps/utils/int8_to_float_scale.cpp | 105 +- apps/utils/ivecs_to_bin.cpp | 89 +- apps/utils/merge_shards.cpp | 45 +- apps/utils/partition_data.cpp | 53 +- apps/utils/partition_with_ram_budget.cpp | 53 +- apps/utils/rand_data_gen.cpp | 400 +- apps/utils/simulate_aggregate_recall.cpp | 119 +- apps/utils/stats_label_data.cpp | 207 +- apps/utils/tsv_to_bin.cpp | 191 +- apps/utils/uint32_to_uint8.cpp | 28 +- apps/utils/uint8_to_float.cpp | 28 +- apps/utils/vector_analysis.cpp | 239 +- include/abstract_data_store.h | 231 +- include/abstract_filter_store.h | 28 +- include/abstract_graph_store.h | 102 +- include/abstract_index.h | 233 +- include/abstract_scratch.h | 38 +- include/aligned_file_reader.h | 133 +- include/ann_exception.h | 31 +- include/any_wrappers.h | 49 +- include/boost_dynamic_bitset_fwd.h | 7 +- include/cached_io.h | 349 +- include/concurrent_queue.h | 172 +- include/cosine_similarity.h | 390 +- include/defaults.h | 6 +- include/disk_utils.h | 93 +- include/distance.h | 362 +- include/exceptions.h | 13 +- include/filter_utils.h | 278 +- include/in_mem_data_store.h | 137 +- include/in_mem_filter_store.h | 236 +- include/in_mem_graph_store.h | 77 +- include/index.h | 862 ++- include/index_build_params.h | 109 +- include/index_config.h | 486 +- include/index_factory.h | 84 +- include/linux_aligned_file_reader.h | 56 +- include/locking.h | 3 +- include/logger.h | 13 +- include/logger_impl.h | 89 +- include/math_utils.h | 44 +- include/memory_mapper.h | 32 +- include/natural_number_map.h | 121 +- include/natural_number_set.h | 54 +- include/neighbor.h | 189 +- include/parameters.h | 188 +- include/partition.h | 39 +- include/percentile_stats.h | 76 +- include/pq.h | 120 +- include/pq_common.h | 24 +- include/pq_data_store.h | 177 +- include/pq_flash_index.h | 390 +- include/pq_l2_distance.h | 141 +- include/pq_scratch.h | 26 +- include/quantized_distance.h | 95 +- include/scratch.h | 298 +- include/simd_utils.h | 135 +- include/tag_uint128.h | 79 +- include/timer.h | 44 +- include/types.h | 3 +- include/utils.h | 1866 +++--- include/windows_aligned_file_reader.h | 60 +- include/windows_slim_lock.h | 74 +- src/abstract_data_store.cpp | 44 +- src/abstract_index.cpp | 741 +- src/ann_exception.cpp | 41 +- src/disk_utils.cpp | 2774 ++++---- src/distance.cpp | 1097 +-- src/filter_utils.cpp | 559 +- src/in_mem_data_store.cpp | 616 +- src/in_mem_filter_store.cpp | 663 +- src/in_mem_graph_store.cpp | 390 +- src/index.cpp | 5933 +++++++++-------- src/index_factory.cpp | 344 +- src/linux_aligned_file_reader.cpp | 339 +- src/logger.cpp | 106 +- src/math_utils.cpp | 672 +- src/memory_mapper.cpp | 146 +- src/natural_number_map.cpp | 130 +- src/natural_number_set.cpp | 67 +- src/partition.cpp | 1133 ++-- src/pq.cpp | 1923 +++--- src/pq_data_store.cpp | 367 +- src/pq_flash_index.cpp | 2351 +++---- src/pq_l2_distance.cpp | 386 +- src/scratch.cpp | 236 +- src/utils.cpp | 815 ++- src/windows_aligned_file_reader.cpp | 298 +- 110 files changed, 21069 insertions(+), 21004 deletions(-) diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index 475c9165b..41a885993 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -13,198 +13,179 @@ namespace po = boost::program_options; -int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, - label_file, universal_label, label_type; - uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; - float B, M; - bool append_reorder_data = false; - bool use_opq = false; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, + label_type; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + float B, M; + bool append_reorder_data = false; + bool use_opq = false; - po::options_description desc{program_options_utils::make_program_description( - "build_disk_index", "Build a disk-based index.")}; - try { - desc.add_options()("help,h", "Print information on arguments"); + po::options_description desc{ + program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - required_configs.add_options()( - "search_DRAM_budget,B", po::value(&B)->required(), - "DRAM budget in GB for searching the index to set the " - "compressed level for data while search happens"); - required_configs.add_options()("build_DRAM_budget,M", - po::value(&M)->required(), - "DRAM budget in GB for building the index"); + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()("search_DRAM_budget,B", po::value(&B)->required(), + "DRAM budget in GB for searching the index to set the " + "compressed level for data while search happens"); + required_configs.add_options()("build_DRAM_budget,M", po::value(&M)->required(), + "DRAM budget in GB for building the index"); - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", - po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()( - "Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("QD", - po::value(&QD)->default_value(0), - " Quantized Dimension for compression"); - optional_configs.add_options()( - "codebook_prefix", - po::value(&codebook_prefix)->default_value(""), - "Path prefix for pre-trained codebook"); - optional_configs.add_options()( - "PQ_disk_bytes", po::value(&disk_PQ)->default_value(0), - "Number of bytes to which vectors should be compressed " - "on SSD; 0 for no compression"); - optional_configs.add_options()( - "append_reorder_data", po::bool_switch()->default_value(false), - "Include full precision data in the index. Use only in " - "conjuction with compressed data on SSD."); - optional_configs.add_options()( - "build_PQ_bytes", po::value(&build_PQ)->default_value(0), - program_options_utils::BUIlD_GRAPH_PQ_BYTES); - optional_configs.add_options()("use_opq", - po::bool_switch()->default_value(false), - program_options_utils::USE_OPQ); - optional_configs.add_options()( - "label_file", po::value(&label_file)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()( - "universal_label", - po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - optional_configs.add_options()("FilteredLbuild", - po::value(&Lf)->default_value(0), - program_options_utils::FILTERED_LBUILD); - optional_configs.add_options()( - "filter_threshold,F", - po::value(&filter_threshold)->default_value(0), - "Threshold to break up the existing nodes to generate new graph " - "internally where each node has a maximum F labels."); - optional_configs.add_options()( - "label_type", - po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("QD", po::value(&QD)->default_value(0), + " Quantized Dimension for compression"); + optional_configs.add_options()("codebook_prefix", po::value(&codebook_prefix)->default_value(""), + "Path prefix for pre-trained codebook"); + optional_configs.add_options()("PQ_disk_bytes", po::value(&disk_PQ)->default_value(0), + "Number of bytes to which vectors should be compressed " + "on SSD; 0 for no compression"); + optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD."); + optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ)->default_value(0), + program_options_utils::BUIlD_GRAPH_PQ_BYTES); + optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), + program_options_utils::USE_OPQ); + optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), + program_options_utils::FILTERED_LBUILD); + optional_configs.add_options()("filter_threshold,F", po::value(&filter_threshold)->default_value(0), + "Threshold to break up the existing nodes to generate new graph " + "internally where each node has a maximum F labels."); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + if (vm["append_reorder_data"].as()) + append_reorder_data = true; + if (vm["use_opq"].as()) + use_opq = true; + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - po::notify(vm); - if (vm["append_reorder_data"].as()) - append_reorder_data = true; - if (vm["use_opq"].as()) - use_opq = true; - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - bool use_filters = (label_file != "") ? true : false; - diskann::Metric metric; - if (dist_fn == std::string("l2")) - metric = diskann::Metric::L2; - else if (dist_fn == std::string("mips")) - metric = diskann::Metric::INNER_PRODUCT; - else if (dist_fn == std::string("cosine")) - metric = diskann::Metric::COSINE; - else { - std::cout << "Error. Only l2 and mips distance functions are supported" - << std::endl; - return -1; - } - if (append_reorder_data) { - if (disk_PQ == 0) { - std::cout << "Error: It is not necessary to append data for reordering " - "when vectors are not compressed on disk." - << std::endl; - return -1; + bool use_filters = (label_file != "") ? true : false; + diskann::Metric metric; + if (dist_fn == std::string("l2")) + metric = diskann::Metric::L2; + else if (dist_fn == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else if (dist_fn == std::string("cosine")) + metric = diskann::Metric::COSINE; + else + { + std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl; + return -1; } - if (data_type != std::string("float")) { - std::cout << "Error: Appending data for reordering currently only " - "supported for float data type." - << std::endl; - return -1; + + if (append_reorder_data) + { + if (disk_PQ == 0) + { + std::cout << "Error: It is not necessary to append data for reordering " + "when vectors are not compressed on disk." + << std::endl; + return -1; + } + if (data_type != std::string("float")) + { + std::cout << "Error: Appending data for reordering currently only " + "supported for float data type." + << std::endl; + return -1; + } } - } - std::string params = std::string(std::to_string(R)) + " " + - std::string(std::to_string(L)) + " " + - std::string(std::to_string(B)) + " " + - std::string(std::to_string(M)) + " " + - std::string(std::to_string(num_threads)) + " " + - std::string(std::to_string(disk_PQ)) + " " + - std::string(std::to_string(append_reorder_data)) + " " + - std::string(std::to_string(build_PQ)) + " " + - std::string(std::to_string(QD)); + std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " + + std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " + + std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + + std::string(std::to_string(append_reorder_data)) + " " + + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); - try { - if (label_file != "" && label_type == "ushort") { - if (data_type == std::string("int8")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("float")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else { - diskann::cerr << "Error. Unsupported data type" << std::endl; - return -1; - } - } else { - if (data_type == std::string("int8")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else if (data_type == std::string("float")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); - else { - diskann::cerr << "Error. Unsupported data type" << std::endl; + try + { + if (label_file != "" && label_type == "ushort") + { + if (data_type == std::string("int8")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, + use_filters, label_file, universal_label, filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, + use_filters, label_file, universal_label, filter_threshold, Lf); + else + { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } + } + else + { + if (data_type == std::string("int8")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else + { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } + } + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; return -1; - } } - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; - return -1; - } } diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 0efd75281..f0d469f4d 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -22,149 +22,143 @@ namespace po = boost::program_options; -int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, - universal_label, label_type; - uint32_t num_threads, R, L, Lf, build_PQ_bytes; - float alpha; - bool use_pq_build, use_opq; - - po::options_description desc{program_options_utils::make_program_description( - "build_memory_index", "Build a memory-based DiskANN index.")}; - try { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", - po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()( - "Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()( - "alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()( - "build_PQ_bytes", - po::value(&build_PQ_bytes)->default_value(0), - program_options_utils::BUIlD_GRAPH_PQ_BYTES); - optional_configs.add_options()("use_opq", - po::bool_switch()->default_value(false), - program_options_utils::USE_OPQ); - optional_configs.add_options()( - "label_file", po::value(&label_file)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()( - "universal_label", - po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - - optional_configs.add_options()("FilteredLbuild", - po::value(&Lf)->default_value(0), - program_options_utils::FILTERED_LBUILD); - optional_configs.add_options()( - "label_type", - po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; + uint32_t num_threads, R, L, Lf, build_PQ_bytes; + float alpha; + bool use_pq_build, use_opq; + + po::options_description desc{ + program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ_bytes)->default_value(0), + program_options_utils::BUIlD_GRAPH_PQ_BYTES); + optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), + program_options_utils::USE_OPQ); + optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + + optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), + program_options_utils::FILTERED_LBUILD); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + use_pq_build = (build_PQ_bytes > 0); + use_opq = vm["use_opq"].as(); } - po::notify(vm); - use_pq_build = (build_PQ_bytes > 0); - use_opq = vm["use_opq"].as(); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("mips")) { - metric = diskann::Metric::INNER_PRODUCT; - } else if (dist_fn == std::string("l2")) { - metric = diskann::Metric::L2; - } else if (dist_fn == std::string("cosine")) { - metric = diskann::Metric::COSINE; - } else { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." - << std::endl; - return -1; - } - - try { - diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L - << " alpha: " << alpha << " #threads: " << num_threads + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." << std::endl; + return -1; + } - size_t data_num, data_dim; - diskann::get_bin_metadata(data_path, data_num, data_dim); - - auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) - .with_filter_list_size(Lf) - .with_alpha(alpha) - .with_saturate_graph(false) - .with_num_threads(num_threads) - .build(); - - auto filter_params = diskann::IndexFilterParamsBuilder() - .with_universal_label(universal_label) - .with_label_file(label_file) - .with_save_path_prefix(index_path_prefix) - .build(); - auto config = - diskann::IndexConfigBuilder() - .with_metric(metric) - .with_dimension(data_dim) - .with_max_points(data_num) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .with_data_type(data_type) - .with_label_type(label_type) - .is_dynamic_index(false) - .with_index_write_params(index_build_params) - .is_enable_tags(false) - .is_use_opq(use_opq) - .is_pq_dist_build(use_pq_build) - .with_num_pq_chunks(build_PQ_bytes) - .build(); - - auto index_factory = diskann::IndexFactory(config); - auto index = index_factory.create_instance(); - index->build(data_path, data_num, filter_params); - index->save(index_path_prefix.c_str()); - index.reset(); - return 0; - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; - return -1; - } + try + { + diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha + << " #threads: " << num_threads << std::endl; + + size_t data_num, data_dim; + diskann::get_bin_metadata(data_path, data_num, data_dim); + + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + auto filter_params = diskann::IndexFilterParamsBuilder() + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(index_path_prefix) + .build(); + auto config = diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(data_dim) + .with_max_points(data_num) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .with_data_type(data_type) + .with_label_type(label_type) + .is_dynamic_index(false) + .with_index_write_params(index_build_params) + .is_enable_tags(false) + .is_use_opq(use_opq) + .is_pq_dist_build(use_pq_build) + .with_num_pq_chunks(build_PQ_bytes) + .build(); + + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->build(data_path, data_num, filter_params); + index->save(index_path_prefix.c_str()); + index.reset(); + return 0; + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; + return -1; + } } diff --git a/apps/build_stitched_index.cpp b/apps/build_stitched_index.cpp index 9b09a062b..0f385cb88 100644 --- a/apps/build_stitched_index.cpp +++ b/apps/build_stitched_index.cpp @@ -21,28 +21,29 @@ #include "utils.h" namespace po = boost::program_options; -typedef std::tuple>, uint64_t> - stitch_indices_return_values; +typedef std::tuple>, uint64_t> stitch_indices_return_values; /* * Inline function to display progress bar. */ -inline void print_progress(double percentage) { - int val = (int)(percentage * 100); - int lpad = (int)(percentage * PBWIDTH); - int rpad = PBWIDTH - lpad; - printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); - fflush(stdout); +inline void print_progress(double percentage) +{ + int val = (int)(percentage * 100); + int lpad = (int)(percentage * PBWIDTH); + int rpad = PBWIDTH - lpad; + printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); + fflush(stdout); } /* * Inline function to generate a random integer in a range. */ -inline size_t random(size_t range_from, size_t range_to) { - std::random_device rand_dev; - std::mt19937 generator(rand_dev()); - std::uniform_int_distribution distr(range_from, range_to); - return distr(generator); +inline size_t random(size_t range_from, size_t range_to) +{ + std::random_device rand_dev; + std::mt19937 generator(rand_dev()); + std::uniform_int_distribution distr(range_from, range_to); + return distr(generator); } /* @@ -50,70 +51,61 @@ inline size_t random(size_t range_from, size_t range_to) { * * Arguments are merely the inputs from the command line. */ -void handle_args(int argc, char **argv, std::string &data_type, - path &input_data_path, path &final_index_path_prefix, - path &label_data_path, std::string &universal_label, - uint32_t &num_threads, uint32_t &R, uint32_t &L, - uint32_t &stitched_R, float &alpha) { - po::options_description desc{program_options_utils::make_program_description( - "build_stitched_index", "Build a stitched DiskANN index.")}; - try { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&final_index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "data_path", po::value(&input_data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", - po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()( - "Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()( - "alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()( - "label_file", - po::value(&label_data_path)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()( - "universal_label", - po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - optional_configs.add_options()( - "stitched_R", po::value(&stitched_R)->default_value(100), - "Degree to prune final graph down to"); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - exit(0); +void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix, + path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L, + uint32_t &stitched_R, float &alpha) +{ + po::options_description desc{ + program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("index_path_prefix", + po::value(&final_index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&input_data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("label_file", po::value(&label_data_path)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + optional_configs.add_options()("stitched_R", po::value(&stitched_R)->default_value(100), + "Degree to prune final graph down to"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + exit(0); + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + throw; } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - throw; - } } /* @@ -124,110 +116,98 @@ void handle_args(int argc, char **argv, std::string &data_type, * 3. data (redundant for static indices) * 4. labels (redundant for static indices) */ -void save_full_index(path final_index_path_prefix, path input_data_path, - uint64_t final_index_size, +void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size, std::vector> stitched_graph, - tsl::robin_map entry_points, - std::string universal_label, path label_data_path) { - // aux. file 1 - auto saving_index_timer = std::chrono::high_resolution_clock::now(); - std::ifstream original_label_data_stream; - original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - original_label_data_stream.open(label_data_path, std::ios::binary); - std::ofstream new_label_data_stream; - new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - new_label_data_stream.open(final_index_path_prefix + "_labels.txt", - std::ios::binary); - new_label_data_stream << original_label_data_stream.rdbuf(); - original_label_data_stream.close(); - new_label_data_stream.close(); - - // aux. file 2 - std::ifstream original_input_data_stream; - original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - original_input_data_stream.open(input_data_path, std::ios::binary); - std::ofstream new_input_data_stream; - new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - new_input_data_stream.open(final_index_path_prefix + ".data", - std::ios::binary); - new_input_data_stream << original_input_data_stream.rdbuf(); - original_input_data_stream.close(); - new_input_data_stream.close(); - - // aux. file 3 - std::ofstream labels_to_medoids_writer; - labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); - labels_to_medoids_writer.open(final_index_path_prefix + - "_labels_to_medoids.txt"); - for (auto iter : entry_points) - labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; - labels_to_medoids_writer.close(); - - // aux. file 4 (only if we're using a universal label) - if (universal_label != "") { - std::ofstream universal_label_writer; - universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); - universal_label_writer.open(final_index_path_prefix + - "_universal_label.txt"); - universal_label_writer << universal_label << std::endl; - universal_label_writer.close(); - } - - // main index - uint64_t index_num_frozen_points = 0, index_num_edges = 0; - uint32_t index_max_observed_degree = 0, index_entry_point = 0; - const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); - for (auto &point_neighbors : stitched_graph) { - index_max_observed_degree = - std::max(index_max_observed_degree, (uint32_t)point_neighbors.size()); - } - - std::ofstream stitched_graph_writer; - stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); - stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); - - stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t)); - stitched_graph_writer.write((char *)&index_max_observed_degree, - sizeof(uint32_t)); - stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t)); - stitched_graph_writer.write((char *)&index_num_frozen_points, - sizeof(uint64_t)); - - size_t bytes_written = METADATA; - for (uint32_t node_point = 0; node_point < stitched_graph.size(); - node_point++) { - uint32_t current_node_num_neighbors = - (uint32_t)stitched_graph[node_point].size(); - std::vector current_node_neighbors = stitched_graph[node_point]; - stitched_graph_writer.write((char *)¤t_node_num_neighbors, - sizeof(uint32_t)); - bytes_written += sizeof(uint32_t); - for (const auto ¤t_node_neighbor : current_node_neighbors) { - stitched_graph_writer.write((char *)¤t_node_neighbor, - sizeof(uint32_t)); - bytes_written += sizeof(uint32_t); + tsl::robin_map entry_points, std::string universal_label, + path label_data_path) +{ + // aux. file 1 + auto saving_index_timer = std::chrono::high_resolution_clock::now(); + std::ifstream original_label_data_stream; + original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_label_data_stream.open(label_data_path, std::ios::binary); + std::ofstream new_label_data_stream; + new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary); + new_label_data_stream << original_label_data_stream.rdbuf(); + original_label_data_stream.close(); + new_label_data_stream.close(); + + // aux. file 2 + std::ifstream original_input_data_stream; + original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_input_data_stream.open(input_data_path, std::ios::binary); + std::ofstream new_input_data_stream; + new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary); + new_input_data_stream << original_input_data_stream.rdbuf(); + original_input_data_stream.close(); + new_input_data_stream.close(); + + // aux. file 3 + std::ofstream labels_to_medoids_writer; + labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); + labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt"); + for (auto iter : entry_points) + labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; + labels_to_medoids_writer.close(); + + // aux. file 4 (only if we're using a universal label) + if (universal_label != "") + { + std::ofstream universal_label_writer; + universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); + universal_label_writer.open(final_index_path_prefix + "_universal_label.txt"); + universal_label_writer << universal_label << std::endl; + universal_label_writer.close(); + } + + // main index + uint64_t index_num_frozen_points = 0, index_num_edges = 0; + uint32_t index_max_observed_degree = 0, index_entry_point = 0; + const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + for (auto &point_neighbors : stitched_graph) + { + index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size()); + } + + std::ofstream stitched_graph_writer; + stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); + stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); + + stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t)); + stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t)); + stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t)); + stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t)); + + size_t bytes_written = METADATA; + for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++) + { + uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size(); + std::vector current_node_neighbors = stitched_graph[node_point]; + stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t)); + bytes_written += sizeof(uint32_t); + for (const auto ¤t_node_neighbor : current_node_neighbors) + { + stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t)); + bytes_written += sizeof(uint32_t); + } + index_num_edges += current_node_num_neighbors; + } + + if (bytes_written != final_index_size) + { + std::cerr << "Error: written bytes does not match allocated space" << std::endl; + throw; } - index_num_edges += current_node_num_neighbors; - } - if (bytes_written != final_index_size) { - std::cerr << "Error: written bytes does not match allocated space" + stitched_graph_writer.close(); + + std::chrono::duration saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer; + std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl; + std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size())) << std::endl; - throw; - } - - stitched_graph_writer.close(); - - std::chrono::duration saving_index_time = - std::chrono::high_resolution_clock::now() - saving_index_timer; - std::cout << "Stitched graph written in " << saving_index_time.count() - << " seconds" << std::endl; - std::cout << "Stitched graph average degree: " - << ((float)index_num_edges) / ((float)(stitched_graph.size())) - << std::endl; - std::cout << "Stitched graph max degree: " << index_max_observed_degree - << std::endl - << std::endl; + std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl; } /* @@ -238,55 +218,52 @@ void save_full_index(path final_index_path_prefix, path input_data_path, */ template stitch_indices_return_values stitch_label_indices( - path final_index_path_prefix, uint32_t total_number_of_points, - label_set all_labels, + path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels, tsl::robin_map labels_to_number_of_points, tsl::robin_map &label_entry_points, - tsl::robin_map> - label_id_to_orig_id_map) { - size_t final_index_size = 0; - std::vector> stitched_graph(total_number_of_points); - - auto stitching_index_timer = std::chrono::high_resolution_clock::now(); - for (const auto &lbl : all_labels) { - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - std::vector> curr_label_index; - uint64_t curr_label_index_size; - uint32_t curr_label_entry_point; - - std::tie(curr_label_index, curr_label_index_size) = - diskann::load_label_index(curr_label_index_path, - labels_to_number_of_points[lbl]); - curr_label_entry_point = (uint32_t)random(0, curr_label_index.size()); - label_entry_points[lbl] = - label_id_to_orig_id_map[lbl][curr_label_entry_point]; - - for (uint32_t node_point = 0; node_point < curr_label_index.size(); - node_point++) { - uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point]; - for (auto &node_neighbor : curr_label_index[node_point]) { - uint32_t original_neighbor_id = - label_id_to_orig_id_map[lbl][node_neighbor]; - std::vector curr_point_neighbors = - stitched_graph[original_point_id]; - if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), - original_neighbor_id) == curr_point_neighbors.end()) { - stitched_graph[original_point_id].push_back(original_neighbor_id); - final_index_size += sizeof(uint32_t); + tsl::robin_map> label_id_to_orig_id_map) +{ + size_t final_index_size = 0; + std::vector> stitched_graph(total_number_of_points); + + auto stitching_index_timer = std::chrono::high_resolution_clock::now(); + for (const auto &lbl : all_labels) + { + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + std::vector> curr_label_index; + uint64_t curr_label_index_size; + uint32_t curr_label_entry_point; + + std::tie(curr_label_index, curr_label_index_size) = + diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]); + curr_label_entry_point = (uint32_t)random(0, curr_label_index.size()); + label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point]; + + for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++) + { + uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point]; + for (auto &node_neighbor : curr_label_index[node_point]) + { + uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor]; + std::vector curr_point_neighbors = stitched_graph[original_point_id]; + if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) == + curr_point_neighbors.end()) + { + stitched_graph[original_point_id].push_back(original_neighbor_id); + final_index_size += sizeof(uint32_t); + } + } } - } } - } - const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); - final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA); + const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA); - std::chrono::duration stitching_index_time = - std::chrono::high_resolution_clock::now() - stitching_index_timer; - std::cout << "stitched graph generated in memory in " - << stitching_index_time.count() << " seconds" << std::endl; + std::chrono::duration stitching_index_time = + std::chrono::high_resolution_clock::now() - stitching_index_timer; + std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl; - return std::make_tuple(stitched_graph, final_index_size); + return std::make_tuple(stitched_graph, final_index_size); } /* @@ -297,39 +274,33 @@ stitch_indices_return_values stitch_label_indices( * and pruned graph. */ template -void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, - path input_data_path, - std::vector> stitched_graph, - uint32_t stitched_R, - tsl::robin_map label_entry_points, - std::string universal_label, path label_data_path, - uint32_t num_threads) { - size_t dimension, number_of_label_points; - auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); - auto std_cout_buffer = std::cout.rdbuf(nullptr); - auto pruning_index_timer = std::chrono::high_resolution_clock::now(); - - diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); - - diskann::Index index(diskann::Metric::L2, dimension, - number_of_label_points, nullptr, nullptr, 0, false, - false, false, false, 0, false); - - // not searching this index, set search_l to 0 - index.load(full_index_path_prefix.c_str(), num_threads, 1); - - std::cout << "parsing labels" << std::endl; - - index.prune_all_neighbors(stitched_R, 750, 1.2); - index.save((final_index_path_prefix).c_str()); - - diskann::cout.rdbuf(diskann_cout_buffer); - std::cout.rdbuf(std_cout_buffer); - std::chrono::duration pruning_index_time = - std::chrono::high_resolution_clock::now() - pruning_index_timer; - std::cout << "pruning performed in " << pruning_index_time.count() - << " seconds\n" - << std::endl; +void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path, + std::vector> stitched_graph, uint32_t stitched_R, + tsl::robin_map label_entry_points, std::string universal_label, + path label_data_path, uint32_t num_threads) +{ + size_t dimension, number_of_label_points; + auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); + auto std_cout_buffer = std::cout.rdbuf(nullptr); + auto pruning_index_timer = std::chrono::high_resolution_clock::now(); + + diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); + + diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false, + false, false, 0, false); + + // not searching this index, set search_l to 0 + index.load(full_index_path_prefix.c_str(), num_threads, 1); + + std::cout << "parsing labels" << std::endl; + + index.prune_all_neighbors(stitched_R, 750, 1.2); + index.save((final_index_path_prefix).c_str()); + + diskann::cout.rdbuf(diskann_cout_buffer); + std::cout.rdbuf(std_cout_buffer); + std::chrono::duration pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer; + std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl; } /* @@ -340,160 +311,131 @@ void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, * 2. the separate diskANN indices built for each label * 3. the '.data' file created while generating the indices */ -void clean_up_artifacts(path input_data_path, path final_index_path_prefix, - label_set all_labels) { - for (const auto &lbl : all_labels) { - path curr_label_input_data_path(input_data_path + "_" + lbl); - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - path curr_label_index_path_data(curr_label_index_path + ".data"); - - if (std::remove(curr_label_index_path.c_str()) != 0) - throw; - if (std::remove(curr_label_input_data_path.c_str()) != 0) - throw; - if (std::remove(curr_label_index_path_data.c_str()) != 0) - throw; - } +void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels) +{ + for (const auto &lbl : all_labels) + { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + path curr_label_index_path_data(curr_label_index_path + ".data"); + + if (std::remove(curr_label_index_path.c_str()) != 0) + throw; + if (std::remove(curr_label_input_data_path.c_str()) != 0) + throw; + if (std::remove(curr_label_index_path_data.c_str()) != 0) + throw; + } } -int main(int argc, char **argv) { - // 1. handle cmdline inputs - std::string data_type; - path input_data_path, final_index_path_prefix, label_data_path; - std::string universal_label; - uint32_t num_threads, R, L, stitched_R; - float alpha; +int main(int argc, char **argv) +{ + // 1. handle cmdline inputs + std::string data_type; + path input_data_path, final_index_path_prefix, label_data_path; + std::string universal_label; + uint32_t num_threads, R, L, stitched_R; + float alpha; - auto index_timer = std::chrono::high_resolution_clock::now(); - handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, - label_data_path, universal_label, num_threads, R, L, stitched_R, - alpha); + auto index_timer = std::chrono::high_resolution_clock::now(); + handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label, + num_threads, R, L, stitched_R, alpha); - path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; - path labels_map_file = final_index_path_prefix + "_labels_map.txt"; + path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; + path labels_map_file = final_index_path_prefix + "_labels_map.txt"; - convert_labels_string_to_int(label_data_path, labels_file_to_use, - labels_map_file, universal_label); + convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label); - // 2. parse label file and create necessary data structures - std::vector point_ids_to_labels; - tsl::robin_map labels_to_number_of_points; - label_set all_labels; + // 2. parse label file and create necessary data structures + std::vector point_ids_to_labels; + tsl::robin_map labels_to_number_of_points; + label_set all_labels; - std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = - diskann::parse_label_file(labels_file_to_use, universal_label); + std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = + diskann::parse_label_file(labels_file_to_use, universal_label); - // 3. for each label, make a separate data file - tsl::robin_map> label_id_to_orig_id_map; - uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size(); + // 3. for each label, make a separate data file + tsl::robin_map> label_id_to_orig_id_map; + uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size(); #ifndef _WINDOWS - if (data_type == "uint8") - label_id_to_orig_id_map = - diskann::generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, - all_labels); - else if (data_type == "int8") - label_id_to_orig_id_map = - diskann::generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, - all_labels); - else if (data_type == "float") - label_id_to_orig_id_map = - diskann::generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, - all_labels); - else - throw; + if (data_type == "uint8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else + throw; #else - if (data_type == "uint8") - label_id_to_orig_id_map = - diskann::generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, - all_labels); - else if (data_type == "int8") - label_id_to_orig_id_map = - diskann::generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, - all_labels); - else if (data_type == "float") - label_id_to_orig_id_map = - diskann::generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, - all_labels); - else - throw; + if (data_type == "uint8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else + throw; #endif - // 4. for each created data file, create a vanilla diskANN index - if (data_type == "uint8") - diskann::generate_label_indices( - input_data_path, final_index_path_prefix, all_labels, R, L, alpha, - num_threads); - else if (data_type == "int8") - diskann::generate_label_indices(input_data_path, - final_index_path_prefix, all_labels, - R, L, alpha, num_threads); - else if (data_type == "float") - diskann::generate_label_indices(input_data_path, - final_index_path_prefix, all_labels, - R, L, alpha, num_threads); - else - throw; - - // 5. "stitch" the indices together - std::vector> stitched_graph; - tsl::robin_map label_entry_points; - uint64_t stitched_graph_size; - - if (data_type == "uint8") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices( - final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, - label_id_to_orig_id_map); - else if (data_type == "int8") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices( - final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, - label_id_to_orig_id_map); - else if (data_type == "float") - std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices( - final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, - label_id_to_orig_id_map); - else - throw; - path full_index_path_prefix = final_index_path_prefix + "_full"; - // 5a. save the stitched graph to disk - save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, - stitched_graph, label_entry_points, universal_label, - labels_file_to_use); - - // 6. run a prune on the stitched index, and save to disk - if (data_type == "uint8") - prune_and_save(final_index_path_prefix, full_index_path_prefix, - input_data_path, stitched_graph, stitched_R, - label_entry_points, universal_label, - labels_file_to_use, num_threads); - else if (data_type == "int8") - prune_and_save(final_index_path_prefix, full_index_path_prefix, - input_data_path, stitched_graph, stitched_R, - label_entry_points, universal_label, - labels_file_to_use, num_threads); - else if (data_type == "float") - prune_and_save(final_index_path_prefix, full_index_path_prefix, - input_data_path, stitched_graph, stitched_R, - label_entry_points, universal_label, - labels_file_to_use, num_threads); - else - throw; - - std::chrono::duration index_time = - std::chrono::high_resolution_clock::now() - index_timer; - std::cout << "pruned/stitched graph generated in " << index_time.count() - << " seconds" << std::endl; - - clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); + // 4. for each created data file, create a vanilla diskANN index + if (data_type == "uint8") + diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else if (data_type == "int8") + diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else if (data_type == "float") + diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else + throw; + + // 5. "stitch" the indices together + std::vector> stitched_graph; + tsl::robin_map label_entry_points; + uint64_t stitched_graph_size; + + if (data_type == "uint8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); + else if (data_type == "int8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); + else if (data_type == "float") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); + else + throw; + path full_index_path_prefix = final_index_path_prefix + "_full"; + // 5a. save the stitched graph to disk + save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points, + universal_label, labels_file_to_use); + + // 6. run a prune on the stitched index, and save to disk + if (data_type == "uint8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, + stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); + else if (data_type == "int8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, + stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); + else if (data_type == "float") + prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, + stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); + else + throw; + + std::chrono::duration index_time = std::chrono::high_resolution_clock::now() - index_timer; + std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl; + + clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); } diff --git a/apps/range_search_disk_index.cpp b/apps/range_search_disk_index.cpp index bfdfafa03..3975298ae 100644 --- a/apps/range_search_disk_index.cpp +++ b/apps/range_search_disk_index.cpp @@ -34,348 +34,346 @@ namespace po = boost::program_options; #define WARMUP false -void print_stats(std::string category, std::vector percentiles, - std::vector results) { - diskann::cout << std::setw(20) << category << ": " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) { - diskann::cout << std::setw(8) << percentiles[s] << "%"; - } - diskann::cout << std::endl; - diskann::cout << std::setw(22) << " " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) { - diskann::cout << std::setw(9) << results[s]; - } - diskann::cout << std::endl; +void print_stats(std::string category, std::vector percentiles, std::vector results) +{ + diskann::cout << std::setw(20) << category << ": " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) + { + diskann::cout << std::setw(8) << percentiles[s] << "%"; + } + diskann::cout << std::endl; + diskann::cout << std::setw(22) << " " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) + { + diskann::cout << std::setw(9) << results[s]; + } + diskann::cout << std::endl; } template -int search_disk_index(diskann::Metric &metric, - const std::string &index_path_prefix, - const std::string &query_file, std::string >_file, - const uint32_t num_threads, const float search_range, - const uint32_t beamwidth, - const uint32_t num_nodes_to_cache, - const std::vector &Lvec) { - std::string pq_prefix = index_path_prefix + "_pq"; - std::string disk_index_file = index_path_prefix + "_disk.index"; - std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; - - diskann::cout << "Search parameters: #threads: " << num_threads << ", "; - if (beamwidth <= 0) - diskann::cout << "beamwidth to be optimized for each L value" << std::endl; - else - diskann::cout << " beamwidth: " << beamwidth << std::endl; - - // load query bin - T *query = nullptr; - std::vector> groundtruth_ids; - size_t query_num, query_dim, query_aligned_dim, gt_num; - diskann::load_aligned_bin(query_file, query, query_num, query_dim, - query_aligned_dim); - - bool calc_recall_flag = false; - if (gt_file != std::string("null") && file_exists(gt_file)) { - diskann::load_range_truthset( - gt_file, groundtruth_ids, - gt_num); // use for range search type of truthset - // diskann::prune_truthset_for_range(gt_file, search_range, - // groundtruth_ids, gt_num); // use for traditional truthset - if (gt_num != query_num) { - diskann::cout - << "Error. Mismatch in number of queries and ground truth data" - << std::endl; - return -1; +int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file, + std::string >_file, const uint32_t num_threads, const float search_range, + const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector &Lvec) +{ + std::string pq_prefix = index_path_prefix + "_pq"; + std::string disk_index_file = index_path_prefix + "_disk.index"; + std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; + + diskann::cout << "Search parameters: #threads: " << num_threads << ", "; + if (beamwidth <= 0) + diskann::cout << "beamwidth to be optimized for each L value" << std::endl; + else + diskann::cout << " beamwidth: " << beamwidth << std::endl; + + // load query bin + T *query = nullptr; + std::vector> groundtruth_ids; + size_t query_num, query_dim, query_aligned_dim, gt_num; + diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); + + bool calc_recall_flag = false; + if (gt_file != std::string("null") && file_exists(gt_file)) + { + diskann::load_range_truthset(gt_file, groundtruth_ids, + gt_num); // use for range search type of truthset + // diskann::prune_truthset_for_range(gt_file, search_range, + // groundtruth_ids, gt_num); // use for traditional truthset + if (gt_num != query_num) + { + diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; + return -1; + } + calc_recall_flag = true; } - calc_recall_flag = true; - } - std::shared_ptr reader = nullptr; + std::shared_ptr reader = nullptr; #ifdef _WINDOWS #ifndef USE_BING_INFRA - reader.reset(new WindowsAlignedFileReader()); + reader.reset(new WindowsAlignedFileReader()); #else - reader.reset(new diskann::BingAlignedFileReader()); + reader.reset(new diskann::BingAlignedFileReader()); #endif #else - reader.reset(new LinuxAlignedFileReader()); + reader.reset(new LinuxAlignedFileReader()); #endif - std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader, metric)); - - int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); - - if (res != 0) { - return res; - } - // cache bfs levels - std::vector node_list; - diskann::cout << "Caching " << num_nodes_to_cache - << " BFS nodes around medoid(s)" << std::endl; - _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); - // _pFlashIndex->generate_cache_list_from_sample_queries( - // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, - // node_list); - _pFlashIndex->load_cache_list(node_list); - node_list.clear(); - node_list.shrink_to_fit(); - - omp_set_num_threads(num_threads); - - uint64_t warmup_L = 20; - uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; - T *warmup = nullptr; - - if (WARMUP) { - if (file_exists(warmup_query_file)) { - diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, - warmup_dim, warmup_aligned_dim); - } else { - warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); - warmup_dim = query_dim; - warmup_aligned_dim = query_aligned_dim; - diskann::alloc_aligned(((void **)&warmup), - warmup_num * warmup_aligned_dim * sizeof(T), - 8 * sizeof(T)); - std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(-128, 127); - for (uint32_t i = 0; i < warmup_num; i++) { - for (uint32_t d = 0; d < warmup_dim; d++) { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); - } - } + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); + + int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); + + if (res != 0) + { + return res; } - diskann::cout << "Warming up index... " << std::flush; - std::vector warmup_result_ids_64(warmup_num, 0); - std::vector warmup_result_dists(warmup_num, 0); + // cache bfs levels + std::vector node_list; + diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl; + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // _pFlashIndex->generate_cache_list_from_sample_queries( + // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, + // node_list); + _pFlashIndex->load_cache_list(node_list); + node_list.clear(); + node_list.shrink_to_fit(); + + omp_set_num_threads(num_threads); + + uint64_t warmup_L = 20; + uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; + T *warmup = nullptr; + + if (WARMUP) + { + if (file_exists(warmup_query_file)) + { + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); + } + else + { + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); + warmup_dim = query_dim; + warmup_aligned_dim = query_aligned_dim; + diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) + { + for (uint32_t d = 0; d < warmup_dim; d++) + { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + } + } + } + diskann::cout << "Warming up index... " << std::flush; + std::vector warmup_result_ids_64(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)warmup_num; i++) { - _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, - warmup_L, - warmup_result_ids_64.data() + (i * 1), - warmup_result_dists.data() + (i * 1), 4); + for (int64_t i = 0; i < (int64_t)warmup_num; i++) + { + _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, + warmup_result_ids_64.data() + (i * 1), + warmup_result_dists.data() + (i * 1), 4); + } + diskann::cout << "..done" << std::endl; } - diskann::cout << "..done" << std::endl; - } - - diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - diskann::cout.precision(2); - - std::string recall_string = "Recall@rng=" + std::to_string(search_range); - diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" - << std::setw(16) << "QPS" << std::setw(16) << "Mean Latency" - << std::setw(16) << "99.9 Latency" << std::setw(16) - << "Mean IOs" << std::setw(16) << "CPU (s)"; - if (calc_recall_flag) { - diskann::cout << std::setw(16) << recall_string << std::endl; - } else - diskann::cout << std::endl; - diskann::cout - << "===============================================================" - "===========================================" - << std::endl; - std::vector>> query_result_ids(Lvec.size()); + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + diskann::cout.precision(2); + + std::string recall_string = "Recall@rng=" + std::to_string(search_range); + diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16) + << "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16) + << "CPU (s)"; + if (calc_recall_flag) + { + diskann::cout << std::setw(16) << recall_string << std::endl; + } + else + diskann::cout << std::endl; + diskann::cout << "===============================================================" + "===========================================" + << std::endl; + + std::vector>> query_result_ids(Lvec.size()); - uint32_t optimized_beamwidth = 2; - uint32_t max_list_size = 10000; + uint32_t optimized_beamwidth = 2; + uint32_t max_list_size = 10000; - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { - uint32_t L = Lvec[test_id]; + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) + { + uint32_t L = Lvec[test_id]; - if (beamwidth <= 0) { - optimized_beamwidth = - optimize_beamwidth(_pFlashIndex, warmup, warmup_num, - warmup_aligned_dim, L, optimized_beamwidth); - } else - optimized_beamwidth = beamwidth; + if (beamwidth <= 0) + { + optimized_beamwidth = + optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); + } + else + optimized_beamwidth = beamwidth; - query_result_ids[test_id].clear(); - query_result_ids[test_id].resize(query_num); + query_result_ids[test_id].clear(); + query_result_ids[test_id].resize(query_num); - diskann::QueryStats *stats = new diskann::QueryStats[query_num]; + diskann::QueryStats *stats = new diskann::QueryStats[query_num]; - auto s = std::chrono::high_resolution_clock::now(); + auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) { - std::vector indices; - std::vector distances; - uint32_t res_count = _pFlashIndex->range_search( - query + (i * query_aligned_dim), search_range, L, max_list_size, - indices, distances, optimized_beamwidth, stats + i); - query_result_ids[test_id][i].reserve(res_count); - query_result_ids[test_id][i].resize(res_count); - for (uint32_t idx = 0; idx < res_count; idx++) - query_result_ids[test_id][i][idx] = (uint32_t)indices[idx]; - } - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - auto qps = (1.0 * query_num) / (1.0 * diff.count()); - - auto mean_latency = diskann::get_mean_stats( - stats, query_num, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto latency_999 = diskann::get_percentile_stats( - stats, query_num, 0.999, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto mean_ios = diskann::get_mean_stats( - stats, query_num, - [](const diskann::QueryStats &stats) { return stats.n_ios; }); - - double mean_cpuus = diskann::get_mean_stats( - stats, query_num, - [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - - double recall = 0; - double ratio_of_sums = 0; - if (calc_recall_flag) { - recall = diskann::calculate_range_search_recall( - (uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]); - - uint32_t total_true_positive = 0; - uint32_t total_positive = 0; - for (uint32_t i = 0; i < query_num; i++) { - total_true_positive += (uint32_t)query_result_ids[test_id][i].size(); - total_positive += (uint32_t)groundtruth_ids[i].size(); - } - - ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive); + for (int64_t i = 0; i < (int64_t)query_num; i++) + { + std::vector indices; + std::vector distances; + uint32_t res_count = + _pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices, + distances, optimized_beamwidth, stats + i); + query_result_ids[test_id][i].reserve(res_count); + query_result_ids[test_id][i].resize(res_count); + for (uint32_t idx = 0; idx < res_count; idx++) + query_result_ids[test_id][i][idx] = (uint32_t)indices[idx]; + } + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + auto qps = (1.0 * query_num) / (1.0 * diff.count()); + + auto mean_latency = diskann::get_mean_stats( + stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto latency_999 = diskann::get_percentile_stats( + stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto mean_ios = diskann::get_mean_stats(stats, query_num, + [](const diskann::QueryStats &stats) { return stats.n_ios; }); + + double mean_cpuus = diskann::get_mean_stats( + stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; }); + + double recall = 0; + double ratio_of_sums = 0; + if (calc_recall_flag) + { + recall = + diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]); + + uint32_t total_true_positive = 0; + uint32_t total_positive = 0; + for (uint32_t i = 0; i < query_num; i++) + { + total_true_positive += (uint32_t)query_result_ids[test_id][i].size(); + total_positive += (uint32_t)groundtruth_ids[i].size(); + } + + ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive); + } + + diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps + << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios + << std::setw(16) << mean_cpuus; + if (calc_recall_flag) + { + diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl; + } + else + diskann::cout << std::endl; } - diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth - << std::setw(16) << qps << std::setw(16) << mean_latency - << std::setw(16) << latency_999 << std::setw(16) << mean_ios - << std::setw(16) << mean_cpuus; - if (calc_recall_flag) { - diskann::cout << std::setw(16) << recall << "," << ratio_of_sums - << std::endl; - } else - diskann::cout << std::endl; - } - - diskann::cout << "Done searching. " << std::endl; - - diskann::aligned_free(query); - if (warmup != nullptr) - diskann::aligned_free(warmup); - return 0; + diskann::cout << "Done searching. " << std::endl; + + diskann::aligned_free(query); + if (warmup != nullptr) + diskann::aligned_free(warmup); + return 0; } -int main(int argc, char **argv) { - std::string data_type, dist_fn, index_path_prefix, result_path_prefix, - query_file, gt_file; - uint32_t num_threads, W, num_nodes_to_cache; - std::vector Lvec; - float range; - - po::options_description desc{program_options_utils::make_program_description( - "range_search_disk_index", "Searches disk DiskANN indexes using ranges")}; - try { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "query_file", po::value(&query_file)->required(), - program_options_utils::QUERY_FILE_DESCRIPTION); - required_configs.add_options()( - "search_list,L", - po::value>(&Lvec)->multitoken()->required(), - program_options_utils::SEARCH_LIST_DESCRIPTION); - required_configs.add_options()("range_threshold,K", - po::value(&range)->required(), - "Number of neighbors to be returned"); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "gt_file", - po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()( - "num_nodes_to_cache", - po::value(&num_nodes_to_cache)->default_value(0), - program_options_utils::NUMBER_OF_NODES_TO_CACHE); - optional_configs.add_options()("beamwidth,W", - po::value(&W)->default_value(2), - program_options_utils::BEAMWIDTH); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file; + uint32_t num_threads, W, num_nodes_to_cache; + std::vector Lvec; + float range; + + po::options_description desc{program_options_utils::make_program_description( + "range_search_disk_index", "Searches disk DiskANN indexes using ranges")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()("search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + required_configs.add_options()("range_threshold,K", po::value(&range)->required(), + "Number of neighbors to be returned"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." + << std::endl; + return -1; + } + + if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) + { + std::cout << "Currently support only floating point data for Inner Product." << std::endl; + return -1; + } + + try + { + if (data_type == std::string("float")) + return search_disk_index(metric, index_path_prefix, query_file, gt_file, num_threads, range, W, + num_nodes_to_cache, Lvec); + else if (data_type == std::string("int8")) + return search_disk_index(metric, index_path_prefix, query_file, gt_file, num_threads, range, W, + num_nodes_to_cache, Lvec); + else if (data_type == std::string("uint8")) + return search_disk_index(metric, index_path_prefix, query_file, gt_file, num_threads, range, W, + num_nodes_to_cache, Lvec); + else + { + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; + return -1; + } } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("mips")) { - metric = diskann::Metric::INNER_PRODUCT; - } else if (dist_fn == std::string("l2")) { - metric = diskann::Metric::L2; - } else if (dist_fn == std::string("cosine")) { - metric = diskann::Metric::COSINE; - } else { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." - << std::endl; - return -1; - } - - if ((data_type != std::string("float")) && - (metric == diskann::Metric::INNER_PRODUCT)) { - std::cout << "Currently support only floating point data for Inner Product." - << std::endl; - return -1; - } - - try { - if (data_type == std::string("float")) - return search_disk_index(metric, index_path_prefix, query_file, - gt_file, num_threads, range, W, - num_nodes_to_cache, Lvec); - else if (data_type == std::string("int8")) - return search_disk_index(metric, index_path_prefix, query_file, - gt_file, num_threads, range, W, - num_nodes_to_cache, Lvec); - else if (data_type == std::string("uint8")) - return search_disk_index(metric, index_path_prefix, query_file, - gt_file, num_threads, range, W, - num_nodes_to_cache, Lvec); - else { - std::cerr << "Unsupported data type. Use float or int8 or uint8" - << std::endl; - return -1; + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index search failed." << std::endl; + return -1; } - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index search failed." << std::endl; - return -1; - } } diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 307fa8e7a..925f31775 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -31,476 +31,466 @@ namespace po = boost::program_options; -void print_stats(std::string category, std::vector percentiles, - std::vector results) { - diskann::cout << std::setw(20) << category << ": " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) { - diskann::cout << std::setw(8) << percentiles[s] << "%"; - } - diskann::cout << std::endl; - diskann::cout << std::setw(22) << " " << std::flush; - for (uint32_t s = 0; s < percentiles.size(); s++) { - diskann::cout << std::setw(9) << results[s]; - } - diskann::cout << std::endl; +void print_stats(std::string category, std::vector percentiles, std::vector results) +{ + diskann::cout << std::setw(20) << category << ": " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) + { + diskann::cout << std::setw(8) << percentiles[s] << "%"; + } + diskann::cout << std::endl; + diskann::cout << std::setw(22) << " " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) + { + diskann::cout << std::setw(9) << results[s]; + } + diskann::cout << std::endl; } template -int search_disk_index( - diskann::Metric &metric, const std::string &index_path_prefix, - const std::string &result_output_prefix, const std::string &query_file, - std::string >_file, const uint32_t num_threads, const uint32_t recall_at, - const uint32_t beamwidth, const uint32_t num_nodes_to_cache, - const uint32_t search_io_limit, const std::vector &Lvec, - const float fail_if_recall_below, - const std::vector &query_filters, - const bool use_reorder_data = false) { - diskann::cout << "Search parameters: #threads: " << num_threads << ", "; - if (beamwidth <= 0) - diskann::cout << "beamwidth to be optimized for each L value" << std::flush; - else - diskann::cout << " beamwidth: " << beamwidth << std::flush; - if (search_io_limit == std::numeric_limits::max()) - diskann::cout << "." << std::endl; - else - diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; - - std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; - - // load query bin - T *query = nullptr; - uint32_t *gt_ids = nullptr; - float *gt_dists = nullptr; - size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; - diskann::load_aligned_bin(query_file, query, query_num, query_dim, - query_aligned_dim); - - bool filtered_search = false; - if (!query_filters.empty()) { - filtered_search = true; - if (query_filters.size() != 1 && query_filters.size() != query_num) { - std::cout << "Error. Mismatch in number of queries and size of query " - "filters file" - << std::endl; - return -1; // To return -1 or some other error handling? +int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, + const std::string &result_output_prefix, const std::string &query_file, std::string >_file, + const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, + const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, + const std::vector &Lvec, const float fail_if_recall_below, + const std::vector &query_filters, const bool use_reorder_data = false) +{ + diskann::cout << "Search parameters: #threads: " << num_threads << ", "; + if (beamwidth <= 0) + diskann::cout << "beamwidth to be optimized for each L value" << std::flush; + else + diskann::cout << " beamwidth: " << beamwidth << std::flush; + if (search_io_limit == std::numeric_limits::max()) + diskann::cout << "." << std::endl; + else + diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; + + std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; + + // load query bin + T *query = nullptr; + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; + diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); + + bool filtered_search = false; + if (!query_filters.empty()) + { + filtered_search = true; + if (query_filters.size() != 1 && query_filters.size() != query_num) + { + std::cout << "Error. Mismatch in number of queries and size of query " + "filters file" + << std::endl; + return -1; // To return -1 or some other error handling? + } } - } - - bool calc_recall_flag = false; - if (gt_file != std::string("null") && gt_file != std::string("NULL") && - file_exists(gt_file)) { - diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim); - if (gt_num != query_num) { - diskann::cout - << "Error. Mismatch in number of queries and ground truth data" - << std::endl; + + bool calc_recall_flag = false; + if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file)) + { + diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim); + if (gt_num != query_num) + { + diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; + } + calc_recall_flag = true; } - calc_recall_flag = true; - } - std::shared_ptr reader = nullptr; + std::shared_ptr reader = nullptr; #ifdef _WINDOWS #ifndef USE_BING_INFRA - reader.reset(new WindowsAlignedFileReader()); + reader.reset(new WindowsAlignedFileReader()); #else - reader.reset(new diskann::BingAlignedFileReader()); + reader.reset(new diskann::BingAlignedFileReader()); #endif #else - reader.reset(new LinuxAlignedFileReader()); + reader.reset(new LinuxAlignedFileReader()); #endif - std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader, metric)); - - int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); - - if (res != 0) { - return res; - } - - std::vector node_list; - diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" - << std::endl; - _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); - // if (num_nodes_to_cache > 0) - // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, - // 15, 6, num_nodes_to_cache, num_threads, node_list); - _pFlashIndex->load_cache_list(node_list); - node_list.clear(); - node_list.shrink_to_fit(); - - omp_set_num_threads(num_threads); - - uint64_t warmup_L = 20; - uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; - T *warmup = nullptr; - - if (WARMUP) { - if (file_exists(warmup_query_file)) { - diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, - warmup_dim, warmup_aligned_dim); - } else { - warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); - warmup_dim = query_dim; - warmup_aligned_dim = query_aligned_dim; - diskann::alloc_aligned(((void **)&warmup), - warmup_num * warmup_aligned_dim * sizeof(T), - 8 * sizeof(T)); - std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(-128, 127); - for (uint32_t i = 0; i < warmup_num; i++) { - for (uint32_t d = 0; d < warmup_dim; d++) { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); - } - } + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); + + int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); + + if (res != 0) + { + return res; } - diskann::cout << "Warming up index... " << std::flush; - std::vector warmup_result_ids_64(warmup_num, 0); - std::vector warmup_result_dists(warmup_num, 0); + + std::vector node_list; + diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl; + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // if (num_nodes_to_cache > 0) + // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, + // 15, 6, num_nodes_to_cache, num_threads, node_list); + _pFlashIndex->load_cache_list(node_list); + node_list.clear(); + node_list.shrink_to_fit(); + + omp_set_num_threads(num_threads); + + uint64_t warmup_L = 20; + uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; + T *warmup = nullptr; + + if (WARMUP) + { + if (file_exists(warmup_query_file)) + { + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); + } + else + { + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); + warmup_dim = query_dim; + warmup_aligned_dim = query_aligned_dim; + diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) + { + for (uint32_t d = 0; d < warmup_dim; d++) + { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + } + } + } + diskann::cout << "Warming up index... " << std::flush; + std::vector warmup_result_ids_64(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)warmup_num; i++) { - _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, - warmup_L, - warmup_result_ids_64.data() + (i * 1), - warmup_result_dists.data() + (i * 1), 4); + for (int64_t i = 0; i < (int64_t)warmup_num; i++) + { + _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, + warmup_result_ids_64.data() + (i * 1), + warmup_result_dists.data() + (i * 1), 4); + } + diskann::cout << "..done" << std::endl; } - diskann::cout << "..done" << std::endl; - } - - diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - diskann::cout.precision(2); - - std::string recall_string = "Recall@" + std::to_string(recall_at); - diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" - << std::setw(16) << "QPS" << std::setw(16) << "Mean Latency" - << std::setw(16) << "99.9 Latency" << std::setw(16) - << "Mean IOs" << std::setw(16) << "CPU (s)"; - if (calc_recall_flag) { - diskann::cout << std::setw(16) << recall_string << std::endl; - } else - diskann::cout << std::endl; - diskann::cout - << "===============================================================" - "=======================================================" - << std::endl; - std::vector> query_result_ids(Lvec.size()); - std::vector> query_result_dists(Lvec.size()); + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + diskann::cout.precision(2); - uint32_t optimized_beamwidth = 2; + std::string recall_string = "Recall@" + std::to_string(recall_at); + diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16) + << "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16) + << "CPU (s)"; + if (calc_recall_flag) + { + diskann::cout << std::setw(16) << recall_string << std::endl; + } + else + diskann::cout << std::endl; + diskann::cout << "===============================================================" + "=======================================================" + << std::endl; - double best_recall = 0.0; + std::vector> query_result_ids(Lvec.size()); + std::vector> query_result_dists(Lvec.size()); - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { - uint32_t L = Lvec[test_id]; + uint32_t optimized_beamwidth = 2; - if (L < recall_at) { - diskann::cout << "Ignoring search with L:" << L - << " since it's smaller than K:" << recall_at << std::endl; - continue; - } + double best_recall = 0.0; + + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) + { + uint32_t L = Lvec[test_id]; - if (beamwidth <= 0) { - diskann::cout << "Tuning beamwidth.." << std::endl; - optimized_beamwidth = - optimize_beamwidth(_pFlashIndex, warmup, warmup_num, - warmup_aligned_dim, L, optimized_beamwidth); - } else - optimized_beamwidth = beamwidth; + if (L < recall_at) + { + diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; + continue; + } + + if (beamwidth <= 0) + { + diskann::cout << "Tuning beamwidth.." << std::endl; + optimized_beamwidth = + optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); + } + else + optimized_beamwidth = beamwidth; - query_result_ids[test_id].resize(recall_at * query_num); - query_result_dists[test_id].resize(recall_at * query_num); + query_result_ids[test_id].resize(recall_at * query_num); + query_result_dists[test_id].resize(recall_at * query_num); - auto stats = new diskann::QueryStats[query_num]; + auto stats = new diskann::QueryStats[query_num]; - std::vector query_result_ids_64(recall_at * query_num); - auto s = std::chrono::high_resolution_clock::now(); + std::vector query_result_ids_64(recall_at * query_num); + auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) { - if (!filtered_search) { - _pFlashIndex->cached_beam_search( - query + (i * query_aligned_dim), recall_at, L, - query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, use_reorder_data, stats + i); - } else { - LabelT label_for_search; - if (query_filters.size() == 1) { // one label for all queries - label_for_search = - _pFlashIndex->get_converted_label(query_filters[0]); - } else { // one label for each query - label_for_search = - _pFlashIndex->get_converted_label(query_filters[i]); + for (int64_t i = 0; i < (int64_t)query_num; i++) + { + if (!filtered_search) + { + _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, + query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), + optimized_beamwidth, use_reorder_data, stats + i); + } + else + { + LabelT label_for_search; + if (query_filters.size() == 1) + { // one label for all queries + label_for_search = _pFlashIndex->get_converted_label(query_filters[0]); + } + else + { // one label for each query + label_for_search = _pFlashIndex->get_converted_label(query_filters[i]); + } + _pFlashIndex->cached_beam_search( + query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, + use_reorder_data, stats + i); + } + } + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + double qps = (1.0 * query_num) / (1.0 * diff.count()); + + diskann::convert_types(query_result_ids_64.data(), query_result_ids[test_id].data(), + query_num, recall_at); + + auto mean_latency = diskann::get_mean_stats( + stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto latency_999 = diskann::get_percentile_stats( + stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); + + auto mean_ios = diskann::get_mean_stats(stats, query_num, + [](const diskann::QueryStats &stats) { return stats.n_ios; }); + + auto mean_cpuus = diskann::get_mean_stats(stats, query_num, + [](const diskann::QueryStats &stats) { return stats.cpu_us; }); + + double recall = 0; + if (calc_recall_flag) + { + recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, + query_result_ids[test_id].data(), recall_at, recall_at); + best_recall = std::max(recall, best_recall); } - _pFlashIndex->cached_beam_search( - query + (i * query_aligned_dim), recall_at, L, - query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, true, label_for_search, use_reorder_data, - stats + i); - } + + diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps + << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios + << std::setw(16) << mean_cpuus; + if (calc_recall_flag) + { + diskann::cout << std::setw(16) << recall << std::endl; + } + else + diskann::cout << std::endl; + delete[] stats; } - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - double qps = (1.0 * query_num) / (1.0 * diff.count()); - - diskann::convert_types(query_result_ids_64.data(), - query_result_ids[test_id].data(), - query_num, recall_at); - - auto mean_latency = diskann::get_mean_stats( - stats, query_num, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto latency_999 = diskann::get_percentile_stats( - stats, query_num, 0.999, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - auto mean_ios = diskann::get_mean_stats( - stats, query_num, - [](const diskann::QueryStats &stats) { return stats.n_ios; }); - - auto mean_cpuus = diskann::get_mean_stats( - stats, query_num, - [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - - double recall = 0; - if (calc_recall_flag) { - recall = diskann::calculate_recall( - (uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, recall_at); - best_recall = std::max(recall, best_recall); + + diskann::cout << "Done searching. Now saving results " << std::endl; + uint64_t test_id = 0; + for (auto L : Lvec) + { + if (L < recall_at) + continue; + + std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; + diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); + + cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; + diskann::save_bin(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at); } - diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth - << std::setw(16) << qps << std::setw(16) << mean_latency - << std::setw(16) << latency_999 << std::setw(16) << mean_ios - << std::setw(16) << mean_cpuus; - if (calc_recall_flag) { - diskann::cout << std::setw(16) << recall << std::endl; - } else - diskann::cout << std::endl; - delete[] stats; - } - - diskann::cout << "Done searching. Now saving results " << std::endl; - uint64_t test_id = 0; - for (auto L : Lvec) { - if (L < recall_at) - continue; - - std::string cur_result_path = - result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; - diskann::save_bin(cur_result_path, - query_result_ids[test_id].data(), query_num, - recall_at); - - cur_result_path = - result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; - diskann::save_bin(cur_result_path, - query_result_dists[test_id++].data(), query_num, - recall_at); - } - - diskann::aligned_free(query); - if (warmup != nullptr) - diskann::aligned_free(warmup); - return best_recall >= fail_if_recall_below ? 0 : -1; + diskann::aligned_free(query); + if (warmup != nullptr) + diskann::aligned_free(warmup); + return best_recall >= fail_if_recall_below ? 0 : -1; } -int main(int argc, char **argv) { - std::string data_type, dist_fn, index_path_prefix, result_path_prefix, - query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; - std::vector Lvec; - bool use_reorder_data = false; - float fail_if_recall_below = 0.0f; - - po::options_description desc{program_options_utils::make_program_description( - "search_disk_index", "Searches on-disk DiskANN indexes")}; - try { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "result_path", po::value(&result_path_prefix)->required(), - program_options_utils::RESULT_PATH_DESCRIPTION); - required_configs.add_options()( - "query_file", po::value(&query_file)->required(), - program_options_utils::QUERY_FILE_DESCRIPTION); - required_configs.add_options()( - "recall_at,K", po::value(&K)->required(), - program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); - required_configs.add_options()( - "search_list,L", - po::value>(&Lvec)->multitoken()->required(), - program_options_utils::SEARCH_LIST_DESCRIPTION); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "gt_file", - po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()("beamwidth,W", - po::value(&W)->default_value(2), - program_options_utils::BEAMWIDTH); - optional_configs.add_options()( - "num_nodes_to_cache", - po::value(&num_nodes_to_cache)->default_value(0), - program_options_utils::NUMBER_OF_NODES_TO_CACHE); - optional_configs.add_options()( - "search_io_limit", - po::value(&search_io_limit) - ->default_value(std::numeric_limits::max()), - "Max #IOs for search. Default value: uint32::max()"); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "use_reorder_data", po::bool_switch()->default_value(false), - "Include full precision data in the index. Use only in " - "conjuction with compressed data on SSD. Default value: false"); - optional_configs.add_options()( - "filter_label", - po::value(&filter_label)->default_value(std::string("")), - program_options_utils::FILTER_LABEL_DESCRIPTION); - optional_configs.add_options()( - "query_filters_file", - po::value(&query_filters_file) - ->default_value(std::string("")), - program_options_utils::FILTERS_FILE_DESCRIPTION); - optional_configs.add_options()( - "label_type", - po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - optional_configs.add_options()( - "fail_if_recall_below", - po::value(&fail_if_recall_below)->default_value(0.0f), - program_options_utils::FAIL_IF_RECALL_BELOW); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, + label_type, query_filters_file; + uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + std::vector Lvec; + bool use_reorder_data = false; + float fail_if_recall_below = 0.0f; + + po::options_description desc{ + program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("result_path", po::value(&result_path_prefix)->required(), + program_options_utils::RESULT_PATH_DESCRIPTION); + required_configs.add_options()("query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()("recall_at,K", po::value(&K)->required(), + program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); + required_configs.add_options()("search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()( + "search_io_limit", + po::value(&search_io_limit)->default_value(std::numeric_limits::max()), + "Max #IOs for search. Default value: uint32::max()"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD. Default value: false"); + optional_configs.add_options()("filter_label", + po::value(&filter_label)->default_value(std::string("")), + program_options_utils::FILTER_LABEL_DESCRIPTION); + optional_configs.add_options()("query_filters_file", + po::value(&query_filters_file)->default_value(std::string("")), + program_options_utils::FILTERS_FILE_DESCRIPTION); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("fail_if_recall_below", + po::value(&fail_if_recall_below)->default_value(0.0f), + program_options_utils::FAIL_IF_RECALL_BELOW); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + if (vm["use_reorder_data"].as()) + use_reorder_data = true; + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - po::notify(vm); - if (vm["use_reorder_data"].as()) - use_reorder_data = true; - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("mips")) { - metric = diskann::Metric::INNER_PRODUCT; - } else if (dist_fn == std::string("l2")) { - metric = diskann::Metric::L2; - } else if (dist_fn == std::string("cosine")) { - metric = diskann::Metric::COSINE; - } else { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." - << std::endl; - return -1; - } - - if ((data_type != std::string("float")) && - (metric == diskann::Metric::INNER_PRODUCT)) { - std::cout << "Currently support only floating point data for Inner Product." - << std::endl; - return -1; - } - - if (use_reorder_data && data_type != std::string("float")) { - std::cout << "Error: Reorder data for reordering currently only " - "supported for float data type." - << std::endl; - return -1; - } - - if (filter_label != "" && query_filters_file != "") { - std::cerr - << "Only one of filter_label and query_filters_file should be provided" - << std::endl; - return -1; - } - - std::vector query_filters; - if (filter_label != "") { - query_filters.push_back(filter_label); - } else if (query_filters_file != "") { - query_filters = read_file_to_vector_of_strings(query_filters_file); - } - - try { - if (!query_filters.empty() && label_type == "ushort") { - if (data_type == std::string("float")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("int8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("uint8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else { - std::cerr << "Unsupported data type. Use float or int8 or uint8" + + diskann::Metric metric; + if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." << std::endl; return -1; - } - } else { - if (data_type == std::string("float")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("int8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else if (data_type == std::string("uint8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); - else { - std::cerr << "Unsupported data type. Use float or int8 or uint8" + } + + if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) + { + std::cout << "Currently support only floating point data for Inner Product." << std::endl; + return -1; + } + + if (use_reorder_data && data_type != std::string("float")) + { + std::cout << "Error: Reorder data for reordering currently only " + "supported for float data type." << std::endl; return -1; - } } - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index search failed." << std::endl; - return -1; - } + + if (filter_label != "" && query_filters_file != "") + { + std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; + return -1; + } + + std::vector query_filters; + if (filter_label != "") + { + query_filters.push_back(filter_label); + } + else if (query_filters_file != "") + { + query_filters = read_file_to_vector_of_strings(query_filters_file); + } + + try + { + if (!query_filters.empty() && label_type == "ushort") + { + if (data_type == std::string("float")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("int8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("uint8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + else + { + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; + return -1; + } + } + else + { + if (data_type == std::string("float")) + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("int8")) + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else if (data_type == std::string("uint8")) + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data); + else + { + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; + return -1; + } + } + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index search failed." << std::endl; + return -1; + } } \ No newline at end of file diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 64c6ac26f..9126ad1fc 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -26,446 +26,452 @@ namespace po = boost::program_options; template -int search_memory_index(diskann::Metric &metric, const std::string &index_path, - const std::string &result_path_prefix, - const std::string &query_file, - const std::string &truthset_file, - const uint32_t num_threads, const uint32_t recall_at, - const bool print_all_recalls, - const std::vector &Lvec, const bool dynamic, - const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, - const float fail_if_recall_below) { - using TagT = uint32_t; - // Load the query file - T *query = nullptr; - uint32_t *gt_ids = nullptr; - float *gt_dists = nullptr; - size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; - diskann::load_aligned_bin(query_file, query, query_num, query_dim, - query_aligned_dim); - - bool calc_recall_flag = false; - if (truthset_file != std::string("null") && file_exists(truthset_file)) { - diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim); - if (gt_num != query_num) { - std::cout << "Error. Mismatch in number of queries and ground truth data" - << std::endl; +int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix, + const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, + const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, + const bool dynamic, const bool tags, const bool show_qps_per_thread, + const std::vector &query_filters, const float fail_if_recall_below) +{ + using TagT = uint32_t; + // Load the query file + T *query = nullptr; + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; + diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); + + bool calc_recall_flag = false; + if (truthset_file != std::string("null") && file_exists(truthset_file)) + { + diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim); + if (gt_num != query_num) + { + std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; + } + calc_recall_flag = true; + } + else + { + diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl; + } + + bool filtered_search = false; + if (!query_filters.empty()) + { + filtered_search = true; + if (query_filters.size() != 1 && query_filters.size() != query_num) + { + std::cout << "Error. Mismatch in number of queries and size of query " + "filters file" + << std::endl; + return -1; // To return -1 or some other error handling? + } + } + + const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); + + auto config = diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(query_dim) + .with_max_points(0) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .with_data_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_tag_type(diskann_type_to_name()) + .is_dynamic_index(dynamic) + .is_enable_tags(tags) + .is_concurrent_consolidate(false) + .is_pq_dist_build(false) + .is_use_opq(false) + .with_num_pq_chunks(0) + .with_num_frozen_pts(num_frozen_pts) + .build(); + + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); + std::cout << "Index loaded" << std::endl; + + if (metric == diskann::FAST_L2) + index->optimize_index_layout(); + + std::cout << "Using " << num_threads << " threads to search" << std::endl; + std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + std::cout.precision(2); + const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS"; + uint32_t table_width = 0; + if (tags) + { + std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)" + << std::setw(15) << "99.9 Latency"; + table_width += 4 + 12 + 20 + 15; + } + else + { + std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps" + << std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency"; + table_width += 4 + 12 + 18 + 20 + 15; } - calc_recall_flag = true; - } else { - diskann::cout << " Truthset file " << truthset_file - << " not found. Not computing recall." << std::endl; - } - - bool filtered_search = false; - if (!query_filters.empty()) { - filtered_search = true; - if (query_filters.size() != 1 && query_filters.size() != query_num) { - std::cout << "Error. Mismatch in number of queries and size of query " - "filters file" - << std::endl; - return -1; // To return -1 or some other error handling? + uint32_t recalls_to_print = 0; + const uint32_t first_recall = print_all_recalls ? 1 : recall_at; + if (calc_recall_flag) + { + for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) + { + std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall)); + } + recalls_to_print = recall_at + 1 - first_recall; + table_width += recalls_to_print * 12; } - } - - const size_t num_frozen_pts = - diskann::get_graph_num_frozen_points(index_path); - - auto config = - diskann::IndexConfigBuilder() - .with_metric(metric) - .with_dimension(query_dim) - .with_max_points(0) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .with_data_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_tag_type(diskann_type_to_name()) - .is_dynamic_index(dynamic) - .is_enable_tags(tags) - .is_concurrent_consolidate(false) - .is_pq_dist_build(false) - .is_use_opq(false) - .with_num_pq_chunks(0) - .with_num_frozen_pts(num_frozen_pts) - .build(); - - auto index_factory = diskann::IndexFactory(config); - auto index = index_factory.create_instance(); - index->load(index_path.c_str(), num_threads, - *(std::max_element(Lvec.begin(), Lvec.end()))); - std::cout << "Index loaded" << std::endl; - - if (metric == diskann::FAST_L2) - index->optimize_index_layout(); - - std::cout << "Using " << num_threads << " threads to search" << std::endl; - std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - std::cout.precision(2); - const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS"; - uint32_t table_width = 0; - if (tags) { - std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title - << std::setw(20) << "Mean Latency (mus)" << std::setw(15) - << "99.9 Latency"; - table_width += 4 + 12 + 20 + 15; - } else { - std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title - << std::setw(18) << "Avg dist cmps" << std::setw(20) - << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency"; - table_width += 4 + 12 + 18 + 20 + 15; - } - uint32_t recalls_to_print = 0; - const uint32_t first_recall = print_all_recalls ? 1 : recall_at; - if (calc_recall_flag) { - for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; - curr_recall++) { - std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall)); + std::cout << std::endl; + std::cout << std::string(table_width, '=') << std::endl; + + std::vector> query_result_ids(Lvec.size()); + std::vector> query_result_dists(Lvec.size()); + std::vector latency_stats(query_num, 0); + std::vector cmp_stats; + if (not tags || filtered_search) + { + cmp_stats = std::vector(query_num, 0); } - recalls_to_print = recall_at + 1 - first_recall; - table_width += recalls_to_print * 12; - } - std::cout << std::endl; - std::cout << std::string(table_width, '=') << std::endl; - - std::vector> query_result_ids(Lvec.size()); - std::vector> query_result_dists(Lvec.size()); - std::vector latency_stats(query_num, 0); - std::vector cmp_stats; - if (not tags || filtered_search) { - cmp_stats = std::vector(query_num, 0); - } - - std::vector query_result_tags; - if (tags) { - query_result_tags.resize(recall_at * query_num); - } - - double best_recall = 0.0; - - for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { - uint32_t L = Lvec[test_id]; - if (L < recall_at) { - diskann::cout << "Ignoring search with L:" << L - << " since it's smaller than K:" << recall_at << std::endl; - continue; + + std::vector query_result_tags; + if (tags) + { + query_result_tags.resize(recall_at * query_num); } - query_result_ids[test_id].resize(recall_at * query_num); - query_result_dists[test_id].resize(recall_at * query_num); - std::vector res = std::vector(); + double best_recall = 0.0; + + for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) + { + uint32_t L = Lvec[test_id]; + if (L < recall_at) + { + diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; + continue; + } + + query_result_ids[test_id].resize(recall_at * query_num); + query_result_dists[test_id].resize(recall_at * query_num); + std::vector res = std::vector(); - auto s = std::chrono::high_resolution_clock::now(); - omp_set_num_threads(num_threads); + auto s = std::chrono::high_resolution_clock::now(); + omp_set_num_threads(num_threads); #pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) { - auto qs = std::chrono::high_resolution_clock::now(); - if (filtered_search && !tags) { - std::string raw_filter = - query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - auto retval = index->search_with_filters( - query + i * query_aligned_dim, raw_filter, recall_at, L, - query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at); - cmp_stats[i] = retval.second; - } else if (metric == diskann::FAST_L2) { - index->search_with_optimized_layout( - query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at); - } else if (tags) { - if (!filtered_search) { - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, - nullptr, res); - } else { - std::string raw_filter = - query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, - nullptr, res, true, raw_filter); + for (int64_t i = 0; i < (int64_t)query_num; i++) + { + auto qs = std::chrono::high_resolution_clock::now(); + if (filtered_search && !tags) + { + std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at); + cmp_stats[i] = retval.second; + } + else if (metric == diskann::FAST_L2) + { + index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at); + } + else if (tags) + { + if (!filtered_search) + { + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res); + } + else + { + std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter); + } + + for (int64_t r = 0; r < (int64_t)recall_at; r++) + { + query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; + } + } + else + { + cmp_stats[i] = index + ->search(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at) + .second; + } + auto qe = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = qe - qs; + latency_stats[i] = (float)(diff.count() * 1000000); + } + std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; + + double displayed_qps = query_num / diff.count(); + + if (show_qps_per_thread) + displayed_qps /= num_threads; + + std::vector recalls; + if (calc_recall_flag) + { + recalls.reserve(recalls_to_print); + for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) + { + recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, + query_result_ids[test_id].data(), recall_at, curr_recall)); + } } - for (int64_t r = 0; r < (int64_t)recall_at; r++) { - query_result_ids[test_id][recall_at * i + r] = - query_result_tags[recall_at * i + r]; + std::sort(latency_stats.begin(), latency_stats.end()); + double mean_latency = + std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast(query_num); + + float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num; + + if (tags && !filtered_search) + { + std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency + << std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)]; + } + else + { + std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps + << std::setw(20) << (float)mean_latency << std::setw(15) + << (float)latency_stats[(uint64_t)(0.999 * query_num)]; } - } else { - cmp_stats[i] = - index - ->search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) - .second; - } - auto qe = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = qe - qs; - latency_stats[i] = (float)(diff.count() * 1000000); + for (double recall : recalls) + { + std::cout << std::setw(12) << recall; + best_recall = std::max(recall, best_recall); + } + std::cout << std::endl; } - std::chrono::duration diff = - std::chrono::high_resolution_clock::now() - s; - - double displayed_qps = query_num / diff.count(); - - if (show_qps_per_thread) - displayed_qps /= num_threads; - - std::vector recalls; - if (calc_recall_flag) { - recalls.reserve(recalls_to_print); - for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; - curr_recall++) { - recalls.push_back(diskann::calculate_recall( - (uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, curr_recall)); - } + + std::cout << "Done searching. Now saving results " << std::endl; + uint64_t test_id = 0; + for (auto L : Lvec) + { + if (L < recall_at) + { + diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; + continue; + } + std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L); + + std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin"; + diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); + + cur_result_path = cur_result_path_prefix + "_dists_float.bin"; + diskann::save_bin(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at); + + test_id++; } - std::sort(latency_stats.begin(), latency_stats.end()); - double mean_latency = - std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / - static_cast(query_num); - - float avg_cmps = - (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / - (float)query_num; - - if (tags && !filtered_search) { - std::cout << std::setw(4) << L << std::setw(12) << displayed_qps - << std::setw(20) << (float)mean_latency << std::setw(15) - << (float)latency_stats[(uint64_t)(0.999 * query_num)]; - } else { - std::cout << std::setw(4) << L << std::setw(12) << displayed_qps - << std::setw(18) << avg_cmps << std::setw(20) - << (float)mean_latency << std::setw(15) - << (float)latency_stats[(uint64_t)(0.999 * query_num)]; + diskann::aligned_free(query); + return best_recall >= fail_if_recall_below ? 0 : -1; +} + +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, + query_filters_file; + uint32_t num_threads, K; + std::vector Lvec; + bool print_all_recalls, dynamic, tags, show_qps_per_thread; + float fail_if_recall_below = 0.0f; + + po::options_description desc{ + program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")}; + try + { + desc.add_options()("help,h", "Print this information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("result_path", po::value(&result_path)->required(), + program_options_utils::RESULT_PATH_DESCRIPTION); + required_configs.add_options()("query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()("recall_at,K", po::value(&K)->required(), + program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); + required_configs.add_options()("search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("filter_label", + po::value(&filter_label)->default_value(std::string("")), + program_options_utils::FILTER_LABEL_DESCRIPTION); + optional_configs.add_options()("query_filters_file", + po::value(&query_filters_file)->default_value(std::string("")), + program_options_utils::FILTERS_FILE_DESCRIPTION); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("dynamic", po::value(&dynamic)->default_value(false), + "Whether the index is dynamic. Dynamic indices must have associated " + "tags. Default false."); + optional_configs.add_options()("tags", po::value(&tags)->default_value(false), + "Whether to search with external identifiers (tags). Default false."); + optional_configs.add_options()("fail_if_recall_below", + po::value(&fail_if_recall_below)->default_value(0.0f), + program_options_utils::FAIL_IF_RECALL_BELOW); + + // Output controls + po::options_description output_controls("Output controls"); + output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls), + "Print recalls at all positions, from 1 up to specified " + "recall_at value"); + output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread), + "Print overall QPS divided by the number of threads in " + "the output table"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs).add(output_controls); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); } - for (double recall : recalls) { - std::cout << std::setw(12) << recall; - best_recall = std::max(recall, best_recall); + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - std::cout << std::endl; - } - - std::cout << "Done searching. Now saving results " << std::endl; - uint64_t test_id = 0; - for (auto L : Lvec) { - if (L < recall_at) { - diskann::cout << "Ignoring search with L:" << L - << " since it's smaller than K:" << recall_at << std::endl; - continue; + + diskann::Metric metric; + if ((dist_fn == std::string("mips")) && (data_type == std::string("float"))) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float"))) + { + metric = diskann::Metric::FAST_L2; + } + else + { + std::cout << "Unsupported distance function. Currently only l2/ cosine are " + "supported in general, and mips/fast_l2 only for floating " + "point data." + << std::endl; + return -1; } - std::string cur_result_path_prefix = - result_path_prefix + "_" + std::to_string(L); - std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin"; - diskann::save_bin(cur_result_path, - query_result_ids[test_id].data(), query_num, - recall_at); + if (dynamic && not tags) + { + std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl; + return -1; + } - cur_result_path = cur_result_path_prefix + "_dists_float.bin"; - diskann::save_bin(cur_result_path, - query_result_dists[test_id].data(), query_num, - recall_at); + if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0) + { + std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl; + return -1; + } - test_id++; - } + if (filter_label != "" && query_filters_file != "") + { + std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; + return -1; + } - diskann::aligned_free(query); - return best_recall >= fail_if_recall_below ? 0 : -1; -} + std::vector query_filters; + if (filter_label != "") + { + query_filters.push_back(filter_label); + } + else if (query_filters_file != "") + { + query_filters = read_file_to_vector_of_strings(query_filters_file); + } -int main(int argc, char **argv) { - std::string data_type, dist_fn, index_path_prefix, result_path, query_file, - gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K; - std::vector Lvec; - bool print_all_recalls, dynamic, tags, show_qps_per_thread; - float fail_if_recall_below = 0.0f; - - po::options_description desc{program_options_utils::make_program_description( - "search_memory_index", "Searches in-memory DiskANN indexes")}; - try { - desc.add_options()("help,h", "Print this information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "result_path", po::value(&result_path)->required(), - program_options_utils::RESULT_PATH_DESCRIPTION); - required_configs.add_options()( - "query_file", po::value(&query_file)->required(), - program_options_utils::QUERY_FILE_DESCRIPTION); - required_configs.add_options()( - "recall_at,K", po::value(&K)->required(), - program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); - required_configs.add_options()( - "search_list,L", - po::value>(&Lvec)->multitoken()->required(), - program_options_utils::SEARCH_LIST_DESCRIPTION); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "filter_label", - po::value(&filter_label)->default_value(std::string("")), - program_options_utils::FILTER_LABEL_DESCRIPTION); - optional_configs.add_options()( - "query_filters_file", - po::value(&query_filters_file) - ->default_value(std::string("")), - program_options_utils::FILTERS_FILE_DESCRIPTION); - optional_configs.add_options()( - "label_type", - po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - optional_configs.add_options()( - "gt_file", - po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "dynamic", po::value(&dynamic)->default_value(false), - "Whether the index is dynamic. Dynamic indices must have associated " - "tags. Default false."); - optional_configs.add_options()( - "tags", po::value(&tags)->default_value(false), - "Whether to search with external identifiers (tags). Default false."); - optional_configs.add_options()( - "fail_if_recall_below", - po::value(&fail_if_recall_below)->default_value(0.0f), - program_options_utils::FAIL_IF_RECALL_BELOW); - - // Output controls - po::options_description output_controls("Output controls"); - output_controls.add_options()( - "print_all_recalls", po::bool_switch(&print_all_recalls), - "Print recalls at all positions, from 1 up to specified " - "recall_at value"); - output_controls.add_options()( - "print_qps_per_thread", po::bool_switch(&show_qps_per_thread), - "Print overall QPS divided by the number of threads in " - "the output table"); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs).add(output_controls); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; + try + { + if (!query_filters.empty() && label_type == "ushort") + { + if (data_type == std::string("int8")) + { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + } + else if (data_type == std::string("uint8")) + { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + } + else if (data_type == std::string("float")) + { + return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } + else + { + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + return -1; + } + } + else + { + if (data_type == std::string("int8")) + { + return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } + else if (data_type == std::string("uint8")) + { + return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } + else if (data_type == std::string("float")) + { + return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, query_filters, fail_if_recall_below); + } + else + { + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + return -1; + } + } } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if ((dist_fn == std::string("mips")) && (data_type == std::string("float"))) { - metric = diskann::Metric::INNER_PRODUCT; - } else if (dist_fn == std::string("l2")) { - metric = diskann::Metric::L2; - } else if (dist_fn == std::string("cosine")) { - metric = diskann::Metric::COSINE; - } else if ((dist_fn == std::string("fast_l2")) && - (data_type == std::string("float"))) { - metric = diskann::Metric::FAST_L2; - } else { - std::cout << "Unsupported distance function. Currently only l2/ cosine are " - "supported in general, and mips/fast_l2 only for floating " - "point data." - << std::endl; - return -1; - } - - if (dynamic && not tags) { - std::cerr - << "Tags must be enabled while searching dynamically built indices" - << std::endl; - return -1; - } - - if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0) { - std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" - << std::endl; - return -1; - } - - if (filter_label != "" && query_filters_file != "") { - std::cerr - << "Only one of filter_label and query_filters_file should be provided" - << std::endl; - return -1; - } - - std::vector query_filters; - if (filter_label != "") { - query_filters.push_back(filter_label); - } else if (query_filters_file != "") { - query_filters = read_file_to_vector_of_strings(query_filters_file); - } - - try { - if (!query_filters.empty() && label_type == "ushort") { - if (data_type == std::string("int8")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } else if (data_type == std::string("uint8")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } else if (data_type == std::string("float")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; - return -1; - } - } else { - if (data_type == std::string("int8")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } else if (data_type == std::string("uint8")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } else if (data_type == std::string("float")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + catch (std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index search failed." << std::endl; return -1; - } } - } catch (std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index search failed." << std::endl; - return -1; - } } diff --git a/apps/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp index 047a677c9..21ce4250f 100644 --- a/apps/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -28,558 +28,509 @@ namespace po = boost::program_options; // load_aligned_bin modified to read pieces of the file, but using ifstream // instead of cached_ifstream. template -inline void load_aligned_bin_part(const std::string &bin_file, T *data, - size_t offset_points, size_t points_to_read) { - diskann::Timer timer; - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(bin_file, std::ios::binary | std::ios::ate); - size_t actual_file_size = reader.tellg(); - reader.seekg(0, std::ios::beg); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - size_t npts = (uint32_t)npts_i32; - size_t dim = (uint32_t)dim_i32; - - size_t expected_actual_file_size = - npts * dim * sizeof(T) + 2 * sizeof(uint32_t); - if (actual_file_size != expected_actual_file_size) { - std::stringstream stream; - stream << "Error. File size mismatch. Actual size is " << actual_file_size - << " while expected size is " << expected_actual_file_size - << " npts = " << npts << " dim = " << dim - << " size of = " << sizeof(T) << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - if (offset_points + points_to_read > npts) { - std::stringstream stream; - stream << "Error. Not enough points in file. Requested " << offset_points - << " offset and " << points_to_read << " points, but have only " - << npts << " points" << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); - - const size_t rounded_dim = ROUND_UP(dim, 8); - - for (size_t i = 0; i < points_to_read; i++) { - reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); - memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); - } - reader.close(); - - const double elapsedSeconds = timer.elapsed() / 1000000.0; - std::cout << "Read " << points_to_read << " points using non-cached reads in " - << elapsedSeconds << std::endl; +inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read) +{ + diskann::Timer timer; + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(bin_file, std::ios::binary | std::ios::ate); + size_t actual_file_size = reader.tellg(); + reader.seekg(0, std::ios::beg); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + size_t npts = (uint32_t)npts_i32; + size_t dim = (uint32_t)dim_i32; + + size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); + if (actual_file_size != expected_actual_file_size) + { + std::stringstream stream; + stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is " + << expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of = " << sizeof(T) + << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (offset_points + points_to_read > npts) + { + std::stringstream stream; + stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read + << " points, but have only " << npts << " points" << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); + + const size_t rounded_dim = ROUND_UP(dim, 8); + + for (size_t i = 0; i < points_to_read; i++) + { + reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); + memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); + } + reader.close(); + + const double elapsedSeconds = timer.elapsed() / 1000000.0; + std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl; } -std::string get_save_filename(const std::string &save_path, - size_t points_to_skip, size_t points_deleted, - size_t last_point_threshold) { - std::string final_path = save_path; - if (points_to_skip > 0) { - final_path += "skip" + std::to_string(points_to_skip) + "-"; - } - - final_path += "del" + std::to_string(points_deleted) + "-"; - final_path += std::to_string(last_point_threshold); - return final_path; +std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted, + size_t last_point_threshold) +{ + std::string final_path = save_path; + if (points_to_skip > 0) + { + final_path += "skip" + std::to_string(points_to_skip) + "-"; + } + + final_path += "del" + std::to_string(points_deleted) + "-"; + final_path += std::to_string(last_point_threshold); + return final_path; } template -void insert_till_next_checkpoint( - diskann::AbstractIndex &index, size_t start, size_t end, - int32_t thread_count, T *data, size_t aligned_dim, - std::vector> &location_to_labels) { - diskann::Timer insert_timer; +void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data, + size_t aligned_dim, std::vector> &location_to_labels) +{ + diskann::Timer insert_timer; #pragma omp parallel for num_threads(thread_count) schedule(dynamic) - for (int64_t j = start; j < (int64_t)end; j++) { - if (!location_to_labels.empty()) { - index.insert_point(&data[(j - start) * aligned_dim], - 1 + static_cast(j), - location_to_labels[j - start]); - } else { - index.insert_point(&data[(j - start) * aligned_dim], - 1 + static_cast(j)); + for (int64_t j = start; j < (int64_t)end; j++) + { + if (!location_to_labels.empty()) + { + index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j), + location_to_labels[j - start]); + } + else + { + index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j)); + } } - } - const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; - std::cout << "Insertion time " << elapsedSeconds << " seconds (" - << (end - start) / elapsedSeconds << " points/second overall, " - << (end - start) / elapsedSeconds / thread_count - << " per thread)\n "; + const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; + std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds + << " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n "; } template -void delete_from_beginning(diskann::AbstractIndex &index, - diskann::IndexWriteParameters &delete_params, - size_t points_to_skip, - size_t points_to_delete_from_beginning) { - try { - std::cout << std::endl - << "Lazy deleting points " << points_to_skip << " to " - << points_to_skip + points_to_delete_from_beginning << "... "; - for (size_t i = points_to_skip; - i < points_to_skip + points_to_delete_from_beginning; ++i) - index.lazy_delete( - static_cast(i + 1)); // Since tags are data location + 1 - std::cout << "done." << std::endl; - - auto report = index.consolidate_deletes(delete_params); - std::cout << "#active points: " << report._active_points << std::endl - << "max points: " << report._max_points << std::endl - << "empty slots: " << report._empty_slots << std::endl - << "deletes processed: " << report._slots_released << std::endl - << "latest delete size: " << report._delete_set_size << std::endl - << "rate: (" << points_to_delete_from_beginning / report._time - << " points/second overall, " - << points_to_delete_from_beginning / report._time / - delete_params.num_threads - << " per thread)" << std::endl; - } catch (std::system_error &e) { - std::cout << "Exception caught in deletion thread: " << e.what() - << std::endl; - } +void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, + size_t points_to_skip, size_t points_to_delete_from_beginning) +{ + try + { + std::cout << std::endl + << "Lazy deleting points " << points_to_skip << " to " + << points_to_skip + points_to_delete_from_beginning << "... "; + for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i) + index.lazy_delete(static_cast(i + 1)); // Since tags are data location + 1 + std::cout << "done." << std::endl; + + auto report = index.consolidate_deletes(delete_params); + std::cout << "#active points: " << report._active_points << std::endl + << "max points: " << report._max_points << std::endl + << "empty slots: " << report._empty_slots << std::endl + << "deletes processed: " << report._slots_released << std::endl + << "latest delete size: " << report._delete_set_size << std::endl + << "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, " + << points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)" + << std::endl; + } + catch (std::system_error &e) + { + std::cout << "Exception caught in deletion thread: " << e.what() << std::endl; + } } template -void build_incremental_index( - const std::string &data_path, diskann::IndexWriteParameters ¶ms, - size_t points_to_skip, size_t max_points_to_insert, - size_t beginning_index_size, float start_point_norm, uint32_t num_start_pts, - size_t points_per_checkpoint, size_t checkpoints_per_snapshot, - const std::string &save_path, size_t points_to_delete_from_beginning, - size_t start_deletes_after, bool concurrent, const std::string &label_file, - const std::string &universal_label) { - size_t dim, aligned_dim; - size_t num_points; - diskann::get_bin_metadata(data_path, num_points, dim); - aligned_dim = ROUND_UP(dim, 8); - bool has_labels = label_file != ""; - using TagT = uint32_t; - using LabelT = uint32_t; - - size_t current_point_offset = points_to_skip; - const size_t last_point_threshold = points_to_skip + max_points_to_insert; - - bool enable_tags = true; - using TagT = uint32_t; - auto index_search_params = - diskann::IndexSearchParams(params.search_list_size, params.num_threads); - diskann::IndexConfig index_config = - diskann::IndexConfigBuilder() - .with_metric(diskann::L2) - .with_dimension(dim) - .with_max_points(max_points_to_insert) - .is_dynamic_index(true) - .with_index_write_params(params) - .with_index_search_params(index_search_params) - .with_data_type(diskann_type_to_name()) - .with_tag_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .is_enable_tags(enable_tags) - .is_filtered(has_labels) - .with_num_frozen_pts(num_start_pts) - .is_concurrent_consolidate(concurrent) - .build(); - - diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); - auto index = index_factory.create_instance(); - - if (universal_label != "") { - LabelT u_label = 0; - index->set_universal_label(u_label); - } - - if (points_to_skip > num_points) { - throw diskann::ANNException("Asked to skip more points than in data file", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (max_points_to_insert == 0) { - max_points_to_insert = num_points; - } - - if (points_to_skip + max_points_to_insert > num_points) { - max_points_to_insert = num_points - points_to_skip; - std::cerr << "WARNING: Reducing max_points_to_insert to " - << max_points_to_insert - << " points since the data file has only that many" << std::endl; - } - - if (beginning_index_size > max_points_to_insert) { - beginning_index_size = max_points_to_insert; - std::cerr << "WARNING: Reducing beginning index size to " - << beginning_index_size - << " points since the data file has only that many" << std::endl; - } - if (checkpoints_per_snapshot > 0 && - beginning_index_size > points_per_checkpoint) { - beginning_index_size = points_per_checkpoint; - std::cerr << "WARNING: Reducing beginning index size to " - << beginning_index_size << std::endl; - } - - T *data = nullptr; - diskann::alloc_aligned((void **)&data, - std::max(points_per_checkpoint, beginning_index_size) * - aligned_dim * sizeof(T), - 8 * sizeof(T)); - - std::vector tags(beginning_index_size); - std::iota(tags.begin(), tags.end(), - 1 + static_cast(current_point_offset)); - - load_aligned_bin_part(data_path, data, current_point_offset, - beginning_index_size); - std::cout << "load aligned bin succeeded" << std::endl; - diskann::Timer timer; - - if (beginning_index_size > 0) { - index->build(data, beginning_index_size, tags); - } else { - index->set_start_points_at_random(static_cast(start_point_norm)); - } - - const double elapsedSeconds = timer.elapsed() / 1000000.0; - std::cout << "Initial non-incremental index build time for " - << beginning_index_size << " points took " << elapsedSeconds - << " seconds (" << beginning_index_size / elapsedSeconds - << " points/second)\n "; - - current_point_offset += beginning_index_size; - - if (points_to_delete_from_beginning > max_points_to_insert) { - points_to_delete_from_beginning = - static_cast(max_points_to_insert); - std::cerr << "WARNING: Reducing points to delete from beginning to " - << points_to_delete_from_beginning - << " points since the data file has only that many" << std::endl; - } - - std::vector> location_to_labels; - if (concurrent) { - // handle labels - const auto save_path_inc = get_save_filename( - save_path + ".after-concurrent-delete-", points_to_skip, - points_to_delete_from_beginning, last_point_threshold); - std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; - if (has_labels) { - convert_labels_string_to_int(label_file, labels_file_to_use, - mem_labels_int_map_file, universal_label); - auto parse_result = - diskann::parse_formatted_label_file(labels_file_to_use); - location_to_labels = std::get<0>(parse_result); +void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip, + size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm, + uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot, + const std::string &save_path, size_t points_to_delete_from_beginning, + size_t start_deletes_after, bool concurrent, const std::string &label_file, + const std::string &universal_label) +{ + size_t dim, aligned_dim; + size_t num_points; + diskann::get_bin_metadata(data_path, num_points, dim); + aligned_dim = ROUND_UP(dim, 8); + bool has_labels = label_file != ""; + using TagT = uint32_t; + using LabelT = uint32_t; + + size_t current_point_offset = points_to_skip; + const size_t last_point_threshold = points_to_skip + max_points_to_insert; + + bool enable_tags = true; + using TagT = uint32_t; + auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads); + diskann::IndexConfig index_config = diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(max_points_to_insert) + .is_dynamic_index(true) + .with_index_write_params(params) + .with_index_search_params(index_search_params) + .with_data_type(diskann_type_to_name()) + .with_tag_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .is_enable_tags(enable_tags) + .is_filtered(has_labels) + .with_num_frozen_pts(num_start_pts) + .is_concurrent_consolidate(concurrent) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + + if (universal_label != "") + { + LabelT u_label = 0; + index->set_universal_label(u_label); } - int32_t sub_threads = (params.num_threads + 1) / 2; - bool delete_launched = false; - std::future delete_task; - - diskann::Timer timer; + if (points_to_skip > num_points) + { + throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__); + } - for (size_t start = current_point_offset; start < last_point_threshold; - start += points_per_checkpoint, - current_point_offset += points_per_checkpoint) { - const size_t end = - std::min(start + points_per_checkpoint, last_point_threshold); - std::cout << std::endl - << "Inserting from " << start << " to " << end << std::endl; - - auto insert_task = std::async(std::launch::async, [&]() { - load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint( - *index, start, end, sub_threads, data, aligned_dim, - location_to_labels); - }); - insert_task.wait(); - - if (!delete_launched && end >= start_deletes_after && - end >= points_to_skip + points_to_delete_from_beginning) { - delete_launched = true; - diskann::IndexWriteParameters delete_params = - diskann::IndexWriteParametersBuilder(params) - .with_num_threads(sub_threads) - .build(); - - delete_task = std::async(std::launch::async, [&]() { - delete_from_beginning(*index, delete_params, points_to_skip, - points_to_delete_from_beginning); - }); - } + if (max_points_to_insert == 0) + { + max_points_to_insert = num_points; } - delete_task.wait(); - - std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; - index->save(save_path_inc.c_str(), true); - } else { - const auto save_path_inc = get_save_filename( - save_path + ".after-delete-", points_to_skip, - points_to_delete_from_beginning, last_point_threshold); - std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; - if (has_labels) { - convert_labels_string_to_int(label_file, labels_file_to_use, - mem_labels_int_map_file, universal_label); - auto parse_result = - diskann::parse_formatted_label_file(labels_file_to_use); - location_to_labels = std::get<0>(parse_result); + + if (points_to_skip + max_points_to_insert > num_points) + { + max_points_to_insert = num_points - points_to_skip; + std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert + << " points since the data file has only that many" << std::endl; } - size_t last_snapshot_points_threshold = 0; - size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot; - - for (size_t start = current_point_offset; start < last_point_threshold; - start += points_per_checkpoint, - current_point_offset += points_per_checkpoint) { - const size_t end = - std::min(start + points_per_checkpoint, last_point_threshold); - std::cout << std::endl - << "Inserting from " << start << " to " << end << std::endl; - - load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint( - *index, start, end, (int32_t)params.num_threads, data, aligned_dim, - location_to_labels); - - if (checkpoints_per_snapshot > 0 && - --num_checkpoints_till_snapshot == 0) { - diskann::Timer save_timer; - - const auto save_path_inc = - get_save_filename(save_path + ".inc-", points_to_skip, - points_to_delete_from_beginning, end); - index->save(save_path_inc.c_str(), false); - const double elapsedSeconds = save_timer.elapsed() / 1000000.0; - const size_t points_saved = end - points_to_skip; - - std::cout << "Saved " << points_saved << " points in " << elapsedSeconds - << " seconds (" << points_saved / elapsedSeconds - << " points/second)\n"; - - num_checkpoints_till_snapshot = checkpoints_per_snapshot; - last_snapshot_points_threshold = end; - } - - std::cout << "Number of points in the index post insertion " << end - << std::endl; + if (beginning_index_size > max_points_to_insert) + { + beginning_index_size = max_points_to_insert; + std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size + << " points since the data file has only that many" << std::endl; + } + if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint) + { + beginning_index_size = points_per_checkpoint; + std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl; } - if (checkpoints_per_snapshot > 0 && - last_snapshot_points_threshold != last_point_threshold) { - const auto save_path_inc = get_save_filename( - save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, - last_point_threshold); - // index.save(save_path_inc.c_str(), false); + T *data = nullptr; + diskann::alloc_aligned( + (void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T)); + + std::vector tags(beginning_index_size); + std::iota(tags.begin(), tags.end(), 1 + static_cast(current_point_offset)); + + load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size); + std::cout << "load aligned bin succeeded" << std::endl; + diskann::Timer timer; + + if (beginning_index_size > 0) + { + index->build(data, beginning_index_size, tags); + } + else + { + index->set_start_points_at_random(static_cast(start_point_norm)); } - if (points_to_delete_from_beginning > 0) { - delete_from_beginning(*index, params, points_to_skip, - points_to_delete_from_beginning); + const double elapsedSeconds = timer.elapsed() / 1000000.0; + std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took " + << elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n "; + + current_point_offset += beginning_index_size; + + if (points_to_delete_from_beginning > max_points_to_insert) + { + points_to_delete_from_beginning = static_cast(max_points_to_insert); + std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning + << " points since the data file has only that many" << std::endl; } - index->save(save_path_inc.c_str(), true); - } + std::vector> location_to_labels; + if (concurrent) + { + // handle labels + const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip, + points_to_delete_from_beginning, last_point_threshold); + std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; + if (has_labels) + { + convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); + auto parse_result = diskann::parse_formatted_label_file(labels_file_to_use); + location_to_labels = std::get<0>(parse_result); + } + + int32_t sub_threads = (params.num_threads + 1) / 2; + bool delete_launched = false; + std::future delete_task; + + diskann::Timer timer; + + for (size_t start = current_point_offset; start < last_point_threshold; + start += points_per_checkpoint, current_point_offset += points_per_checkpoint) + { + const size_t end = std::min(start + points_per_checkpoint, last_point_threshold); + std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; + + auto insert_task = std::async(std::launch::async, [&]() { + load_aligned_bin_part(data_path, data, start, end - start); + insert_till_next_checkpoint(*index, start, end, sub_threads, data, aligned_dim, + location_to_labels); + }); + insert_task.wait(); + + if (!delete_launched && end >= start_deletes_after && + end >= points_to_skip + points_to_delete_from_beginning) + { + delete_launched = true; + diskann::IndexWriteParameters delete_params = + diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build(); + + delete_task = std::async(std::launch::async, [&]() { + delete_from_beginning(*index, delete_params, points_to_skip, + points_to_delete_from_beginning); + }); + } + } + delete_task.wait(); + + std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; + index->save(save_path_inc.c_str(), true); + } + else + { + const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip, + points_to_delete_from_beginning, last_point_threshold); + std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; + if (has_labels) + { + convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); + auto parse_result = diskann::parse_formatted_label_file(labels_file_to_use); + location_to_labels = std::get<0>(parse_result); + } + + size_t last_snapshot_points_threshold = 0; + size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot; + + for (size_t start = current_point_offset; start < last_point_threshold; + start += points_per_checkpoint, current_point_offset += points_per_checkpoint) + { + const size_t end = std::min(start + points_per_checkpoint, last_point_threshold); + std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; + + load_aligned_bin_part(data_path, data, start, end - start); + insert_till_next_checkpoint(*index, start, end, (int32_t)params.num_threads, data, + aligned_dim, location_to_labels); + + if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0) + { + diskann::Timer save_timer; + + const auto save_path_inc = + get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end); + index->save(save_path_inc.c_str(), false); + const double elapsedSeconds = save_timer.elapsed() / 1000000.0; + const size_t points_saved = end - points_to_skip; + + std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds (" + << points_saved / elapsedSeconds << " points/second)\n"; + + num_checkpoints_till_snapshot = checkpoints_per_snapshot; + last_snapshot_points_threshold = end; + } + + std::cout << "Number of points in the index post insertion " << end << std::endl; + } + + if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold) + { + const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip, + points_to_delete_from_beginning, last_point_threshold); + // index.save(save_path_inc.c_str(), false); + } + + if (points_to_delete_from_beginning > 0) + { + delete_from_beginning(*index, params, points_to_skip, points_to_delete_from_beginning); + } + + index->save(save_path_inc.c_str(), true); + } - diskann::aligned_free(data); + diskann::aligned_free(data); } -int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix; - uint32_t num_threads, R, L, num_start_pts; - float alpha, start_point_norm; - size_t points_to_skip, max_points_to_insert, beginning_index_size, - points_per_checkpoint, checkpoints_per_snapshot, - points_to_delete_from_beginning, start_deletes_after; - bool concurrent; - - // label options - std::string label_file, label_type, universal_label; - std::uint32_t Lf, unique_labels_supported; - - po::options_description desc{program_options_utils::make_program_description( - "test_insert_deletes_consolidate", "Test insert deletes & consolidate")}; - try { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - required_configs.add_options()( - "points_to_skip", po::value(&points_to_skip)->required(), - "Skip these first set of points from file"); - required_configs.add_options()( - "beginning_index_size", - po::value(&beginning_index_size)->required(), - "Batch build will be called on these set of points"); - required_configs.add_options()( - "points_per_checkpoint", - po::value(&points_per_checkpoint)->required(), - "Insertions are done in batches of points_per_checkpoint"); - required_configs.add_options()( - "checkpoints_per_snapshot", - po::value(&checkpoints_per_snapshot)->required(), - "Save the index to disk every few checkpoints"); - required_configs.add_options()( - "points_to_delete_from_beginning", - po::value(&points_to_delete_from_beginning)->required(), ""); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", - po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()( - "Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()( - "alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()( - "max_points_to_insert", - po::value(&max_points_to_insert)->default_value(0), - "These number of points from the file are inserted after " - "points_to_skip"); - optional_configs.add_options()( - "do_concurrent", po::value(&concurrent)->default_value(false), - ""); - optional_configs.add_options()( - "start_deletes_after", - po::value(&start_deletes_after)->default_value(0), ""); - optional_configs.add_options()( - "start_point_norm", - po::value(&start_point_norm)->default_value(0), - "Set the start point to a random point on a sphere of this radius"); - - // optional params for filters - optional_configs.add_options()( - "label_file", po::value(&label_file)->default_value(""), - "Input label file in txt format for Filtered Index search. " - "The file should contain comma separated filters for each node " - "with each line corresponding to a graph node"); - optional_configs.add_options()( - "universal_label", - po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with labels_file"); - optional_configs.add_options()( - "FilteredLbuild,Lf", po::value(&Lf)->default_value(0), - "Build complexity for filtered points, higher value " - "results in better graphs"); - optional_configs.add_options()( - "label_type", - po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - optional_configs.add_options()( - "unique_labels_supported", - po::value(&unique_labels_supported)->default_value(0), - "Number of unique labels supported by the dynamic index."); - - optional_configs.add_options()( - "num_start_points", - po::value(&num_start_pts) - ->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), - "Set the number of random start (frozen) points to use when " - "inserting and searching"); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, data_path, index_path_prefix; + uint32_t num_threads, R, L, num_start_pts; + float alpha, start_point_norm; + size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot, + points_to_delete_from_beginning, start_deletes_after; + bool concurrent; + + // label options + std::string label_file, label_type, universal_label; + std::uint32_t Lf, unique_labels_supported; + + po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate", + "Test insert deletes & consolidate")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()("points_to_skip", po::value(&points_to_skip)->required(), + "Skip these first set of points from file"); + required_configs.add_options()("beginning_index_size", po::value(&beginning_index_size)->required(), + "Batch build will be called on these set of points"); + required_configs.add_options()("points_per_checkpoint", po::value(&points_per_checkpoint)->required(), + "Insertions are done in batches of points_per_checkpoint"); + required_configs.add_options()("checkpoints_per_snapshot", + po::value(&checkpoints_per_snapshot)->required(), + "Save the index to disk every few checkpoints"); + required_configs.add_options()("points_to_delete_from_beginning", + po::value(&points_to_delete_from_beginning)->required(), ""); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("max_points_to_insert", + po::value(&max_points_to_insert)->default_value(0), + "These number of points from the file are inserted after " + "points_to_skip"); + optional_configs.add_options()("do_concurrent", po::value(&concurrent)->default_value(false), ""); + optional_configs.add_options()("start_deletes_after", + po::value(&start_deletes_after)->default_value(0), ""); + optional_configs.add_options()("start_point_norm", po::value(&start_point_norm)->default_value(0), + "Set the start point to a random point on a sphere of this radius"); + + // optional params for filters + optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), + "Input label file in txt format for Filtered Index search. " + "The file should contain comma separated filters for each node " + "with each line corresponding to a graph node"); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with labels_file"); + optional_configs.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), + "Build complexity for filtered points, higher value " + "results in better graphs"); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); + optional_configs.add_options()("unique_labels_supported", + po::value(&unique_labels_supported)->default_value(0), + "Number of unique labels supported by the dynamic index."); + + optional_configs.add_options()( + "num_start_points", + po::value(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), + "Set the number of random start (frozen) points to use when " + "inserting and searching"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + if (beginning_index_size == 0) + if (start_point_norm == 0) + { + std::cout << "When beginning_index_size is 0, use a start " + "point with " + "appropriate norm" + << std::endl; + return -1; + } } - po::notify(vm); - if (beginning_index_size == 0) - if (start_point_norm == 0) { - std::cout << "When beginning_index_size is 0, use a start " - "point with " - "appropriate norm" - << std::endl; + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; return -1; - } - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - bool has_labels = false; - if (!label_file.empty() || label_file != "") { - has_labels = true; - } - - if (num_start_pts < unique_labels_supported) { - num_start_pts = unique_labels_supported; - } - - try { - diskann::IndexWriteParameters params = - diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(500) - .with_alpha(alpha) - .with_num_threads(num_threads) - .with_filter_list_size(Lf) - .build(); - - if (data_type == std::string("int8")) - build_incremental_index( - data_path, params, points_to_skip, max_points_to_insert, - beginning_index_size, start_point_norm, num_start_pts, - points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, - label_file, universal_label); - else if (data_type == std::string("uint8")) - build_incremental_index( - data_path, params, points_to_skip, max_points_to_insert, - beginning_index_size, start_point_norm, num_start_pts, - points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, - label_file, universal_label); - else if (data_type == std::string("float")) - build_incremental_index( - data_path, params, points_to_skip, max_points_to_insert, - beginning_index_size, start_point_norm, num_start_pts, - points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, - points_to_delete_from_beginning, start_deletes_after, concurrent, - label_file, universal_label); - else - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; - } catch (const std::exception &e) { - std::cerr << "Caught exception: " << e.what() << std::endl; - exit(-1); - } catch (...) { - std::cerr << "Caught unknown exception" << std::endl; - exit(-1); - } - - return 0; + } + + bool has_labels = false; + if (!label_file.empty() || label_file != "") + { + has_labels = true; + } + + if (num_start_pts < unique_labels_supported) + { + num_start_pts = unique_labels_supported; + } + + try + { + diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(500) + .with_alpha(alpha) + .with_num_threads(num_threads) + .with_filter_list_size(Lf) + .build(); + + if (data_type == std::string("int8")) + build_incremental_index( + data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, + num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, + points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label); + else if (data_type == std::string("uint8")) + build_incremental_index( + data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, + num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, + points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label); + else if (data_type == std::string("float")) + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, + beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, + checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, + start_deletes_after, concurrent, label_file, universal_label); + else + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + } + catch (const std::exception &e) + { + std::cerr << "Caught exception: " << e.what() << std::endl; + exit(-1); + } + catch (...) + { + std::cerr << "Caught unknown exception" << std::endl; + exit(-1); + } + + return 0; } diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index b51db73fd..d8ea0577c 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -29,516 +29,495 @@ namespace po = boost::program_options; // load_aligned_bin modified to read pieces of the file, but using ifstream // instead of cached_ifstream. template -inline void load_aligned_bin_part(const std::string &bin_file, T *data, - size_t offset_points, size_t points_to_read) { - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(bin_file, std::ios::binary | std::ios::ate); - size_t actual_file_size = reader.tellg(); - reader.seekg(0, std::ios::beg); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - size_t npts = (uint32_t)npts_i32; - size_t dim = (uint32_t)dim_i32; - - size_t expected_actual_file_size = - npts * dim * sizeof(T) + 2 * sizeof(uint32_t); - if (actual_file_size != expected_actual_file_size) { - std::stringstream stream; - stream << "Error. File size mismatch. Actual size is " << actual_file_size - << " while expected size is " << expected_actual_file_size - << " npts = " << npts << " dim = " << dim - << " size of = " << sizeof(T) << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - if (offset_points + points_to_read > npts) { - std::stringstream stream; - stream << "Error. Not enough points in file. Requested " << offset_points - << " offset and " << points_to_read << " points, but have only " - << npts << " points" << std::endl; - std::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); - - const size_t rounded_dim = ROUND_UP(dim, 8); - - for (size_t i = 0; i < points_to_read; i++) { - reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); - memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); - } - reader.close(); +inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(bin_file, std::ios::binary | std::ios::ate); + size_t actual_file_size = reader.tellg(); + reader.seekg(0, std::ios::beg); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + size_t npts = (uint32_t)npts_i32; + size_t dim = (uint32_t)dim_i32; + + size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); + if (actual_file_size != expected_actual_file_size) + { + std::stringstream stream; + stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is " + << expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of = " << sizeof(T) + << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (offset_points + points_to_read > npts) + { + std::stringstream stream; + stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read + << " points, but have only " << npts << " points" << std::endl; + std::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T)); + + const size_t rounded_dim = ROUND_UP(dim, 8); + + for (size_t i = 0; i < points_to_read; i++) + { + reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); + memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); + } + reader.close(); } -std::string get_save_filename(const std::string &save_path, - size_t active_window, size_t consolidate_interval, - size_t max_points_to_insert) { - std::string final_path = save_path; - final_path += "act" + std::to_string(active_window) + "-"; - final_path += "cons" + std::to_string(consolidate_interval) + "-"; - final_path += "max" + std::to_string(max_points_to_insert); - return final_path; +std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval, + size_t max_points_to_insert) +{ + std::string final_path = save_path; + final_path += "act" + std::to_string(active_window) + "-"; + final_path += "cons" + std::to_string(consolidate_interval) + "-"; + final_path += "max" + std::to_string(max_points_to_insert); + return final_path; } template -void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, - size_t insert_threads, T *data, size_t aligned_dim, - std::vector> &pts_to_labels) { - try { - diskann::Timer insert_timer; - std::cout << std::endl - << "Inserting from " << start << " to " << end << std::endl; - - size_t num_failed = 0; +void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data, + size_t aligned_dim, std::vector> &pts_to_labels) +{ + try + { + diskann::Timer insert_timer; + std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; + + size_t num_failed = 0; #pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed) - for (int64_t j = start; j < (int64_t)end; j++) { - int insert_result = -1; - if (pts_to_labels.size() > 0) { - insert_result = index.insert_point(&data[(j - start) * aligned_dim], - 1 + static_cast(j), - pts_to_labels[j - start]); - } else { - insert_result = index.insert_point(&data[(j - start) * aligned_dim], - 1 + static_cast(j)); - } - - if (insert_result != 0) { - std::cerr << "Insert failed " << j << std::endl; - num_failed++; - } + for (int64_t j = start; j < (int64_t)end; j++) + { + int insert_result = -1; + if (pts_to_labels.size() > 0) + { + insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j), + pts_to_labels[j - start]); + } + else + { + insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j)); + } + + if (insert_result != 0) + { + std::cerr << "Insert failed " << j << std::endl; + num_failed++; + } + } + const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; + std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds + << " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)" + << std::endl; + if (num_failed > 0) + std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl; + } + catch (std::system_error &e) + { + std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl; + exit(-1); } - const double elapsedSeconds = insert_timer.elapsed() / 1000000.0; - std::cout << "Insertion time " << elapsedSeconds << " seconds (" - << (end - start) / elapsedSeconds << " points/second overall, " - << (end - start) / elapsedSeconds / insert_threads - << " per thread)" << std::endl; - if (num_failed > 0) - std::cout << num_failed << " of " << end - start << "inserts failed" - << std::endl; - } catch (std::system_error &e) { - std::cout << "Exiting after catching exception in insertion task: " - << e.what() << std::endl; - exit(-1); - } } template -void delete_and_consolidate(diskann::AbstractIndex &index, - diskann::IndexWriteParameters &delete_params, - size_t start, size_t end) { - try { - std::cout << std::endl - << "Lazy deleting points " << start << " to " << end << "... "; - for (size_t i = start; i < end; ++i) - index.lazy_delete(static_cast(1 + i)); - std::cout << "lazy delete done." << std::endl; - - auto report = index.consolidate_deletes(delete_params); - while (report._status != - diskann::consolidation_report::status_code::SUCCESS) { - int wait_time = 5; - if (report._status == - diskann::consolidation_report::status_code::LOCK_FAIL) { - diskann::cerr << "Unable to acquire consolidate delete lock after " - << "deleting points " << start << " to " << end - << ". Will retry in " << wait_time << "seconds." - << std::endl; - } else if (report._status == diskann::consolidation_report::status_code:: - INCONSISTENT_COUNT_ERROR) { - diskann::cerr << "Inconsistent counts in data structure. " - << "Will retry in " << wait_time << "seconds." - << std::endl; - } else { - std::cerr << "Exiting after unknown error in consolidate delete" - << std::endl; +void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start, + size_t end) +{ + try + { + std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... "; + for (size_t i = start; i < end; ++i) + index.lazy_delete(static_cast(1 + i)); + std::cout << "lazy delete done." << std::endl; + + auto report = index.consolidate_deletes(delete_params); + while (report._status != diskann::consolidation_report::status_code::SUCCESS) + { + int wait_time = 5; + if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL) + { + diskann::cerr << "Unable to acquire consolidate delete lock after " + << "deleting points " << start << " to " << end << ". Will retry in " << wait_time + << "seconds." << std::endl; + } + else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR) + { + diskann::cerr << "Inconsistent counts in data structure. " + << "Will retry in " << wait_time << "seconds." << std::endl; + } + else + { + std::cerr << "Exiting after unknown error in consolidate delete" << std::endl; + exit(-1); + } + std::this_thread::sleep_for(std::chrono::seconds(wait_time)); + report = index.consolidate_deletes(delete_params); + } + auto points_processed = report._active_points + report._slots_released; + auto deletion_rate = points_processed / report._time; + std::cout << "#active points: " << report._active_points << std::endl + << "max points: " << report._max_points << std::endl + << "empty slots: " << report._empty_slots << std::endl + << "deletes processed: " << report._slots_released << std::endl + << "latest delete size: " << report._delete_set_size << std::endl + << "Deletion rate: " << deletion_rate << "/sec " + << "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl; + } + catch (std::system_error &e) + { + std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl; exit(-1); - } - std::this_thread::sleep_for(std::chrono::seconds(wait_time)); - report = index.consolidate_deletes(delete_params); } - auto points_processed = report._active_points + report._slots_released; - auto deletion_rate = points_processed / report._time; - std::cout << "#active points: " << report._active_points << std::endl - << "max points: " << report._max_points << std::endl - << "empty slots: " << report._empty_slots << std::endl - << "deletes processed: " << report._slots_released << std::endl - << "latest delete size: " << report._delete_set_size << std::endl - << "Deletion rate: " << deletion_rate << "/sec " - << "Deletion rate: " << deletion_rate / delete_params.num_threads - << "/thread/sec " << std::endl; - } catch (std::system_error &e) { - std::cerr << "Exiting after catching exception in deletion task: " - << e.what() << std::endl; - exit(-1); - } } template -void build_incremental_index( - const std::string &data_path, const uint32_t L, const uint32_t R, - const float alpha, const uint32_t insert_threads, - const uint32_t consolidate_threads, size_t max_points_to_insert, - size_t active_window, size_t consolidate_interval, - const float start_point_norm, uint32_t num_start_pts, - const std::string &save_path, const std::string &label_file, - const std::string &universal_label, const uint32_t Lf) { - const uint32_t C = 500; - const bool saturate_graph = false; - bool has_labels = label_file != ""; - - diskann::IndexWriteParameters params = - diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(C) - .with_alpha(alpha) - .with_saturate_graph(saturate_graph) - .with_num_threads(insert_threads) - .with_filter_list_size(Lf) - .build(); - - auto index_search_params = diskann::IndexSearchParams(L, insert_threads); - diskann::IndexWriteParameters delete_params = - diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(C) - .with_alpha(alpha) - .with_saturate_graph(saturate_graph) - .with_num_threads(consolidate_threads) - .with_filter_list_size(Lf) - .build(); - - size_t dim, aligned_dim; - size_t num_points; - - std::vector> pts_to_labels; - - const auto save_path_inc = - get_save_filename(save_path + ".after-streaming-", active_window, - consolidate_interval, max_points_to_insert); - std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; - if (has_labels) { - convert_labels_string_to_int(label_file, labels_file_to_use, - mem_labels_int_map_file, universal_label); - auto parse_result = - diskann::parse_formatted_label_file(labels_file_to_use); - pts_to_labels = std::get<0>(parse_result); - } - - diskann::get_bin_metadata(data_path, num_points, dim); - diskann::cout << "metadata: file " << data_path << " has " << num_points - << " points in " << dim << " dims" << std::endl; - aligned_dim = ROUND_UP(dim, 8); - auto index_config = - diskann::IndexConfigBuilder() - .with_metric(diskann::L2) - .with_dimension(dim) - .with_max_points(active_window + 4 * consolidate_interval) - .is_dynamic_index(true) - .is_enable_tags(true) - .is_use_opq(false) - .is_filtered(has_labels) - .with_num_pq_chunks(0) - .is_pq_dist_build(false) - .with_num_frozen_pts(num_start_pts) - .with_tag_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_data_type(diskann_type_to_name()) - .with_index_write_params(params) - .with_index_search_params(index_search_params) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .build(); - - diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); - auto index = index_factory.create_instance(); - - if (universal_label != "") { - LabelT u_label = 0; - index->set_universal_label(u_label); - } - - if (max_points_to_insert == 0) { - max_points_to_insert = num_points; - } - - if (num_points < max_points_to_insert) - throw diskann::ANNException(std::string("num_points(") + - std::to_string(num_points) + - ") < max_points_to_insert(" + - std::to_string(max_points_to_insert) + ")", - -1, __FUNCSIG__, __FILE__, __LINE__); - - if (max_points_to_insert < active_window + consolidate_interval) - throw diskann::ANNException("ERROR: max_points_to_insert < " - "active_window + consolidate_interval", - -1, __FUNCSIG__, __FILE__, __LINE__); - - if (consolidate_interval < max_points_to_insert / 1000) - throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, - __FUNCSIG__, __FILE__, __LINE__); - - index->set_start_points_at_random(static_cast(start_point_norm)); - - T *data = nullptr; - diskann::alloc_aligned((void **)&data, - std::max(consolidate_interval, active_window) * - aligned_dim * sizeof(T), - 8 * sizeof(T)); - - std::vector tags(max_points_to_insert); - std::iota(tags.begin(), tags.end(), static_cast(0)); - - diskann::Timer timer; - - std::vector> delete_tasks; - - auto insert_task = std::async(std::launch::async, [&]() { - load_aligned_bin_part(data_path, data, 0, active_window); - insert_next_batch(*index, (size_t)0, active_window, - params.num_threads, data, aligned_dim, - pts_to_labels); - }); - insert_task.wait(); - - for (size_t start = active_window; - start + consolidate_interval <= max_points_to_insert; - start += consolidate_interval) { - auto end = std::min(start + consolidate_interval, max_points_to_insert); +void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha, + const uint32_t insert_threads, const uint32_t consolidate_threads, + size_t max_points_to_insert, size_t active_window, size_t consolidate_interval, + const float start_point_norm, uint32_t num_start_pts, const std::string &save_path, + const std::string &label_file, const std::string &universal_label, const uint32_t Lf) +{ + const uint32_t C = 500; + const bool saturate_graph = false; + bool has_labels = label_file != ""; + + diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(C) + .with_alpha(alpha) + .with_saturate_graph(saturate_graph) + .with_num_threads(insert_threads) + .with_filter_list_size(Lf) + .build(); + + auto index_search_params = diskann::IndexSearchParams(L, insert_threads); + diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(C) + .with_alpha(alpha) + .with_saturate_graph(saturate_graph) + .with_num_threads(consolidate_threads) + .with_filter_list_size(Lf) + .build(); + + size_t dim, aligned_dim; + size_t num_points; + + std::vector> pts_to_labels; + + const auto save_path_inc = + get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert); + std::string labels_file_to_use = save_path_inc + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt"; + if (has_labels) + { + convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); + auto parse_result = diskann::parse_formatted_label_file(labels_file_to_use); + pts_to_labels = std::get<0>(parse_result); + } + + diskann::get_bin_metadata(data_path, num_points, dim); + diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims" + << std::endl; + aligned_dim = ROUND_UP(dim, 8); + auto index_config = diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(active_window + 4 * consolidate_interval) + .is_dynamic_index(true) + .is_enable_tags(true) + .is_use_opq(false) + .is_filtered(has_labels) + .with_num_pq_chunks(0) + .is_pq_dist_build(false) + .with_num_frozen_pts(num_start_pts) + .with_tag_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_data_type(diskann_type_to_name()) + .with_index_write_params(params) + .with_index_search_params(index_search_params) + .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) + .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + + if (universal_label != "") + { + LabelT u_label = 0; + index->set_universal_label(u_label); + } + + if (max_points_to_insert == 0) + { + max_points_to_insert = num_points; + } + + if (num_points < max_points_to_insert) + throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) + + ") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")", + -1, __FUNCSIG__, __FILE__, __LINE__); + + if (max_points_to_insert < active_window + consolidate_interval) + throw diskann::ANNException("ERROR: max_points_to_insert < " + "active_window + consolidate_interval", + -1, __FUNCSIG__, __FILE__, __LINE__); + + if (consolidate_interval < max_points_to_insert / 1000) + throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__); + + index->set_start_points_at_random(static_cast(start_point_norm)); + + T *data = nullptr; + diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T), + 8 * sizeof(T)); + + std::vector tags(max_points_to_insert); + std::iota(tags.begin(), tags.end(), static_cast(0)); + + diskann::Timer timer; + + std::vector> delete_tasks; + auto insert_task = std::async(std::launch::async, [&]() { - load_aligned_bin_part(data_path, data, start, end - start); - insert_next_batch(*index, start, end, params.num_threads, - data, aligned_dim, pts_to_labels); + load_aligned_bin_part(data_path, data, 0, active_window); + insert_next_batch(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim, + pts_to_labels); }); insert_task.wait(); - if (delete_tasks.size() > 0) - delete_tasks[delete_tasks.size() - 1].wait(); - if (start >= active_window + consolidate_interval) { - auto start_del = start - active_window - consolidate_interval; - auto end_del = start - active_window; - - delete_tasks.emplace_back(std::async(std::launch::async, [&]() { - delete_and_consolidate( - *index, delete_params, (size_t)start_del, (size_t)end_del); - })); + for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert; + start += consolidate_interval) + { + auto end = std::min(start + consolidate_interval, max_points_to_insert); + auto insert_task = std::async(std::launch::async, [&]() { + load_aligned_bin_part(data_path, data, start, end - start); + insert_next_batch(*index, start, end, params.num_threads, data, aligned_dim, + pts_to_labels); + }); + insert_task.wait(); + + if (delete_tasks.size() > 0) + delete_tasks[delete_tasks.size() - 1].wait(); + if (start >= active_window + consolidate_interval) + { + auto start_del = start - active_window - consolidate_interval; + auto end_del = start - active_window; + + delete_tasks.emplace_back(std::async(std::launch::async, [&]() { + delete_and_consolidate(*index, delete_params, (size_t)start_del, (size_t)end_del); + })); + } } - } - if (delete_tasks.size() > 0) - delete_tasks[delete_tasks.size() - 1].wait(); + if (delete_tasks.size() > 0) + delete_tasks[delete_tasks.size() - 1].wait(); - std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; + std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; - index->save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); - diskann::aligned_free(data); + diskann::aligned_free(data); } -int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, - universal_label, label_type; - uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, - unique_labels_supported; - float alpha, start_point_norm; - size_t max_points_to_insert, active_window, consolidate_interval; - - po::options_description desc{program_options_utils::make_program_description( - "test_streaming_scenario", "Test insert deletes & consolidate")}; - try { - desc.add_options()("help,h", "Print information on arguments"); - - // Required parameters - po::options_description required_configs("Required"); - required_configs.add_options()( - "data_type", po::value(&data_type)->required(), - program_options_utils::DATA_TYPE_DESCRIPTION); - required_configs.add_options()( - "dist_fn", po::value(&dist_fn)->required(), - program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); - required_configs.add_options()( - "index_path_prefix", - po::value(&index_path_prefix)->required(), - program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()( - "data_path", po::value(&data_path)->required(), - program_options_utils::INPUT_DATA_PATH); - required_configs.add_options()( - "active_window", po::value(&active_window)->required(), - "Program maintains an index over an active window of " - "this size that slides through the data"); - required_configs.add_options()( - "consolidate_interval", - po::value(&consolidate_interval)->required(), - "The program simultaneously adds this number of points to the " - "right of " - "the window while deleting the same number from the left"); - required_configs.add_options()( - "start_point_norm", po::value(&start_point_norm)->required(), - "Set the start point to a random point on a sphere of this radius"); - - // Optional parameters - po::options_description optional_configs("Optional"); - optional_configs.add_options()("max_degree,R", - po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()( - "Lbuild,L", po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()( - "alpha", po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()( - "insert_threads", - po::value(&insert_threads) - ->default_value(omp_get_num_procs() / 2), - "Number of threads used for inserting into the index (defaults to " - "omp_get_num_procs()/2)"); - optional_configs.add_options()( - "consolidate_threads", - po::value(&consolidate_threads) - ->default_value(omp_get_num_procs() / 2), - "Number of threads used for consolidating deletes to " - "the index (defaults to omp_get_num_procs()/2)"); - optional_configs.add_options()( - "max_points_to_insert", - po::value(&max_points_to_insert)->default_value(0), - "The number of points from the file that the program streams " - "over "); - optional_configs.add_options()( - "num_start_points", - po::value(&num_start_pts) - ->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), - "Set the number of random start (frozen) points to use when " - "inserting and searching"); - - optional_configs.add_options()( - "label_file", po::value(&label_file)->default_value(""), - "Input label file in txt format for Filtered Index search. " - "The file should contain comma separated filters for each node " - "with each line corresponding to a graph node"); - optional_configs.add_options()( - "universal_label", - po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with labels_file"); - optional_configs.add_options()( - "FilteredLbuild,Lf", po::value(&Lf)->default_value(0), - "Build complexity for filtered points, higher value " - "results in better graphs"); - optional_configs.add_options()( - "label_type", - po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - optional_configs.add_options()( - "unique_labels_supported", - po::value(&unique_labels_supported)->default_value(0), - "Number of unique labels supported by the dynamic index."); - - // Merge required and optional parameters - desc.add(required_configs).add(optional_configs); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; + uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported; + float alpha, start_point_norm; + size_t max_points_to_insert, active_window, consolidate_interval; + + po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario", + "Test insert deletes & consolidate")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()("active_window", po::value(&active_window)->required(), + "Program maintains an index over an active window of " + "this size that slides through the data"); + required_configs.add_options()("consolidate_interval", po::value(&consolidate_interval)->required(), + "The program simultaneously adds this number of points to the " + "right of " + "the window while deleting the same number from the left"); + required_configs.add_options()("start_point_norm", po::value(&start_point_norm)->required(), + "Set the start point to a random point on a sphere of this radius"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("insert_threads", + po::value(&insert_threads)->default_value(omp_get_num_procs() / 2), + "Number of threads used for inserting into the index (defaults to " + "omp_get_num_procs()/2)"); + optional_configs.add_options()( + "consolidate_threads", po::value(&consolidate_threads)->default_value(omp_get_num_procs() / 2), + "Number of threads used for consolidating deletes to " + "the index (defaults to omp_get_num_procs()/2)"); + optional_configs.add_options()("max_points_to_insert", + po::value(&max_points_to_insert)->default_value(0), + "The number of points from the file that the program streams " + "over "); + optional_configs.add_options()( + "num_start_points", + po::value(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), + "Set the number of random start (frozen) points to use when " + "inserting and searching"); + + optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), + "Input label file in txt format for Filtered Index search. " + "The file should contain comma separated filters for each node " + "with each line corresponding to a graph node"); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with labels_file"); + optional_configs.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), + "Build complexity for filtered points, higher value " + "results in better graphs"); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); + optional_configs.add_options()("unique_labels_supported", + po::value(&unique_labels_supported)->default_value(0), + "Number of unique labels supported by the dynamic index."); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + // Validate arguments + if (start_point_norm == 0) + { + std::cout << "When beginning_index_size is 0, use a start point with " + "appropriate norm" + << std::endl; + return -1; + } + + if (label_type != std::string("ushort") && label_type != std::string("uint")) + { + std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl; + return -1; } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - // Validate arguments - if (start_point_norm == 0) { - std::cout << "When beginning_index_size is 0, use a start point with " - "appropriate norm" - << std::endl; - return -1; - } - - if (label_type != std::string("ushort") && - label_type != std::string("uint")) { - std::cerr << "Invalid label type. Supported types are uint and ushort" - << std::endl; - return -1; - } - - if (data_type != std::string("int8") && data_type != std::string("uint8") && - data_type != std::string("float")) { - std::cerr << "Invalid data type. Supported types are int8, uint8 and float" - << std::endl; - return -1; - } - - // TODO: Are additional distance functions supported? - if (dist_fn != std::string("l2") && dist_fn != std::string("mips")) { - std::cerr - << "Invalid distance function. Supported functions are l2 and mips" - << std::endl; - return -1; - } - - if (num_start_pts < unique_labels_supported) { - num_start_pts = unique_labels_supported; - } - - try { - if (data_type == std::string("uint8")) { - if (label_type == std::string("ushort")) { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, - max_points_to_insert, active_window, consolidate_interval, - start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } else if (label_type == std::string("uint")) { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, - max_points_to_insert, active_window, consolidate_interval, - start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - } else if (data_type == std::string("int8")) { - if (label_type == std::string("ushort")) { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, - max_points_to_insert, active_window, consolidate_interval, - start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } else if (label_type == std::string("uint")) { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, - max_points_to_insert, active_window, consolidate_interval, - start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } - } else if (data_type == std::string("float")) { - if (label_type == std::string("ushort")) { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, - max_points_to_insert, active_window, consolidate_interval, - start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } else if (label_type == std::string("uint")) { - build_incremental_index( - data_path, L, R, alpha, insert_threads, consolidate_threads, - max_points_to_insert, active_window, consolidate_interval, - start_point_norm, num_start_pts, index_path_prefix, label_file, - universal_label, Lf); - } + + if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float")) + { + std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl; + return -1; } - } catch (const std::exception &e) { - std::cerr << "Caught exception: " << e.what() << std::endl; - exit(-1); - } catch (...) { - std::cerr << "Caught unknown exception" << std::endl; - exit(-1); - } - - return 0; + + // TODO: Are additional distance functions supported? + if (dist_fn != std::string("l2") && dist_fn != std::string("mips")) + { + std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl; + return -1; + } + + if (num_start_pts < unique_labels_supported) + { + num_start_pts = unique_labels_supported; + } + + try + { + if (data_type == std::string("uint8")) + { + if (label_type == std::string("ushort")) + { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, + consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + else if (label_type == std::string("uint")) + { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, + consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + } + else if (data_type == std::string("int8")) + { + if (label_type == std::string("ushort")) + { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, + consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + else if (label_type == std::string("uint")) + { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, + consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + } + else if (data_type == std::string("float")) + { + if (label_type == std::string("ushort")) + { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, + consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + else if (label_type == std::string("uint")) + { + build_incremental_index( + data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window, + consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file, + universal_label, Lf); + } + } + } + catch (const std::exception &e) + { + std::cerr << "Caught exception: " << e.what() << std::endl; + exit(-1); + } + catch (...) + { + std::cerr << "Caught unknown exception" << std::endl; + exit(-1); + } + + return 0; } diff --git a/apps/utils/bin_to_fvecs.cpp b/apps/utils/bin_to_fvecs.cpp index a9b86686c..ebd8229ba 100644 --- a/apps/utils/bin_to_fvecs.cpp +++ b/apps/utils/bin_to_fvecs.cpp @@ -4,58 +4,60 @@ #include "util.h" #include -void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, - float *write_buf, uint64_t npts, uint64_t ndims) { - writr.write((char *)read_buf, - npts * (ndims * sizeof(float) + sizeof(unsigned))); +void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts, + uint64_t ndims) +{ + writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned))); #pragma omp parallel for - for (uint64_t i = 0; i < npts; i++) { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, - ndims * sizeof(float)); - } - readr.read((char *)write_buf, npts * ndims * sizeof(float)); + for (uint64_t i = 0; i < npts; i++) + { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float)); + } + readr.read((char *)write_buf, npts * ndims * sizeof(float)); } -int main(int argc, char **argv) { - if (argc != 3) { - std::cout << argv[0] << " input_bin output_fvecs" << std::endl; - exit(-1); - } - std::ifstream readr(argv[1], std::ios::binary); - int npts_s32; - int ndims_s32; - readr.read((char *)&npts_s32, sizeof(int32_t)); - readr.read((char *)&ndims_s32, sizeof(int32_t)); - size_t npts = npts_s32; - size_t ndims = ndims_s32; - uint32_t ndims_u32 = (uint32_t)ndims_s32; - // uint64_t fsize = writr.tellg(); - readr.seekg(0, std::ios::beg); - - unsigned ndims_u32; - writr.write((char *)&ndims_u32, sizeof(unsigned)); - writr.seekg(0, std::ios::beg); - uint64_t ndims = (uint64_t)ndims_u32; - uint64_t npts = fsize / ((ndims + 1) * sizeof(float)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; - - uint64_t blk_size = 131072; - uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - - std::ofstream writr(argv[2], std::ios::binary); - float *read_buf = new float[npts * (ndims + 1)]; - float *write_buf = new float[npts * ndims]; - for (uint64_t i = 0; i < nblks; i++) { - uint64_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } - - delete[] read_buf; - delete[] write_buf; - - writr.close(); - readr.close(); +int main(int argc, char **argv) +{ + if (argc != 3) + { + std::cout << argv[0] << " input_bin output_fvecs" << std::endl; + exit(-1); + } + std::ifstream readr(argv[1], std::ios::binary); + int npts_s32; + int ndims_s32; + readr.read((char *)&npts_s32, sizeof(int32_t)); + readr.read((char *)&ndims_s32, sizeof(int32_t)); + size_t npts = npts_s32; + size_t ndims = ndims_s32; + uint32_t ndims_u32 = (uint32_t)ndims_s32; + // uint64_t fsize = writr.tellg(); + readr.seekg(0, std::ios::beg); + + unsigned ndims_u32; + writr.write((char *)&ndims_u32, sizeof(unsigned)); + writr.seekg(0, std::ios::beg); + uint64_t ndims = (uint64_t)ndims_u32; + uint64_t npts = fsize / ((ndims + 1) * sizeof(float)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + + uint64_t blk_size = 131072; + uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + + std::ofstream writr(argv[2], std::ios::binary); + float *read_buf = new float[npts * (ndims + 1)]; + float *write_buf = new float[npts * ndims]; + for (uint64_t i = 0; i < nblks; i++) + { + uint64_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } + + delete[] read_buf; + delete[] write_buf; + + writr.close(); + readr.close(); } diff --git a/apps/utils/bin_to_tsv.cpp b/apps/utils/bin_to_tsv.cpp index 62ed77e55..5c31c8595 100644 --- a/apps/utils/bin_to_tsv.cpp +++ b/apps/utils/bin_to_tsv.cpp @@ -5,64 +5,65 @@ #include template -void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, - size_t npts, size_t ndims) { - reader.read((char *)read_buf, npts * ndims * sizeof(float)); +void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims) +{ + reader.read((char *)read_buf, npts * ndims * sizeof(float)); - for (size_t i = 0; i < npts; i++) { - for (size_t d = 0; d < ndims; d++) { - writer << read_buf[d + i * ndims]; - if (d < ndims - 1) - writer << "\t"; - else - writer << "\n"; + for (size_t i = 0; i < npts; i++) + { + for (size_t d = 0; d < ndims; d++) + { + writer << read_buf[d + i * ndims]; + if (d < ndims - 1) + writer << "\t"; + else + writer << "\n"; + } } - } } -int main(int argc, char **argv) { - if (argc != 4) { - std::cout << argv[0] << " input_bin output_tsv" - << std::endl; - exit(-1); - } - std::string type_string(argv[1]); - if ((type_string != std::string("float")) && - (type_string != std::string("int8")) && - (type_string != std::string("uin8"))) { - std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl; - } +int main(int argc, char **argv) +{ + if (argc != 4) + { + std::cout << argv[0] << " input_bin output_tsv" << std::endl; + exit(-1); + } + std::string type_string(argv[1]); + if ((type_string != std::string("float")) && (type_string != std::string("int8")) && + (type_string != std::string("uin8"))) + { + std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl; + } - std::ifstream reader(argv[2], std::ios::binary); - uint32_t npts_u32; - uint32_t ndims_u32; - reader.read((char *)&npts_u32, sizeof(uint32_t)); - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - size_t npts = npts_u32; - size_t ndims = ndims_u32; - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; + std::ifstream reader(argv[2], std::ios::binary); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + size_t npts = npts_u32; + size_t ndims = ndims_u32; + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::ofstream writer(argv[3]); - char *read_buf = new char[blk_size * ndims * 4]; - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (type_string == std::string("float")) - block_convert(writer, reader, (float *)read_buf, cblk_size, ndims); - else if (type_string == std::string("int8")) - block_convert(writer, reader, (int8_t *)read_buf, cblk_size, - ndims); - else if (type_string == std::string("uint8")) - block_convert(writer, reader, (uint8_t *)read_buf, cblk_size, - ndims); - std::cout << "Block #" << i << " written" << std::endl; - } + std::ofstream writer(argv[3]); + char *read_buf = new char[blk_size * ndims * 4]; + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (type_string == std::string("float")) + block_convert(writer, reader, (float *)read_buf, cblk_size, ndims); + else if (type_string == std::string("int8")) + block_convert(writer, reader, (int8_t *)read_buf, cblk_size, ndims); + else if (type_string == std::string("uint8")) + block_convert(writer, reader, (uint8_t *)read_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; + delete[] read_buf; - writer.close(); - reader.close(); + writer.close(); + reader.close(); } diff --git a/apps/utils/calculate_recall.cpp b/apps/utils/calculate_recall.cpp index d329ebba4..3946bfdf2 100644 --- a/apps/utils/calculate_recall.cpp +++ b/apps/utils/calculate_recall.cpp @@ -12,43 +12,44 @@ #include "disk_utils.h" #include "utils.h" -int main(int argc, char **argv) { - if (argc != 4) { - std::cout << argv[0] << " " - << std::endl; - return -1; - } - uint32_t *gold_std = NULL; - float *gs_dist = nullptr; - uint32_t *our_results = NULL; - float *or_dist = nullptr; - size_t points_num, points_num_gs, points_num_or; - size_t dim_gs; - size_t dim_or; - diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs); - diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or); +int main(int argc, char **argv) +{ + if (argc != 4) + { + std::cout << argv[0] << " " << std::endl; + return -1; + } + uint32_t *gold_std = NULL; + float *gs_dist = nullptr; + uint32_t *our_results = NULL; + float *or_dist = nullptr; + size_t points_num, points_num_gs, points_num_or; + size_t dim_gs; + size_t dim_or; + diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs); + diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or); - if (points_num_gs != points_num_or) { - std::cout << "Error. Number of queries mismatch in ground truth and " - "our results" - << std::endl; - return -1; - } - points_num = points_num_gs; + if (points_num_gs != points_num_or) + { + std::cout << "Error. Number of queries mismatch in ground truth and " + "our results" + << std::endl; + return -1; + } + points_num = points_num_gs; - uint32_t recall_at = std::atoi(argv[3]); + uint32_t recall_at = std::atoi(argv[3]); - if ((dim_or < recall_at) || (recall_at > dim_gs)) { - std::cout << "ground truth has size " << dim_gs << "; our set has " - << dim_or << " points. Asking for recall " << recall_at - << std::endl; - return -1; - } - std::cout << "Calculating recall@" << recall_at << std::endl; - double recall_val = diskann::calculate_recall( - (uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs, our_results, - (uint32_t)dim_or, (uint32_t)recall_at); + if ((dim_or < recall_at) || (recall_at > dim_gs)) + { + std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall " + << recall_at << std::endl; + return -1; + } + std::cout << "Calculating recall@" << recall_at << std::endl; + double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs, + our_results, (uint32_t)dim_or, (uint32_t)recall_at); - // double avg_recall = (recall*1.0)/(points_num*1.0); - std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n"; + // double avg_recall = (recall*1.0)/(points_num*1.0); + std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n"; } diff --git a/apps/utils/compute_groundtruth.cpp b/apps/utils/compute_groundtruth.cpp index 2b644a8ec..b86f28289 100644 --- a/apps/utils/compute_groundtruth.cpp +++ b/apps/utils/compute_groundtruth.cpp @@ -40,539 +40,535 @@ typedef std::string path; namespace po = boost::program_options; -template T div_round_up(const T numerator, const T denominator) { - return (numerator % denominator == 0) ? (numerator / denominator) - : 1 + (numerator / denominator); +template T div_round_up(const T numerator, const T denominator) +{ + return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); } using pairIF = std::pair; -struct cmpmaxstruct { - bool operator()(const pairIF &l, const pairIF &r) { - return l.second < r.second; - }; +struct cmpmaxstruct +{ + bool operator()(const pairIF &l, const pairIF &r) + { + return l.second < r.second; + }; }; -using maxPQIFCS = - std::priority_queue, cmpmaxstruct>; +using maxPQIFCS = std::priority_queue, cmpmaxstruct>; -template T *aligned_malloc(const size_t n, const size_t alignment) { +template T *aligned_malloc(const size_t n, const size_t alignment) +{ #ifdef _WINDOWS - return (T *)_aligned_malloc(sizeof(T) * n, alignment); + return (T *)_aligned_malloc(sizeof(T) * n, alignment); #else - return static_cast(aligned_alloc(alignment, sizeof(T) * n)); + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); #endif } -inline bool custom_dist(const std::pair &a, - const std::pair &b) { - return a.second < b.second; +inline bool custom_dist(const std::pair &a, const std::pair &b) +{ + return a.second < b.second; } -void compute_l2sq(float *const points_l2sq, const float *const matrix, - const int64_t num_points, const uint64_t dim) { - assert(points_l2sq != NULL); +void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) +{ + assert(points_l2sq != NULL); #pragma omp parallel for schedule(static, 65536) - for (int64_t d = 0; d < num_points; ++d) - points_l2sq[d] = - cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, - matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); } -void distsq_to_points( - const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, - const float *const points_l2sq, // points in Col major - size_t nqueries, const float *const queries, - const float *const queries_l2sq, // queries in Col major - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void distsq_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, - (float)-2.0, points, dim, queries, dim, (float)0.0, dist_matrix, - npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, - (float)1.0, points_l2sq, npoints, ones_vec, nqueries, (float)1.0, - dist_matrix, npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, - (float)1.0, ones_vec, npoints, queries_l2sq, nqueries, (float)1.0, - dist_matrix, npoints); - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, + ones_vec, nqueries, (float)1.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, + queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; } -void inner_prod_to_points( - const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, size_t nqueries, - const float *const queries, - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void inner_prod_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, - (float)-1.0, points, dim, queries, dim, (float)0.0, dist_matrix, - npoints); - - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + + if (ones_vec_alloc) + delete[] ones_vec; } -void exact_knn( - const size_t dim, const size_t k, - size_t *const closest_points, // k * num_queries preallocated, col - // major, queries columns - float *const dist_closest_points, // k * num_queries - // preallocated, Dist to - // corresponding closes_points - size_t npoints, - float *points_in, // points in Col major - size_t nqueries, float *queries_in, - diskann::Metric metric = diskann::Metric::L2) // queries in Col major +void exact_knn(const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major { - float *points_l2sq = new float[npoints]; - float *queries_l2sq = new float[nqueries]; - compute_l2sq(points_l2sq, points_in, npoints, dim); - compute_l2sq(queries_l2sq, queries_in, nqueries, dim); - - float *points = points_in; - float *queries = queries_in; - - if (metric == diskann::Metric::COSINE) { // we convert cosine distance as - // normalized L2 distnace - points = new float[npoints * dim]; - queries = new float[nqueries * dim]; + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) + { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)npoints; i++) { - float norm = std::sqrt(points_l2sq[i]); - if (norm == 0) { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) { - points[i * dim + j] = points_in[i * dim + j] / norm; - } - } + for (int64_t i = 0; i < (int64_t)npoints; i++) + { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)nqueries; i++) { - float norm = std::sqrt(queries_l2sq[i]); - if (norm == 0) { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) { - queries[i * dim + j] = queries_in[i * dim + j] / norm; - } + for (int64_t i = 0; i < (int64_t)nqueries; i++) + { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } + } + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); } - // recalculate norms after normalizing, they should all be one. - compute_l2sq(points_l2sq, points, npoints, dim); - compute_l2sq(queries_l2sq, queries, nqueries, dim); - } - - std::cout << "Going to compute " << k << " NNs for " << nqueries - << " queries over " << npoints << " points in " << dim - << " dimensions using"; - if (metric == diskann::Metric::INNER_PRODUCT) - std::cout << " MIPS "; - else if (metric == diskann::Metric::COSINE) - std::cout << " Cosine "; - else - std::cout << " L2 "; - std::cout << "distance fn. " << std::endl; - - size_t q_batch_size = (1 << 9); - float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; - - for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) { - int64_t q_b = b * q_batch_size; - int64_t q_e = - ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; - - if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) { - distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, - q_e - q_b, queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, - queries_l2sq + q_b); - } else { - inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, - queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); - } - std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" - << std::endl; + + std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " + << dim << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) + { + int64_t q_b = b * q_batch_size; + int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) + { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); + } + else + { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; #pragma omp parallel for schedule(dynamic, 16) - for (long long q = q_b; q < q_e; q++) { - maxPQIFCS point_dist; - for (size_t p = 0; p < k; p++) - point_dist.emplace(p, - dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * - (ptrdiff_t)npoints]); - for (size_t p = k; p < npoints; p++) { - if (point_dist.top().second > - dist_matrix[(ptrdiff_t)p + - (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) - point_dist.emplace( - p, dist_matrix[(ptrdiff_t)p + - (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - if (point_dist.size() > k) - point_dist.pop(); - } - for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) { - closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = - point_dist.top().first; - dist_closest_points[(ptrdiff_t)(k - 1 - l) + - (ptrdiff_t)q * (ptrdiff_t)k] = - point_dist.top().second; - point_dist.pop(); - } - assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, - dist_closest_points + - (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); + for (long long q = q_b; q < q_e; q++) + { + maxPQIFCS point_dist; + for (size_t p = 0; p < k; p++) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + for (size_t p = k; p < npoints; p++) + { + if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + if (point_dist.size() > k) + point_dist.pop(); + } + for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) + { + closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; + dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; + point_dist.pop(); + } + assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, + dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); + } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; } - std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e - << ")" << std::endl; - } - delete[] dist_matrix; + delete[] dist_matrix; - delete[] points_l2sq; - delete[] queries_l2sq; + delete[] points_l2sq; + delete[] queries_l2sq; - if (metric == diskann::Metric::COSINE) { - delete[] points; - delete[] queries; - } + if (metric == diskann::Metric::COSINE) + { + delete[] points; + delete[] queries; + } } -template inline int get_num_parts(const char *filename) { - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; - reader.close(); - uint32_t num_parts = (npts_i32 % PARTSIZE) == 0 - ? npts_i32 / PARTSIZE - : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; - std::cout << "Number of parts: " << num_parts << std::endl; - return num_parts; +template inline int get_num_parts(const char *filename) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + uint32_t num_parts = + (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; } template -inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, - size_t &ndims, int part_num) { - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts = end_id - start_id; - ndims = (uint64_t)ndims_i32; - std::cout << "#pts in part = " << npts << ", #dims = " << ndims - << ", size = " << npts * ndims * sizeof(T) << "B" << std::endl; - - reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), - std::ios::beg); - T *data_T = new T[npts * ndims]; - reader.read((char *)data_T, sizeof(T) * npts * ndims); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - data = aligned_malloc(npts * ndims, ALIGNMENT); +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B" + << std::endl; + + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + T *data_T = new T[npts * ndims]; + reader.read((char *)data_T, sizeof(T) * npts * ndims); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts * ndims, ALIGNMENT); #pragma omp parallel for schedule(dynamic, 32768) - for (int64_t i = 0; i < (int64_t)npts; i++) { - for (int64_t j = 0; j < (int64_t)ndims; j++) { - float cur_val_float = (float)data_T[i * ndims + j]; - std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, - sizeof(float)); + for (int64_t i = 0; i < (int64_t)npts; i++) + { + for (int64_t j = 0; j < (int64_t)ndims; j++) + { + float cur_val_float = (float)data_T[i * ndims + j]; + std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float)); + } } - } - delete[] data_T; - std::cout << "Finished converting part data to float." << std::endl; + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; } -template -inline void save_bin(const std::string filename, T *data, size_t npts, - size_t ndims) { - std::ofstream writer; - writer.exceptions(std::ios::failbit | std::ios::badbit); - writer.open(filename, std::ios::binary | std::ios::out); - std::cout << "Writing bin: " << filename << "\n"; - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "bin: #pts = " << npts << ", #dims = " << ndims - << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" - << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(T)); - writer.close(); - std::cout << "Finished writing bin" << std::endl; +template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) +{ + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; } -inline void save_groundtruth_as_one_file(const std::string filename, - int32_t *data, float *distances, - size_t npts, size_t ndims) { - std::ofstream writer(filename, std::ios::binary | std::ios::out); - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " - "npts*dim dist-matrix) with npts = " - << npts << ", dim = " << ndims << ", size = " - << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) << "B" - << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(uint32_t)); - writer.write((char *)distances, npts * ndims * sizeof(float)); - writer.close(); - std::cout << "Finished writing truthset" << std::endl; +inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, + size_t ndims) +{ + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) + << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; } template -std::vector>> -processUnfilteredParts(const std::string &base_file, size_t &nqueries, - size_t &npoints, size_t &dim, size_t &k, - float *query_data, const diskann::Metric &metric, - std::vector &location_to_tag) { - float *base_data = nullptr; - int num_parts = get_num_parts(base_file.c_str()); - std::vector>> res(nqueries); - for (int p = 0; p < num_parts; p++) { - size_t start_id = p * PARTSIZE; - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, - npoints, base_data, nqueries, query_data, metric); - - for (size_t i = 0; i < nqueries; i++) { - for (size_t j = 0; j < part_k; j++) { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - - res[i].push_back(std::make_pair( - (uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); - } +std::vector>> processUnfilteredParts(const std::string &base_file, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag) +{ + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + metric); + + for (size_t i = 0; i < nqueries; i++) + { + for (size_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } + } + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); } - - delete[] closest_points_part; - delete[] dist_closest_points_part; - - diskann::aligned_free(base_data); - } - return res; + return res; }; template -int aux_main(const std::string &base_file, const std::string &query_file, - const std::string >_file, size_t k, - const diskann::Metric &metric, - const std::string &tags_file = std::string("")) { - size_t npoints, nqueries, dim; - - float *query_data; - - load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); - if (nqueries > PARTSIZE) - std::cerr << "WARNING: #Queries provided (" << nqueries - << ") is greater than " << PARTSIZE - << ". Computing GT only for the first " << PARTSIZE << " queries." - << std::endl; - - // load tags - const bool tags_enabled = tags_file.empty() ? false : true; - std::vector location_to_tag = - diskann::loadTags(tags_file, base_file); - - int *closest_points = new int[nqueries * k]; - float *dist_closest_points = new float[nqueries * k]; - - std::vector>> results = - processUnfilteredParts(base_file, nqueries, npoints, dim, k, - query_data, metric, location_to_tag); - - for (size_t i = 0; i < nqueries; i++) { - std::vector> &cur_res = results[i]; - std::sort(cur_res.begin(), cur_res.end(), custom_dist); - size_t j = 0; - for (auto iter : cur_res) { - if (j == k) - break; - if (tags_enabled) { - std::uint32_t index_with_tag = location_to_tag[iter.first]; - closest_points[i * k + j] = (int32_t)index_with_tag; - } else { - closest_points[i * k + j] = (int32_t)iter.first; - } - - if (metric == diskann::Metric::INNER_PRODUCT) - dist_closest_points[i * k + j] = -iter.second; - else - dist_closest_points[i * k + j] = iter.second; - - ++j; +int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k, + const diskann::Metric &metric, const std::string &tags_file = std::string("")) +{ + size_t npoints, nqueries, dim; + + float *query_data; + + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; + + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector>> results = + processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); + + for (size_t i = 0; i < nqueries; i++) + { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) + { + if (j == k) + break; + if (tags_enabled) + { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } + else + { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; + } + if (j < k) + std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; } - if (j < k) - std::cout << "WARNING: found less than k GT entries for query " << i - << std::endl; - } - - save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, - nqueries, k); - delete[] closest_points; - delete[] dist_closest_points; - diskann::aligned_free(query_data); - - return 0; + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; } -void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, - size_t &npts, size_t &dim) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." - << std::endl; - size_t actual_file_size = reader.get_file_size(); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (uint32_t)npts_i32; - dim = (uint32_t)dim_i32; - - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " - << std::endl; - - int truthset_type = -1; // 1 means truthset has ids and distances, 2 means - // only ids, -1 is error - size_t expected_file_size_with_dists = - 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_with_dists) - truthset_type = 1; - - size_t expected_file_size_just_ids = - npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_just_ids) - truthset_type = 2; - - if (truthset_type == -1) { - std::stringstream stream; - stream << "Error. File size mismatch. File should have bin format, with " - "npts followed by ngt followed by npts*ngt ids and optionally " - "followed by npts*ngt distance values; actual size: " - << actual_file_size - << ", expected: " << expected_file_size_with_dists << " or " - << expected_file_size_just_ids; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - ids = new uint32_t[npts * dim]; - reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - - if (truthset_type == 1) { - dists = new float[npts * dim]; - reader.read((char *)dists, npts * dim * sizeof(float)); - } +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } } -int main(int argc, char **argv) { - std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; - uint64_t K; - - try { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), - "distance function "); - desc.add_options()("base_file", - po::value(&base_file)->required(), - "File containing the base vectors in binary format"); - desc.add_options()("query_file", - po::value(&query_file)->required(), - "File containing the query vectors in binary format"); - desc.add_options()("gt_file", po::value(>_file)->required(), - "File name for the writing ground truth in binary " - "format, please don' append .bin at end if " - "no filter_label or filter_label_file is provided it " - "will save the file with '.bin' at end." - "else it will save the file as filename_label.bin"); - desc.add_options()("K", po::value(&K)->required(), - "Number of ground truth nearest neighbors to compute"); - desc.add_options()( - "tags_file", - po::value(&tags_file)->default_value(std::string()), - "File containing the tags in binary format"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; + uint64_t K; + + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), + "distance function "); + desc.add_options()("base_file", po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; + return -1; + } + + try + { + if (data_type == std::string("float")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - if (data_type != std::string("float") && data_type != std::string("int8") && - data_type != std::string("uint8")) { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." - << std::endl; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("l2")) { - metric = diskann::Metric::L2; - } else if (dist_fn == std::string("mips")) { - metric = diskann::Metric::INNER_PRODUCT; - } else if (dist_fn == std::string("cosine")) { - metric = diskann::Metric::COSINE; - } else { - std::cerr << "Unsupported distance function. Use l2/mips/cosine." - << std::endl; - return -1; - } - - try { - if (data_type == std::string("float")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); - if (data_type == std::string("int8")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); - if (data_type == std::string("uint8")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Compute GT failed." << std::endl; - return -1; - } } diff --git a/apps/utils/compute_groundtruth_for_filters.cpp b/apps/utils/compute_groundtruth_for_filters.cpp index c6cfe476b..e90da2444 100644 --- a/apps/utils/compute_groundtruth_for_filters.cpp +++ b/apps/utils/compute_groundtruth_for_filters.cpp @@ -41,876 +41,879 @@ typedef std::string path; namespace po = boost::program_options; -template T div_round_up(const T numerator, const T denominator) { - return (numerator % denominator == 0) ? (numerator / denominator) - : 1 + (numerator / denominator); +template T div_round_up(const T numerator, const T denominator) +{ + return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); } using pairIF = std::pair; -struct cmpmaxstruct { - bool operator()(const pairIF &l, const pairIF &r) { - return l.second < r.second; - }; +struct cmpmaxstruct +{ + bool operator()(const pairIF &l, const pairIF &r) + { + return l.second < r.second; + }; }; -using maxPQIFCS = - std::priority_queue, cmpmaxstruct>; +using maxPQIFCS = std::priority_queue, cmpmaxstruct>; -template T *aligned_malloc(const size_t n, const size_t alignment) { +template T *aligned_malloc(const size_t n, const size_t alignment) +{ #ifdef _WINDOWS - return (T *)_aligned_malloc(sizeof(T) * n, alignment); + return (T *)_aligned_malloc(sizeof(T) * n, alignment); #else - return static_cast(aligned_alloc(alignment, sizeof(T) * n)); + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); #endif } -inline bool custom_dist(const std::pair &a, - const std::pair &b) { - return a.second < b.second; +inline bool custom_dist(const std::pair &a, const std::pair &b) +{ + return a.second < b.second; } -void compute_l2sq(float *const points_l2sq, const float *const matrix, - const int64_t num_points, const uint64_t dim) { - assert(points_l2sq != NULL); +void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) +{ + assert(points_l2sq != NULL); #pragma omp parallel for schedule(static, 65536) - for (int64_t d = 0; d < num_points; ++d) - points_l2sq[d] = - cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, - matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); } -void distsq_to_points( - const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, - const float *const points_l2sq, // points in Col major - size_t nqueries, const float *const queries, - const float *const queries_l2sq, // queries in Col major - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void distsq_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, - (float)-2.0, points, dim, queries, dim, (float)0.0, dist_matrix, - npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, - (float)1.0, points_l2sq, npoints, ones_vec, nqueries, (float)1.0, - dist_matrix, npoints); - cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, - (float)1.0, ones_vec, npoints, queries_l2sq, nqueries, (float)1.0, - dist_matrix, npoints); - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, + ones_vec, nqueries, (float)1.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, + queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; } -void inner_prod_to_points( - const size_t dim, - float *dist_matrix, // Col Major, cols are queries, rows are points - size_t npoints, const float *const points, size_t nqueries, - const float *const queries, - float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +void inner_prod_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 { - bool ones_vec_alloc = false; - if (ones_vec == NULL) { - ones_vec = new float[nqueries > npoints ? nqueries : npoints]; - std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); - ones_vec_alloc = true; - } - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, - (float)-1.0, points, dim, queries, dim, (float)0.0, dist_matrix, - npoints); - - if (ones_vec_alloc) - delete[] ones_vec; + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + + if (ones_vec_alloc) + delete[] ones_vec; } -void exact_knn( - const size_t dim, const size_t k, - size_t *const closest_points, // k * num_queries preallocated, col - // major, queries columns - float *const dist_closest_points, // k * num_queries - // preallocated, Dist to - // corresponding closes_points - size_t npoints, - float *points_in, // points in Col major - size_t nqueries, float *queries_in, - diskann::Metric metric = diskann::Metric::L2) // queries in Col major +void exact_knn(const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major { - float *points_l2sq = new float[npoints]; - float *queries_l2sq = new float[nqueries]; - compute_l2sq(points_l2sq, points_in, npoints, dim); - compute_l2sq(queries_l2sq, queries_in, nqueries, dim); - - float *points = points_in; - float *queries = queries_in; - - if (metric == diskann::Metric::COSINE) { // we convert cosine distance as - // normalized L2 distnace - points = new float[npoints * dim]; - queries = new float[nqueries * dim]; + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) + { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)npoints; i++) { - float norm = std::sqrt(points_l2sq[i]); - if (norm == 0) { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) { - points[i * dim + j] = points_in[i * dim + j] / norm; - } - } + for (int64_t i = 0; i < (int64_t)npoints; i++) + { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } #pragma omp parallel for schedule(static, 4096) - for (int64_t i = 0; i < (int64_t)nqueries; i++) { - float norm = std::sqrt(queries_l2sq[i]); - if (norm == 0) { - norm = std::numeric_limits::epsilon(); - } - for (uint32_t j = 0; j < dim; j++) { - queries[i * dim + j] = queries_in[i * dim + j] / norm; - } - } - // recalculate norms after normalizing, they should all be one. - compute_l2sq(points_l2sq, points, npoints, dim); - compute_l2sq(queries_l2sq, queries, nqueries, dim); - } - - std::cout << "Going to compute " << k << " NNs for " << nqueries - << " queries over " << npoints << " points in " << dim - << " dimensions using"; - if (metric == diskann::Metric::INNER_PRODUCT) - std::cout << " MIPS "; - else if (metric == diskann::Metric::COSINE) - std::cout << " Cosine "; - else - std::cout << " L2 "; - std::cout << "distance fn. " << std::endl; - - size_t q_batch_size = (1 << 9); - float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; - - for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) { - int64_t q_b = b * q_batch_size; - int64_t q_e = - ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; - - if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) { - distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, - q_e - q_b, queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, - queries_l2sq + q_b); - } else { - inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, - queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + for (int64_t i = 0; i < (int64_t)nqueries; i++) + { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } + } + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); } - std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" - << std::endl; + + std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " + << dim << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) + { + int64_t q_b = b * q_batch_size; + int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) + { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); + } + else + { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; #pragma omp parallel for schedule(dynamic, 16) - for (long long q = q_b; q < q_e; q++) { - maxPQIFCS point_dist; - for (size_t p = 0; p < k; p++) - point_dist.emplace(p, - dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * - (ptrdiff_t)npoints]); - for (size_t p = k; p < npoints; p++) { - if (point_dist.top().second > - dist_matrix[(ptrdiff_t)p + - (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) - point_dist.emplace( - p, dist_matrix[(ptrdiff_t)p + - (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - if (point_dist.size() > k) - point_dist.pop(); - } - for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) { - closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = - point_dist.top().first; - dist_closest_points[(ptrdiff_t)(k - 1 - l) + - (ptrdiff_t)q * (ptrdiff_t)k] = - point_dist.top().second; - point_dist.pop(); - } - assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, - dist_closest_points + - (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); + for (long long q = q_b; q < q_e; q++) + { + maxPQIFCS point_dist; + for (size_t p = 0; p < k; p++) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + for (size_t p = k; p < npoints; p++) + { + if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + if (point_dist.size() > k) + point_dist.pop(); + } + for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) + { + closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; + dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; + point_dist.pop(); + } + assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, + dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); + } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; } - std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e - << ")" << std::endl; - } - delete[] dist_matrix; + delete[] dist_matrix; - delete[] points_l2sq; - delete[] queries_l2sq; + delete[] points_l2sq; + delete[] queries_l2sq; - if (metric == diskann::Metric::COSINE) { - delete[] points; - delete[] queries; - } + if (metric == diskann::Metric::COSINE) + { + delete[] points; + delete[] queries; + } } -template inline int get_num_parts(const char *filename) { - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; - reader.close(); - int num_parts = (npts_i32 % PARTSIZE) == 0 - ? npts_i32 / PARTSIZE - : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; - std::cout << "Number of parts: " << num_parts << std::endl; - return num_parts; +template inline int get_num_parts(const char *filename) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; } template -inline void load_bin_as_float(const char *filename, float *&data, - size_t &npts_u64, size_t &ndims_u64, - int part_num) { - std::ifstream reader; - reader.exceptions(std::ios::failbit | std::ios::badbit); - reader.open(filename, std::ios::binary); - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts_u64 = end_id - start_id; - ndims_u64 = (uint64_t)ndims_i32; - std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64 - << ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" - << std::endl; - - reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), - std::ios::beg); - T *data_T = new T[npts_u64 * ndims_u64]; - reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - data = aligned_malloc(npts_u64 * ndims_u64, ALIGNMENT); +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts_u64 = end_id - start_id; + ndims_u64 = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64 + << ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl; + + reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + T *data_T = new T[npts_u64 * ndims_u64]; + reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts_u64 * ndims_u64, ALIGNMENT); #pragma omp parallel for schedule(dynamic, 32768) - for (int64_t i = 0; i < (int64_t)npts_u64; i++) { - for (int64_t j = 0; j < (int64_t)ndims_u64; j++) { - float cur_val_float = (float)data_T[i * ndims_u64 + j]; - std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, - sizeof(float)); + for (int64_t i = 0; i < (int64_t)npts_u64; i++) + { + for (int64_t j = 0; j < (int64_t)ndims_u64; j++) + { + float cur_val_float = (float)data_T[i * ndims_u64 + j]; + std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float)); + } } - } - delete[] data_T; - std::cout << "Finished converting part data to float." << std::endl; + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; } template -inline std::vector load_filtered_bin_as_float( - const char *filename, float *&data, size_t &npts, size_t &ndims, - int part_num, const char *label_file, const std::string &filter_label, - const std::string &universal_label, size_t &npoints_filt, - std::vector> &pts_to_labels) { - std::ifstream reader(filename, std::ios::binary); - if (reader.fail()) { - throw diskann::ANNException(std::string("Failed to open file ") + filename, - -1); - } - - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - std::vector rev_map; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts = end_id - start_id; - ndims = (uint32_t)ndims_i32; - uint64_t nptsuint64_t = (uint64_t)npts; - uint64_t ndimsuint64_t = (uint64_t)ndims; - npoints_filt = 0; - std::cout << "#pts in part = " << npts << ", #dims = " << ndims - << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" - << std::endl; - std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; - reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), - std::ios::beg); - - T *data_T = new T[nptsuint64_t * ndimsuint64_t]; - reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - - data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); - - for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++) { - if (std::find(pts_to_labels[start_id + i].begin(), - pts_to_labels[start_id + i].end(), - filter_label) != pts_to_labels[start_id + i].end() || - std::find(pts_to_labels[start_id + i].begin(), - pts_to_labels[start_id + i].end(), - universal_label) != pts_to_labels[start_id + i].end()) { - rev_map.push_back(start_id + i); - for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++) { - float cur_val_float = (float)data_T[i * ndimsuint64_t + j]; - std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), - (char *)&cur_val_float, sizeof(float)); - } - npoints_filt++; +inline std::vector load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, + int part_num, const char *label_file, + const std::string &filter_label, + const std::string &universal_label, size_t &npoints_filt, + std::vector> &pts_to_labels) +{ + std::ifstream reader(filename, std::ios::binary); + if (reader.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); } - } - delete[] data_T; - std::cout << "Finished converting part data to float.. identified " - << npoints_filt << " points matching the filter." << std::endl; - return rev_map; + + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + std::vector rev_map; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint32_t)ndims_i32; + uint64_t nptsuint64_t = (uint64_t)npts; + uint64_t ndimsuint64_t = (uint64_t)ndims; + npoints_filt = 0; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims + << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl; + std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + + T *data_T = new T[nptsuint64_t * ndimsuint64_t]; + reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + + data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); + + for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++) + { + if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) != + pts_to_labels[start_id + i].end() || + std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) != + pts_to_labels[start_id + i].end()) + { + rev_map.push_back(start_id + i); + for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++) + { + float cur_val_float = (float)data_T[i * ndimsuint64_t + j]; + std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float)); + } + npoints_filt++; + } + } + delete[] data_T; + std::cout << "Finished converting part data to float.. identified " << npoints_filt + << " points matching the filter." << std::endl; + return rev_map; } -template -inline void save_bin(const std::string filename, T *data, size_t npts, - size_t ndims) { - std::ofstream writer; - writer.exceptions(std::ios::failbit | std::ios::badbit); - writer.open(filename, std::ios::binary | std::ios::out); - std::cout << "Writing bin: " << filename << "\n"; - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "bin: #pts = " << npts << ", #dims = " << ndims - << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" - << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(T)); - writer.close(); - std::cout << "Finished writing bin" << std::endl; +template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) +{ + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; } -inline void save_groundtruth_as_one_file(const std::string filename, - int32_t *data, float *distances, - size_t npts, size_t ndims) { - std::ofstream writer(filename, std::ios::binary | std::ios::out); - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " - "npts*dim dist-matrix) with npts = " - << npts << ", dim = " << ndims << ", size = " - << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) << "B" - << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(uint32_t)); - writer.write((char *)distances, npts * ndims * sizeof(float)); - writer.close(); - std::cout << "Finished writing truthset" << std::endl; +inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, + size_t ndims) +{ + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) + << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; } -inline void parse_label_file_into_vec( - size_t &line_cnt, const std::string &map_file, - std::vector> &pts_to_labels) { - std::ifstream infile(map_file); - std::string line, token; - std::set labels; - infile.clear(); - infile.seekg(0, std::ios::beg); - while (std::getline(infile, line)) { - std::istringstream iss(line); - std::vector lbls(0); - - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - lbls.push_back(token); - labels.insert(token); +inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file, + std::vector> &pts_to_labels) +{ + std::ifstream infile(map_file); + std::string line, token; + std::set labels; + infile.clear(); + infile.seekg(0, std::ios::beg); + while (std::getline(infile, line)) + { + std::istringstream iss(line); + std::vector lbls(0); + + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + lbls.push_back(token); + labels.insert(token); + } + std::sort(lbls.begin(), lbls.end()); + pts_to_labels.push_back(lbls); } - std::sort(lbls.begin(), lbls.end()); - pts_to_labels.push_back(lbls); - } - std::cout << "Identified " << labels.size() - << " distinct label(s), and populated labels for " - << pts_to_labels.size() << " points" << std::endl; + std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for " + << pts_to_labels.size() << " points" << std::endl; } template -std::vector>> -processUnfilteredParts(const std::string &base_file, size_t &nqueries, - size_t &npoints, size_t &dim, size_t &k, - float *query_data, const diskann::Metric &metric, - std::vector &location_to_tag) { - float *base_data = nullptr; - int num_parts = get_num_parts(base_file.c_str()); - std::vector>> res(nqueries); - for (int p = 0; p < num_parts; p++) { - size_t start_id = p * PARTSIZE; - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, - npoints, base_data, nqueries, query_data, metric); - - for (size_t i = 0; i < nqueries; i++) { - for (uint64_t j = 0; j < part_k; j++) { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - - res[i].push_back(std::make_pair( - (uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); - } - } +std::vector>> processUnfilteredParts(const std::string &base_file, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag) +{ + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + metric); + + for (size_t i = 0; i < nqueries; i++) + { + for (uint64_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } + } - delete[] closest_points_part; - delete[] dist_closest_points_part; + delete[] closest_points_part; + delete[] dist_closest_points_part; - diskann::aligned_free(base_data); - } - return res; + diskann::aligned_free(base_data); + } + return res; }; template std::vector>> processFilteredParts( - const std::string &base_file, const std::string &label_file, - const std::string &filter_label, const std::string &universal_label, - size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, - float *query_data, const diskann::Metric &metric, - std::vector &location_to_tag) { - size_t npoints_filt = 0; - float *base_data = nullptr; - std::vector>> res(nqueries); - int num_parts = get_num_parts(base_file.c_str()); - - std::vector> pts_to_labels; - if (filter_label != "") - parse_label_file_into_vec(npoints, label_file, pts_to_labels); - - for (int p = 0; p < num_parts; p++) { - size_t start_id = p * PARTSIZE; - std::vector rev_map; - if (filter_label != "") - rev_map = load_filtered_bin_as_float( - base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), - filter_label, universal_label, npoints_filt, pts_to_labels); - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints_filt ? k : npoints_filt; - if (npoints_filt > 0) { - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, - npoints_filt, base_data, nqueries, query_data, metric); - } + const std::string &base_file, const std::string &label_file, const std::string &filter_label, + const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, std::vector &location_to_tag) +{ + size_t npoints_filt = 0; + float *base_data = nullptr; + std::vector>> res(nqueries); + int num_parts = get_num_parts(base_file.c_str()); - for (size_t i = 0; i < nqueries; i++) { - for (uint64_t j = 0; j < part_k; j++) { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; + std::vector> pts_to_labels; + if (filter_label != "") + parse_label_file_into_vec(npoints, label_file, pts_to_labels); + + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + std::vector rev_map; + if (filter_label != "") + rev_map = load_filtered_bin_as_float(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), + filter_label, universal_label, npoints_filt, pts_to_labels); + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints_filt ? k : npoints_filt; + if (npoints_filt > 0) + { + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries, + query_data, metric); + } - res[i].push_back(std::make_pair( - (uint32_t)(rev_map[closest_points_part[i * part_k + j]]), - dist_closest_points_part[i * part_k + j])); - } - } + for (size_t i = 0; i < nqueries; i++) + { + for (uint64_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]), + dist_closest_points_part[i * part_k + j])); + } + } - delete[] closest_points_part; - delete[] dist_closest_points_part; + delete[] closest_points_part; + delete[] dist_closest_points_part; - diskann::aligned_free(base_data); - } - return res; + diskann::aligned_free(base_data); + } + return res; }; template -int aux_main(const std::string &base_file, const std::string &label_file, - const std::string &query_file, const std::string >_file, - size_t k, const std::string &universal_label, - const diskann::Metric &metric, const std::string &filter_label, - const std::string &tags_file = std::string("")) { - size_t npoints, nqueries, dim; - - float *query_data = nullptr; - - load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); - if (nqueries > PARTSIZE) - std::cerr << "WARNING: #Queries provided (" << nqueries - << ") is greater than " << PARTSIZE - << ". Computing GT only for the first " << PARTSIZE << " queries." - << std::endl; - - // load tags - const bool tags_enabled = tags_file.empty() ? false : true; - std::vector location_to_tag = - diskann::loadTags(tags_file, base_file); - - int *closest_points = new int[nqueries * k]; - float *dist_closest_points = new float[nqueries * k]; - - std::vector>> results; - if (filter_label == "") { - results = processUnfilteredParts(base_file, nqueries, npoints, dim, k, - query_data, metric, location_to_tag); - } else { - results = processFilteredParts(base_file, label_file, filter_label, - universal_label, nqueries, npoints, dim, - k, query_data, metric, location_to_tag); - } - - for (size_t i = 0; i < nqueries; i++) { - std::vector> &cur_res = results[i]; - std::sort(cur_res.begin(), cur_res.end(), custom_dist); - size_t j = 0; - for (auto iter : cur_res) { - if (j == k) - break; - if (tags_enabled) { - std::uint32_t index_with_tag = location_to_tag[iter.first]; - closest_points[i * k + j] = (int32_t)index_with_tag; - } else { - closest_points[i * k + j] = (int32_t)iter.first; - } - - if (metric == diskann::Metric::INNER_PRODUCT) - dist_closest_points[i * k + j] = -iter.second; - else - dist_closest_points[i * k + j] = iter.second; - - ++j; +int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file, + const std::string >_file, size_t k, const std::string &universal_label, const diskann::Metric &metric, + const std::string &filter_label, const std::string &tags_file = std::string("")) +{ + size_t npoints, nqueries, dim; + + float *query_data = nullptr; + + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; + + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector>> results; + if (filter_label == "") + { + results = processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); + } + else + { + results = processFilteredParts(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim, + k, query_data, metric, location_to_tag); + } + + for (size_t i = 0; i < nqueries; i++) + { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) + { + if (j == k) + break; + if (tags_enabled) + { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } + else + { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; + } + if (j < k) + std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; } - if (j < k) - std::cout << "WARNING: found less than k GT entries for query " << i - << std::endl; - } - - save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, - nqueries, k); - delete[] closest_points; - delete[] dist_closest_points; - diskann::aligned_free(query_data); - - return 0; + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; } -void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, - size_t &npts, size_t &dim) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." - << std::endl; - size_t actual_file_size = reader.get_file_size(); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (uint32_t)npts_i32; - dim = (uint32_t)dim_i32; - - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " - << std::endl; - - int truthset_type = -1; // 1 means truthset has ids and distances, 2 means - // only ids, -1 is error - size_t expected_file_size_with_dists = - 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_with_dists) - truthset_type = 1; - - size_t expected_file_size_just_ids = - npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_just_ids) - truthset_type = 2; - - if (truthset_type == -1) { - std::stringstream stream; - stream << "Error. File size mismatch. File should have bin format, with " - "npts followed by ngt followed by npts*ngt ids and optionally " - "followed by npts*ngt distance values; actual size: " - << actual_file_size - << ", expected: " << expected_file_size_with_dists << " or " - << expected_file_size_just_ids; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - ids = new uint32_t[npts * dim]; - reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - - if (truthset_type == 1) { - dists = new float[npts * dim]; - reader.read((char *)dists, npts * dim * sizeof(float)); - } +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } } -int main(int argc, char **argv) { - std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, - label_file, filter_label, universal_label, filter_label_file; - uint64_t K; - - try { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), - "distance function "); - desc.add_options()("base_file", - po::value(&base_file)->required(), - "File containing the base vectors in binary format"); - desc.add_options()("query_file", - po::value(&query_file)->required(), - "File containing the query vectors in binary format"); - desc.add_options()("label_file", - po::value(&label_file)->default_value(""), - "Input labels file in txt format if present"); - desc.add_options()("filter_label", - po::value(&filter_label)->default_value(""), - "Input filter label if doing filtered groundtruth"); - desc.add_options()( - "universal_label", - po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with label_file"); - desc.add_options()("gt_file", po::value(>_file)->required(), - "File name for the writing ground truth in binary " - "format, please don' append .bin at end if " - "no filter_label or filter_label_file is provided it " - "will save the file with '.bin' at end." - "else it will save the file as filename_label.bin"); - desc.add_options()("K", po::value(&K)->required(), - "Number of ground truth nearest neighbors to compute"); - desc.add_options()( - "tags_file", - po::value(&tags_file)->default_value(std::string()), - "File containing the tags in binary format"); - desc.add_options()("filter_label_file", - po::value(&filter_label_file) - ->default_value(std::string("")), - "Filter file for Queries for Filtered Search "); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label, + universal_label, filter_label_file; + uint64_t K; + + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); + desc.add_options()("base_file", po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("label_file", po::value(&label_file)->default_value(""), + "Input labels file in txt format if present"); + desc.add_options()("filter_label", po::value(&filter_label)->default_value(""), + "Input filter label if doing filtered groundtruth"); + desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with label_file"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + desc.add_options()("filter_label_file", + po::value(&filter_label_file)->default_value(std::string("")), + "Filter file for Queries for Filtered Search "); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - if (data_type != std::string("float") && data_type != std::string("int8") && - data_type != std::string("uint8")) { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." - << std::endl; - return -1; - } - - if (filter_label != "" && filter_label_file != "") { - std::cerr - << "Only one of filter_label and query_filters_file should be provided" - << std::endl; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("l2")) { - metric = diskann::Metric::L2; - } else if (dist_fn == std::string("mips")) { - metric = diskann::Metric::INNER_PRODUCT; - } else if (dist_fn == std::string("cosine")) { - metric = diskann::Metric::COSINE; - } else { - std::cerr << "Unsupported distance function. Use l2/mips/cosine." - << std::endl; - return -1; - } - - std::vector filter_labels; - if (filter_label != "") { - filter_labels.push_back(filter_label); - } else if (filter_label_file != "") { - filter_labels = read_file_to_vector_of_strings(filter_label_file, false); - } - - // only when there is no filter label or 1 filter label for all queries - if (filter_labels.size() == 1) { - try { - if (data_type == std::string("float")) - aux_main(base_file, label_file, query_file, gt_file, K, - universal_label, metric, filter_labels[0], tags_file); - if (data_type == std::string("int8")) - aux_main(base_file, label_file, query_file, gt_file, K, - universal_label, metric, filter_labels[0], tags_file); - if (data_type == std::string("uint8")) - aux_main(base_file, label_file, query_file, gt_file, K, - universal_label, metric, filter_labels[0], tags_file); - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Compute GT failed." << std::endl; - return -1; + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - } else { // Each query has its own filter label - // Split up data and query bins into label specific ones - tsl::robin_map labels_to_number_of_points; - tsl::robin_map labels_to_number_of_queries; - - label_set all_labels; - for (size_t i = 0; i < filter_labels.size(); i++) { - std::string label = filter_labels[i]; - all_labels.insert(label); - - if (labels_to_number_of_queries.find(label) == - labels_to_number_of_queries.end()) { - labels_to_number_of_queries[label] = 0; - } - labels_to_number_of_queries[label] += 1; + + if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + return -1; } - size_t npoints; - std::vector> point_to_labels; - parse_label_file_into_vec(npoints, label_file, point_to_labels); - std::vector point_ids_to_labels(point_to_labels.size()); - std::vector query_ids_to_labels(filter_labels.size()); - - for (size_t i = 0; i < point_to_labels.size(); i++) { - for (size_t j = 0; j < point_to_labels[i].size(); j++) { - std::string label = point_to_labels[i][j]; - if (all_labels.find(label) != all_labels.end()) { - point_ids_to_labels[i].insert(point_to_labels[i][j]); - if (labels_to_number_of_points.find(label) == - labels_to_number_of_points.end()) { - labels_to_number_of_points[label] = 0; - } - labels_to_number_of_points[label] += 1; - } - } + if (filter_label != "" && filter_label_file != "") + { + std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; + return -1; } - for (size_t i = 0; i < filter_labels.size(); i++) { - query_ids_to_labels[i].insert(filter_labels[i]); + diskann::Metric metric; + if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; + return -1; } - tsl::robin_map> label_id_to_orig_id; - tsl::robin_map> - label_query_id_to_orig_id; - - if (data_type == std::string("float")) { - label_id_to_orig_id = - diskann::generate_label_specific_vector_files_compat( - base_file, labels_to_number_of_points, point_ids_to_labels, - all_labels); - - label_query_id_to_orig_id = - diskann::generate_label_specific_vector_files_compat( - query_file, labels_to_number_of_queries, query_ids_to_labels, - all_labels); // query_filters acts like query_ids_to_labels - } else if (data_type == std::string("int8")) { - label_id_to_orig_id = - diskann::generate_label_specific_vector_files_compat( - base_file, labels_to_number_of_points, point_ids_to_labels, - all_labels); - - label_query_id_to_orig_id = - diskann::generate_label_specific_vector_files_compat( - query_file, labels_to_number_of_queries, query_ids_to_labels, - all_labels); // query_filters acts like query_ids_to_labels - } else if (data_type == std::string("uint8")) { - label_id_to_orig_id = - diskann::generate_label_specific_vector_files_compat( - base_file, labels_to_number_of_points, point_ids_to_labels, - all_labels); - - label_query_id_to_orig_id = - diskann::generate_label_specific_vector_files_compat( - query_file, labels_to_number_of_queries, query_ids_to_labels, - all_labels); // query_filters acts like query_ids_to_labels - } else { - diskann::cerr << "Invalid data type" << std::endl; - return -1; + std::vector filter_labels; + if (filter_label != "") + { + filter_labels.push_back(filter_label); + } + else if (filter_label_file != "") + { + filter_labels = read_file_to_vector_of_strings(filter_label_file, false); } - // Generate label specific ground truths + // only when there is no filter label or 1 filter label for all queries + if (filter_labels.size() == 1) + { + try + { + if (data_type == std::string("float")) + aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, + filter_labels[0], tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, + filter_labels[0], tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, + filter_labels[0], tags_file); + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } + } + else + { // Each query has its own filter label + // Split up data and query bins into label specific ones + tsl::robin_map labels_to_number_of_points; + tsl::robin_map labels_to_number_of_queries; + + label_set all_labels; + for (size_t i = 0; i < filter_labels.size(); i++) + { + std::string label = filter_labels[i]; + all_labels.insert(label); + + if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end()) + { + labels_to_number_of_queries[label] = 0; + } + labels_to_number_of_queries[label] += 1; + } + + size_t npoints; + std::vector> point_to_labels; + parse_label_file_into_vec(npoints, label_file, point_to_labels); + std::vector point_ids_to_labels(point_to_labels.size()); + std::vector query_ids_to_labels(filter_labels.size()); + + for (size_t i = 0; i < point_to_labels.size(); i++) + { + for (size_t j = 0; j < point_to_labels[i].size(); j++) + { + std::string label = point_to_labels[i][j]; + if (all_labels.find(label) != all_labels.end()) + { + point_ids_to_labels[i].insert(point_to_labels[i][j]); + if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end()) + { + labels_to_number_of_points[label] = 0; + } + labels_to_number_of_points[label] += 1; + } + } + } + + for (size_t i = 0; i < filter_labels.size(); i++) + { + query_ids_to_labels[i].insert(filter_labels[i]); + } + + tsl::robin_map> label_id_to_orig_id; + tsl::robin_map> label_query_id_to_orig_id; - try { - for (const auto &label : all_labels) { - std::string filtered_base_file = base_file + "_" + label; - std::string filtered_query_file = query_file + "_" + label; - std::string filtered_gt_file = gt_file + "_" + label; if (data_type == std::string("float")) - aux_main(filtered_base_file, "", filtered_query_file, - filtered_gt_file, K, "", metric, ""); - if (data_type == std::string("int8")) - aux_main(filtered_base_file, "", filtered_query_file, - filtered_gt_file, K, "", metric, ""); - if (data_type == std::string("uint8")) - aux_main(filtered_base_file, "", filtered_query_file, - filtered_gt_file, K, "", metric, ""); - } - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Compute GT failed." << std::endl; - return -1; - } + { + label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); - // Combine the label specific ground truths to produce a single GT file + label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } + else if (data_type == std::string("int8")) + { + label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); + + label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } + else if (data_type == std::string("uint8")) + { + label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); + + label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } + else + { + diskann::cerr << "Invalid data type" << std::endl; + return -1; + } - uint32_t *gt_ids = nullptr; - float *gt_dists = nullptr; - size_t gt_num, gt_dim; + // Generate label specific ground truths + + try + { + for (const auto &label : all_labels) + { + std::string filtered_base_file = base_file + "_" + label; + std::string filtered_query_file = query_file + "_" + label; + std::string filtered_gt_file = gt_file + "_" + label; + if (data_type == std::string("float")) + aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); + if (data_type == std::string("int8")) + aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); + if (data_type == std::string("uint8")) + aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); + } + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } - std::vector> final_gt_ids; - std::vector> final_gt_dists; + // Combine the label specific ground truths to produce a single GT file - uint32_t query_num = 0; - for (const auto &lbl : all_labels) { - query_num += labels_to_number_of_queries[lbl]; - } + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t gt_num, gt_dim; - for (uint32_t i = 0; i < query_num; i++) { - final_gt_ids.push_back(std::vector(K)); - final_gt_dists.push_back(std::vector(K)); - } + std::vector> final_gt_ids; + std::vector> final_gt_dists; - for (const auto &lbl : all_labels) { - std::string filtered_gt_file = gt_file + "_" + lbl; - load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim); + uint32_t query_num = 0; + for (const auto &lbl : all_labels) + { + query_num += labels_to_number_of_queries[lbl]; + } - for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++) { - uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i]; - for (uint64_t j = 0; j < K; j++) { - final_gt_ids[orig_query_id][j] = - label_id_to_orig_id[lbl][gt_ids[i * K + j]]; - final_gt_dists[orig_query_id][j] = gt_dists[i * K + j]; + for (uint32_t i = 0; i < query_num; i++) + { + final_gt_ids.push_back(std::vector(K)); + final_gt_dists.push_back(std::vector(K)); } - } - } - int32_t *closest_points = new int32_t[query_num * K]; - float *dist_closest_points = new float[query_num * K]; + for (const auto &lbl : all_labels) + { + std::string filtered_gt_file = gt_file + "_" + lbl; + load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim); + + for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++) + { + uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i]; + for (uint64_t j = 0; j < K; j++) + { + final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]]; + final_gt_dists[orig_query_id][j] = gt_dists[i * K + j]; + } + } + } - for (uint32_t i = 0; i < query_num; i++) { - for (uint32_t j = 0; j < K; j++) { - closest_points[i * K + j] = final_gt_ids[i][j]; - dist_closest_points[i * K + j] = final_gt_dists[i][j]; - } - } + int32_t *closest_points = new int32_t[query_num * K]; + float *dist_closest_points = new float[query_num * K]; - save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, - query_num, K); + for (uint32_t i = 0; i < query_num; i++) + { + for (uint32_t j = 0; j < K; j++) + { + closest_points[i * K + j] = final_gt_ids[i][j]; + dist_closest_points[i * K + j] = final_gt_dists[i][j]; + } + } - // cleanup artifacts - std::cout << "Cleaning up artifacts..." << std::endl; - tsl::robin_set paths_to_clean{gt_file, base_file, query_file}; - clean_up_artifacts(paths_to_clean, all_labels); - } + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K); + + // cleanup artifacts + std::cout << "Cleaning up artifacts..." << std::endl; + tsl::robin_set paths_to_clean{gt_file, base_file, query_file}; + clean_up_artifacts(paths_to_clean, all_labels); + } } diff --git a/apps/utils/count_bfs_levels.cpp b/apps/utils/count_bfs_levels.cpp index 2f1a1db92..6e45ef13d 100644 --- a/apps/utils/count_bfs_levels.cpp +++ b/apps/utils/count_bfs_levels.cpp @@ -23,57 +23,60 @@ namespace po = boost::program_options; -template -void bfs_count(const std::string &index_path, uint32_t data_dims) { - using TagT = uint32_t; - using LabelT = uint32_t; - diskann::Index index(diskann::Metric::L2, data_dims, 0, - nullptr, nullptr, 0, false, false, - false, false, 0, false); - std::cout << "Index class instantiated" << std::endl; - index.load(index_path.c_str(), 1, 100); - std::cout << "Index loaded" << std::endl; - index.count_nodes_at_bfs_levels(); +template void bfs_count(const std::string &index_path, uint32_t data_dims) +{ + using TagT = uint32_t; + using LabelT = uint32_t; + diskann::Index index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false, + false, 0, false); + std::cout << "Index class instantiated" << std::endl; + index.load(index_path.c_str(), 1, 100); + std::cout << "Index loaded" << std::endl; + index.count_nodes_at_bfs_levels(); } -int main(int argc, char **argv) { - std::string data_type, index_path_prefix; - uint32_t data_dims; +int main(int argc, char **argv) +{ + std::string data_type, index_path_prefix; + uint32_t data_dims; - po::options_description desc{"Arguments"}; - try { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("index_path_prefix", - po::value(&index_path_prefix)->required(), - "Path prefix to the index"); - desc.add_options()("data_dims", po::value(&data_dims)->required(), - "Dimensionality of the data"); + po::options_description desc{"Arguments"}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + "Path prefix to the index"); + desc.add_options()("data_dims", po::value(&data_dims)->required(), "Dimensionality of the data"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - try { - if (data_type == std::string("int8")) - bfs_count(index_path_prefix, data_dims); - else if (data_type == std::string("uint8")) - bfs_count(index_path_prefix, data_dims); - if (data_type == std::string("float")) - bfs_count(index_path_prefix, data_dims); - } catch (std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index BFS failed." << std::endl; - return -1; - } + try + { + if (data_type == std::string("int8")) + bfs_count(index_path_prefix, data_dims); + else if (data_type == std::string("uint8")) + bfs_count(index_path_prefix, data_dims); + if (data_type == std::string("float")) + bfs_count(index_path_prefix, data_dims); + } + catch (std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index BFS failed." << std::endl; + return -1; + } } diff --git a/apps/utils/create_disk_layout.cpp b/apps/utils/create_disk_layout.cpp index 7c8eca1b0..6d5314fb4 100644 --- a/apps/utils/create_disk_layout.cpp +++ b/apps/utils/create_disk_layout.cpp @@ -12,33 +12,37 @@ #include "disk_utils.h" #include "utils.h" -template int create_disk_layout(char **argv) { - std::string base_file(argv[2]); - std::string vamana_file(argv[3]); - std::string output_file(argv[4]); - diskann::create_disk_layout(base_file, vamana_file, output_file); - return 0; +template int create_disk_layout(char **argv) +{ + std::string base_file(argv[2]); + std::string vamana_file(argv[3]); + std::string output_file(argv[4]); + diskann::create_disk_layout(base_file, vamana_file, output_file); + return 0; } -int main(int argc, char **argv) { - if (argc != 5) { - std::cout << argv[0] - << " data_type data_bin " - "vamana_index_file output_diskann_index_file" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 5) + { + std::cout << argv[0] + << " data_type data_bin " + "vamana_index_file output_diskann_index_file" + << std::endl; + exit(-1); + } - int ret_val = -1; - if (std::string(argv[1]) == std::string("float")) - ret_val = create_disk_layout(argv); - else if (std::string(argv[1]) == std::string("int8")) - ret_val = create_disk_layout(argv); - else if (std::string(argv[1]) == std::string("uint8")) - ret_val = create_disk_layout(argv); - else { - std::cout << "unsupported type. use int8/uint8/float " << std::endl; - ret_val = -2; - } - return ret_val; + int ret_val = -1; + if (std::string(argv[1]) == std::string("float")) + ret_val = create_disk_layout(argv); + else if (std::string(argv[1]) == std::string("int8")) + ret_val = create_disk_layout(argv); + else if (std::string(argv[1]) == std::string("uint8")) + ret_val = create_disk_layout(argv); + else + { + std::cout << "unsupported type. use int8/uint8/float " << std::endl; + ret_val = -2; + } + return ret_val; } diff --git a/apps/utils/float_bin_to_int8.cpp b/apps/utils/float_bin_to_int8.cpp index d3776b641..c3fa8f8ec 100644 --- a/apps/utils/float_bin_to_int8.cpp +++ b/apps/utils/float_bin_to_int8.cpp @@ -4,59 +4,60 @@ #include "utils.h" #include -void block_convert(std::ofstream &writer, int8_t *write_buf, - std::ifstream &reader, float *read_buf, size_t npts, - size_t ndims, float bias, float scale) { - reader.read((char *)read_buf, npts * ndims * sizeof(float)); - - for (size_t i = 0; i < npts; i++) { - for (size_t d = 0; d < ndims; d++) { - write_buf[d + i * ndims] = - (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale)); +void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts, + size_t ndims, float bias, float scale) +{ + reader.read((char *)read_buf, npts * ndims * sizeof(float)); + + for (size_t i = 0; i < npts; i++) + { + for (size_t d = 0; d < ndims; d++) + { + write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale)); + } } - } - writer.write((char *)write_buf, npts * ndims); + writer.write((char *)write_buf, npts * ndims); } -int main(int argc, char **argv) { - if (argc != 5) { - std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" - << std::endl; - exit(-1); - } - - std::ifstream reader(argv[1], std::ios::binary); - uint32_t npts_u32; - uint32_t ndims_u32; - reader.read((char *)&npts_u32, sizeof(uint32_t)); - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - size_t npts = npts_u32; - size_t ndims = ndims_u32; - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - - std::ofstream writer(argv[2], std::ios::binary); - auto read_buf = new float[blk_size * ndims]; - auto write_buf = new int8_t[blk_size * ndims]; - float bias = (float)atof(argv[3]); - float scale = (float)atof(argv[4]); - - writer.write((char *)(&npts_u32), sizeof(uint32_t)); - writer.write((char *)(&ndims_u32), sizeof(uint32_t)); - - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, - scale); - std::cout << "Block #" << i << " written" << std::endl; - } - - delete[] read_buf; - delete[] write_buf; - - writer.close(); - reader.close(); +int main(int argc, char **argv) +{ + if (argc != 5) + { + std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl; + exit(-1); + } + + std::ifstream reader(argv[1], std::ios::binary); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + size_t npts = npts_u32; + size_t ndims = ndims_u32; + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + + std::ofstream writer(argv[2], std::ios::binary); + auto read_buf = new float[blk_size * ndims]; + auto write_buf = new int8_t[blk_size * ndims]; + float bias = (float)atof(argv[3]); + float scale = (float)atof(argv[4]); + + writer.write((char *)(&npts_u32), sizeof(uint32_t)); + writer.write((char *)(&ndims_u32), sizeof(uint32_t)); + + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale); + std::cout << "Block #" << i << " written" << std::endl; + } + + delete[] read_buf; + delete[] write_buf; + + writer.close(); + reader.close(); } diff --git a/apps/utils/fvecs_to_bin.cpp b/apps/utils/fvecs_to_bin.cpp index 02dbacf54..1428a9c6e 100644 --- a/apps/utils/fvecs_to_bin.cpp +++ b/apps/utils/fvecs_to_bin.cpp @@ -5,88 +5,91 @@ #include // Convert float types -void block_convert_float(std::ifstream &reader, std::ofstream &writer, - float *read_buf, float *write_buf, size_t npts, - size_t ndims) { - reader.read((char *)read_buf, - npts * (ndims * sizeof(float) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, - ndims * sizeof(float)); - } - writer.write((char *)write_buf, npts * ndims * sizeof(float)); +void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts, + size_t ndims) +{ + reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) + { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float)); + } + writer.write((char *)write_buf, npts * ndims * sizeof(float)); } // Convert byte types -void block_convert_byte(std::ifstream &reader, std::ofstream &writer, - uint8_t *read_buf, uint8_t *write_buf, size_t npts, - size_t ndims) { - reader.read((char *)read_buf, - npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * ndims, - (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t), - ndims * sizeof(uint8_t)); - } - writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t)); +void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf, + size_t npts, size_t ndims) +{ + reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) + { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t), + ndims * sizeof(uint8_t)); + } + writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t)); } -int main(int argc, char **argv) { - if (argc != 4) { - std::cout << argv[0] << " input_vecs output_bin" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 4) + { + std::cout << argv[0] << " input_vecs output_bin" << std::endl; + exit(-1); + } - int datasize = sizeof(float); + int datasize = sizeof(float); - if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0) { - datasize = sizeof(uint8_t); - } else if (strcmp(argv[1], "float") != 0) { - std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl; - exit(-1); - } + if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0) + { + datasize = sizeof(uint8_t); + } + else if (strcmp(argv[1], "float") != 0) + { + std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl; + exit(-1); + } - std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); - size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); + std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); + size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); - uint32_t ndims_u32; - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - reader.seekg(0, std::ios::beg); - size_t ndims = (size_t)ndims_u32; - size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + reader.seekg(0, std::ios::beg); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[3], std::ios::binary); - int32_t npts_s32 = (int32_t)npts; - int32_t ndims_s32 = (int32_t)ndims; - writer.write((char *)&npts_s32, sizeof(int32_t)); - writer.write((char *)&ndims_s32, sizeof(int32_t)); + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[3], std::ios::binary); + int32_t npts_s32 = (int32_t)npts; + int32_t ndims_s32 = (int32_t)ndims; + writer.write((char *)&npts_s32, sizeof(int32_t)); + writer.write((char *)&ndims_s32, sizeof(int32_t)); - size_t chunknpts = std::min(npts, blk_size); - uint8_t *read_buf = - new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))]; - uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize]; + size_t chunknpts = std::min(npts, blk_size); + uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))]; + uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize]; - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (datasize == sizeof(float)) { - block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, - cblk_size, ndims); - } else { - block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims); + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (datasize == sizeof(float)) + { + block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims); + } + else + { + block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims); + } + std::cout << "Block #" << i << " written" << std::endl; } - std::cout << "Block #" << i << " written" << std::endl; - } - delete[] read_buf; - delete[] write_buf; + delete[] read_buf; + delete[] write_buf; - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/fvecs_to_bvecs.cpp b/apps/utils/fvecs_to_bvecs.cpp index a5cc09449..60ac12126 100644 --- a/apps/utils/fvecs_to_bvecs.cpp +++ b/apps/utils/fvecs_to_bvecs.cpp @@ -4,53 +4,53 @@ #include "utils.h" #include -void block_convert(std::ifstream &reader, std::ofstream &writer, - float *read_buf, uint8_t *write_buf, size_t npts, - size_t ndims) { - reader.read((char *)read_buf, - npts * (ndims * sizeof(float) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), - sizeof(uint32_t)); - for (size_t d = 0; d < ndims; d++) - write_buf[i * (ndims + 4) + 4 + d] = - (uint8_t)read_buf[i * (ndims + 1) + 1 + d]; - } - writer.write((char *)write_buf, npts * (ndims * 1 + 4)); +void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts, + size_t ndims) +{ + reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) + { + memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t)); + for (size_t d = 0; d < ndims; d++) + write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d]; + } + writer.write((char *)write_buf, npts * (ndims * 1 + 4)); } -int main(int argc, char **argv) { - if (argc != 3) { - std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl; - exit(-1); - } - std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); - size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); +int main(int argc, char **argv) +{ + if (argc != 3) + { + std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl; + exit(-1); + } + std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); + size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); - uint32_t ndims_u32; - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - reader.seekg(0, std::ios::beg); - size_t ndims = (size_t)ndims_u32; - size_t npts = fsize / ((ndims + 1) * sizeof(float)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + reader.seekg(0, std::ios::beg); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims + 1) * sizeof(float)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[2], std::ios::binary); - auto read_buf = new float[npts * (ndims + 1)]; - auto write_buf = new uint8_t[npts * (ndims + 4)]; - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[2], std::ios::binary); + auto read_buf = new float[npts * (ndims + 1)]; + auto write_buf = new uint8_t[npts * (ndims + 4)]; + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; - delete[] write_buf; + delete[] read_buf; + delete[] write_buf; - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/gen_random_slice.cpp b/apps/utils/gen_random_slice.cpp index 29307937d..64bc994ef 100644 --- a/apps/utils/gen_random_slice.cpp +++ b/apps/utils/gen_random_slice.cpp @@ -20,30 +20,39 @@ #include #include -template int aux_main(char **argv) { - std::string base_file(argv[2]); - std::string output_prefix(argv[3]); - float sampling_rate = (float)(std::atof(argv[4])); - gen_random_slice(base_file, output_prefix, sampling_rate); - return 0; +template int aux_main(char **argv) +{ + std::string base_file(argv[2]); + std::string output_prefix(argv[3]); + float sampling_rate = (float)(std::atof(argv[4])); + gen_random_slice(base_file, output_prefix, sampling_rate); + return 0; } -int main(int argc, char **argv) { - if (argc != 5) { - std::cout << argv[0] - << " data_type [float/int8/uint8] base_bin_file " - "sample_output_prefix sampling_probability" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 5) + { + std::cout << argv[0] + << " data_type [float/int8/uint8] base_bin_file " + "sample_output_prefix sampling_probability" + << std::endl; + exit(-1); + } - if (std::string(argv[1]) == std::string("float")) { - aux_main(argv); - } else if (std::string(argv[1]) == std::string("int8")) { - aux_main(argv); - } else if (std::string(argv[1]) == std::string("uint8")) { - aux_main(argv); - } else - std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; - return 0; + if (std::string(argv[1]) == std::string("float")) + { + aux_main(argv); + } + else if (std::string(argv[1]) == std::string("int8")) + { + aux_main(argv); + } + else if (std::string(argv[1]) == std::string("uint8")) + { + aux_main(argv); + } + else + std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; + return 0; } diff --git a/apps/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp index 390bf4355..cff7a3526 100644 --- a/apps/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -8,66 +8,63 @@ #define KMEANS_ITERS_FOR_PQ 15 template -bool generate_pq(const std::string &data_path, - const std::string &index_prefix_path, - const size_t num_pq_centers, const size_t num_pq_chunks, - const float sampling_rate, const bool opq) { - std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; - std::string pq_compressed_vectors_path = - index_prefix_path + "_pq_compressed.bin"; +bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers, + const size_t num_pq_chunks, const float sampling_rate, const bool opq) +{ + std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; + std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin"; - // generates random sample and sets it to train_data and updates train_size - size_t train_size, train_dim; - float *train_data; - gen_random_slice(data_path, sampling_rate, train_data, train_size, - train_dim); - std::cout << "For computing pivots, loaded sample data of size " << train_size - << std::endl; + // generates random sample and sets it to train_data and updates train_size + size_t train_size, train_dim; + float *train_data; + gen_random_slice(data_path, sampling_rate, train_data, train_size, train_dim); + std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl; - if (opq) { - diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, - (uint32_t)num_pq_centers, - (uint32_t)num_pq_chunks, pq_pivots_path, true); - } else { - diskann::generate_pq_pivots( - train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, - (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); - } - diskann::generate_pq_data_from_pivots( - data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, - pq_pivots_path, pq_compressed_vectors_path, true); + if (opq) + { + diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, + (uint32_t)num_pq_chunks, pq_pivots_path, true); + } + else + { + diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, + (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); + } + diskann::generate_pq_data_from_pivots(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, + pq_pivots_path, pq_compressed_vectors_path, true); - delete[] train_data; + delete[] train_data; - return 0; + return 0; } -int main(int argc, char **argv) { - if (argc != 7) { - std::cout << "Usage: \n" - << argv[0] - << " " - " " - " " - << std::endl; - } else { - const std::string data_path(argv[2]); - const std::string index_prefix_path(argv[3]); - const size_t num_pq_centers = 256; - const size_t num_pq_chunks = (size_t)atoi(argv[4]); - const float sampling_rate = (float)atof(argv[5]); - const bool opq = atoi(argv[6]) == 0 ? false : true; - - if (std::string(argv[1]) == std::string("float")) - generate_pq(data_path, index_prefix_path, num_pq_centers, - num_pq_chunks, sampling_rate, opq); - else if (std::string(argv[1]) == std::string("int8")) - generate_pq(data_path, index_prefix_path, num_pq_centers, - num_pq_chunks, sampling_rate, opq); - else if (std::string(argv[1]) == std::string("uint8")) - generate_pq(data_path, index_prefix_path, num_pq_centers, - num_pq_chunks, sampling_rate, opq); +int main(int argc, char **argv) +{ + if (argc != 7) + { + std::cout << "Usage: \n" + << argv[0] + << " " + " " + " " + << std::endl; + } else - std::cout << "Error. wrong file type" << std::endl; - } + { + const std::string data_path(argv[2]); + const std::string index_prefix_path(argv[3]); + const size_t num_pq_centers = 256; + const size_t num_pq_chunks = (size_t)atoi(argv[4]); + const float sampling_rate = (float)atof(argv[5]); + const bool opq = atoi(argv[6]) == 0 ? false : true; + + if (std::string(argv[1]) == std::string("float")) + generate_pq(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq); + else if (std::string(argv[1]) == std::string("int8")) + generate_pq(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq); + else if (std::string(argv[1]) == std::string("uint8")) + generate_pq(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq); + else + std::cout << "Error. wrong file type" << std::endl; + } } diff --git a/apps/utils/generate_synthetic_labels.cpp b/apps/utils/generate_synthetic_labels.cpp index 4b11df0a4..766c297d7 100644 --- a/apps/utils/generate_synthetic_labels.cpp +++ b/apps/utils/generate_synthetic_labels.cpp @@ -9,168 +9,196 @@ #include namespace po = boost::program_options; -class ZipfDistribution { -public: - ZipfDistribution(uint64_t num_points, uint32_t num_labels) - : num_labels(num_labels), num_points(num_points), - uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)) {} - - std::unordered_map createDistributionMap() { - std::unordered_map map; - uint32_t primary_label_freq = - (uint32_t)ceil(num_points * distribution_factor); - for (uint32_t i{1}; i < num_labels + 1; i++) { - map[i] = (uint32_t)ceil(primary_label_freq / i); +class ZipfDistribution +{ + public: + ZipfDistribution(uint64_t num_points, uint32_t num_labels) + : num_labels(num_labels), num_points(num_points), + uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)) + { } - return map; - } - - int writeDistribution(std::ofstream &outfile) { - auto distribution_map = createDistributionMap(); - for (uint32_t i{0}; i < num_points; i++) { - bool label_written = false; - for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); - it++) { - auto label_selection_probability = std::bernoulli_distribution( - distribution_factor / (double)it->first); - if (label_selection_probability(rand_engine) && - distribution_map[it->first] > 0) { - if (label_written) { - outfile << ','; - } - outfile << it->first; - label_written = true; - // remove label from map if we have used all labels - distribution_map[it->first] -= 1; + + std::unordered_map createDistributionMap() + { + std::unordered_map map; + uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor); + for (uint32_t i{1}; i < num_labels + 1; i++) + { + map[i] = (uint32_t)ceil(primary_label_freq / i); } - } - if (!label_written) { - outfile << 0; - } - if (i < num_points - 1) { - outfile << '\n'; - } + return map; + } + + int writeDistribution(std::ofstream &outfile) + { + auto distribution_map = createDistributionMap(); + for (uint32_t i{0}; i < num_points; i++) + { + bool label_written = false; + for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++) + { + auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first); + if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0) + { + if (label_written) + { + outfile << ','; + } + outfile << it->first; + label_written = true; + // remove label from map if we have used all labels + distribution_map[it->first] -= 1; + } + } + if (!label_written) + { + outfile << 0; + } + if (i < num_points - 1) + { + outfile << '\n'; + } + } + return 0; } - return 0; - } - int writeDistribution(std::string filename) { - std::ofstream outfile(filename); - if (!outfile.is_open()) { - std::cerr << "Error: could not open output file " << filename << '\n'; - return -1; + int writeDistribution(std::string filename) + { + std::ofstream outfile(filename); + if (!outfile.is_open()) + { + std::cerr << "Error: could not open output file " << filename << '\n'; + return -1; + } + writeDistribution(outfile); + outfile.close(); } - writeDistribution(outfile); - outfile.close(); - } - -private: - const uint32_t num_labels; - const uint64_t num_points; - const double distribution_factor = 0.7; - std::knuth_b rand_engine; - const std::uniform_real_distribution uniform_zero_to_one; + + private: + const uint32_t num_labels; + const uint64_t num_points; + const double distribution_factor = 0.7; + std::knuth_b rand_engine; + const std::uniform_real_distribution uniform_zero_to_one; }; -int main(int argc, char **argv) { - std::string output_file, distribution_type; - uint32_t num_labels; - uint64_t num_points; - - try { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("output_file,O", - po::value(&output_file)->required(), - "Filename for saving the label file"); - desc.add_options()("num_points,N", - po::value(&num_points)->required(), - "Number of points in dataset"); - desc.add_options()("num_labels,L", - po::value(&num_labels)->required(), - "Number of unique labels, up to 5000"); - desc.add_options()( - "distribution_type,DT", - po::value(&distribution_type)->default_value("random"), - "Distribution function for labels defaults " - "to random"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; +int main(int argc, char **argv) +{ + std::string output_file, distribution_type; + uint32_t num_labels; + uint64_t num_points; + + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("output_file,O", po::value(&output_file)->required(), + "Filename for saving the label file"); + desc.add_options()("num_points,N", po::value(&num_points)->required(), "Number of points in dataset"); + desc.add_options()("num_labels,L", po::value(&num_labels)->required(), + "Number of unique labels, up to 5000"); + desc.add_options()("distribution_type,DT", po::value(&distribution_type)->default_value("random"), + "Distribution function for labels defaults " + "to random"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - if (num_labels > 5000) { - std::cerr << "Error: num_labels must be 5000 or less" << '\n'; - return -1; - } - - if (num_points <= 0) { - std::cerr << "Error: num_points must be greater than 0" << '\n'; - return -1; - } - - std::cout << "Generating synthetic labels for " << num_points - << " points with " << num_labels << " unique labels" << '\n'; - - try { - std::ofstream outfile(output_file); - if (!outfile.is_open()) { - std::cerr << "Error: could not open output file " << output_file << '\n'; - return -1; + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - if (distribution_type == "zipf") { - ZipfDistribution zipf(num_points, num_labels); - zipf.writeDistribution(outfile); - } else if (distribution_type == "random") { - for (size_t i = 0; i < num_points; i++) { - bool label_written = false; - for (size_t j = 1; j <= num_labels; j++) { - // 50% chance to assign each label - if (rand() > (RAND_MAX / 2)) { - if (label_written) { - outfile << ','; + if (num_labels > 5000) + { + std::cerr << "Error: num_labels must be 5000 or less" << '\n'; + return -1; + } + + if (num_points <= 0) + { + std::cerr << "Error: num_points must be greater than 0" << '\n'; + return -1; + } + + std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels" + << '\n'; + + try + { + std::ofstream outfile(output_file); + if (!outfile.is_open()) + { + std::cerr << "Error: could not open output file " << output_file << '\n'; + return -1; + } + + if (distribution_type == "zipf") + { + ZipfDistribution zipf(num_points, num_labels); + zipf.writeDistribution(outfile); + } + else if (distribution_type == "random") + { + for (size_t i = 0; i < num_points; i++) + { + bool label_written = false; + for (size_t j = 1; j <= num_labels; j++) + { + // 50% chance to assign each label + if (rand() > (RAND_MAX / 2)) + { + if (label_written) + { + outfile << ','; + } + outfile << j; + label_written = true; + } + } + if (!label_written) + { + outfile << 0; + } + if (i < num_points - 1) + { + outfile << '\n'; + } } - outfile << j; - label_written = true; - } } - if (!label_written) { - outfile << 0; + else if (distribution_type == "one_per_point") + { + std::random_device rd; // obtain a random number from hardware + std::mt19937 gen(rd()); // seed the generator + std::uniform_int_distribution<> distr(0, num_labels); // define the range + + for (size_t i = 0; i < num_points; i++) + { + outfile << distr(gen); + if (i != num_points - 1) + outfile << '\n'; + } } - if (i < num_points - 1) { - outfile << '\n'; + if (outfile.is_open()) + { + outfile.close(); } - } - } else if (distribution_type == "one_per_point") { - std::random_device rd; // obtain a random number from hardware - std::mt19937 gen(rd()); // seed the generator - std::uniform_int_distribution<> distr(0, num_labels); // define the range - - for (size_t i = 0; i < num_points; i++) { - outfile << distr(gen); - if (i != num_points - 1) - outfile << '\n'; - } + + std::cout << "Labels written to " << output_file << '\n'; } - if (outfile.is_open()) { - outfile.close(); + catch (const std::exception &ex) + { + std::cerr << "Label generation failed: " << ex.what() << '\n'; + return -1; } - std::cout << "Labels written to " << output_file << '\n'; - } catch (const std::exception &ex) { - std::cerr << "Label generation failed: " << ex.what() << '\n'; - return -1; - } - - return 0; + return 0; } \ No newline at end of file diff --git a/apps/utils/int8_to_float.cpp b/apps/utils/int8_to_float.cpp index d8b1e6f5a..8277b9a09 100644 --- a/apps/utils/int8_to_float.cpp +++ b/apps/utils/int8_to_float.cpp @@ -4,18 +4,20 @@ #include "utils.h" #include -int main(int argc, char **argv) { - if (argc != 3) { - std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 3) + { + std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl; + exit(-1); + } - int8_t *input; - size_t npts, nd; - diskann::load_bin(argv[1], input, npts, nd); - float *output = new float[npts * nd]; - diskann::convert_types(input, output, npts, nd); - diskann::save_bin(argv[2], output, npts, nd); - delete[] output; - delete[] input; + int8_t *input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + float *output = new float[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; } diff --git a/apps/utils/int8_to_float_scale.cpp b/apps/utils/int8_to_float_scale.cpp index ff1b62aa6..757e79be1 100644 --- a/apps/utils/int8_to_float_scale.cpp +++ b/apps/utils/int8_to_float_scale.cpp @@ -4,59 +4,60 @@ #include "utils.h" #include -void block_convert(std::ofstream &writer, float *write_buf, - std::ifstream &reader, int8_t *read_buf, size_t npts, - size_t ndims, float bias, float scale) { - reader.read((char *)read_buf, npts * ndims * sizeof(int8_t)); - - for (size_t i = 0; i < npts; i++) { - for (size_t d = 0; d < ndims; d++) { - write_buf[d + i * ndims] = - (((float)read_buf[d + i * ndims] - bias) * scale); +void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts, + size_t ndims, float bias, float scale) +{ + reader.read((char *)read_buf, npts * ndims * sizeof(int8_t)); + + for (size_t i = 0; i < npts; i++) + { + for (size_t d = 0; d < ndims; d++) + { + write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale); + } } - } - writer.write((char *)write_buf, npts * ndims * sizeof(float)); + writer.write((char *)write_buf, npts * ndims * sizeof(float)); } -int main(int argc, char **argv) { - if (argc != 5) { - std::cout << "Usage: " << argv[0] - << " input-int8.bin output-float.bin bias scale" << std::endl; - exit(-1); - } - - std::ifstream reader(argv[1], std::ios::binary); - uint32_t npts_u32; - uint32_t ndims_u32; - reader.read((char *)&npts_u32, sizeof(uint32_t)); - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - size_t npts = npts_u32; - size_t ndims = ndims_u32; - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - - std::ofstream writer(argv[2], std::ios::binary); - auto read_buf = new int8_t[blk_size * ndims]; - auto write_buf = new float[blk_size * ndims]; - float bias = (float)atof(argv[3]); - float scale = (float)atof(argv[4]); - - writer.write((char *)(&npts_u32), sizeof(uint32_t)); - writer.write((char *)(&ndims_u32), sizeof(uint32_t)); - - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, - scale); - std::cout << "Block #" << i << " written" << std::endl; - } - - delete[] read_buf; - delete[] write_buf; - - writer.close(); - reader.close(); +int main(int argc, char **argv) +{ + if (argc != 5) + { + std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl; + exit(-1); + } + + std::ifstream reader(argv[1], std::ios::binary); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + size_t npts = npts_u32; + size_t ndims = ndims_u32; + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + + std::ofstream writer(argv[2], std::ios::binary); + auto read_buf = new int8_t[blk_size * ndims]; + auto write_buf = new float[blk_size * ndims]; + float bias = (float)atof(argv[3]); + float scale = (float)atof(argv[4]); + + writer.write((char *)(&npts_u32), sizeof(uint32_t)); + writer.write((char *)(&ndims_u32), sizeof(uint32_t)); + + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale); + std::cout << "Block #" << i << " written" << std::endl; + } + + delete[] read_buf; + delete[] write_buf; + + writer.close(); + reader.close(); } diff --git a/apps/utils/ivecs_to_bin.cpp b/apps/utils/ivecs_to_bin.cpp index f439d1d73..854c06839 100644 --- a/apps/utils/ivecs_to_bin.cpp +++ b/apps/utils/ivecs_to_bin.cpp @@ -4,54 +4,55 @@ #include "utils.h" #include -void block_convert(std::ifstream &reader, std::ofstream &writer, - uint32_t *read_buf, uint32_t *write_buf, size_t npts, - size_t ndims) { - reader.read((char *)read_buf, - npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t))); - for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, - ndims * sizeof(uint32_t)); - } - writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t)); +void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts, + size_t ndims) +{ + reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) + { + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t)); + } + writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t)); } -int main(int argc, char **argv) { - if (argc != 3) { - std::cout << argv[0] << " input_ivecs output_bin" << std::endl; - exit(-1); - } - std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); - size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); +int main(int argc, char **argv) +{ + if (argc != 3) + { + std::cout << argv[0] << " input_ivecs output_bin" << std::endl; + exit(-1); + } + std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); + size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); - uint32_t ndims_u32; - reader.read((char *)&ndims_u32, sizeof(uint32_t)); - reader.seekg(0, std::ios::beg); - size_t ndims = (size_t)ndims_u32; - size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t)); - std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); + reader.seekg(0, std::ios::beg); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t)); + std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[2], std::ios::binary); - int npts_s32 = (int)npts; - int ndims_s32 = (int)ndims; - writer.write((char *)&npts_s32, sizeof(int)); - writer.write((char *)&ndims_s32, sizeof(int)); - uint32_t *read_buf = new uint32_t[npts * (ndims + 1)]; - uint32_t *write_buf = new uint32_t[npts * ndims]; - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); - std::cout << "Block #" << i << " written" << std::endl; - } + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[2], std::ios::binary); + int npts_s32 = (int)npts; + int ndims_s32 = (int)ndims; + writer.write((char *)&npts_s32, sizeof(int)); + writer.write((char *)&ndims_s32, sizeof(int)); + uint32_t *read_buf = new uint32_t[npts * (ndims + 1)]; + uint32_t *write_buf = new uint32_t[npts * ndims]; + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); + std::cout << "Block #" << i << " written" << std::endl; + } - delete[] read_buf; - delete[] write_buf; + delete[] read_buf; + delete[] write_buf; - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/merge_shards.cpp b/apps/utils/merge_shards.cpp index cf5c6a00b..be64e6ff9 100644 --- a/apps/utils/merge_shards.cpp +++ b/apps/utils/merge_shards.cpp @@ -14,28 +14,29 @@ #include "disk_utils.h" #include "utils.h" -int main(int argc, char **argv) { - if (argc != 9) { - std::cout << argv[0] - << " vamana_index_prefix[1] vamana_index_suffix[2] " - "idmaps_prefix[3] " - "idmaps_suffix[4] n_shards[5] max_degree[6] " - "output_vamana_path[7] " - "output_medoids_path[8]" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 9) + { + std::cout << argv[0] + << " vamana_index_prefix[1] vamana_index_suffix[2] " + "idmaps_prefix[3] " + "idmaps_suffix[4] n_shards[5] max_degree[6] " + "output_vamana_path[7] " + "output_medoids_path[8]" + << std::endl; + exit(-1); + } - std::string vamana_prefix(argv[1]); - std::string vamana_suffix(argv[2]); - std::string idmaps_prefix(argv[3]); - std::string idmaps_suffix(argv[4]); - uint64_t nshards = (uint64_t)std::atoi(argv[5]); - uint32_t max_degree = (uint64_t)std::atoi(argv[6]); - std::string output_index(argv[7]); - std::string output_medoids(argv[8]); + std::string vamana_prefix(argv[1]); + std::string vamana_suffix(argv[2]); + std::string idmaps_prefix(argv[3]); + std::string idmaps_suffix(argv[4]); + uint64_t nshards = (uint64_t)std::atoi(argv[5]); + uint32_t max_degree = (uint64_t)std::atoi(argv[6]); + std::string output_index(argv[7]); + std::string output_medoids(argv[8]); - return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, - idmaps_suffix, nshards, max_degree, output_index, - output_medoids); + return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree, + output_index, output_medoids); } diff --git a/apps/utils/partition_data.cpp b/apps/utils/partition_data.cpp index 72eb7af90..42c22d231 100644 --- a/apps/utils/partition_data.cpp +++ b/apps/utils/partition_data.cpp @@ -8,33 +8,32 @@ // DEPRECATED: NEED TO REPROGRAM -int main(int argc, char **argv) { - if (argc != 7) { - std::cout << "Usage:\n" - << argv[0] - << " datatype " - " " - " " - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 7) + { + std::cout << "Usage:\n" + << argv[0] + << " datatype " + " " + " " + << std::endl; + exit(-1); + } - const std::string data_path(argv[2]); - const std::string prefix_path(argv[3]); - const float sampling_rate = (float)atof(argv[4]); - const size_t num_partitions = (size_t)std::atoi(argv[5]); - const size_t max_reps = 15; - const size_t k_index = (size_t)std::atoi(argv[6]); + const std::string data_path(argv[2]); + const std::string prefix_path(argv[3]); + const float sampling_rate = (float)atof(argv[4]); + const size_t num_partitions = (size_t)std::atoi(argv[5]); + const size_t max_reps = 15; + const size_t k_index = (size_t)std::atoi(argv[6]); - if (std::string(argv[1]) == std::string("float")) - partition(data_path, sampling_rate, num_partitions, max_reps, - prefix_path, k_index); - else if (std::string(argv[1]) == std::string("int8")) - partition(data_path, sampling_rate, num_partitions, max_reps, - prefix_path, k_index); - else if (std::string(argv[1]) == std::string("uint8")) - partition(data_path, sampling_rate, num_partitions, max_reps, - prefix_path, k_index); - else - std::cout << "unsupported data format. use float/int8/uint8" << std::endl; + if (std::string(argv[1]) == std::string("float")) + partition(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index); + else if (std::string(argv[1]) == std::string("int8")) + partition(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index); + else if (std::string(argv[1]) == std::string("uint8")) + partition(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index); + else + std::cout << "unsupported data format. use float/int8/uint8" << std::endl; } diff --git a/apps/utils/partition_with_ram_budget.cpp b/apps/utils/partition_with_ram_budget.cpp index 9c5535def..c5b6ed596 100644 --- a/apps/utils/partition_with_ram_budget.cpp +++ b/apps/utils/partition_with_ram_budget.cpp @@ -8,33 +8,32 @@ // DEPRECATED: NEED TO REPROGRAM -int main(int argc, char **argv) { - if (argc != 8) { - std::cout << "Usage:\n" - << argv[0] - << " datatype " - " " - " " - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 8) + { + std::cout << "Usage:\n" + << argv[0] + << " datatype " + " " + " " + << std::endl; + exit(-1); + } - const std::string data_path(argv[2]); - const std::string prefix_path(argv[3]); - const float sampling_rate = (float)atof(argv[4]); - const double ram_budget = (double)std::atof(argv[5]); - const size_t graph_degree = (size_t)std::atoi(argv[6]); - const size_t k_index = (size_t)std::atoi(argv[7]); + const std::string data_path(argv[2]); + const std::string prefix_path(argv[3]); + const float sampling_rate = (float)atof(argv[4]); + const double ram_budget = (double)std::atof(argv[5]); + const size_t graph_degree = (size_t)std::atoi(argv[6]); + const size_t k_index = (size_t)std::atoi(argv[7]); - if (std::string(argv[1]) == std::string("float")) - partition_with_ram_budget(data_path, sampling_rate, ram_budget, - graph_degree, prefix_path, k_index); - else if (std::string(argv[1]) == std::string("int8")) - partition_with_ram_budget(data_path, sampling_rate, ram_budget, - graph_degree, prefix_path, k_index); - else if (std::string(argv[1]) == std::string("uint8")) - partition_with_ram_budget(data_path, sampling_rate, ram_budget, - graph_degree, prefix_path, k_index); - else - std::cout << "unsupported data format. use float/int8/uint8" << std::endl; + if (std::string(argv[1]) == std::string("float")) + partition_with_ram_budget(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index); + else if (std::string(argv[1]) == std::string("int8")) + partition_with_ram_budget(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index); + else if (std::string(argv[1]) == std::string("uint8")) + partition_with_ram_budget(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index); + else + std::cout << "unsupported data format. use float/int8/uint8" << std::endl; } diff --git a/apps/utils/rand_data_gen.cpp b/apps/utils/rand_data_gen.cpp index 25577d242..799aa0f33 100644 --- a/apps/utils/rand_data_gen.cpp +++ b/apps/utils/rand_data_gen.cpp @@ -11,214 +11,232 @@ namespace po = boost::program_options; -int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, - bool normalization, float norm, float rand_scale) { - auto vec = new float[ndims]; - - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution<> normal_rand{0, 1}; - std::uniform_real_distribution<> unif_dis(1.0, rand_scale); - - for (size_t i = 0; i < npts; i++) { - float sum = 0; - float scale = 1.0f; - if (rand_scale > 1.0f) - scale = (float)unif_dis(gen); - for (size_t d = 0; d < ndims; ++d) - vec[d] = scale * (float)normal_rand(gen); - if (normalization) { - for (size_t d = 0; d < ndims; ++d) - sum += vec[d] * vec[d]; - for (size_t d = 0; d < ndims; ++d) - vec[d] = vec[d] * norm / std::sqrt(sum); +int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm, + float rand_scale) +{ + auto vec = new float[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + std::uniform_real_distribution<> unif_dis(1.0, rand_scale); + + for (size_t i = 0; i < npts; i++) + { + float sum = 0; + float scale = 1.0f; + if (rand_scale > 1.0f) + scale = (float)unif_dis(gen); + for (size_t d = 0; d < ndims; ++d) + vec[d] = scale * (float)normal_rand(gen); + if (normalization) + { + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); + } + + writer.write((char *)vec, ndims * sizeof(float)); } - writer.write((char *)vec, ndims * sizeof(float)); - } - - delete[] vec; - return 0; + delete[] vec; + return 0; } -int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, - float norm) { - auto vec = new float[ndims]; - auto vec_T = new int8_t[ndims]; - - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution<> normal_rand{0, 1}; - - for (size_t i = 0; i < npts; i++) { - float sum = 0; - for (size_t d = 0; d < ndims; ++d) - vec[d] = (float)normal_rand(gen); - for (size_t d = 0; d < ndims; ++d) - sum += vec[d] * vec[d]; - for (size_t d = 0; d < ndims; ++d) - vec[d] = vec[d] * norm / std::sqrt(sum); - - for (size_t d = 0; d < ndims; ++d) { - vec_T[d] = (int8_t)std::round(vec[d]); +int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm) +{ + auto vec = new float[ndims]; + auto vec_T = new int8_t[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + + for (size_t i = 0; i < npts; i++) + { + float sum = 0; + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); + + for (size_t d = 0; d < ndims; ++d) + { + vec_T[d] = (int8_t)std::round(vec[d]); + } + + writer.write((char *)vec_T, ndims * sizeof(int8_t)); } - writer.write((char *)vec_T, ndims * sizeof(int8_t)); - } + delete[] vec; + delete[] vec_T; + return 0; +} - delete[] vec; - delete[] vec_T; - return 0; +int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm) +{ + auto vec = new float[ndims]; + auto vec_T = new int8_t[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + + for (size_t i = 0; i < npts; i++) + { + float sum = 0; + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); + + for (size_t d = 0; d < ndims; ++d) + { + vec_T[d] = 128 + (int8_t)std::round(vec[d]); + } + + writer.write((char *)vec_T, ndims * sizeof(uint8_t)); + } + + delete[] vec; + delete[] vec_T; + return 0; } -int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, - float norm) { - auto vec = new float[ndims]; - auto vec_T = new int8_t[ndims]; - - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution<> normal_rand{0, 1}; - - for (size_t i = 0; i < npts; i++) { - float sum = 0; - for (size_t d = 0; d < ndims; ++d) - vec[d] = (float)normal_rand(gen); - for (size_t d = 0; d < ndims; ++d) - sum += vec[d] * vec[d]; - for (size_t d = 0; d < ndims; ++d) - vec[d] = vec[d] * norm / std::sqrt(sum); - - for (size_t d = 0; d < ndims; ++d) { - vec_T[d] = 128 + (int8_t)std::round(vec[d]); +int main(int argc, char **argv) +{ + std::string data_type, output_file; + size_t ndims, npts; + float norm, rand_scaling; + bool normalization = false; + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("output_file", po::value(&output_file)->required(), + "File name for saving the random vectors"); + desc.add_options()("ndims,D", po::value(&ndims)->required(), "Dimensoinality of the vector"); + desc.add_options()("npts,N", po::value(&npts)->required(), "Number of vectors"); + desc.add_options()("norm", po::value(&norm)->default_value(-1.0f), + "Norm of the vectors (if not specified, vectors are not normalized)"); + desc.add_options()("rand_scaling", po::value(&rand_scaling)->default_value(1.0f), + "Each vector will be scaled (if not explicitly normalized) by a factor " + "randomly chosen from " + "[1, rand_scale]. Only applicable for floating point data"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; } - writer.write((char *)vec_T, ndims * sizeof(uint8_t)); - } + if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + return -1; + } - delete[] vec; - delete[] vec_T; - return 0; -} + if (norm > 0.0) + { + normalization = true; + } -int main(int argc, char **argv) { - std::string data_type, output_file; - size_t ndims, npts; - float norm, rand_scaling; - bool normalization = false; - try { - po::options_description desc{"Arguments"}; - - desc.add_options()("help,h", "Print information on arguments"); - - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("output_file", - po::value(&output_file)->required(), - "File name for saving the random vectors"); - desc.add_options()("ndims,D", po::value(&ndims)->required(), - "Dimensoinality of the vector"); - desc.add_options()("npts,N", po::value(&npts)->required(), - "Number of vectors"); - desc.add_options()( - "norm", po::value(&norm)->default_value(-1.0f), - "Norm of the vectors (if not specified, vectors are not normalized)"); - desc.add_options()( - "rand_scaling", po::value(&rand_scaling)->default_value(1.0f), - "Each vector will be scaled (if not explicitly normalized) by a factor " - "randomly chosen from " - "[1, rand_scale]. Only applicable for floating point data"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; + if (rand_scaling < 1.0) + { + std::cout << "We will only scale the vector norms randomly in [1, value], " + "so value must be >= 1." + << std::endl; + return -1; } - po::notify(vm); - } catch (const std::exception &ex) { - std::cerr << ex.what() << '\n'; - return -1; - } - - if (data_type != std::string("float") && data_type != std::string("int8") && - data_type != std::string("uint8")) { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." - << std::endl; - return -1; - } - - if (norm > 0.0) { - normalization = true; - } - - if (rand_scaling < 1.0) { - std::cout << "We will only scale the vector norms randomly in [1, value], " - "so value must be >= 1." - << std::endl; - return -1; - } - - if ((rand_scaling > 1.0) && (normalization == true)) { - std::cout << "Data cannot be normalized and randomly scaled at same time. " - "Use one or the other." - << std::endl; - return -1; - } - - if (data_type == std::string("int8") || data_type == std::string("uint8")) { - if (norm > 127) { - std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be " - "greater " - "than 127" - << std::endl; - return -1; + + if ((rand_scaling > 1.0) && (normalization == true)) + { + std::cout << "Data cannot be normalized and randomly scaled at same time. " + "Use one or the other." + << std::endl; + return -1; } - if (rand_scaling > 1.0) { - std::cout << "Data scaling only supported for floating point data." - << std::endl; - return -1; + + if (data_type == std::string("int8") || data_type == std::string("uint8")) + { + if (norm > 127) + { + std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be " + "greater " + "than 127" + << std::endl; + return -1; + } + if (rand_scaling > 1.0) + { + std::cout << "Data scaling only supported for floating point data." << std::endl; + return -1; + } } - } - - try { - std::ofstream writer; - writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); - writer.open(output_file, std::ios::binary); - auto npts_u32 = (uint32_t)npts; - auto ndims_u32 = (uint32_t)ndims; - writer.write((char *)&npts_u32, sizeof(uint32_t)); - writer.write((char *)&ndims_u32, sizeof(uint32_t)); - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - - int ret = 0; - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (data_type == std::string("float")) { - ret = block_write_float(writer, ndims, cblk_size, normalization, norm, - rand_scaling); - } else if (data_type == std::string("int8")) { - ret = block_write_int8(writer, ndims, cblk_size, norm); - } else if (data_type == std::string("uint8")) { - ret = block_write_uint8(writer, ndims, cblk_size, norm); - } - if (ret == 0) - std::cout << "Block #" << i << " written" << std::endl; - else { + + try + { + std::ofstream writer; + writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); + writer.open(output_file, std::ios::binary); + auto npts_u32 = (uint32_t)npts; + auto ndims_u32 = (uint32_t)ndims; + writer.write((char *)&npts_u32, sizeof(uint32_t)); + writer.write((char *)&ndims_u32, sizeof(uint32_t)); + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + + int ret = 0; + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (data_type == std::string("float")) + { + ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling); + } + else if (data_type == std::string("int8")) + { + ret = block_write_int8(writer, ndims, cblk_size, norm); + } + else if (data_type == std::string("uint8")) + { + ret = block_write_uint8(writer, ndims, cblk_size, norm); + } + if (ret == 0) + std::cout << "Block #" << i << " written" << std::endl; + else + { + writer.close(); + std::cout << "failed to write" << std::endl; + return -1; + } + } writer.close(); - std::cout << "failed to write" << std::endl; + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; return -1; - } } - writer.close(); - } catch (const std::exception &e) { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; - return -1; - } - - return 0; + + return 0; } diff --git a/apps/utils/simulate_aggregate_recall.cpp b/apps/utils/simulate_aggregate_recall.cpp index b934c2bea..30cb24f13 100644 --- a/apps/utils/simulate_aggregate_recall.cpp +++ b/apps/utils/simulate_aggregate_recall.cpp @@ -6,73 +6,80 @@ #include #include -inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, - const uint32_t npart, uint32_t *count, - const std::vector &recalls) { - float found = 0; - for (uint32_t i = 0; i < npart; ++i) { - size_t max_found = std::min(count[i], k); - found += recalls[max_found - 1] * max_found; - } - return found / (float)k_aggr; +inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count, + const std::vector &recalls) +{ + float found = 0; + for (uint32_t i = 0; i < npart; ++i) + { + size_t max_found = std::min(count[i], k); + found += recalls[max_found - 1] * max_found; + } + return found / (float)k_aggr; } -void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, - const uint32_t nsim, const std::vector &recalls) { - std::random_device r; - std::default_random_engine randeng(r()); - std::uniform_int_distribution uniform_dist(0, npart - 1); +void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, const uint32_t nsim, + const std::vector &recalls) +{ + std::random_device r; + std::default_random_engine randeng(r()); + std::uniform_int_distribution uniform_dist(0, npart - 1); - uint32_t *count = new uint32_t[npart]; - double aggr_recall = 0; + uint32_t *count = new uint32_t[npart]; + double aggr_recall = 0; - for (uint32_t i = 0; i < nsim; ++i) { - for (uint32_t p = 0; p < npart; ++p) { - count[p] = 0; - } - for (uint32_t t = 0; t < k_aggr; ++t) { - count[uniform_dist(randeng)]++; + for (uint32_t i = 0; i < nsim; ++i) + { + for (uint32_t p = 0; p < npart; ++p) + { + count[p] = 0; + } + for (uint32_t t = 0; t < k_aggr; ++t) + { + count[uniform_dist(randeng)]++; + } + aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls); } - aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls); - } - std::cout << "Aggregate recall is " << aggr_recall / (double)nsim - << std::endl; - delete[] count; + std::cout << "Aggregate recall is " << aggr_recall / (double)nsim << std::endl; + delete[] count; } -int main(int argc, char **argv) { - if (argc < 6) { - std::cout << argv[0] - << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc < 6) + { + std::cout << argv[0] << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" << std::endl; + exit(-1); + } - const uint32_t k_aggr = atoi(argv[1]); - const uint32_t k = atoi(argv[2]); - const uint32_t npart = atoi(argv[3]); - const uint32_t nsim = atoi(argv[4]); + const uint32_t k_aggr = atoi(argv[1]); + const uint32_t k = atoi(argv[2]); + const uint32_t npart = atoi(argv[3]); + const uint32_t nsim = atoi(argv[4]); - std::vector recalls; - for (int ctr = 5; ctr < argc; ctr++) { - recalls.push_back((float)atof(argv[ctr])); - } + std::vector recalls; + for (int ctr = 5; ctr < argc; ctr++) + { + recalls.push_back((float)atof(argv[ctr])); + } - if (recalls.size() != k) { - std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" - << std::endl; - } - if (k_aggr > npart * k) { - std::cerr << "k_aggr must be <= k * npart" << std::endl; - exit(-1); - } - if (nsim <= npart * k_aggr) { - std::cerr << "Choose nsim > npart*k_aggr" << std::endl; - exit(-1); - } + if (recalls.size() != k) + { + std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" << std::endl; + } + if (k_aggr > npart * k) + { + std::cerr << "k_aggr must be <= k * npart" << std::endl; + exit(-1); + } + if (nsim <= npart * k_aggr) + { + std::cerr << "Choose nsim > npart*k_aggr" << std::endl; + exit(-1); + } - simulate(k_aggr, k, npart, nsim, recalls); + simulate(k_aggr, k, npart, nsim, recalls); - return 0; + return 0; } diff --git a/apps/utils/stats_label_data.cpp b/apps/utils/stats_label_data.cpp index 4de42f7a0..1fad04b61 100644 --- a/apps/utils/stats_label_data.cpp +++ b/apps/utils/stats_label_data.cpp @@ -28,123 +28,120 @@ #endif namespace po = boost::program_options; -void stats_analysis(const std::string labels_file, std::string univeral_label, - uint32_t density = 10) { - std::string token, line; - std::ifstream labels_stream(labels_file); - std::unordered_map label_counts; - std::string label_with_max_points; - uint32_t max_points = 0; - long long sum = 0; - long long point_cnt = 0; - float avg_labels_per_pt, mean_label_size; +void stats_analysis(const std::string labels_file, std::string univeral_label, uint32_t density = 10) +{ + std::string token, line; + std::ifstream labels_stream(labels_file); + std::unordered_map label_counts; + std::string label_with_max_points; + uint32_t max_points = 0; + long long sum = 0; + long long point_cnt = 0; + float avg_labels_per_pt, mean_label_size; - std::vector labels_per_point; - uint32_t dense_pts = 0; - if (labels_stream.is_open()) { - while (getline(labels_stream, line)) { - point_cnt++; - std::stringstream iss(line); - uint32_t lbl_cnt = 0; - while (getline(iss, token, ',')) { - lbl_cnt++; - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - if (label_counts.find(token) == label_counts.end()) - label_counts[token] = 0; - label_counts[token]++; - } - if (lbl_cnt >= density) { - dense_pts++; - } - labels_per_point.emplace_back(lbl_cnt); + std::vector labels_per_point; + uint32_t dense_pts = 0; + if (labels_stream.is_open()) + { + while (getline(labels_stream, line)) + { + point_cnt++; + std::stringstream iss(line); + uint32_t lbl_cnt = 0; + while (getline(iss, token, ',')) + { + lbl_cnt++; + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + if (label_counts.find(token) == label_counts.end()) + label_counts[token] = 0; + label_counts[token]++; + } + if (lbl_cnt >= density) + { + dense_pts++; + } + labels_per_point.emplace_back(lbl_cnt); + } } - } - std::cout << "fraction of dense points with >= " << density - << " labels = " << (float)dense_pts / (float)labels_per_point.size() - << std::endl; - std::sort(labels_per_point.begin(), labels_per_point.end()); + std::cout << "fraction of dense points with >= " << density + << " labels = " << (float)dense_pts / (float)labels_per_point.size() << std::endl; + std::sort(labels_per_point.begin(), labels_per_point.end()); - std::vector> label_count_vec; + std::vector> label_count_vec; - for (auto it = label_counts.begin(); it != label_counts.end(); it++) { - auto &lbl = *it; - label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second)); - if (lbl.second > max_points) { - max_points = lbl.second; - label_with_max_points = lbl.first; + for (auto it = label_counts.begin(); it != label_counts.end(); it++) + { + auto &lbl = *it; + label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second)); + if (lbl.second > max_points) + { + max_points = lbl.second; + label_with_max_points = lbl.first; + } + sum += lbl.second; } - sum += lbl.second; - } - sort(label_count_vec.begin(), label_count_vec.end(), - [](const std::pair &lhs, - const std::pair &rhs) { - return lhs.second < rhs.second; - }); + sort(label_count_vec.begin(), label_count_vec.end(), + [](const std::pair &lhs, const std::pair &rhs) { + return lhs.second < rhs.second; + }); - for (float p = 0; p < 1; p += 0.05) { - std::cout << "Percentile " << (100 * p) << "\t" - << label_count_vec[(size_t)(p * label_count_vec.size())].first - << " with count=" - << label_count_vec[(size_t)(p * label_count_vec.size())].second - << std::endl; - } + for (float p = 0; p < 1; p += 0.05) + { + std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(size_t)(p * label_count_vec.size())].first + << " with count=" << label_count_vec[(size_t)(p * label_count_vec.size())].second << std::endl; + } - std::cout << "Most common label " - << "\t" << label_count_vec[label_count_vec.size() - 1].first - << " with count=" - << label_count_vec[label_count_vec.size() - 1].second << std::endl; - if (label_count_vec.size() > 1) - std::cout << "Second common label " - << "\t" << label_count_vec[label_count_vec.size() - 2].first - << " with count=" - << label_count_vec[label_count_vec.size() - 2].second + std::cout << "Most common label " + << "\t" << label_count_vec[label_count_vec.size() - 1].first + << " with count=" << label_count_vec[label_count_vec.size() - 1].second << std::endl; + if (label_count_vec.size() > 1) + std::cout << "Second common label " + << "\t" << label_count_vec[label_count_vec.size() - 2].first + << " with count=" << label_count_vec[label_count_vec.size() - 2].second << std::endl; + if (label_count_vec.size() > 2) + std::cout << "Third common label " + << "\t" << label_count_vec[label_count_vec.size() - 3].first + << " with count=" << label_count_vec[label_count_vec.size() - 3].second << std::endl; + avg_labels_per_pt = sum / (float)point_cnt; + mean_label_size = sum / (float)label_counts.size(); + std::cout << "Total number of points = " << point_cnt << ", number of labels = " << label_counts.size() << std::endl; - if (label_count_vec.size() > 2) - std::cout << "Third common label " - << "\t" << label_count_vec[label_count_vec.size() - 3].first - << " with count=" - << label_count_vec[label_count_vec.size() - 3].second - << std::endl; - avg_labels_per_pt = sum / (float)point_cnt; - mean_label_size = sum / (float)label_counts.size(); - std::cout << "Total number of points = " << point_cnt - << ", number of labels = " << label_counts.size() << std::endl; - std::cout << "Average number of labels per point = " << avg_labels_per_pt - << std::endl; - std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl; - std::cout << "Most popular label is " << label_with_max_points << " with " - << max_points << " pts" << std::endl; + std::cout << "Average number of labels per point = " << avg_labels_per_pt << std::endl; + std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl; + std::cout << "Most popular label is " << label_with_max_points << " with " << max_points << " pts" << std::endl; } -int main(int argc, char **argv) { - std::string labels_file, universal_label; - uint32_t density; +int main(int argc, char **argv) +{ + std::string labels_file, universal_label; + uint32_t density; - po::options_description desc{"Arguments"}; - try { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("labels_file", - po::value(&labels_file)->required(), - "path to labels data file."); - desc.add_options()("universal_label", - po::value(&universal_label)->required(), - "Universal label used in labels file."); - desc.add_options()( - "density", po::value(&density)->default_value(1), - "Number of labels each point in labels file, defaults to 1"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; + po::options_description desc{"Arguments"}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("labels_file", po::value(&labels_file)->required(), + "path to labels data file."); + desc.add_options()("universal_label", po::value(&universal_label)->required(), + "Universal label used in labels file."); + desc.add_options()("density", po::value(&density)->default_value(1), + "Number of labels each point in labels file, defaults to 1"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &e) + { + std::cerr << e.what() << '\n'; + return -1; } - po::notify(vm); - } catch (const std::exception &e) { - std::cerr << e.what() << '\n'; - return -1; - } - stats_analysis(labels_file, universal_label, density); + stats_analysis(labels_file, universal_label, density); } diff --git a/apps/utils/tsv_to_bin.cpp b/apps/utils/tsv_to_bin.cpp index 2cd00ae38..9d52f70a2 100644 --- a/apps/utils/tsv_to_bin.cpp +++ b/apps/utils/tsv_to_bin.cpp @@ -4,105 +4,118 @@ #include "utils.h" #include -void block_convert_float(std::ifstream &reader, std::ofstream &writer, - size_t npts, size_t ndims) { - auto read_buf = new float[npts * (ndims + 1)]; - - auto cursor = read_buf; - float val; - - for (size_t i = 0; i < npts; i++) { - for (size_t d = 0; d < ndims; ++d) { - reader >> val; - *cursor = val; - cursor++; +void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) +{ + auto read_buf = new float[npts * (ndims + 1)]; + + auto cursor = read_buf; + float val; + + for (size_t i = 0; i < npts; i++) + { + for (size_t d = 0; d < ndims; ++d) + { + reader >> val; + *cursor = val; + cursor++; + } } - } - writer.write((char *)read_buf, npts * ndims * sizeof(float)); - delete[] read_buf; + writer.write((char *)read_buf, npts * ndims * sizeof(float)); + delete[] read_buf; } -void block_convert_int8(std::ifstream &reader, std::ofstream &writer, - size_t npts, size_t ndims) { - auto read_buf = new int8_t[npts * (ndims + 1)]; - - auto cursor = read_buf; - int val; - - for (size_t i = 0; i < npts; i++) { - for (size_t d = 0; d < ndims; ++d) { - reader >> val; - *cursor = (int8_t)val; - cursor++; +void block_convert_int8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) +{ + auto read_buf = new int8_t[npts * (ndims + 1)]; + + auto cursor = read_buf; + int val; + + for (size_t i = 0; i < npts; i++) + { + for (size_t d = 0; d < ndims; ++d) + { + reader >> val; + *cursor = (int8_t)val; + cursor++; + } } - } - writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); - delete[] read_buf; + writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); + delete[] read_buf; } -void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, - size_t npts, size_t ndims) { - auto read_buf = new uint8_t[npts * (ndims + 1)]; +void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) +{ + auto read_buf = new uint8_t[npts * (ndims + 1)]; + + auto cursor = read_buf; + int val; + + for (size_t i = 0; i < npts; i++) + { + for (size_t d = 0; d < ndims; ++d) + { + reader >> val; + *cursor = (uint8_t)val; + cursor++; + } + } + writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); + delete[] read_buf; +} - auto cursor = read_buf; - int val; +int main(int argc, char **argv) +{ + if (argc != 6) + { + std::cout << argv[0] + << " input_filename.tsv output_filename.bin " + "dim num_pts>" + << std::endl; + exit(-1); + } - for (size_t i = 0; i < npts; i++) { - for (size_t d = 0; d < ndims; ++d) { - reader >> val; - *cursor = (uint8_t)val; - cursor++; + if (std::string(argv[1]) != std::string("float") && std::string(argv[1]) != std::string("int8") && + std::string(argv[1]) != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; } - } - writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t)); - delete[] read_buf; -} -int main(int argc, char **argv) { - if (argc != 6) { - std::cout << argv[0] - << " input_filename.tsv output_filename.bin " - "dim num_pts>" - << std::endl; - exit(-1); - } - - if (std::string(argv[1]) != std::string("float") && - std::string(argv[1]) != std::string("int8") && - std::string(argv[1]) != std::string("uint8")) { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." - << std::endl; - } - - size_t ndims = atoi(argv[4]); - size_t npts = atoi(argv[5]); - - std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); - // size_t fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); - reader.seekg(0, std::ios::beg); - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - std::cout << "# blks: " << nblks << std::endl; - std::ofstream writer(argv[3], std::ios::binary); - auto npts_u32 = (uint32_t)npts; - auto ndims_u32 = (uint32_t)ndims; - writer.write((char *)&npts_u32, sizeof(uint32_t)); - writer.write((char *)&ndims_u32, sizeof(uint32_t)); - - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - if (std::string(argv[1]) == std::string("float")) { - block_convert_float(reader, writer, cblk_size, ndims); - } else if (std::string(argv[1]) == std::string("int8")) { - block_convert_int8(reader, writer, cblk_size, ndims); - } else if (std::string(argv[1]) == std::string("uint8")) { - block_convert_uint8(reader, writer, cblk_size, ndims); + size_t ndims = atoi(argv[4]); + size_t npts = atoi(argv[5]); + + std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); + // size_t fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); + reader.seekg(0, std::ios::beg); + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + std::cout << "# blks: " << nblks << std::endl; + std::ofstream writer(argv[3], std::ios::binary); + auto npts_u32 = (uint32_t)npts; + auto ndims_u32 = (uint32_t)ndims; + writer.write((char *)&npts_u32, sizeof(uint32_t)); + writer.write((char *)&ndims_u32, sizeof(uint32_t)); + + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + if (std::string(argv[1]) == std::string("float")) + { + block_convert_float(reader, writer, cblk_size, ndims); + } + else if (std::string(argv[1]) == std::string("int8")) + { + block_convert_int8(reader, writer, cblk_size, ndims); + } + else if (std::string(argv[1]) == std::string("uint8")) + { + block_convert_uint8(reader, writer, cblk_size, ndims); + } + std::cout << "Block #" << i << " written" << std::endl; } - std::cout << "Block #" << i << " written" << std::endl; - } - reader.close(); - writer.close(); + reader.close(); + writer.close(); } diff --git a/apps/utils/uint32_to_uint8.cpp b/apps/utils/uint32_to_uint8.cpp index 3868780e6..348dcaa20 100644 --- a/apps/utils/uint32_to_uint8.cpp +++ b/apps/utils/uint32_to_uint8.cpp @@ -4,18 +4,20 @@ #include "utils.h" #include -int main(int argc, char **argv) { - if (argc != 3) { - std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 3) + { + std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl; + exit(-1); + } - uint32_t *input; - size_t npts, nd; - diskann::load_bin(argv[1], input, npts, nd); - uint8_t *output = new uint8_t[npts * nd]; - diskann::convert_types(input, output, npts, nd); - diskann::save_bin(argv[2], output, npts, nd); - delete[] output; - delete[] input; + uint32_t *input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + uint8_t *output = new uint8_t[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; } diff --git a/apps/utils/uint8_to_float.cpp b/apps/utils/uint8_to_float.cpp index 779226f90..352aea00c 100644 --- a/apps/utils/uint8_to_float.cpp +++ b/apps/utils/uint8_to_float.cpp @@ -4,18 +4,20 @@ #include "utils.h" #include -int main(int argc, char **argv) { - if (argc != 3) { - std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc != 3) + { + std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl; + exit(-1); + } - uint8_t *input; - size_t npts, nd; - diskann::load_bin(argv[1], input, npts, nd); - float *output = new float[npts * nd]; - diskann::convert_types(input, output, npts, nd); - diskann::save_bin(argv[2], output, npts, nd); - delete[] output; - delete[] input; + uint8_t *input; + size_t npts, nd; + diskann::load_bin(argv[1], input, npts, nd); + float *output = new float[npts * nd]; + diskann::convert_types(input, output, npts, nd); + diskann::save_bin(argv[2], output, npts, nd); + delete[] output; + delete[] input; } diff --git a/apps/utils/vector_analysis.cpp b/apps/utils/vector_analysis.cpp index 9dde684f1..63364bc67 100644 --- a/apps/utils/vector_analysis.cpp +++ b/apps/utils/vector_analysis.cpp @@ -20,129 +20,144 @@ #include "partition.h" #include "utils.h" -template int analyze_norm(std::string base_file) { - std::cout << "Analyzing data norms" << std::endl; - T *data; - size_t npts, ndims; - diskann::load_bin(base_file, data, npts, ndims); - std::vector norms(npts, 0); +template int analyze_norm(std::string base_file) +{ + std::cout << "Analyzing data norms" << std::endl; + T *data; + size_t npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + std::vector norms(npts, 0); #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)npts; i++) { - for (size_t d = 0; d < ndims; d++) - norms[i] += data[i * ndims + d] * data[i * ndims + d]; - norms[i] = std::sqrt(norms[i]); - } - std::sort(norms.begin(), norms.end()); - for (int p = 0; p < 100; p += 5) - std::cout << "percentile " << p << ": " - << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl; - std::cout << "percentile 100" - << ": " << norms[npts - 1] << std::endl; - delete[] data; - return 0; + for (int64_t i = 0; i < (int64_t)npts; i++) + { + for (size_t d = 0; d < ndims; d++) + norms[i] += data[i * ndims + d] * data[i * ndims + d]; + norms[i] = std::sqrt(norms[i]); + } + std::sort(norms.begin(), norms.end()); + for (int p = 0; p < 100; p += 5) + std::cout << "percentile " << p << ": " << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl; + std::cout << "percentile 100" + << ": " << norms[npts - 1] << std::endl; + delete[] data; + return 0; } -template -int normalize_base(std::string base_file, std::string out_file) { - std::cout << "Normalizing base" << std::endl; - T *data; - size_t npts, ndims; - diskann::load_bin(base_file, data, npts, ndims); - // std::vector norms(npts, 0); +template int normalize_base(std::string base_file, std::string out_file) +{ + std::cout << "Normalizing base" << std::endl; + T *data; + size_t npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + // std::vector norms(npts, 0); #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)npts; i++) { - float pt_norm = 0; - for (size_t d = 0; d < ndims; d++) - pt_norm += data[i * ndims + d] * data[i * ndims + d]; - pt_norm = std::sqrt(pt_norm); - for (size_t d = 0; d < ndims; d++) - data[i * ndims + d] = static_cast(data[i * ndims + d] / pt_norm); - } - diskann::save_bin(out_file, data, npts, ndims); - delete[] data; - return 0; + for (int64_t i = 0; i < (int64_t)npts; i++) + { + float pt_norm = 0; + for (size_t d = 0; d < ndims; d++) + pt_norm += data[i * ndims + d] * data[i * ndims + d]; + pt_norm = std::sqrt(pt_norm); + for (size_t d = 0; d < ndims; d++) + data[i * ndims + d] = static_cast(data[i * ndims + d] / pt_norm); + } + diskann::save_bin(out_file, data, npts, ndims); + delete[] data; + return 0; } -template -int augment_base(std::string base_file, std::string out_file, - bool prep_base = true) { - std::cout << "Analyzing data norms" << std::endl; - T *data; - size_t npts, ndims; - diskann::load_bin(base_file, data, npts, ndims); - std::vector norms(npts, 0); - float max_norm = 0; +template int augment_base(std::string base_file, std::string out_file, bool prep_base = true) +{ + std::cout << "Analyzing data norms" << std::endl; + T *data; + size_t npts, ndims; + diskann::load_bin(base_file, data, npts, ndims); + std::vector norms(npts, 0); + float max_norm = 0; #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)npts; i++) { - for (size_t d = 0; d < ndims; d++) - norms[i] += data[i * ndims + d] * data[i * ndims + d]; - max_norm = norms[i] > max_norm ? norms[i] : max_norm; - } - // std::sort(norms.begin(), norms.end()); - max_norm = std::sqrt(max_norm); - std::cout << "Max norm: " << max_norm << std::endl; - T *new_data; - size_t newdims = ndims + 1; - new_data = new T[npts * newdims]; - for (size_t i = 0; i < npts; i++) { - if (prep_base) { - for (size_t j = 0; j < ndims; j++) { - new_data[i * newdims + j] = - static_cast(data[i * ndims + j] / max_norm); - } - float diff = 1 - (norms[i] / (max_norm * max_norm)); - diff = diff <= 0 ? 0 : std::sqrt(diff); - new_data[i * newdims + ndims] = static_cast(diff); - if (diff <= 0) { - std::cout << i << " has large max norm, investigate if needed. diff = " - << diff << std::endl; - } - } else { - for (size_t j = 0; j < ndims; j++) { - new_data[i * newdims + j] = - static_cast(data[i * ndims + j] / std::sqrt(norms[i])); - } - new_data[i * newdims + ndims] = 0; + for (int64_t i = 0; i < (int64_t)npts; i++) + { + for (size_t d = 0; d < ndims; d++) + norms[i] += data[i * ndims + d] * data[i * ndims + d]; + max_norm = norms[i] > max_norm ? norms[i] : max_norm; } - } - diskann::save_bin(out_file, new_data, npts, newdims); - delete[] new_data; - delete[] data; - return 0; + // std::sort(norms.begin(), norms.end()); + max_norm = std::sqrt(max_norm); + std::cout << "Max norm: " << max_norm << std::endl; + T *new_data; + size_t newdims = ndims + 1; + new_data = new T[npts * newdims]; + for (size_t i = 0; i < npts; i++) + { + if (prep_base) + { + for (size_t j = 0; j < ndims; j++) + { + new_data[i * newdims + j] = static_cast(data[i * ndims + j] / max_norm); + } + float diff = 1 - (norms[i] / (max_norm * max_norm)); + diff = diff <= 0 ? 0 : std::sqrt(diff); + new_data[i * newdims + ndims] = static_cast(diff); + if (diff <= 0) + { + std::cout << i << " has large max norm, investigate if needed. diff = " << diff << std::endl; + } + } + else + { + for (size_t j = 0; j < ndims; j++) + { + new_data[i * newdims + j] = static_cast(data[i * ndims + j] / std::sqrt(norms[i])); + } + new_data[i * newdims + ndims] = 0; + } + } + diskann::save_bin(out_file, new_data, npts, newdims); + delete[] new_data; + delete[] data; + return 0; } -template int aux_main(char **argv) { - std::string base_file(argv[2]); - uint32_t option = atoi(argv[3]); - if (option == 1) - analyze_norm(base_file); - else if (option == 2) - augment_base(base_file, std::string(argv[4]), true); - else if (option == 3) - augment_base(base_file, std::string(argv[4]), false); - else if (option == 4) - normalize_base(base_file, std::string(argv[4])); - return 0; +template int aux_main(char **argv) +{ + std::string base_file(argv[2]); + uint32_t option = atoi(argv[3]); + if (option == 1) + analyze_norm(base_file); + else if (option == 2) + augment_base(base_file, std::string(argv[4]), true); + else if (option == 3) + augment_base(base_file, std::string(argv[4]), false); + else if (option == 4) + normalize_base(base_file, std::string(argv[4])); + return 0; } -int main(int argc, char **argv) { - if (argc < 4) { - std::cout << argv[0] - << " data_type [float/int8/uint8] base_bin_file " - "[option: 1-norm analysis, 2-prep_base_for_mip, " - "3-prep_query_for_mip, 4-normalize-vecs] [out_file for " - "options 2/3/4]" - << std::endl; - exit(-1); - } +int main(int argc, char **argv) +{ + if (argc < 4) + { + std::cout << argv[0] + << " data_type [float/int8/uint8] base_bin_file " + "[option: 1-norm analysis, 2-prep_base_for_mip, " + "3-prep_query_for_mip, 4-normalize-vecs] [out_file for " + "options 2/3/4]" + << std::endl; + exit(-1); + } - if (std::string(argv[1]) == std::string("float")) { - aux_main(argv); - } else if (std::string(argv[1]) == std::string("int8")) { - aux_main(argv); - } else if (std::string(argv[1]) == std::string("uint8")) { - aux_main(argv); - } else - std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; - return 0; + if (std::string(argv[1]) == std::string("float")) + { + aux_main(argv); + } + else if (std::string(argv[1]) == std::string("int8")) + { + aux_main(argv); + } + else if (std::string(argv[1]) == std::string("uint8")) + { + aux_main(argv); + } + else + std::cout << "Unsupported type. Use float/int8/uint8." << std::endl; + return 0; } diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index d19fcfabd..44401a08f 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -10,130 +10,119 @@ #include "types.h" #include "windows_customizations.h" -namespace diskann { +namespace diskann +{ template class AbstractScratch; -template class AbstractDataStore { -public: - AbstractDataStore(const location_t capacity, const size_t dim); - - virtual ~AbstractDataStore() = default; - - // Return number of points returned - virtual location_t load(const std::string &filename) = 0; - - // Why does store take num_pts? Since store only has capacity, but we allow - // resizing we can end up in a situation where the store has spare capacity. - // To optimize disk utilization, we pass the number of points that are "true" - // points, so that the store can discard the empty locations before saving. - virtual size_t save(const std::string &filename, - const location_t num_pts) = 0; - - DISKANN_DLLEXPORT virtual location_t capacity() const; - - DISKANN_DLLEXPORT virtual size_t get_dims() const; - - // Implementers can choose to return _dim if they are not - // concerned about memory alignment. - // Some distance metrics (like l2) need data vectors to be aligned, so we - // align the dimension by padding zeros. - virtual size_t get_aligned_dim() const = 0; - - // populate the store with vectors (either from a pointer or bin file), - // potentially after pre-processing the vectors if the metric deems so - // e.g., normalizing vectors for cosine distance over floating-point vectors - // useful for bulk or static index building. - virtual void populate_data(const data_t *vectors, - const location_t num_pts) = 0; - virtual void populate_data(const std::string &filename, - const size_t offset) = 0; - - // save the first num_pts many vectors back to bin file - // note: cannot undo the pre-processing done in populate data - virtual void extract_data_to_bin(const std::string &filename, - const location_t num_pts) = 0; - - // Returns the updated capacity of the datastore. Clients should check - // if resize actually changed the capacity to new_num_points before - // proceeding with operations. See the code below: - // auto new_capcity = data_store->resize(new_num_points); - // if ( new_capacity >= new_num_points) { - // //PROCEED - // else - // //ERROR. - virtual location_t resize(const location_t new_num_points); - - // operations on vectors - // like populate_data function, but over one vector at a time useful for - // streaming setting - virtual void get_vector(const location_t i, data_t *dest) const = 0; - virtual void set_vector(const location_t i, const data_t *const vector) = 0; - virtual void prefetch_vector(const location_t loc) = 0; - - // internal shuffle operations to move around vectors - // will bulk-move all the vectors in [old_start_loc, old_start_loc + - // num_points) to [new_start_loc, new_start_loc + num_points) and set the old - // positions to zero vectors. - virtual void move_vectors(const location_t old_start_loc, - const location_t new_start_loc, - const location_t num_points) = 0; - - // same as above, without resetting the vectors in [from_loc, from_loc + - // num_points) to zero - virtual void copy_vectors(const location_t from_loc, const location_t to_loc, - const location_t num_points) = 0; - - // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to - // NOT take the query from the scratch object. Therefore every data store has - // to implement preprocess_query which at the least will be to copy the query - // into the scratch object. So making this pure virtual. - virtual void - preprocess_query(const data_t *aligned_query, - AbstractScratch *query_scratch = nullptr) const = 0; - // distance functions. - virtual float get_distance(const data_t *query, - const location_t loc) const = 0; - virtual void - get_distance(const data_t *query, const location_t *locations, - const uint32_t location_count, float *distances, - AbstractScratch *scratch_space = nullptr) const = 0; - // Specific overload for index.cpp. - virtual void get_distance(const data_t *preprocessed_query, - const std::vector &ids, - std::vector &distances, - AbstractScratch *scratch_space) const = 0; - virtual float get_distance(const location_t loc1, - const location_t loc2) const = 0; - - // stats of the data stored in store - // Returns the point in the dataset that is closest to the mean of all points - // in the dataset - virtual location_t calculate_medoid() const = 0; - - // REFACTOR PQ TODO: Each data store knows about its distance function, so - // this is redundant. However, we don't have an OptmizedDataStore yet, and to - // preserve code compability, we are exposing this function. - virtual Distance *get_dist_fn() const = 0; - - // search helpers - // if the base data is aligned per the request of the metric, this will tell - // how to align the query vector in a consistent manner - virtual size_t get_alignment_factor() const = 0; - -protected: - // Expand the datastore to new_num_points. Returns the new capacity created, - // which should be == new_num_points in the normal case. Implementers can also - // return _capacity to indicate that there are not implementing this method. - virtual location_t expand(const location_t new_num_points) = 0; - - // Shrink the datastore to new_num_points. It is NOT an error if shrink - // doesn't reduce the capacity so callers need to check this correctly. See - // also for "default" implementation - virtual location_t shrink(const location_t new_num_points) = 0; - - location_t _capacity; - size_t _dim; +template class AbstractDataStore +{ + public: + AbstractDataStore(const location_t capacity, const size_t dim); + + virtual ~AbstractDataStore() = default; + + // Return number of points returned + virtual location_t load(const std::string &filename) = 0; + + // Why does store take num_pts? Since store only has capacity, but we allow + // resizing we can end up in a situation where the store has spare capacity. + // To optimize disk utilization, we pass the number of points that are "true" + // points, so that the store can discard the empty locations before saving. + virtual size_t save(const std::string &filename, const location_t num_pts) = 0; + + DISKANN_DLLEXPORT virtual location_t capacity() const; + + DISKANN_DLLEXPORT virtual size_t get_dims() const; + + // Implementers can choose to return _dim if they are not + // concerned about memory alignment. + // Some distance metrics (like l2) need data vectors to be aligned, so we + // align the dimension by padding zeros. + virtual size_t get_aligned_dim() const = 0; + + // populate the store with vectors (either from a pointer or bin file), + // potentially after pre-processing the vectors if the metric deems so + // e.g., normalizing vectors for cosine distance over floating-point vectors + // useful for bulk or static index building. + virtual void populate_data(const data_t *vectors, const location_t num_pts) = 0; + virtual void populate_data(const std::string &filename, const size_t offset) = 0; + + // save the first num_pts many vectors back to bin file + // note: cannot undo the pre-processing done in populate data + virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) = 0; + + // Returns the updated capacity of the datastore. Clients should check + // if resize actually changed the capacity to new_num_points before + // proceeding with operations. See the code below: + // auto new_capcity = data_store->resize(new_num_points); + // if ( new_capacity >= new_num_points) { + // //PROCEED + // else + // //ERROR. + virtual location_t resize(const location_t new_num_points); + + // operations on vectors + // like populate_data function, but over one vector at a time useful for + // streaming setting + virtual void get_vector(const location_t i, data_t *dest) const = 0; + virtual void set_vector(const location_t i, const data_t *const vector) = 0; + virtual void prefetch_vector(const location_t loc) = 0; + + // internal shuffle operations to move around vectors + // will bulk-move all the vectors in [old_start_loc, old_start_loc + + // num_points) to [new_start_loc, new_start_loc + num_points) and set the old + // positions to zero vectors. + virtual void move_vectors(const location_t old_start_loc, const location_t new_start_loc, + const location_t num_points) = 0; + + // same as above, without resetting the vectors in [from_loc, from_loc + + // num_points) to zero + virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0; + + // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to + // NOT take the query from the scratch object. Therefore every data store has + // to implement preprocess_query which at the least will be to copy the query + // into the scratch object. So making this pure virtual. + virtual void preprocess_query(const data_t *aligned_query, + AbstractScratch *query_scratch = nullptr) const = 0; + // distance functions. + virtual float get_distance(const data_t *query, const location_t loc) const = 0; + virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count, + float *distances, AbstractScratch *scratch_space = nullptr) const = 0; + // Specific overload for index.cpp. + virtual void get_distance(const data_t *preprocessed_query, const std::vector &ids, + std::vector &distances, AbstractScratch *scratch_space) const = 0; + virtual float get_distance(const location_t loc1, const location_t loc2) const = 0; + + // stats of the data stored in store + // Returns the point in the dataset that is closest to the mean of all points + // in the dataset + virtual location_t calculate_medoid() const = 0; + + // REFACTOR PQ TODO: Each data store knows about its distance function, so + // this is redundant. However, we don't have an OptmizedDataStore yet, and to + // preserve code compability, we are exposing this function. + virtual Distance *get_dist_fn() const = 0; + + // search helpers + // if the base data is aligned per the request of the metric, this will tell + // how to align the query vector in a consistent manner + virtual size_t get_alignment_factor() const = 0; + + protected: + // Expand the datastore to new_num_points. Returns the new capacity created, + // which should be == new_num_points in the normal case. Implementers can also + // return _capacity to indicate that there are not implementing this method. + virtual location_t expand(const location_t new_num_points) = 0; + + // Shrink the datastore to new_num_points. It is NOT an error if shrink + // doesn't reduce the capacity so callers need to check this correctly. See + // also for "default" implementation + virtual location_t shrink(const location_t new_num_points) = 0; + + location_t _capacity; + size_t _dim; }; } // namespace diskann diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h index 7afd3490e..858c6e283 100644 --- a/include/abstract_filter_store.h +++ b/include/abstract_filter_store.h @@ -3,23 +3,23 @@ #include "windows_customizations.h" #include -namespace diskann { -template class AbstractFilterStore { -public: - DISKANN_DLLEXPORT virtual bool has_filter_support() const = 0; +namespace diskann +{ +template class AbstractFilterStore +{ + public: + DISKANN_DLLEXPORT virtual bool has_filter_support() const = 0; - DISKANN_DLLEXPORT virtual bool - point_has_label(location_t point_id, const LabelT label_id) const = 0; + DISKANN_DLLEXPORT virtual bool point_has_label(location_t point_id, const LabelT label_id) const = 0; - // Returns true if the index is filter-enabled and all files were loaded - // correctly. false otherwise. Note that "false" can mean that the index - // does not have filter support, or that some index files do not exist, or - // that they exist and could not be opened. - DISKANN_DLLEXPORT virtual bool load(const std::string &disk_index_file) = 0; + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT virtual bool load(const std::string &disk_index_file) = 0; - DISKANN_DLLEXPORT virtual void - generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads) = 0; + DISKANN_DLLEXPORT virtual void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) = 0; }; } // namespace diskann diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h index 465b2b814..961a4f994 100644 --- a/include/abstract_graph_store.h +++ b/include/abstract_graph_store.h @@ -7,52 +7,62 @@ #include #include -namespace diskann { - -class AbstractGraphStore { -public: - AbstractGraphStore(const size_t total_pts, const size_t reserve_graph_degree) - : _capacity(total_pts), _reserve_graph_degree(reserve_graph_degree) {} - - virtual ~AbstractGraphStore() = default; - - // returns tuple of - virtual std::tuple - load(const std::string &index_path_prefix, const size_t num_points) = 0; - virtual int store(const std::string &index_path_prefix, - const size_t num_points, const size_t num_fz_points, - const uint32_t start) = 0; - - // not synchronised, user should use lock when necvessary. - virtual const std::vector & - get_neighbours(const location_t i) const = 0; - virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0; - virtual void clear_neighbours(const location_t i) = 0; - virtual void swap_neighbours(const location_t a, location_t b) = 0; - - virtual void set_neighbours(const location_t i, - std::vector &neighbours) = 0; - - virtual size_t resize_graph(const size_t new_size) = 0; - virtual void clear_graph() = 0; - - virtual uint32_t get_max_observed_degree() = 0; - - // set during load - virtual size_t get_max_range_of_graph() = 0; - - // Total internal points _max_points + _num_frozen_points - size_t get_total_points() { return _capacity; } - -protected: - // Internal function, changes total points when resize_graph is called. - void set_total_points(size_t new_capacity) { _capacity = new_capacity; } - - size_t get_reserve_graph_degree() { return _reserve_graph_degree; } - -private: - size_t _capacity; - size_t _reserve_graph_degree; +namespace diskann +{ + +class AbstractGraphStore +{ + public: + AbstractGraphStore(const size_t total_pts, const size_t reserve_graph_degree) + : _capacity(total_pts), _reserve_graph_degree(reserve_graph_degree) + { + } + + virtual ~AbstractGraphStore() = default; + + // returns tuple of + virtual std::tuple load(const std::string &index_path_prefix, + const size_t num_points) = 0; + virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_fz_points, + const uint32_t start) = 0; + + // not synchronised, user should use lock when necvessary. + virtual const std::vector &get_neighbours(const location_t i) const = 0; + virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0; + virtual void clear_neighbours(const location_t i) = 0; + virtual void swap_neighbours(const location_t a, location_t b) = 0; + + virtual void set_neighbours(const location_t i, std::vector &neighbours) = 0; + + virtual size_t resize_graph(const size_t new_size) = 0; + virtual void clear_graph() = 0; + + virtual uint32_t get_max_observed_degree() = 0; + + // set during load + virtual size_t get_max_range_of_graph() = 0; + + // Total internal points _max_points + _num_frozen_points + size_t get_total_points() + { + return _capacity; + } + + protected: + // Internal function, changes total points when resize_graph is called. + void set_total_points(size_t new_capacity) + { + _capacity = new_capacity; + } + + size_t get_reserve_graph_degree() + { + return _reserve_graph_degree; + } + + private: + size_t _capacity; + size_t _reserve_graph_degree; }; } // namespace diskann \ No newline at end of file diff --git a/include/abstract_index.h b/include/abstract_index.h index 1f974b458..9f0d402ef 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -7,147 +7,124 @@ #include "utils.h" #include -namespace diskann { -struct consolidation_report { - enum status_code { - SUCCESS = 0, - FAIL = 1, - LOCK_FAIL = 2, - INCONSISTENT_COUNT_ERROR = 3 - }; - status_code _status; - size_t _active_points, _max_points, _empty_slots, _slots_released, - _delete_set_size, _num_calls_to_process_delete; - double _time; - - consolidation_report(status_code status, size_t active_points, - size_t max_points, size_t empty_slots, - size_t slots_released, size_t delete_set_size, - size_t num_calls_to_process_delete, double time_secs) - : _status(status), _active_points(active_points), _max_points(max_points), - _empty_slots(empty_slots), _slots_released(slots_released), - _delete_set_size(delete_set_size), - _num_calls_to_process_delete(num_calls_to_process_delete), - _time(time_secs) {} +namespace diskann +{ +struct consolidation_report +{ + enum status_code + { + SUCCESS = 0, + FAIL = 1, + LOCK_FAIL = 2, + INCONSISTENT_COUNT_ERROR = 3 + }; + status_code _status; + size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete; + double _time; + + consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots, + size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete, + double time_secs) + : _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots), + _slots_released(slots_released), _delete_set_size(delete_set_size), + _num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs) + { + } }; /* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index. */ -class AbstractIndex { -public: - AbstractIndex() = default; - virtual ~AbstractIndex() = default; +class AbstractIndex +{ + public: + AbstractIndex() = default; + virtual ~AbstractIndex() = default; - virtual void build(const std::string &data_file, - const size_t num_points_to_load, - IndexFilterParams &build_params) = 0; + virtual void build(const std::string &data_file, const size_t num_points_to_load, + IndexFilterParams &build_params) = 0; - template - void build(const data_type *data, const size_t num_points_to_load, - const std::vector &tags); + template + void build(const data_type *data, const size_t num_points_to_load, const std::vector &tags); - virtual void save(const char *filename, bool compact_before_save = false) = 0; + virtual void save(const char *filename, bool compact_before_save = false) = 0; #ifdef EXEC_ENV_OLS - virtual void load(AlignedFileReader &reader, uint32_t num_threads, - uint32_t search_l) = 0; + virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0; #else - virtual void load(const char *index_file, uint32_t num_threads, - uint32_t search_l) = 0; + virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l) = 0; #endif - // For FastL2 search on optimized layout - template - void search_with_optimized_layout(const data_type *query, size_t K, size_t L, - uint32_t *indices); - - // Initialize space for res_vectors before calling. - template - size_t search_with_tags(const data_type *query, const uint64_t K, - const uint32_t L, tag_type *tags, float *distances, - std::vector &res_vectors, - bool use_filters = false, - const std::string filter_label = ""); - - // Added search overload that takes L as parameter, so that we - // can customize L on a per-query basis without tampering with "Parameters" - // IDtype is either uint32_t or uint64_t - template - std::pair search(const data_type *query, const size_t K, - const uint32_t L, IDType *indices, - float *distances = nullptr); - - // Filter support search - // IndexType is either uint32_t or uint64_t - template - std::pair - search_with_filters(const DataType &query, const std::string &raw_label, - const size_t K, const uint32_t L, IndexType *indices, - float *distances); - - // insert points with labels, labels should be present for filtered index - template - int insert_point(const data_type *point, const tag_type tag, - const std::vector &labels); - - // insert point for unfiltered index build. do not use with filtered index - template - int insert_point(const data_type *point, const tag_type tag); - - // delete point with tag, or return -1 if point can not be deleted - template int lazy_delete(const tag_type &tag); - - // batch delete tags and populates failed tags if unabke to delete given tags. - template - void lazy_delete(const std::vector &tags, - std::vector &failed_tags); - - template - void get_active_tags(tsl::robin_set &active_tags); - - template - void set_start_points_at_random(data_type radius, uint32_t random_seed = 0); - - virtual consolidation_report - consolidate_deletes(const IndexWriteParameters ¶meters) = 0; - - virtual void optimize_index_layout() = 0; - - // memory should be allocated for vec before calling this function - template - int get_vector_by_tag(tag_type &tag, data_type *vec); - - template - void set_universal_label(const label_type universal_label); - -private: - virtual void _build(const DataType &data, const size_t num_points_to_load, - TagVector &tags) = 0; - virtual std::pair - _search(const DataType &query, const size_t K, const uint32_t L, - std::any &indices, float *distances = nullptr) = 0; - virtual std::pair - _search_with_filters(const DataType &query, const std::string &filter_label, - const size_t K, const uint32_t L, std::any &indices, - float *distances) = 0; - virtual int _insert_point(const DataType &data_point, const TagType tag, - Labelvector &labels) = 0; - virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; - virtual int _lazy_delete(const TagType &tag) = 0; - virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0; - virtual void _get_active_tags(TagRobinSet &active_tags) = 0; - virtual void _set_start_points_at_random(DataType radius, - uint32_t random_seed = 0) = 0; - virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; - virtual size_t _search_with_tags(const DataType &query, const uint64_t K, - const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, - bool use_filters = false, - const std::string filter_label = "") = 0; - virtual void _search_with_optimized_layout(const DataType &query, size_t K, - size_t L, uint32_t *indices) = 0; - virtual void _set_universal_label(const LabelType universal_label) = 0; + // For FastL2 search on optimized layout + template + void search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices); + + // Initialize space for res_vectors before calling. + template + size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors, bool use_filters = false, + const std::string filter_label = ""); + + // Added search overload that takes L as parameter, so that we + // can customize L on a per-query basis without tampering with "Parameters" + // IDtype is either uint32_t or uint64_t + template + std::pair search(const data_type *query, const size_t K, const uint32_t L, IDType *indices, + float *distances = nullptr); + + // Filter support search + // IndexType is either uint32_t or uint64_t + template + std::pair search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances); + + // insert points with labels, labels should be present for filtered index + template + int insert_point(const data_type *point, const tag_type tag, const std::vector &labels); + + // insert point for unfiltered index build. do not use with filtered index + template int insert_point(const data_type *point, const tag_type tag); + + // delete point with tag, or return -1 if point can not be deleted + template int lazy_delete(const tag_type &tag); + + // batch delete tags and populates failed tags if unabke to delete given tags. + template + void lazy_delete(const std::vector &tags, std::vector &failed_tags); + + template void get_active_tags(tsl::robin_set &active_tags); + + template void set_start_points_at_random(data_type radius, uint32_t random_seed = 0); + + virtual consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters) = 0; + + virtual void optimize_index_layout() = 0; + + // memory should be allocated for vec before calling this function + template int get_vector_by_tag(tag_type &tag, data_type *vec); + + template void set_universal_label(const label_type universal_label); + + private: + virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0; + virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances = nullptr) = 0; + virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, + const size_t K, const uint32_t L, std::any &indices, + float *distances) = 0; + virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0; + virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; + virtual int _lazy_delete(const TagType &tag) = 0; + virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0; + virtual void _get_active_tags(TagRobinSet &active_tags) = 0; + virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; + virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; + virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors, bool use_filters = false, + const std::string filter_label = "") = 0; + virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; + virtual void _set_universal_label(const LabelType universal_label) = 0; }; } // namespace diskann diff --git a/include/abstract_scratch.h b/include/abstract_scratch.h index fc5919b8a..b42a836f6 100644 --- a/include/abstract_scratch.h +++ b/include/abstract_scratch.h @@ -1,27 +1,35 @@ #pragma once -namespace diskann { +namespace diskann +{ template class PQScratch; // By somewhat more than a coincidence, it seems that both InMemQueryScratch // and SSDQueryScratch have the aligned query and PQScratch objects. So we // can put them in a neat hierarchy and keep PQScratch as a standalone class. -template class AbstractScratch { -public: - AbstractScratch() = default; - // This class does not take any responsibilty for memory management of - // its members. It is the responsibility of the derived classes to do so. - virtual ~AbstractScratch() = default; +template class AbstractScratch +{ + public: + AbstractScratch() = default; + // This class does not take any responsibilty for memory management of + // its members. It is the responsibility of the derived classes to do so. + virtual ~AbstractScratch() = default; - // Scratch objects should not be copied - AbstractScratch(const AbstractScratch &) = delete; - AbstractScratch &operator=(const AbstractScratch &) = delete; + // Scratch objects should not be copied + AbstractScratch(const AbstractScratch &) = delete; + AbstractScratch &operator=(const AbstractScratch &) = delete; - data_t *aligned_query_T() { return _aligned_query_T; } - PQScratch *pq_scratch() { return _pq_scratch; } + data_t *aligned_query_T() + { + return _aligned_query_T; + } + PQScratch *pq_scratch() + { + return _pq_scratch; + } -protected: - data_t *_aligned_query_T = nullptr; - PQScratch *_pq_scratch = nullptr; + protected: + data_t *_aligned_query_T = nullptr; + PQScratch *_pq_scratch = nullptr; }; } // namespace diskann diff --git a/include/aligned_file_reader.h b/include/aligned_file_reader.h index 19f0b4983..447b34609 100644 --- a/include/aligned_file_reader.h +++ b/include/aligned_file_reader.h @@ -18,10 +18,11 @@ typedef io_context_t IOContext; #include #ifndef USE_BING_INFRA -struct IOContext { - HANDLE fhandle = NULL; - HANDLE iocp = NULL; - std::vector reqs; +struct IOContext +{ + HANDLE fhandle = NULL; + HANDLE iocp = NULL; + std::vector reqs; }; #else #include "IDiskPriorityIO.h" @@ -31,25 +32,32 @@ struct IOContext { // errors. // Because of such callous copying, we have to use ptr->atomic instead // of atomic, as atomic is not copyable. -struct IOContext { - enum Status { READ_WAIT = 0, READ_SUCCESS, READ_FAILED, PROCESS_COMPLETE }; - - std::shared_ptr m_pDiskIO = nullptr; - std::shared_ptr> m_pRequests; - std::shared_ptr> m_pRequestsStatus; - - // waitonaddress on this memory to wait for IO completion signal - // reader should signal this memory after IO completion - // TODO: WindowsAlignedFileReader can be modified to take advantage of this - // and can largely share code with the file reader for Bing. - mutable volatile long m_completeCount = 0; - - IOContext() - : m_pRequestsStatus(new std::vector()), - m_pRequests(new std::vector()) { - (*m_pRequestsStatus).reserve(MAX_IO_DEPTH); - (*m_pRequests).reserve(MAX_IO_DEPTH); - } +struct IOContext +{ + enum Status + { + READ_WAIT = 0, + READ_SUCCESS, + READ_FAILED, + PROCESS_COMPLETE + }; + + std::shared_ptr m_pDiskIO = nullptr; + std::shared_ptr> m_pRequests; + std::shared_ptr> m_pRequestsStatus; + + // waitonaddress on this memory to wait for IO completion signal + // reader should signal this memory after IO completion + // TODO: WindowsAlignedFileReader can be modified to take advantage of this + // and can largely share code with the file reader for Bing. + mutable volatile long m_completeCount = 0; + + IOContext() + : m_pRequestsStatus(new std::vector()), m_pRequests(new std::vector()) + { + (*m_pRequestsStatus).reserve(MAX_IO_DEPTH); + (*m_pRequests).reserve(MAX_IO_DEPTH); + } }; #endif @@ -63,52 +71,55 @@ struct IOContext { #include // NOTE :: all 3 fields must be 512-aligned -struct AlignedRead { - uint64_t offset; // where to read from - uint64_t len; // how much to read - void *buf; // where to read into - - AlignedRead() : offset(0), len(0), buf(nullptr) {} - - AlignedRead(uint64_t offset, uint64_t len, void *buf) - : offset(offset), len(len), buf(buf) { - assert(IS_512_ALIGNED(offset)); - assert(IS_512_ALIGNED(len)); - assert(IS_512_ALIGNED(buf)); - // assert(malloc_usable_size(buf) >= len); - } +struct AlignedRead +{ + uint64_t offset; // where to read from + uint64_t len; // how much to read + void *buf; // where to read into + + AlignedRead() : offset(0), len(0), buf(nullptr) + { + } + + AlignedRead(uint64_t offset, uint64_t len, void *buf) : offset(offset), len(len), buf(buf) + { + assert(IS_512_ALIGNED(offset)); + assert(IS_512_ALIGNED(len)); + assert(IS_512_ALIGNED(buf)); + // assert(malloc_usable_size(buf) >= len); + } }; -class AlignedFileReader { -protected: - tsl::robin_map ctx_map; - std::mutex ctx_mut; +class AlignedFileReader +{ + protected: + tsl::robin_map ctx_map; + std::mutex ctx_mut; -public: - // returns the thread-specific context - // returns (io_context_t)(-1) if thread is not registered - virtual IOContext &get_ctx() = 0; + public: + // returns the thread-specific context + // returns (io_context_t)(-1) if thread is not registered + virtual IOContext &get_ctx() = 0; - virtual ~AlignedFileReader(){}; + virtual ~AlignedFileReader(){}; - // register thread-id for a context - virtual void register_thread() = 0; - // de-register thread-id for a context - virtual void deregister_thread() = 0; - virtual void deregister_all_threads() = 0; + // register thread-id for a context + virtual void register_thread() = 0; + // de-register thread-id for a context + virtual void deregister_thread() = 0; + virtual void deregister_all_threads() = 0; - // Open & close ops - // Blocking calls - virtual void open(const std::string &fname) = 0; - virtual void close() = 0; + // Open & close ops + // Blocking calls + virtual void open(const std::string &fname) = 0; + virtual void close() = 0; - // process batch of aligned requests in parallel - // NOTE :: blocking call - virtual void read(std::vector &read_reqs, IOContext &ctx, - bool async = false) = 0; + // process batch of aligned requests in parallel + // NOTE :: blocking call + virtual void read(std::vector &read_reqs, IOContext &ctx, bool async = false) = 0; #ifdef USE_BING_INFRA - // wait for completion of one request in a batch of requests - virtual void wait(IOContext &ctx, int &completedIndex) = 0; + // wait for completion of one request in a batch of requests + virtual void wait(IOContext &ctx, int &completedIndex) = 0; #endif }; diff --git a/include/ann_exception.h b/include/ann_exception.h index 3dc544866..a9b940573 100644 --- a/include/ann_exception.h +++ b/include/ann_exception.h @@ -11,25 +11,24 @@ #define __FUNCSIG__ __PRETTY_FUNCTION__ #endif -namespace diskann { +namespace diskann +{ -class ANNException : public std::runtime_error { -public: - DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode); - DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode, - const std::string &funcSig, - const std::string &fileName, uint32_t lineNum); +class ANNException : public std::runtime_error +{ + public: + DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode); + DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode, const std::string &funcSig, + const std::string &fileName, uint32_t lineNum); -private: - int _errorCode; + private: + int _errorCode; }; -class FileException : public ANNException { -public: - DISKANN_DLLEXPORT FileException(const std::string &filename, - std::system_error &e, - const std::string &funcSig, - const std::string &fileName, - uint32_t lineNum); +class FileException : public ANNException +{ + public: + DISKANN_DLLEXPORT FileException(const std::string &filename, std::system_error &e, const std::string &funcSig, + const std::string &fileName, uint32_t lineNum); }; } // namespace diskann diff --git a/include/any_wrappers.h b/include/any_wrappers.h index 53d013762..f35ac947c 100644 --- a/include/any_wrappers.h +++ b/include/any_wrappers.h @@ -9,34 +9,45 @@ #include #include -namespace AnyWrapper { +namespace AnyWrapper +{ /* * Base Struct to hold refrence to the data. * Note: No memory mamagement, caller need to keep object alive. */ -struct AnyReference { - template AnyReference(Ty &reference) : _data(&reference) {} +struct AnyReference +{ + template AnyReference(Ty &reference) : _data(&reference) + { + } - template Ty &get() { - auto ptr = std::any_cast(_data); - return *ptr; - } + template Ty &get() + { + auto ptr = std::any_cast(_data); + return *ptr; + } -private: - std::any _data; + private: + std::any _data; }; -struct AnyRobinSet : public AnyReference { - template - AnyRobinSet(const tsl::robin_set &robin_set) : AnyReference(robin_set) {} - template - AnyRobinSet(tsl::robin_set &robin_set) : AnyReference(robin_set) {} +struct AnyRobinSet : public AnyReference +{ + template AnyRobinSet(const tsl::robin_set &robin_set) : AnyReference(robin_set) + { + } + template AnyRobinSet(tsl::robin_set &robin_set) : AnyReference(robin_set) + { + } }; -struct AnyVector : public AnyReference { - template - AnyVector(const std::vector &vector) : AnyReference(vector) {} - template - AnyVector(std::vector &vector) : AnyReference(vector) {} +struct AnyVector : public AnyReference +{ + template AnyVector(const std::vector &vector) : AnyReference(vector) + { + } + template AnyVector(std::vector &vector) : AnyReference(vector) + { + } }; } // namespace AnyWrapper diff --git a/include/boost_dynamic_bitset_fwd.h b/include/boost_dynamic_bitset_fwd.h index db1a05624..5aebb2bc2 100644 --- a/include/boost_dynamic_bitset_fwd.h +++ b/include/boost_dynamic_bitset_fwd.h @@ -3,10 +3,9 @@ #pragma once -namespace boost { +namespace boost +{ #ifndef BOOST_DYNAMIC_BITSET_FWD_HPP -template > -class dynamic_bitset; +template > class dynamic_bitset; #endif } // namespace boost diff --git a/include/cached_io.h b/include/cached_io.h index 921283356..dabe448dc 100644 --- a/include/cached_io.h +++ b/include/cached_io.h @@ -12,176 +12,207 @@ #include "logger.h" // sequential cached reads -class cached_ifstream { -public: - cached_ifstream() {} - cached_ifstream(const std::string &filename, uint64_t cacheSize) - : cache_size(cacheSize), cur_off(0) { - reader.exceptions(std::ifstream::failbit | std::ifstream::badbit); - this->open(filename, cache_size); - } - ~cached_ifstream() { - delete[] cache_buf; - reader.close(); - } - - void open(const std::string &filename, uint64_t cacheSize) { - this->cur_off = 0; - - try { - reader.open(filename, std::ios::binary | std::ios::ate); - fsize = reader.tellg(); - reader.seekg(0, std::ios::beg); - assert(reader.is_open()); - assert(cacheSize > 0); - cacheSize = (std::min)(cacheSize, fsize); - this->cache_size = cacheSize; - cache_buf = new char[cacheSize]; - reader.read(cache_buf, cacheSize); - diskann::cout << "Opened: " << filename.c_str() << ", size: " << fsize - << ", cache_size: " << cacheSize << std::endl; - } catch (std::system_error &e) { - throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, - __LINE__); +class cached_ifstream +{ + public: + cached_ifstream() + { } - } - - size_t get_file_size() { return fsize; } - - void read(char *read_buf, uint64_t n_bytes) { - assert(cache_buf != nullptr); - assert(read_buf != nullptr); - - if (n_bytes <= (cache_size - cur_off)) { - // case 1: cache contains all data - memcpy(read_buf, cache_buf + cur_off, n_bytes); - cur_off += n_bytes; - } else { - // case 2: cache contains some data - uint64_t cached_bytes = cache_size - cur_off; - if (n_bytes - cached_bytes > fsize - reader.tellg()) { - std::stringstream stream; - stream << "Reading beyond end of file" << std::endl; - stream << "n_bytes: " << n_bytes << " cached_bytes: " << cached_bytes - << " fsize: " << fsize << " current pos:" << reader.tellg() - << std::endl; - diskann::cout << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - memcpy(read_buf, cache_buf + cur_off, cached_bytes); - - // go to disk and fetch more data - reader.read(read_buf + cached_bytes, n_bytes - cached_bytes); - // reset cur off - cur_off = cache_size; - - uint64_t size_left = fsize - reader.tellg(); - - if (size_left >= cache_size) { - reader.read(cache_buf, cache_size); - cur_off = 0; - } - // note that if size_left < cache_size, then cur_off = cache_size, - // so subsequent reads will all be directly from file + cached_ifstream(const std::string &filename, uint64_t cacheSize) : cache_size(cacheSize), cur_off(0) + { + reader.exceptions(std::ifstream::failbit | std::ifstream::badbit); + this->open(filename, cache_size); + } + ~cached_ifstream() + { + delete[] cache_buf; + reader.close(); + } + + void open(const std::string &filename, uint64_t cacheSize) + { + this->cur_off = 0; + + try + { + reader.open(filename, std::ios::binary | std::ios::ate); + fsize = reader.tellg(); + reader.seekg(0, std::ios::beg); + assert(reader.is_open()); + assert(cacheSize > 0); + cacheSize = (std::min)(cacheSize, fsize); + this->cache_size = cacheSize; + cache_buf = new char[cacheSize]; + reader.read(cache_buf, cacheSize); + diskann::cout << "Opened: " << filename.c_str() << ", size: " << fsize << ", cache_size: " << cacheSize + << std::endl; + } + catch (std::system_error &e) + { + throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, __LINE__); + } + } + + size_t get_file_size() + { + return fsize; + } + + void read(char *read_buf, uint64_t n_bytes) + { + assert(cache_buf != nullptr); + assert(read_buf != nullptr); + + if (n_bytes <= (cache_size - cur_off)) + { + // case 1: cache contains all data + memcpy(read_buf, cache_buf + cur_off, n_bytes); + cur_off += n_bytes; + } + else + { + // case 2: cache contains some data + uint64_t cached_bytes = cache_size - cur_off; + if (n_bytes - cached_bytes > fsize - reader.tellg()) + { + std::stringstream stream; + stream << "Reading beyond end of file" << std::endl; + stream << "n_bytes: " << n_bytes << " cached_bytes: " << cached_bytes << " fsize: " << fsize + << " current pos:" << reader.tellg() << std::endl; + diskann::cout << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + memcpy(read_buf, cache_buf + cur_off, cached_bytes); + + // go to disk and fetch more data + reader.read(read_buf + cached_bytes, n_bytes - cached_bytes); + // reset cur off + cur_off = cache_size; + + uint64_t size_left = fsize - reader.tellg(); + + if (size_left >= cache_size) + { + reader.read(cache_buf, cache_size); + cur_off = 0; + } + // note that if size_left < cache_size, then cur_off = cache_size, + // so subsequent reads will all be directly from file + } } - } - -private: - // underlying ifstream - std::ifstream reader; - // # bytes to cache in one shot read - uint64_t cache_size = 0; - // underlying buf for cache - char *cache_buf = nullptr; - // offset into cache_buf for cur_pos - uint64_t cur_off = 0; - // file size - uint64_t fsize = 0; + + private: + // underlying ifstream + std::ifstream reader; + // # bytes to cache in one shot read + uint64_t cache_size = 0; + // underlying buf for cache + char *cache_buf = nullptr; + // offset into cache_buf for cur_pos + uint64_t cur_off = 0; + // file size + uint64_t fsize = 0; }; // sequential cached writes -class cached_ofstream { -public: - cached_ofstream(const std::string &filename, uint64_t cache_size) - : cache_size(cache_size), cur_off(0) { - writer.exceptions(std::ifstream::failbit | std::ifstream::badbit); - try { - writer.open(filename, std::ios::binary); - assert(writer.is_open()); - assert(cache_size > 0); - cache_buf = new char[cache_size]; - diskann::cout << "Opened: " << filename.c_str() - << ", cache_size: " << cache_size << std::endl; - } catch (std::system_error &e) { - throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, - __LINE__); +class cached_ofstream +{ + public: + cached_ofstream(const std::string &filename, uint64_t cache_size) : cache_size(cache_size), cur_off(0) + { + writer.exceptions(std::ifstream::failbit | std::ifstream::badbit); + try + { + writer.open(filename, std::ios::binary); + assert(writer.is_open()); + assert(cache_size > 0); + cache_buf = new char[cache_size]; + diskann::cout << "Opened: " << filename.c_str() << ", cache_size: " << cache_size << std::endl; + } + catch (std::system_error &e) + { + throw diskann::FileException(filename, e, __FUNCSIG__, __FILE__, __LINE__); + } } - } - ~cached_ofstream() { this->close(); } + ~cached_ofstream() + { + this->close(); + } - void close() { - // dump any remaining data in memory - if (cur_off > 0) { - this->flush_cache(); + void close() + { + // dump any remaining data in memory + if (cur_off > 0) + { + this->flush_cache(); + } + + if (cache_buf != nullptr) + { + delete[] cache_buf; + cache_buf = nullptr; + } + + if (writer.is_open()) + writer.close(); + diskann::cout << "Finished writing " << fsize << "B" << std::endl; } - if (cache_buf != nullptr) { - delete[] cache_buf; - cache_buf = nullptr; + size_t get_file_size() + { + return fsize; + } + // writes n_bytes from write_buf to the underlying ofstream/cache + void write(char *write_buf, uint64_t n_bytes) + { + assert(cache_buf != nullptr); + if (n_bytes <= (cache_size - cur_off)) + { + // case 1: cache can take all data + memcpy(cache_buf + cur_off, write_buf, n_bytes); + cur_off += n_bytes; + } + else + { + // case 2: cache cant take all data + // go to disk and write existing cache data + writer.write(cache_buf, cur_off); + fsize += cur_off; + // write the new data to disk + writer.write(write_buf, n_bytes); + fsize += n_bytes; + // memset all cache data and reset cur_off + memset(cache_buf, 0, cache_size); + cur_off = 0; + } } - if (writer.is_open()) - writer.close(); - diskann::cout << "Finished writing " << fsize << "B" << std::endl; - } - - size_t get_file_size() { return fsize; } - // writes n_bytes from write_buf to the underlying ofstream/cache - void write(char *write_buf, uint64_t n_bytes) { - assert(cache_buf != nullptr); - if (n_bytes <= (cache_size - cur_off)) { - // case 1: cache can take all data - memcpy(cache_buf + cur_off, write_buf, n_bytes); - cur_off += n_bytes; - } else { - // case 2: cache cant take all data - // go to disk and write existing cache data - writer.write(cache_buf, cur_off); - fsize += cur_off; - // write the new data to disk - writer.write(write_buf, n_bytes); - fsize += n_bytes; - // memset all cache data and reset cur_off - memset(cache_buf, 0, cache_size); - cur_off = 0; + void flush_cache() + { + assert(cache_buf != nullptr); + writer.write(cache_buf, cur_off); + fsize += cur_off; + memset(cache_buf, 0, cache_size); + cur_off = 0; } - } - - void flush_cache() { - assert(cache_buf != nullptr); - writer.write(cache_buf, cur_off); - fsize += cur_off; - memset(cache_buf, 0, cache_size); - cur_off = 0; - } - - void reset() { - flush_cache(); - writer.seekp(0); - } - -private: - // underlying ofstream - std::ofstream writer; - // # bytes to cache for one shot write - uint64_t cache_size = 0; - // underlying buf for cache - char *cache_buf = nullptr; - // offset into cache_buf for cur_pos - uint64_t cur_off = 0; - - // file size - uint64_t fsize = 0; + + void reset() + { + flush_cache(); + writer.seekp(0); + } + + private: + // underlying ofstream + std::ofstream writer; + // # bytes to cache for one shot write + uint64_t cache_size = 0; + // underlying buf for cache + char *cache_buf = nullptr; + // offset into cache_buf for cur_pos + uint64_t cur_off = 0; + + // file size + uint64_t fsize = 0; }; diff --git a/include/concurrent_queue.h b/include/concurrent_queue.h index b5e16a2cf..1e57bbf0f 100644 --- a/include/concurrent_queue.h +++ b/include/concurrent_queue.h @@ -11,90 +11,122 @@ #include #include -namespace diskann { +namespace diskann +{ -template class ConcurrentQueue { - typedef std::chrono::microseconds chrono_us_t; - typedef std::unique_lock mutex_locker; +template class ConcurrentQueue +{ + typedef std::chrono::microseconds chrono_us_t; + typedef std::unique_lock mutex_locker; - std::queue q; - std::mutex mut; - std::mutex push_mut; - std::mutex pop_mut; - std::condition_variable push_cv; - std::condition_variable pop_cv; - T null_T; + std::queue q; + std::mutex mut; + std::mutex push_mut; + std::mutex pop_mut; + std::condition_variable push_cv; + std::condition_variable pop_cv; + T null_T; -public: - ConcurrentQueue() {} + public: + ConcurrentQueue() + { + } - ConcurrentQueue(T nullT) { this->null_T = nullT; } + ConcurrentQueue(T nullT) + { + this->null_T = nullT; + } - ~ConcurrentQueue() { - this->push_cv.notify_all(); - this->pop_cv.notify_all(); - } + ~ConcurrentQueue() + { + this->push_cv.notify_all(); + this->pop_cv.notify_all(); + } - // queue stats - uint64_t size() { - mutex_locker lk(this->mut); - uint64_t ret = q.size(); - lk.unlock(); - return ret; - } + // queue stats + uint64_t size() + { + mutex_locker lk(this->mut); + uint64_t ret = q.size(); + lk.unlock(); + return ret; + } - bool empty() { return (this->size() == 0); } + bool empty() + { + return (this->size() == 0); + } - // PUSH BACK - void push(T &new_val) { - mutex_locker lk(this->mut); - this->q.push(new_val); - lk.unlock(); - } + // PUSH BACK + void push(T &new_val) + { + mutex_locker lk(this->mut); + this->q.push(new_val); + lk.unlock(); + } - template - void insert(Iterator iter_begin, Iterator iter_end) { - mutex_locker lk(this->mut); - for (Iterator it = iter_begin; it != iter_end; it++) { - this->q.push(*it); + template void insert(Iterator iter_begin, Iterator iter_end) + { + mutex_locker lk(this->mut); + for (Iterator it = iter_begin; it != iter_end; it++) + { + this->q.push(*it); + } + lk.unlock(); } - lk.unlock(); - } - // POP FRONT - T pop() { - mutex_locker lk(this->mut); - if (this->q.empty()) { - lk.unlock(); - return this->null_T; - } else { - T ret = this->q.front(); - this->q.pop(); - // diskann::cout << "thread_id: " << std::this_thread::get_id() << - // ", ctx: " - // << ret.ctx << "\n"; - lk.unlock(); - return ret; + // POP FRONT + T pop() + { + mutex_locker lk(this->mut); + if (this->q.empty()) + { + lk.unlock(); + return this->null_T; + } + else + { + T ret = this->q.front(); + this->q.pop(); + // diskann::cout << "thread_id: " << std::this_thread::get_id() << + // ", ctx: " + // << ret.ctx << "\n"; + lk.unlock(); + return ret; + } } - } - // register for notifications - void wait_for_push_notify(chrono_us_t wait_time = chrono_us_t{10}) { - mutex_locker lk(this->push_mut); - this->push_cv.wait_for(lk, wait_time); - lk.unlock(); - } + // register for notifications + void wait_for_push_notify(chrono_us_t wait_time = chrono_us_t{10}) + { + mutex_locker lk(this->push_mut); + this->push_cv.wait_for(lk, wait_time); + lk.unlock(); + } - void wait_for_pop_notify(chrono_us_t wait_time = chrono_us_t{10}) { - mutex_locker lk(this->pop_mut); - this->pop_cv.wait_for(lk, wait_time); - lk.unlock(); - } + void wait_for_pop_notify(chrono_us_t wait_time = chrono_us_t{10}) + { + mutex_locker lk(this->pop_mut); + this->pop_cv.wait_for(lk, wait_time); + lk.unlock(); + } - // just notify functions - void push_notify_one() { this->push_cv.notify_one(); } - void push_notify_all() { this->push_cv.notify_all(); } - void pop_notify_one() { this->pop_cv.notify_one(); } - void pop_notify_all() { this->pop_cv.notify_all(); } + // just notify functions + void push_notify_one() + { + this->push_cv.notify_one(); + } + void push_notify_all() + { + this->push_cv.notify_all(); + } + void pop_notify_one() + { + this->pop_cv.notify_one(); + } + void pop_notify_all() + { + this->pop_cv.notify_all(); + } }; } // namespace diskann diff --git a/include/cosine_similarity.h b/include/cosine_similarity.h index 5b37c6cc5..af62eb53b 100644 --- a/include/cosine_similarity.h +++ b/include/cosine_similarity.h @@ -38,194 +38,202 @@ extern bool Avx2SupportedCPU; * */ -namespace diskann { +namespace diskann +{ using namespace std; #define PORTABLE_ALIGN16 __declspec(align(16)) -static float NormScalarProductSIMD2(const int8_t *pVect1, const int8_t *pVect2, - uint32_t qty) { - if (Avx2SupportedCPU) { - __m256 cos, p1Len, p2Len; - cos = p1Len = p2Len = _mm256_setzero_ps(); - while (qty >= 32) { - __m256i rx = _mm256_load_si256((__m256i *)pVect1), - ry = _mm256_load_si256((__m256i *)pVect2); - cos = _mm256_add_ps(cos, _mm256_mul_epi8(rx, ry)); - p1Len = _mm256_add_ps(p1Len, _mm256_mul_epi8(rx, rx)); - p2Len = _mm256_add_ps(p2Len, _mm256_mul_epi8(ry, ry)); - pVect1 += 32; - pVect2 += 32; - qty -= 32; +static float NormScalarProductSIMD2(const int8_t *pVect1, const int8_t *pVect2, uint32_t qty) +{ + if (Avx2SupportedCPU) + { + __m256 cos, p1Len, p2Len; + cos = p1Len = p2Len = _mm256_setzero_ps(); + while (qty >= 32) + { + __m256i rx = _mm256_load_si256((__m256i *)pVect1), ry = _mm256_load_si256((__m256i *)pVect2); + cos = _mm256_add_ps(cos, _mm256_mul_epi8(rx, ry)); + p1Len = _mm256_add_ps(p1Len, _mm256_mul_epi8(rx, rx)); + p2Len = _mm256_add_ps(p2Len, _mm256_mul_epi8(ry, ry)); + pVect1 += 32; + pVect2 += 32; + qty -= 32; + } + while (qty > 0) + { + __m128i rx = _mm_load_si128((__m128i *)pVect1), ry = _mm_load_si128((__m128i *)pVect2); + cos = _mm256_add_ps(cos, _mm256_mul32_pi8(rx, ry)); + p1Len = _mm256_add_ps(p1Len, _mm256_mul32_pi8(rx, rx)); + p2Len = _mm256_add_ps(p2Len, _mm256_mul32_pi8(ry, ry)); + pVect1 += 4; + pVect2 += 4; + qty -= 4; + } + cos = _mm256_hadd_ps(_mm256_hadd_ps(cos, cos), cos); + p1Len = _mm256_hadd_ps(_mm256_hadd_ps(p1Len, p1Len), p1Len); + p2Len = _mm256_hadd_ps(_mm256_hadd_ps(p2Len, p2Len), p2Len); + float denominator = max(numeric_limits::min() * 2, sqrt(p1Len.m256_f32[0] + p1Len.m256_f32[4]) * + sqrt(p2Len.m256_f32[0] + p2Len.m256_f32[4])); + float cosine = (cos.m256_f32[0] + cos.m256_f32[4]) / denominator; + + return max(float(-1), min(float(1), cosine)); } - while (qty > 0) { - __m128i rx = _mm_load_si128((__m128i *)pVect1), - ry = _mm_load_si128((__m128i *)pVect2); - cos = _mm256_add_ps(cos, _mm256_mul32_pi8(rx, ry)); - p1Len = _mm256_add_ps(p1Len, _mm256_mul32_pi8(rx, rx)); - p2Len = _mm256_add_ps(p2Len, _mm256_mul32_pi8(ry, ry)); - pVect1 += 4; - pVect2 += 4; - qty -= 4; + + __m128 cos, p1Len, p2Len; + cos = p1Len = p2Len = _mm_setzero_ps(); + __m128i rx, ry; + while (qty >= 16) + { + rx = _mm_load_si128((__m128i *)pVect1); + ry = _mm_load_si128((__m128i *)pVect2); + cos = _mm_add_ps(cos, _mm_mul_epi8(rx, ry)); + p1Len = _mm_add_ps(p1Len, _mm_mul_epi8(rx, rx)); + p2Len = _mm_add_ps(p2Len, _mm_mul_epi8(ry, ry)); + pVect1 += 16; + pVect2 += 16; + qty -= 16; + } + while (qty > 0) + { + rx = _mm_load_si128((__m128i *)pVect1); + ry = _mm_load_si128((__m128i *)pVect2); + cos = _mm_add_ps(cos, _mm_mul32_pi8(rx, ry)); + p1Len = _mm_add_ps(p1Len, _mm_mul32_pi8(rx, rx)); + p2Len = _mm_add_ps(p2Len, _mm_mul32_pi8(ry, ry)); + pVect1 += 4; + pVect2 += 4; + qty -= 4; } - cos = _mm256_hadd_ps(_mm256_hadd_ps(cos, cos), cos); - p1Len = _mm256_hadd_ps(_mm256_hadd_ps(p1Len, p1Len), p1Len); - p2Len = _mm256_hadd_ps(_mm256_hadd_ps(p2Len, p2Len), p2Len); - float denominator = max(numeric_limits::min() * 2, - sqrt(p1Len.m256_f32[0] + p1Len.m256_f32[4]) * - sqrt(p2Len.m256_f32[0] + p2Len.m256_f32[4])); - float cosine = (cos.m256_f32[0] + cos.m256_f32[4]) / denominator; - - return max(float(-1), min(float(1), cosine)); - } - - __m128 cos, p1Len, p2Len; - cos = p1Len = p2Len = _mm_setzero_ps(); - __m128i rx, ry; - while (qty >= 16) { - rx = _mm_load_si128((__m128i *)pVect1); - ry = _mm_load_si128((__m128i *)pVect2); - cos = _mm_add_ps(cos, _mm_mul_epi8(rx, ry)); - p1Len = _mm_add_ps(p1Len, _mm_mul_epi8(rx, rx)); - p2Len = _mm_add_ps(p2Len, _mm_mul_epi8(ry, ry)); - pVect1 += 16; - pVect2 += 16; - qty -= 16; - } - while (qty > 0) { - rx = _mm_load_si128((__m128i *)pVect1); - ry = _mm_load_si128((__m128i *)pVect2); - cos = _mm_add_ps(cos, _mm_mul32_pi8(rx, ry)); - p1Len = _mm_add_ps(p1Len, _mm_mul32_pi8(rx, rx)); - p2Len = _mm_add_ps(p2Len, _mm_mul32_pi8(ry, ry)); - pVect1 += 4; - pVect2 += 4; - qty -= 4; - } - cos = _mm_hadd_ps(_mm_hadd_ps(cos, cos), cos); - p1Len = _mm_hadd_ps(_mm_hadd_ps(p1Len, p1Len), p1Len); - p2Len = _mm_hadd_ps(_mm_hadd_ps(p2Len, p2Len), p2Len); - float norm1 = p1Len.m128_f32[0]; - float norm2 = p2Len.m128_f32[0]; - - static const float eps = numeric_limits::min() * 2; - - if (norm1 < eps) { /* - * This shouldn't normally happen for this space, but - * if it does, we don't want to get NANs - */ - if (norm2 < eps) { - return 1; + cos = _mm_hadd_ps(_mm_hadd_ps(cos, cos), cos); + p1Len = _mm_hadd_ps(_mm_hadd_ps(p1Len, p1Len), p1Len); + p2Len = _mm_hadd_ps(_mm_hadd_ps(p2Len, p2Len), p2Len); + float norm1 = p1Len.m128_f32[0]; + float norm2 = p2Len.m128_f32[0]; + + static const float eps = numeric_limits::min() * 2; + + if (norm1 < eps) + { /* + * This shouldn't normally happen for this space, but + * if it does, we don't want to get NANs + */ + if (norm2 < eps) + { + return 1; + } + return 0; } - return 0; - } - /* - * Sometimes due to rounding errors, we get values > 1 or < -1. - * This throws off other functions that use scalar product, e.g., acos - */ - return max(float(-1), - min(float(1), cos.m128_f32[0] / sqrt(norm1) / sqrt(norm2))); + /* + * Sometimes due to rounding errors, we get values > 1 or < -1. + * This throws off other functions that use scalar product, e.g., acos + */ + return max(float(-1), min(float(1), cos.m128_f32[0] / sqrt(norm1) / sqrt(norm2))); } -static float NormScalarProductSIMD(const float *pVect1, const float *pVect2, - uint32_t qty) { - // Didn't get significant performance gain compared with 128bit version. - static const float eps = numeric_limits::min() * 2; - - if (Avx2SupportedCPU) { - uint32_t qty8 = qty / 8; - - const float *pEnd1 = pVect1 + 8 * qty8; - const float *pEnd2 = pVect1 + qty; - - __m256 v1, v2; - __m256 sum_prod = _mm256_set_ps(0, 0, 0, 0, 0, 0, 0, 0); - __m256 sum_square1 = sum_prod; - __m256 sum_square2 = sum_prod; - - while (pVect1 < pEnd1) { - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum_prod = _mm256_add_ps(sum_prod, _mm256_mul_ps(v1, v2)); - sum_square1 = _mm256_add_ps(sum_square1, _mm256_mul_ps(v1, v1)); - sum_square2 = _mm256_add_ps(sum_square2, _mm256_mul_ps(v2, v2)); +static float NormScalarProductSIMD(const float *pVect1, const float *pVect2, uint32_t qty) +{ + // Didn't get significant performance gain compared with 128bit version. + static const float eps = numeric_limits::min() * 2; + + if (Avx2SupportedCPU) + { + uint32_t qty8 = qty / 8; + + const float *pEnd1 = pVect1 + 8 * qty8; + const float *pEnd2 = pVect1 + qty; + + __m256 v1, v2; + __m256 sum_prod = _mm256_set_ps(0, 0, 0, 0, 0, 0, 0, 0); + __m256 sum_square1 = sum_prod; + __m256 sum_square2 = sum_prod; + + while (pVect1 < pEnd1) + { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum_prod = _mm256_add_ps(sum_prod, _mm256_mul_ps(v1, v2)); + sum_square1 = _mm256_add_ps(sum_square1, _mm256_mul_ps(v1, v1)); + sum_square2 = _mm256_add_ps(sum_square2, _mm256_mul_ps(v2, v2)); + } + + float PORTABLE_ALIGN16 TmpResProd[8]; + float PORTABLE_ALIGN16 TmpResSquare1[8]; + float PORTABLE_ALIGN16 TmpResSquare2[8]; + + _mm256_store_ps(TmpResProd, sum_prod); + _mm256_store_ps(TmpResSquare1, sum_square1); + _mm256_store_ps(TmpResSquare2, sum_square2); + + float sum = 0.0f; + float norm1 = 0.0f; + float norm2 = 0.0f; + for (uint32_t i = 0; i < 8; ++i) + { + sum += TmpResProd[i]; + norm1 += TmpResSquare1[i]; + norm2 += TmpResSquare2[i]; + } + + while (pVect1 < pEnd2) + { + sum += (*pVect1) * (*pVect2); + norm1 += (*pVect1) * (*pVect1); + norm2 += (*pVect2) * (*pVect2); + + ++pVect1; + ++pVect2; + } + + if (norm1 < eps) + { + return norm2 < eps ? 1.0f : 0.0f; + } + + return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2))); } - float PORTABLE_ALIGN16 TmpResProd[8]; - float PORTABLE_ALIGN16 TmpResSquare1[8]; - float PORTABLE_ALIGN16 TmpResSquare2[8]; - - _mm256_store_ps(TmpResProd, sum_prod); - _mm256_store_ps(TmpResSquare1, sum_square1); - _mm256_store_ps(TmpResSquare2, sum_square2); - - float sum = 0.0f; - float norm1 = 0.0f; - float norm2 = 0.0f; - for (uint32_t i = 0; i < 8; ++i) { - sum += TmpResProd[i]; - norm1 += TmpResSquare1[i]; - norm2 += TmpResSquare2[i]; + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + __m128 sum_square1 = sum_prod; + __m128 sum_square2 = sum_prod; + + while (qty >= 4) + { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + sum_square1 = _mm_add_ps(sum_square1, _mm_mul_ps(v1, v1)); + sum_square2 = _mm_add_ps(sum_square2, _mm_mul_ps(v2, v2)); + + qty -= 4; } - while (pVect1 < pEnd2) { - sum += (*pVect1) * (*pVect2); - norm1 += (*pVect1) * (*pVect1); - norm2 += (*pVect2) * (*pVect2); - - ++pVect1; - ++pVect2; - } + float sum = sum_prod.m128_f32[0] + sum_prod.m128_f32[1] + sum_prod.m128_f32[2] + sum_prod.m128_f32[3]; + float norm1 = sum_square1.m128_f32[0] + sum_square1.m128_f32[1] + sum_square1.m128_f32[2] + sum_square1.m128_f32[3]; + float norm2 = sum_square2.m128_f32[0] + sum_square2.m128_f32[1] + sum_square2.m128_f32[2] + sum_square2.m128_f32[3]; - if (norm1 < eps) { - return norm2 < eps ? 1.0f : 0.0f; + if (norm1 < eps) + { + return norm2 < eps ? 1.0f : 0.0f; } return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2))); - } - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - __m128 sum_square1 = sum_prod; - __m128 sum_square2 = sum_prod; - - while (qty >= 4) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - sum_square1 = _mm_add_ps(sum_square1, _mm_mul_ps(v1, v1)); - sum_square2 = _mm_add_ps(sum_square2, _mm_mul_ps(v2, v2)); - - qty -= 4; - } - - float sum = sum_prod.m128_f32[0] + sum_prod.m128_f32[1] + - sum_prod.m128_f32[2] + sum_prod.m128_f32[3]; - float norm1 = sum_square1.m128_f32[0] + sum_square1.m128_f32[1] + - sum_square1.m128_f32[2] + sum_square1.m128_f32[3]; - float norm2 = sum_square2.m128_f32[0] + sum_square2.m128_f32[1] + - sum_square2.m128_f32[2] + sum_square2.m128_f32[3]; - - if (norm1 < eps) { - return norm2 < eps ? 1.0f : 0.0f; - } - - return max(float(-1), min(float(1), sum / sqrt(norm1) / sqrt(norm2))); } -static float NormScalarProductSIMD2(const float *pVect1, const float *pVect2, - uint32_t qty) { - return NormScalarProductSIMD(pVect1, pVect2, qty); +static float NormScalarProductSIMD2(const float *pVect1, const float *pVect2, uint32_t qty) +{ + return NormScalarProductSIMD(pVect1, pVect2, qty); } -template -static float CosineSimilarity2(const T *p1, const T *p2, uint32_t qty) { - return std::max(0.0f, 1.0f - NormScalarProductSIMD2(p1, p2, qty)); +template static float CosineSimilarity2(const T *p1, const T *p2, uint32_t qty) +{ + return std::max(0.0f, 1.0f - NormScalarProductSIMD2(p1, p2, qty)); } // static template float CosineSimilarity2<__int8>(const __int8* pVect1, @@ -234,19 +242,22 @@ static float CosineSimilarity2(const T *p1, const T *p2, uint32_t qty) { // static template float CosineSimilarity2(const float* pVect1, // const float* pVect2, size_t qty); -template -static void CosineSimilarityNormalize(T *pVector, uint32_t qty) { - T sum = 0; - for (uint32_t i = 0; i < qty; ++i) { - sum += pVector[i] * pVector[i]; - } - sum = 1 / sqrt(sum); - if (sum == 0) { - sum = numeric_limits::min(); - } - for (uint32_t i = 0; i < qty; ++i) { - pVector[i] *= sum; - } +template static void CosineSimilarityNormalize(T *pVector, uint32_t qty) +{ + T sum = 0; + for (uint32_t i = 0; i < qty; ++i) + { + sum += pVector[i] * pVector[i]; + } + sum = 1 / sqrt(sum); + if (sum == 0) + { + sum = numeric_limits::min(); + } + for (uint32_t i = 0; i < qty; ++i) + { + pVector[i] *= sum; + } } // template static void CosineSimilarityNormalize(float* pVector, @@ -254,22 +265,19 @@ static void CosineSimilarityNormalize(T *pVector, uint32_t qty) { // template static void CosineSimilarityNormalize(double* pVector, // size_t qty); -template <> -void CosineSimilarityNormalize(__int8 * /*pVector*/, uint32_t /*qty*/) { - throw std::runtime_error( - "For int8 type vector, you can not use cosine distance!"); +template <> void CosineSimilarityNormalize(__int8 * /*pVector*/, uint32_t /*qty*/) +{ + throw std::runtime_error("For int8 type vector, you can not use cosine distance!"); } -template <> -void CosineSimilarityNormalize(__int16 * /*pVector*/, uint32_t /*qty*/) { - throw std::runtime_error( - "For int16 type vector, you can not use cosine distance!"); +template <> void CosineSimilarityNormalize(__int16 * /*pVector*/, uint32_t /*qty*/) +{ + throw std::runtime_error("For int16 type vector, you can not use cosine distance!"); } -template <> -void CosineSimilarityNormalize(int * /*pVector*/, uint32_t /*qty*/) { - throw std::runtime_error( - "For int type vector, you can not use cosine distance!"); +template <> void CosineSimilarityNormalize(int * /*pVector*/, uint32_t /*qty*/) +{ + throw std::runtime_error("For int type vector, you can not use cosine distance!"); } } // namespace diskann #endif \ No newline at end of file diff --git a/include/defaults.h b/include/defaults.h index 6ba9f9d07..ef1750fcf 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -4,8 +4,10 @@ #pragma once #include -namespace diskann { -namespace defaults { +namespace diskann +{ +namespace defaults +{ const float ALPHA = 1.2f; const uint32_t NUM_THREADS = 0; const uint32_t MAX_OCCLUSION_SIZE = 750; diff --git a/include/disk_utils.h b/include/disk_utils.h index 426042bb3..1acb7f981 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -31,7 +31,8 @@ typedef int FileHandle; #include "utils.h" #include "windows_customizations.h" -namespace diskann { +namespace diskann +{ const size_t MAX_SAMPLE_POINTS_FOR_WARMUP = 100000; const double PQ_TRAINING_SET_FRACTION = 0.1; const double SPACE_FOR_CACHED_NODES_IN_GB = 0.25; @@ -44,84 +45,64 @@ template class PQFlashIndex; DISKANN_DLLEXPORT double get_memory_budget(const std::string &mem_budget_str); DISKANN_DLLEXPORT double get_memory_budget(double search_ram_budget_in_gb); -DISKANN_DLLEXPORT void add_new_file_to_single_index(std::string index_file, - std::string new_file); +DISKANN_DLLEXPORT void add_new_file_to_single_index(std::string index_file, std::string new_file); -DISKANN_DLLEXPORT size_t calculate_num_pq_chunks(double final_index_ram_limit, - size_t points_num, - uint32_t dim); +DISKANN_DLLEXPORT size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim); -DISKANN_DLLEXPORT void read_idmap(const std::string &fname, - std::vector &ivecs); +DISKANN_DLLEXPORT void read_idmap(const std::string &fname, std::vector &ivecs); #ifdef EXEC_ENV_OLS template -DISKANN_DLLEXPORT T *load_warmup(MemoryMappedFiles &files, - const std::string &cache_warmup_file, - uint64_t &warmup_num, uint64_t warmup_dim, - uint64_t warmup_aligned_dim); +DISKANN_DLLEXPORT T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, uint64_t &warmup_num, + uint64_t warmup_dim, uint64_t warmup_aligned_dim); #else template -DISKANN_DLLEXPORT T *load_warmup(const std::string &cache_warmup_file, - uint64_t &warmup_num, uint64_t warmup_dim, +DISKANN_DLLEXPORT T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim); #endif -DISKANN_DLLEXPORT int -merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix, - const std::string &idmaps_prefix, const std::string &idmaps_suffix, - const uint64_t nshards, uint32_t max_degree, - const std::string &output_vamana, const std::string &medoids_file, - bool use_filters = false, - const std::string &labels_to_medoids_file = std::string("")); +DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix, + const std::string &idmaps_prefix, const std::string &idmaps_suffix, + const uint64_t nshards, uint32_t max_degree, const std::string &output_vamana, + const std::string &medoids_file, bool use_filters = false, + const std::string &labels_to_medoids_file = std::string("")); -DISKANN_DLLEXPORT void -extract_shard_labels(const std::string &in_label_file, - const std::string &shard_ids_bin, - const std::string &shard_label_file); +DISKANN_DLLEXPORT void extract_shard_labels(const std::string &in_label_file, const std::string &shard_ids_bin, + const std::string &shard_label_file); template -DISKANN_DLLEXPORT std::string -preprocess_base_file(const std::string &infile, const std::string &indexPrefix, - diskann::Metric &distMetric); +DISKANN_DLLEXPORT std::string preprocess_base_file(const std::string &infile, const std::string &indexPrefix, + diskann::Metric &distMetric); template -DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric _compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_file, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters = false, - const std::string &label_file = std::string(""), - const std::string &labels_to_medoids_file = std::string(""), - const std::string &universal_label = "", const uint32_t Lf = 0); +DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::Metric _compareMetric, uint32_t L, + uint32_t R, double sampling_rate, double ram_budget, + std::string mem_index_path, std::string medoids_file, + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + uint32_t num_threads, bool use_filters = false, + const std::string &label_file = std::string(""), + const std::string &labels_to_medoids_file = std::string(""), + const std::string &universal_label = "", const uint32_t Lf = 0); template -DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &_pFlashIndex, - T *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw = 2); +DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr> &_pFlashIndex, + T *tuning_sample, uint64_t tuning_sample_num, + uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, + uint32_t start_bw = 2); template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric _compareMetric, - bool use_opq = false, - const std::string &codebook_prefix = - "", // default is empty for no codebook pass in + const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, + diskann::Metric _compareMetric, bool use_opq = false, + const std::string &codebook_prefix = "", // default is empty for no codebook pass in bool use_filters = false, - const std::string &label_file = - std::string(""), // default is empty string for no label_file - const std::string &universal_label = "", - const uint32_t filter_threshold = 0, + const std::string &label_file = std::string(""), // default is empty string for no label_file + const std::string &universal_label = "", const uint32_t filter_threshold = 0, const uint32_t Lf = 0); // default is empty string for no universal label template -DISKANN_DLLEXPORT void -create_disk_layout(const std::string base_file, - const std::string mem_index_file, - const std::string output_file, - const std::string reorder_data_file = std::string("")); +DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file = std::string("")); } // namespace diskann diff --git a/include/distance.h b/include/distance.h index ae0f1d452..7a3ec8b26 100644 --- a/include/distance.h +++ b/include/distance.h @@ -3,204 +3,232 @@ #include #include -namespace diskann { -enum Metric { L2 = 0, INNER_PRODUCT = 1, COSINE = 2, FAST_L2 = 3 }; - -template class Distance { -public: - DISKANN_DLLEXPORT Distance(diskann::Metric dist_metric) - : _distance_metric(dist_metric) {} - - // distance comparison function - DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, - uint32_t length) const = 0; - - // Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE - DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, - const float normA, const float normB, - uint32_t length) const; - - // For MIPS, normalization adds an extra dimension to the vectors. - // This function lets callers know if the normalization process - // changes the dimension. - DISKANN_DLLEXPORT virtual uint32_t - post_normalization_dimension(uint32_t orig_dimension) const; - - DISKANN_DLLEXPORT virtual diskann::Metric get_metric() const; - - // This is for efficiency. If no normalization is required, the callers - // can simply ignore the normalize_data_for_build() function. - DISKANN_DLLEXPORT virtual bool preprocessing_required() const; - - // Check the preprocessing_required() function before calling this. - // Clients can call the function like this: - // - // if (metric->preprocessing_required()){ - // T* normalized_data_batch; - // Split data into batches of batch_size and for each, call: - // metric->preprocess_base_points(data_batch, batch_size); - // - // TODO: This does not take into account the case for SSD inner product - // where the dimensions change after normalization. - DISKANN_DLLEXPORT virtual void - preprocess_base_points(T *original_data, const size_t orig_dim, - const size_t num_points); - - // Invokes normalization for a single vector during search. The scratch space - // has to be created by the caller keeping track of the fact that - // normalization might change the dimension of the query vector. - DISKANN_DLLEXPORT virtual void preprocess_query(const T *query_vec, - const size_t query_dim, - T *scratch_query); - - // If an algorithm has a requirement that some data be aligned to a certain - // boundary it can use this function to indicate that requirement. Currently, - // we are setting it to 8 because that works well for AVX2. If we have AVX512 - // implementations of distance algos, they might have to set this to 16 - // (depending on how they are implemented) - DISKANN_DLLEXPORT virtual size_t get_required_alignment() const; - - // Providing a default implementation for the virtual destructor because we - // don't expect most metric implementations to need it. - DISKANN_DLLEXPORT virtual ~Distance() = default; - -protected: - diskann::Metric _distance_metric; - size_t _alignment_factor = 8; +namespace diskann +{ +enum Metric +{ + L2 = 0, + INNER_PRODUCT = 1, + COSINE = 2, + FAST_L2 = 3 }; -class DistanceCosineInt8 : public Distance { -public: - DistanceCosineInt8() : Distance(diskann::Metric::COSINE) {} - DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, - uint32_t length) const; +template class Distance +{ + public: + DISKANN_DLLEXPORT Distance(diskann::Metric dist_metric) : _distance_metric(dist_metric) + { + } + + // distance comparison function + DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const = 0; + + // Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE + DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB, + uint32_t length) const; + + // For MIPS, normalization adds an extra dimension to the vectors. + // This function lets callers know if the normalization process + // changes the dimension. + DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const; + + DISKANN_DLLEXPORT virtual diskann::Metric get_metric() const; + + // This is for efficiency. If no normalization is required, the callers + // can simply ignore the normalize_data_for_build() function. + DISKANN_DLLEXPORT virtual bool preprocessing_required() const; + + // Check the preprocessing_required() function before calling this. + // Clients can call the function like this: + // + // if (metric->preprocessing_required()){ + // T* normalized_data_batch; + // Split data into batches of batch_size and for each, call: + // metric->preprocess_base_points(data_batch, batch_size); + // + // TODO: This does not take into account the case for SSD inner product + // where the dimensions change after normalization. + DISKANN_DLLEXPORT virtual void preprocess_base_points(T *original_data, const size_t orig_dim, + const size_t num_points); + + // Invokes normalization for a single vector during search. The scratch space + // has to be created by the caller keeping track of the fact that + // normalization might change the dimension of the query vector. + DISKANN_DLLEXPORT virtual void preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query); + + // If an algorithm has a requirement that some data be aligned to a certain + // boundary it can use this function to indicate that requirement. Currently, + // we are setting it to 8 because that works well for AVX2. If we have AVX512 + // implementations of distance algos, they might have to set this to 16 + // (depending on how they are implemented) + DISKANN_DLLEXPORT virtual size_t get_required_alignment() const; + + // Providing a default implementation for the virtual destructor because we + // don't expect most metric implementations to need it. + DISKANN_DLLEXPORT virtual ~Distance() = default; + + protected: + diskann::Metric _distance_metric; + size_t _alignment_factor = 8; }; -class DistanceL2Int8 : public Distance { -public: - DistanceL2Int8() : Distance(diskann::Metric::L2) {} - DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, - uint32_t size) const; +class DistanceCosineInt8 : public Distance +{ + public: + DistanceCosineInt8() : Distance(diskann::Metric::COSINE) + { + } + DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const; +}; + +class DistanceL2Int8 : public Distance +{ + public: + DistanceL2Int8() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const; }; // AVX implementations. Borrowed from HNSW code. -class AVXDistanceL2Int8 : public Distance { -public: - AVXDistanceL2Int8() : Distance(diskann::Metric::L2) {} - DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, - uint32_t length) const; +class AVXDistanceL2Int8 : public Distance +{ + public: + AVXDistanceL2Int8() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const; }; -class DistanceCosineFloat : public Distance { -public: - DistanceCosineFloat() : Distance(diskann::Metric::COSINE) {} - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; +class DistanceCosineFloat : public Distance +{ + public: + DistanceCosineFloat() : Distance(diskann::Metric::COSINE) + { + } + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; }; -class DistanceL2Float : public Distance { -public: - DistanceL2Float() : Distance(diskann::Metric::L2) {} +class DistanceL2Float : public Distance +{ + public: + DistanceL2Float() : Distance(diskann::Metric::L2) + { + } #ifdef _WINDOWS - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t size) const; + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const; #else - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t size) const - __attribute__((hot)); + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const __attribute__((hot)); #endif }; -class AVXDistanceL2Float : public Distance { -public: - AVXDistanceL2Float() : Distance(diskann::Metric::L2) {} - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; +class AVXDistanceL2Float : public Distance +{ + public: + AVXDistanceL2Float() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; }; -template class SlowDistanceL2 : public Distance { -public: - SlowDistanceL2() : Distance(diskann::Metric::L2) {} - DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, - uint32_t length) const; +template class SlowDistanceL2 : public Distance +{ + public: + SlowDistanceL2() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const; }; -class SlowDistanceCosineUInt8 : public Distance { -public: - SlowDistanceCosineUInt8() : Distance(diskann::Metric::COSINE) {} - DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, - uint32_t length) const; +class SlowDistanceCosineUInt8 : public Distance +{ + public: + SlowDistanceCosineUInt8() : Distance(diskann::Metric::COSINE) + { + } + DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const; }; -class DistanceL2UInt8 : public Distance { -public: - DistanceL2UInt8() : Distance(diskann::Metric::L2) {} - DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, - uint32_t size) const; +class DistanceL2UInt8 : public Distance +{ + public: + DistanceL2UInt8() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const; }; -template class DistanceInnerProduct : public Distance { -public: - DistanceInnerProduct() : Distance(diskann::Metric::INNER_PRODUCT) {} - - DistanceInnerProduct(diskann::Metric metric) : Distance(metric) {} - inline float inner_product(const T *a, const T *b, unsigned size) const; - - inline float compare(const T *a, const T *b, unsigned size) const { - float result = inner_product(a, b, size); - // if (result < 0) - // return std::numeric_limits::max(); - // else - return -result; - } +template class DistanceInnerProduct : public Distance +{ + public: + DistanceInnerProduct() : Distance(diskann::Metric::INNER_PRODUCT) + { + } + + DistanceInnerProduct(diskann::Metric metric) : Distance(metric) + { + } + inline float inner_product(const T *a, const T *b, unsigned size) const; + + inline float compare(const T *a, const T *b, unsigned size) const + { + float result = inner_product(a, b, size); + // if (result < 0) + // return std::numeric_limits::max(); + // else + return -result; + } }; -template class DistanceFastL2 : public DistanceInnerProduct { - // currently defined only for float. - // templated for future use. -public: - DistanceFastL2() : DistanceInnerProduct(diskann::Metric::FAST_L2) {} - float norm(const T *a, unsigned size) const; - float compare(const T *a, const T *b, float norm, unsigned size) const; +template class DistanceFastL2 : public DistanceInnerProduct +{ + // currently defined only for float. + // templated for future use. + public: + DistanceFastL2() : DistanceInnerProduct(diskann::Metric::FAST_L2) + { + } + float norm(const T *a, unsigned size) const; + float compare(const T *a, const T *b, float norm, unsigned size) const; }; -class AVXDistanceInnerProductFloat : public Distance { -public: - AVXDistanceInnerProductFloat() - : Distance(diskann::Metric::INNER_PRODUCT) {} - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const; +class AVXDistanceInnerProductFloat : public Distance +{ + public: + AVXDistanceInnerProductFloat() : Distance(diskann::Metric::INNER_PRODUCT) + { + } + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; }; -class AVXNormalizedCosineDistanceFloat : public Distance { -private: - AVXDistanceInnerProductFloat _innerProduct; - -protected: - void normalize_and_copy(const float *a, uint32_t length, float *a_norm) const; - -public: - AVXNormalizedCosineDistanceFloat() - : Distance(diskann::Metric::COSINE) {} - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, - uint32_t length) const { - // Inner product returns negative values to indicate distance. - // This will ensure that cosine is between -1 and 1. - return 1.0f + _innerProduct.compare(a, b, length); - } - DISKANN_DLLEXPORT virtual uint32_t - post_normalization_dimension(uint32_t orig_dimension) const override; - - DISKANN_DLLEXPORT virtual bool preprocessing_required() const; - - DISKANN_DLLEXPORT virtual void - preprocess_base_points(float *original_data, const size_t orig_dim, - const size_t num_points) override; - - DISKANN_DLLEXPORT virtual void - preprocess_query(const float *query_vec, const size_t query_dim, - float *scratch_query_vector) override; +class AVXNormalizedCosineDistanceFloat : public Distance +{ + private: + AVXDistanceInnerProductFloat _innerProduct; + + protected: + void normalize_and_copy(const float *a, uint32_t length, float *a_norm) const; + + public: + AVXNormalizedCosineDistanceFloat() : Distance(diskann::Metric::COSINE) + { + } + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const + { + // Inner product returns negative values to indicate distance. + // This will ensure that cosine is between -1 and 1. + return 1.0f + _innerProduct.compare(a, b, length); + } + DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const override; + + DISKANN_DLLEXPORT virtual bool preprocessing_required() const; + + DISKANN_DLLEXPORT virtual void preprocess_base_points(float *original_data, const size_t orig_dim, + const size_t num_points) override; + + DISKANN_DLLEXPORT virtual void preprocess_query(const float *query_vec, const size_t query_dim, + float *scratch_query_vector) override; }; template Distance *get_distance_function(Metric m); diff --git a/include/exceptions.h b/include/exceptions.h index 7ac02d38d..99e4e7361 100644 --- a/include/exceptions.h +++ b/include/exceptions.h @@ -4,11 +4,14 @@ #pragma once #include -namespace diskann { +namespace diskann +{ -class NotImplementedException : public std::logic_error { -public: - NotImplementedException() - : std::logic_error("Function not yet implemented.") {} +class NotImplementedException : public std::logic_error +{ + public: + NotImplementedException() : std::logic_error("Function not yet implemented.") + { + } }; } // namespace diskann diff --git a/include/filter_utils.h b/include/filter_utils.h index 00e30a781..ba5b2d601 100644 --- a/include/filter_utils.h +++ b/include/filter_utils.h @@ -44,36 +44,28 @@ typedef tsl::robin_set label_set; typedef std::string path; // structs for returning multiple items from a function -typedef std::tuple, - tsl::robin_map, - tsl::robin_set> +typedef std::tuple, tsl::robin_map, tsl::robin_set> parse_label_file_return_values; -typedef std::tuple>, uint64_t> - load_label_index_return_values; +typedef std::tuple>, uint64_t> load_label_index_return_values; -namespace diskann { +namespace diskann +{ template -DISKANN_DLLEXPORT void -generate_label_indices(path input_data_path, path final_index_path_prefix, - label_set all_labels, unsigned R, unsigned L, - float alpha, unsigned num_threads); +DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, + unsigned R, unsigned L, float alpha, unsigned num_threads); -DISKANN_DLLEXPORT load_label_index_return_values -load_label_index(path label_index_path, uint32_t label_number_of_points); +DISKANN_DLLEXPORT load_label_index_return_values load_label_index(path label_index_path, + uint32_t label_number_of_points); template -DISKANN_DLLEXPORT - std::tuple>, tsl::robin_set> - parse_formatted_label_file(path label_file); +DISKANN_DLLEXPORT std::tuple>, tsl::robin_set> parse_formatted_label_file( + path label_file); -DISKANN_DLLEXPORT parse_label_file_return_values -parse_label_file(path label_data_path, std::string universal_label); +DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label); template -DISKANN_DLLEXPORT tsl::robin_map> -generate_label_specific_vector_files_compat( - path input_data_path, - tsl::robin_map labels_to_number_of_points, +DISKANN_DLLEXPORT tsl::robin_map> generate_label_specific_vector_files_compat( + path input_data_path, tsl::robin_map labels_to_number_of_points, std::vector point_ids_to_labels, label_set all_labels); /* @@ -88,142 +80,142 @@ generate_label_specific_vector_files_compat( */ #ifndef _WINDOWS template -inline tsl::robin_map> -generate_label_specific_vector_files( - path input_data_path, - tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels) { +inline tsl::robin_map> generate_label_specific_vector_files( + path input_data_path, tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels) +{ #ifndef _WINDOWS - auto file_writing_timer = std::chrono::high_resolution_clock::now(); - diskann::MemoryMapper input_data(input_data_path); - char *input_start = input_data.getBuf(); - - uint32_t number_of_points, dimension; - std::memcpy(&number_of_points, input_start, sizeof(uint32_t)); - std::memcpy(&dimension, input_start + sizeof(uint32_t), sizeof(uint32_t)); - const uint32_t VECTOR_SIZE = dimension * sizeof(T); - const size_t METADATA = 2 * sizeof(uint32_t); - if (number_of_points != point_ids_to_labels.size()) { - std::cerr << "Error: number of points in labels file and data file differ." - << std::endl; - throw; - } - - tsl::robin_map label_to_iovec_map; - tsl::robin_map label_to_curr_iovec; - tsl::robin_map> label_id_to_orig_id; - - // setup iovec list for each label - for (const auto &lbl : all_labels) { - iovec *label_iovecs = - (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec)); - if (label_iovecs == nullptr) { - throw; + auto file_writing_timer = std::chrono::high_resolution_clock::now(); + diskann::MemoryMapper input_data(input_data_path); + char *input_start = input_data.getBuf(); + + uint32_t number_of_points, dimension; + std::memcpy(&number_of_points, input_start, sizeof(uint32_t)); + std::memcpy(&dimension, input_start + sizeof(uint32_t), sizeof(uint32_t)); + const uint32_t VECTOR_SIZE = dimension * sizeof(T); + const size_t METADATA = 2 * sizeof(uint32_t); + if (number_of_points != point_ids_to_labels.size()) + { + std::cerr << "Error: number of points in labels file and data file differ." << std::endl; + throw; } - label_to_iovec_map[lbl] = label_iovecs; - label_to_curr_iovec[lbl] = 0; - label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]); - } - - // each point added to corresponding per-label iovec list - for (uint32_t point_id = 0; point_id < number_of_points; point_id++) { - char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id); - iovec curr_iovec; - - curr_iovec.iov_base = curr_point; - curr_iovec.iov_len = VECTOR_SIZE; - for (const auto &lbl : point_ids_to_labels[point_id]) { - *(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec; - label_to_curr_iovec[lbl]++; - label_id_to_orig_id[lbl].push_back(point_id); + + tsl::robin_map label_to_iovec_map; + tsl::robin_map label_to_curr_iovec; + tsl::robin_map> label_id_to_orig_id; + + // setup iovec list for each label + for (const auto &lbl : all_labels) + { + iovec *label_iovecs = (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec)); + if (label_iovecs == nullptr) + { + throw; + } + label_to_iovec_map[lbl] = label_iovecs; + label_to_curr_iovec[lbl] = 0; + label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]); } - } - - // write each label iovec to resp. file - for (const auto &lbl : all_labels) { - int label_input_data_fd; - path curr_label_input_data_path(input_data_path + "_" + lbl); - uint32_t curr_num_pts = labels_to_number_of_points[lbl]; - - label_input_data_fd = - open(curr_label_input_data_path.c_str(), - O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644); - if (label_input_data_fd == -1) - throw; - - // write metadata - uint32_t metadata[2] = {curr_num_pts, dimension}; - int return_value = - write(label_input_data_fd, metadata, sizeof(uint32_t) * 2); - if (return_value == -1) { - throw; + + // each point added to corresponding per-label iovec list + for (uint32_t point_id = 0; point_id < number_of_points; point_id++) + { + char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id); + iovec curr_iovec; + + curr_iovec.iov_base = curr_point; + curr_iovec.iov_len = VECTOR_SIZE; + for (const auto &lbl : point_ids_to_labels[point_id]) + { + *(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec; + label_to_curr_iovec[lbl]++; + label_id_to_orig_id[lbl].push_back(point_id); + } } - // limits on number of iovec structs per writev means we need to perform - // multiple writevs - size_t i = 0; - while (curr_num_pts > IOV_MAX) { - return_value = writev(label_input_data_fd, - (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX); - if (return_value == -1) { + // write each label iovec to resp. file + for (const auto &lbl : all_labels) + { + int label_input_data_fd; + path curr_label_input_data_path(input_data_path + "_" + lbl); + uint32_t curr_num_pts = labels_to_number_of_points[lbl]; + + label_input_data_fd = + open(curr_label_input_data_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644); + if (label_input_data_fd == -1) + throw; + + // write metadata + uint32_t metadata[2] = {curr_num_pts, dimension}; + int return_value = write(label_input_data_fd, metadata, sizeof(uint32_t) * 2); + if (return_value == -1) + { + throw; + } + + // limits on number of iovec structs per writev means we need to perform + // multiple writevs + size_t i = 0; + while (curr_num_pts > IOV_MAX) + { + return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX); + if (return_value == -1) + { + close(label_input_data_fd); + throw; + } + curr_num_pts -= IOV_MAX; + i += 1; + } + return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), curr_num_pts); + if (return_value == -1) + { + close(label_input_data_fd); + throw; + } + + free(label_to_iovec_map[lbl]); close(label_input_data_fd); - throw; - } - curr_num_pts -= IOV_MAX; - i += 1; - } - return_value = - writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), - curr_num_pts); - if (return_value == -1) { - close(label_input_data_fd); - throw; } - free(label_to_iovec_map[lbl]); - close(label_input_data_fd); - } - - std::chrono::duration file_writing_time = - std::chrono::high_resolution_clock::now() - file_writing_timer; - std::cout << "generated " << all_labels.size() - << " label-specific vector files for index building in time " - << file_writing_time.count() << "\n" - << std::endl; + std::chrono::duration file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer; + std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time " + << file_writing_time.count() << "\n" + << std::endl; - return label_id_to_orig_id; + return label_id_to_orig_id; #endif } #endif -inline std::vector loadTags(const std::string &tags_file, - const std::string &base_file) { - const bool tags_enabled = tags_file.empty() ? false : true; - std::vector location_to_tag; - if (tags_enabled) { - size_t tag_file_ndims, tag_file_npts; - std::uint32_t *tag_data; - diskann::load_bin(tags_file, tag_data, tag_file_npts, - tag_file_ndims); - if (tag_file_ndims != 1) { - diskann::cerr << "tags file error" << std::endl; - throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, - __LINE__); +inline std::vector loadTags(const std::string &tags_file, const std::string &base_file) +{ + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag; + if (tags_enabled) + { + size_t tag_file_ndims, tag_file_npts; + std::uint32_t *tag_data; + diskann::load_bin(tags_file, tag_data, tag_file_npts, tag_file_ndims); + if (tag_file_ndims != 1) + { + diskann::cerr << "tags file error" << std::endl; + throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // check if the point count match + size_t base_file_npts, base_file_ndims; + diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims); + if (base_file_npts != tag_file_npts) + { + diskann::cerr << "point num in tags file mismatch" << std::endl; + throw diskann::ANNException("point num in tags file mismatch", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + location_to_tag.assign(tag_data, tag_data + tag_file_npts); + delete[] tag_data; } - - // check if the point count match - size_t base_file_npts, base_file_ndims; - diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims); - if (base_file_npts != tag_file_npts) { - diskann::cerr << "point num in tags file mismatch" << std::endl; - throw diskann::ANNException("point num in tags file mismatch", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - location_to_tag.assign(tag_data, tag_data + tag_file_npts); - delete[] tag_data; - } - return location_to_tag; + return location_to_tag; } } // namespace diskann diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 19b196503..ad5dc8d7a 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -17,88 +17,73 @@ #include "natural_number_map.h" #include "natural_number_set.h" -namespace diskann { -template -class InMemDataStore : public AbstractDataStore { -public: - InMemDataStore(const location_t capacity, const size_t dim, - std::unique_ptr> distance_fn); - virtual ~InMemDataStore(); - - virtual location_t load(const std::string &filename) override; - virtual size_t save(const std::string &filename, - const location_t num_points) override; - - virtual size_t get_aligned_dim() const override; - - // Populate internal data from unaligned data while doing alignment and any - // normalization that is required. - virtual void populate_data(const data_t *vectors, - const location_t num_pts) override; - virtual void populate_data(const std::string &filename, - const size_t offset) override; - - virtual void extract_data_to_bin(const std::string &filename, - const location_t num_pts) override; - - virtual void get_vector(const location_t i, data_t *target) const override; - virtual void set_vector(const location_t i, - const data_t *const vector) override; - virtual void prefetch_vector(const location_t loc) override; - - virtual void move_vectors(const location_t old_location_start, - const location_t new_location_start, - const location_t num_points) override; - virtual void copy_vectors(const location_t from_loc, const location_t to_loc, - const location_t num_points) override; - - virtual void - preprocess_query(const data_t *query, - AbstractScratch *query_scratch) const override; - - virtual float get_distance(const data_t *preprocessed_query, - const location_t loc) const override; - virtual float get_distance(const location_t loc1, - const location_t loc2) const override; - - virtual void get_distance(const data_t *preprocessed_query, - const location_t *locations, - const uint32_t location_count, float *distances, - AbstractScratch *scratch) const override; - virtual void - get_distance(const data_t *preprocessed_query, - const std::vector &ids, - std::vector &distances, - AbstractScratch *scratch_space) const override; - - virtual location_t calculate_medoid() const override; - - virtual Distance *get_dist_fn() const override; - - virtual size_t get_alignment_factor() const override; - -protected: - virtual location_t expand(const location_t new_size) override; - virtual location_t shrink(const location_t new_size) override; - - virtual location_t load_impl(const std::string &filename); +namespace diskann +{ +template class InMemDataStore : public AbstractDataStore +{ + public: + InMemDataStore(const location_t capacity, const size_t dim, std::unique_ptr> distance_fn); + virtual ~InMemDataStore(); + + virtual location_t load(const std::string &filename) override; + virtual size_t save(const std::string &filename, const location_t num_points) override; + + virtual size_t get_aligned_dim() const override; + + // Populate internal data from unaligned data while doing alignment and any + // normalization that is required. + virtual void populate_data(const data_t *vectors, const location_t num_pts) override; + virtual void populate_data(const std::string &filename, const size_t offset) override; + + virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override; + + virtual void get_vector(const location_t i, data_t *target) const override; + virtual void set_vector(const location_t i, const data_t *const vector) override; + virtual void prefetch_vector(const location_t loc) override; + + virtual void move_vectors(const location_t old_location_start, const location_t new_location_start, + const location_t num_points) override; + virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override; + + virtual void preprocess_query(const data_t *query, AbstractScratch *query_scratch) const override; + + virtual float get_distance(const data_t *preprocessed_query, const location_t loc) const override; + virtual float get_distance(const location_t loc1, const location_t loc2) const override; + + virtual void get_distance(const data_t *preprocessed_query, const location_t *locations, + const uint32_t location_count, float *distances, + AbstractScratch *scratch) const override; + virtual void get_distance(const data_t *preprocessed_query, const std::vector &ids, + std::vector &distances, AbstractScratch *scratch_space) const override; + + virtual location_t calculate_medoid() const override; + + virtual Distance *get_dist_fn() const override; + + virtual size_t get_alignment_factor() const override; + + protected: + virtual location_t expand(const location_t new_size) override; + virtual location_t shrink(const location_t new_size) override; + + virtual location_t load_impl(const std::string &filename); #ifdef EXEC_ENV_OLS - virtual location_t load_impl(AlignedFileReader &reader); + virtual location_t load_impl(AlignedFileReader &reader); #endif -private: - data_t *_data = nullptr; + private: + data_t *_data = nullptr; - size_t _aligned_dim; + size_t _aligned_dim; - // It may seem weird to put distance metric along with the data store class, - // but this gives us perf benefits as the datastore can do distance - // computations during search and compute norms of vectors internally without - // have to copy data back and forth. - std::unique_ptr> _distance_fn; + // It may seem weird to put distance metric along with the data store class, + // but this gives us perf benefits as the datastore can do distance + // computations during search and compute norms of vectors internally without + // have to copy data back and forth. + std::unique_ptr> _distance_fn; - // in case we need to save vector norms for optimization - std::shared_ptr _pre_computed_norms; + // in case we need to save vector norms for optimization + std::shared_ptr _pre_computed_norms; }; } // namespace diskann \ No newline at end of file diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index 3dcb1029e..4915f37ee 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -13,128 +13,128 @@ #include #include -namespace diskann { -template -class InMemFilterStore : public AbstractFilterStore { -public: - // Do nothing constructor because all the work is done in load() - DISKANN_DLLEXPORT InMemFilterStore() {} - - /// - /// Destructor - /// - DISKANN_DLLEXPORT virtual ~InMemFilterStore(); - - // No copy, no assignment. - DISKANN_DLLEXPORT InMemFilterStore & - operator=(const InMemFilterStore &v) = delete; - DISKANN_DLLEXPORT - InMemFilterStore(const InMemFilterStore &v) = delete; - - DISKANN_DLLEXPORT virtual bool has_filter_support() const; - - DISKANN_DLLEXPORT virtual const std::unordered_map> & - get_label_to_medoids() const; - - DISKANN_DLLEXPORT virtual const std::vector & - get_medoids_of_label(const LabelT label); - - DISKANN_DLLEXPORT virtual void set_universal_label(const LabelT univ_label); - - DISKANN_DLLEXPORT inline bool point_has_label(location_t point_id, - const LabelT label_id) const { - uint32_t start_vec = _pts_to_label_offsets[point_id]; - uint32_t num_lbls = _pts_to_label_counts[point_id]; - bool ret_val = false; - for (uint32_t i = 0; i < num_lbls; i++) { - if (_pts_to_labels[start_vec + i] == label_id) { - ret_val = true; - break; - } +namespace diskann +{ +template class InMemFilterStore : public AbstractFilterStore +{ + public: + // Do nothing constructor because all the work is done in load() + DISKANN_DLLEXPORT InMemFilterStore() + { } - return ret_val; - } - - DISKANN_DLLEXPORT inline bool is_dummy_point(location_t id) const { - return _dummy_pts.find(id) != _dummy_pts.end(); - } - - DISKANN_DLLEXPORT inline location_t - get_real_point_for_dummy(location_t dummy_id) { - if (is_dummy_point(dummy_id)) { - return _dummy_to_real_map[dummy_id]; - } else { - return dummy_id; // it is a real point. + + /// + /// Destructor + /// + DISKANN_DLLEXPORT virtual ~InMemFilterStore(); + + // No copy, no assignment. + DISKANN_DLLEXPORT InMemFilterStore &operator=(const InMemFilterStore &v) = delete; + DISKANN_DLLEXPORT + InMemFilterStore(const InMemFilterStore &v) = delete; + + DISKANN_DLLEXPORT virtual bool has_filter_support() const; + + DISKANN_DLLEXPORT virtual const std::unordered_map> &get_label_to_medoids() const; + + DISKANN_DLLEXPORT virtual const std::vector &get_medoids_of_label(const LabelT label); + + DISKANN_DLLEXPORT virtual void set_universal_label(const LabelT univ_label); + + DISKANN_DLLEXPORT inline bool point_has_label(location_t point_id, const LabelT label_id) const + { + uint32_t start_vec = _pts_to_label_offsets[point_id]; + uint32_t num_lbls = _pts_to_label_counts[point_id]; + bool ret_val = false; + for (uint32_t i = 0; i < num_lbls; i++) + { + if (_pts_to_labels[start_vec + i] == label_id) + { + ret_val = true; + break; + } + } + return ret_val; + } + + DISKANN_DLLEXPORT inline bool is_dummy_point(location_t id) const + { + return _dummy_pts.find(id) != _dummy_pts.end(); + } + + DISKANN_DLLEXPORT inline location_t get_real_point_for_dummy(location_t dummy_id) + { + if (is_dummy_point(dummy_id)) + { + return _dummy_to_real_map[dummy_id]; + } + else + { + return dummy_id; // it is a real point. + } } - } - - DISKANN_DLLEXPORT inline bool - point_has_label_or_universal_label(location_t id, - const LabelT filter_label) const { - return point_has_label(id, filter_label) || - (_use_universal_label && - point_has_label(id, _universal_filter_label)); - } - - DISKANN_DLLEXPORT inline LabelT - get_converted_label(const std::string &filter_label) { - if (_label_map.find(filter_label) != _label_map.end()) { - return _label_map[filter_label]; + + DISKANN_DLLEXPORT inline bool point_has_label_or_universal_label(location_t id, const LabelT filter_label) const + { + return point_has_label(id, filter_label) || + (_use_universal_label && point_has_label(id, _universal_filter_label)); } - if (_use_universal_label) { - return _universal_filter_label; + + DISKANN_DLLEXPORT inline LabelT get_converted_label(const std::string &filter_label) + { + if (_label_map.find(filter_label) != _label_map.end()) + { + return _label_map[filter_label]; + } + if (_use_universal_label) + { + return _universal_filter_label; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - std::stringstream stream; - stream << "Unable to find label in the Label Map"; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - // Returns true if the index is filter-enabled and all files were loaded - // correctly. false otherwise. Note that "false" can mean that the index - // does not have filter support, or that some index files do not exist, or - // that they exist and could not be opened. - DISKANN_DLLEXPORT bool load(const std::string &disk_index_file); - - DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, - const uint32_t num_labels, - const uint32_t nthreads); - -private: - // Load functions for search START - void load_label_file(const std::string_view &file_content); - void load_label_map(std::basic_istream &map_reader); - void load_labels_to_medoids(std::basic_istream &reader); - void load_dummy_map(std::basic_istream &dummy_map_stream); - void parse_universal_label(const std::string_view &content); - void get_label_file_metadata(const std::string_view &fileContent, - uint32_t &num_pts, uint32_t &num_total_labels); - - bool load_file_and_parse( - const std::string &filename, - void (InMemFilterStore::*parse_fn)(const std::string_view &content)); - bool parse_stream( - const std::string &filename, - void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)); - - void reset_stream_for_reading(std::basic_istream &infile); - // Load functions for search END - - location_t _num_points = 0; - location_t *_pts_to_label_offsets = nullptr; - location_t *_pts_to_label_counts = nullptr; - LabelT *_pts_to_labels = nullptr; - bool _use_universal_label = false; - LabelT _universal_filter_label; - tsl::robin_set _dummy_pts; - tsl::robin_set _has_dummy_pts; - tsl::robin_map _dummy_to_real_map; - tsl::robin_map> _real_to_dummy_map; - std::unordered_map _label_map; - std::unordered_map> _filter_to_medoid_ids; - bool _is_valid = false; + + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT bool load(const std::string &disk_index_file); + + DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads); + + private: + // Load functions for search START + void load_label_file(const std::string_view &file_content); + void load_label_map(std::basic_istream &map_reader); + void load_labels_to_medoids(std::basic_istream &reader); + void load_dummy_map(std::basic_istream &dummy_map_stream); + void parse_universal_label(const std::string_view &content); + void get_label_file_metadata(const std::string_view &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); + + bool load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)); + bool parse_stream(const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)); + + void reset_stream_for_reading(std::basic_istream &infile); + // Load functions for search END + + location_t _num_points = 0; + location_t *_pts_to_label_offsets = nullptr; + location_t *_pts_to_label_counts = nullptr; + LabelT *_pts_to_labels = nullptr; + bool _use_universal_label = false; + LabelT _universal_filter_label; + tsl::robin_set _dummy_pts; + tsl::robin_set _has_dummy_pts; + tsl::robin_map _dummy_to_real_map; + tsl::robin_map> _real_to_dummy_map; + std::unordered_map _label_map; + std::unordered_map> _filter_to_medoid_ids; + bool _is_valid = false; }; } // namespace diskann diff --git a/include/in_mem_graph_store.h b/include/in_mem_graph_store.h index aa5daa2de..d0206a7d6 100644 --- a/include/in_mem_graph_store.h +++ b/include/in_mem_graph_store.h @@ -5,52 +5,47 @@ #include "abstract_graph_store.h" -namespace diskann { - -class InMemGraphStore : public AbstractGraphStore { -public: - InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree); - - // returns tuple of - virtual std::tuple - load(const std::string &index_path_prefix, const size_t num_points) override; - virtual int store(const std::string &index_path_prefix, - const size_t num_points, const size_t num_frozen_points, - const uint32_t start) override; - - virtual const std::vector & - get_neighbours(const location_t i) const override; - virtual void add_neighbour(const location_t i, - location_t neighbour_id) override; - virtual void clear_neighbours(const location_t i) override; - virtual void swap_neighbours(const location_t a, location_t b) override; - - virtual void set_neighbours(const location_t i, - std::vector &neighbors) override; - - virtual size_t resize_graph(const size_t new_size) override; - virtual void clear_graph() override; - - virtual size_t get_max_range_of_graph() override; - virtual uint32_t get_max_observed_degree() override; - -protected: - virtual std::tuple - load_impl(const std::string &filename, size_t expected_num_points); +namespace diskann +{ + +class InMemGraphStore : public AbstractGraphStore +{ + public: + InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree); + + // returns tuple of + virtual std::tuple load(const std::string &index_path_prefix, + const size_t num_points) override; + virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points, + const uint32_t start) override; + + virtual const std::vector &get_neighbours(const location_t i) const override; + virtual void add_neighbour(const location_t i, location_t neighbour_id) override; + virtual void clear_neighbours(const location_t i) override; + virtual void swap_neighbours(const location_t a, location_t b) override; + + virtual void set_neighbours(const location_t i, std::vector &neighbors) override; + + virtual size_t resize_graph(const size_t new_size) override; + virtual void clear_graph() override; + + virtual size_t get_max_range_of_graph() override; + virtual uint32_t get_max_observed_degree() override; + + protected: + virtual std::tuple load_impl(const std::string &filename, size_t expected_num_points); #ifdef EXEC_ENV_OLS - virtual std::tuple - load_impl(AlignedFileReader &reader, size_t expected_num_points); + virtual std::tuple load_impl(AlignedFileReader &reader, size_t expected_num_points); #endif - int save_graph(const std::string &index_path_prefix, - const size_t active_points, const size_t num_frozen_points, - const uint32_t start); + int save_graph(const std::string &index_path_prefix, const size_t active_points, const size_t num_frozen_points, + const uint32_t start); -private: - size_t _max_range_of_graph = 0; - uint32_t _max_observed_degree = 0; + private: + size_t _max_range_of_graph = 0; + uint32_t _max_observed_degree = 0; - std::vector> _graph; + std::vector> _graph; }; } // namespace diskann diff --git a/include/index.h b/include/index.h index 7267dc804..fed9f2843 100644 --- a/include/index.h +++ b/include/index.h @@ -29,473 +29,423 @@ #define EXPAND_IF_FULL 0 #define DEFAULT_MAXC 750 -namespace diskann { - -inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, - uint32_t degree) { - double size_of_data = ((double)size) * ROUND_UP(dim, 8) * datasize; - double size_of_graph = - ((double)size) * degree * sizeof(uint32_t) * defaults::GRAPH_SLACK_FACTOR; - double size_of_locks = ((double)size) * sizeof(non_recursive_mutex); - double size_of_outer_vector = ((double)size) * sizeof(ptrdiff_t); - - return OVERHEAD_FACTOR * - (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector); +namespace diskann +{ + +inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, uint32_t degree) +{ + double size_of_data = ((double)size) * ROUND_UP(dim, 8) * datasize; + double size_of_graph = ((double)size) * degree * sizeof(uint32_t) * defaults::GRAPH_SLACK_FACTOR; + double size_of_locks = ((double)size) * sizeof(non_recursive_mutex); + double size_of_outer_vector = ((double)size) * sizeof(ptrdiff_t); + + return OVERHEAD_FACTOR * (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector); } -template -class Index : public AbstractIndex { - /************************************************************************** - * - * Public functions acquire one or more of _update_lock, _consolidate_lock, - * _tag_lock, _delete_lock before calling protected functions which DO NOT - * acquire these locks. They might acquire locks on _locks[i] - * - **************************************************************************/ - -public: - // Constructor for Bulk operations and for creating the index object solely - // for loading a prexisting index. - DISKANN_DLLEXPORT - Index(const IndexConfig &index_config, - std::shared_ptr> data_store, - std::unique_ptr graph_store, - std::shared_ptr> pq_data_store = nullptr); - - // Constructor for incremental index - DISKANN_DLLEXPORT - Index(Metric m, const size_t dim, const size_t max_points, - const std::shared_ptr index_parameters, - const std::shared_ptr index_search_params, - const size_t num_frozen_pts = 0, const bool dynamic_index = false, - const bool enable_tags = false, - const bool concurrent_consolidate = false, - const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false, const bool filtered_index = false); - - DISKANN_DLLEXPORT ~Index(); - - // Saves graph, data, metadata and associated tags. - DISKANN_DLLEXPORT void save(const char *filename, - bool compact_before_save = false); - - // Load functions +template class Index : public AbstractIndex +{ + /************************************************************************** + * + * Public functions acquire one or more of _update_lock, _consolidate_lock, + * _tag_lock, _delete_lock before calling protected functions which DO NOT + * acquire these locks. They might acquire locks on _locks[i] + * + **************************************************************************/ + + public: + // Constructor for Bulk operations and for creating the index object solely + // for loading a prexisting index. + DISKANN_DLLEXPORT + Index(const IndexConfig &index_config, std::shared_ptr> data_store, + std::unique_ptr graph_store, + std::shared_ptr> pq_data_store = nullptr); + + // Constructor for incremental index + DISKANN_DLLEXPORT + Index(Metric m, const size_t dim, const size_t max_points, + const std::shared_ptr index_parameters, + const std::shared_ptr index_search_params, const size_t num_frozen_pts = 0, + const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false, + const bool pq_dist_build = false, const size_t num_pq_chunks = 0, const bool use_opq = false, + const bool filtered_index = false); + + DISKANN_DLLEXPORT ~Index(); + + // Saves graph, data, metadata and associated tags. + DISKANN_DLLEXPORT void save(const char *filename, bool compact_before_save = false); + + // Load functions #ifdef EXEC_ENV_OLS - DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, - uint32_t search_l); + DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l); #else - // Reads the number of frozen points from graph's metadata file section. - DISKANN_DLLEXPORT static size_t - get_graph_num_frozen_points(const std::string &graph_file); + // Reads the number of frozen points from graph's metadata file section. + DISKANN_DLLEXPORT static size_t get_graph_num_frozen_points(const std::string &graph_file); - DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, - uint32_t search_l); + DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l); #endif - // get some private variables - DISKANN_DLLEXPORT size_t get_num_points(); - DISKANN_DLLEXPORT size_t get_max_points(); - - DISKANN_DLLEXPORT bool - detect_common_filters(uint32_t point_id, bool search_invocation, - const std::vector &incoming_labels); - - // Batch build from a file. Optionally pass tags vector. - DISKANN_DLLEXPORT void - build(const char *filename, const size_t num_points_to_load, - const std::vector &tags = std::vector()); - - // Batch build from a file. Optionally pass tags file. - DISKANN_DLLEXPORT void build(const char *filename, - const size_t num_points_to_load, - const char *tag_filename); - - // Batch build from a data array, which must pad vectors to aligned_dim - DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, - const std::vector &tags); - - // Based on filter params builds a filtered or unfiltered index - DISKANN_DLLEXPORT void build(const std::string &data_file, - const size_t num_points_to_load, - IndexFilterParams &filter_params); - - // Filtered Support - DISKANN_DLLEXPORT void - build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, - const std::vector &tags = std::vector()); - - DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); - - // Get converted integer label from string to int map (_label_map) - DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label); - - // Set starting point of an index before inserting any points incrementally. - // The data count should be equal to _num_frozen_pts * _aligned_dim. - DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count); - // Set starting points to random points on a sphere of certain radius. - // A fixed random seed can be specified for scenarios where it's important - // to have higher consistency between index builds. - DISKANN_DLLEXPORT void set_start_points_at_random(T radius, - uint32_t random_seed = 0); - - // For FastL2 search on a static index, we interleave the data with graph - DISKANN_DLLEXPORT void optimize_index_layout(); - - // For FastL2 search on optimized layout - DISKANN_DLLEXPORT void search_with_optimized_layout(const T *query, size_t K, - size_t L, - uint32_t *indices); - - // Added search overload that takes L as parameter, so that we - // can customize L on a per-query basis without tampering with "Parameters" - template - DISKANN_DLLEXPORT std::pair - search(const T *query, const size_t K, const uint32_t L, IDType *indices, - float *distances = nullptr); - - // Initialize space for res_vectors before calling. - DISKANN_DLLEXPORT size_t search_with_tags( - const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors, bool use_filters = false, - const std::string filter_label = ""); - - // Filter support search - template - DISKANN_DLLEXPORT std::pair - search_with_filters(const T *query, const LabelT &filter_label, - const size_t K, const uint32_t L, IndexType *indices, - float *distances); - - // Will fail if tag already in the index or if tag=0. - DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag); - - // Will fail if tag already in the index or if tag=0. - DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag, - const std::vector &label); - - // call this before issuing deletions to sets relevant flags - DISKANN_DLLEXPORT int enable_delete(); - - // Record deleted point now and restructure graph later. Return -1 if tag - // not found, 0 if OK. - DISKANN_DLLEXPORT int lazy_delete(const TagT &tag); - - // Record deleted points now and restructure graph later. Add to failed_tags - // if tag not found. - DISKANN_DLLEXPORT void lazy_delete(const std::vector &tags, - std::vector &failed_tags); + // get some private variables + DISKANN_DLLEXPORT size_t get_num_points(); + DISKANN_DLLEXPORT size_t get_max_points(); + + DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels); + + // Batch build from a file. Optionally pass tags vector. + DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, + const std::vector &tags = std::vector()); + + // Batch build from a file. Optionally pass tags file. + DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, const char *tag_filename); + + // Batch build from a data array, which must pad vectors to aligned_dim + DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const std::vector &tags); + + // Based on filter params builds a filtered or unfiltered index + DISKANN_DLLEXPORT void build(const std::string &data_file, const size_t num_points_to_load, + IndexFilterParams &filter_params); + + // Filtered Support + DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file, + const size_t num_points_to_load, + const std::vector &tags = std::vector()); + + DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); + + // Get converted integer label from string to int map (_label_map) + DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label); + + // Set starting point of an index before inserting any points incrementally. + // The data count should be equal to _num_frozen_pts * _aligned_dim. + DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count); + // Set starting points to random points on a sphere of certain radius. + // A fixed random seed can be specified for scenarios where it's important + // to have higher consistency between index builds. + DISKANN_DLLEXPORT void set_start_points_at_random(T radius, uint32_t random_seed = 0); + + // For FastL2 search on a static index, we interleave the data with graph + DISKANN_DLLEXPORT void optimize_index_layout(); + + // For FastL2 search on optimized layout + DISKANN_DLLEXPORT void search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices); + + // Added search overload that takes L as parameter, so that we + // can customize L on a per-query basis without tampering with "Parameters" + template + DISKANN_DLLEXPORT std::pair search(const T *query, const size_t K, const uint32_t L, + IDType *indices, float *distances = nullptr); + + // Initialize space for res_vectors before calling. + DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, + float *distances, std::vector &res_vectors, bool use_filters = false, + const std::string filter_label = ""); + + // Filter support search + template + DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const LabelT &filter_label, + const size_t K, const uint32_t L, + IndexType *indices, float *distances); + + // Will fail if tag already in the index or if tag=0. + DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag); + + // Will fail if tag already in the index or if tag=0. + DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag, const std::vector &label); + + // call this before issuing deletions to sets relevant flags + DISKANN_DLLEXPORT int enable_delete(); + + // Record deleted point now and restructure graph later. Return -1 if tag + // not found, 0 if OK. + DISKANN_DLLEXPORT int lazy_delete(const TagT &tag); + + // Record deleted points now and restructure graph later. Add to failed_tags + // if tag not found. + DISKANN_DLLEXPORT void lazy_delete(const std::vector &tags, std::vector &failed_tags); + + // Call after a series of lazy deletions + // Returns number of live points left after consolidation + // If _conc_consolidates is set in the ctor, then this call can be invoked + // alongside inserts and lazy deletes, else it acquires _update_lock + DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters); + + DISKANN_DLLEXPORT void prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion, + const float alpha); + + DISKANN_DLLEXPORT bool is_index_saved(); + + // repositions frozen points to the end of _data - if they have been moved + // during deletion + DISKANN_DLLEXPORT void reposition_frozen_point_to_end(); + DISKANN_DLLEXPORT void reposition_points(uint32_t old_location_start, uint32_t new_location_start, + uint32_t num_locations); + + // DISKANN_DLLEXPORT void save_index_as_one_file(bool flag); + + DISKANN_DLLEXPORT void get_active_tags(tsl::robin_set &active_tags); + + // memory should be allocated for vec before calling this function + DISKANN_DLLEXPORT int get_vector_by_tag(TagT &tag, T *vec); + + DISKANN_DLLEXPORT void print_status(); + + DISKANN_DLLEXPORT void count_nodes_at_bfs_levels(); + + // This variable MUST be updated if the number of entries in the metadata + // change. + DISKANN_DLLEXPORT static const int METADATA_ROWS = 5; + + // ******************************** + // + // Internals of the library + // + // ******************************** + + protected: + // overload of abstract index virtual methods + virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) override; + + virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances = nullptr) override; + virtual std::pair _search_with_filters(const DataType &query, + const std::string &filter_label_raw, const size_t K, + const uint32_t L, std::any &indices, + float *distances) override; + + virtual int _insert_point(const DataType &data_point, const TagType tag) override; + virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override; + + virtual int _lazy_delete(const TagType &tag) override; + + virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) override; + + virtual void _get_active_tags(TagRobinSet &active_tags) override; + + virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) override; + + virtual int _get_vector_by_tag(TagType &tag, DataType &vec) override; + + virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override; + + virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors, bool use_filters = false, + const std::string filter_label = "") override; + + virtual void _set_universal_label(const LabelType universal_label) override; + + // No copy/assign. + Index(const Index &) = delete; + Index &operator=(const Index &) = delete; + + // Use after _data and _nd have been populated + // Acquire exclusive _update_lock before calling + void build_with_data_populated(const std::vector &tags); + + // generates 1 frozen point that will never be deleted from the graph + // This is not visible to the user + void generate_frozen_point(); + + // determines navigating node of the graph by calculating medoid of datafopt + uint32_t calculate_entry_point(); + + void parse_label_file(const std::string &label_file, size_t &num_pts_labels); + + std::unordered_map load_label_map(const std::string &map_file); + + // Returns the locations of start point and frozen points suitable for use + // with iterate_to_fixed_point. + std::vector get_init_ids(); + + // The query to use is placed in scratch->aligned_query + std::pair iterate_to_fixed_point(InMemQueryScratch *scratch, const uint32_t Lindex, + const std::vector &init_ids, bool use_filter, + const std::vector &filters, bool search_invocation); + + void search_for_point_and_prune(int location, uint32_t Lindex, std::vector &pruned_list, + InMemQueryScratch *scratch, bool use_filter = false, + uint32_t filteredLindex = 0); + + void prune_neighbors(const uint32_t location, std::vector &pool, std::vector &pruned_list, + InMemQueryScratch *scratch); + + void prune_neighbors(const uint32_t location, std::vector &pool, const uint32_t range, + const uint32_t max_candidate_size, const float alpha, std::vector &pruned_list, + InMemQueryScratch *scratch); + + // Prunes candidates in @pool to a shorter list @result + // @pool must be sorted before calling + void occlude_list(const uint32_t location, std::vector &pool, const float alpha, const uint32_t degree, + const uint32_t maxc, std::vector &result, InMemQueryScratch *scratch, + const tsl::robin_set *const delete_set_ptr = nullptr); + + // add reverse links from all the visited nodes to node n. + void inter_insert(uint32_t n, std::vector &pruned_list, const uint32_t range, + InMemQueryScratch *scratch); + + void inter_insert(uint32_t n, std::vector &pruned_list, InMemQueryScratch *scratch); + + // Acquire exclusive _update_lock before calling + void link(); + + // Acquire exclusive _tag_lock and _delete_lock before calling + int reserve_location(); + + // Acquire exclusive _tag_lock before calling + size_t release_location(int location); + size_t release_locations(const tsl::robin_set &locations); + + // Resize the index when no slots are left for insertion. + // Acquire exclusive _update_lock and _tag_lock before calling. + void resize(size_t new_max_points); + + // Acquire unique lock on _update_lock, _consolidate_lock, _tag_lock + // and _delete_lock before calling these functions. + // Renumber nodes, update tag and location maps and compact the + // graph, mode = _consolidated_order in case of lazy deletion and + // _compacted_order in case of eager deletion + DISKANN_DLLEXPORT void compact_data(); + DISKANN_DLLEXPORT void compact_frozen_point(); + + // Remove deleted nodes from adjacency list of node loc + // Replace removed neighbors with second order neighbors. + // Also acquires _locks[i] for i = loc and out-neighbors of loc. + void process_delete(const tsl::robin_set &old_delete_set, size_t loc, const uint32_t range, + const uint32_t maxc, const float alpha, InMemQueryScratch *scratch); + + void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r, + uint32_t maxc, size_t dim); - // Call after a series of lazy deletions - // Returns number of live points left after consolidation - // If _conc_consolidates is set in the ctor, then this call can be invoked - // alongside inserts and lazy deletes, else it acquires _update_lock - DISKANN_DLLEXPORT consolidation_report - consolidate_deletes(const IndexWriteParameters ¶meters); - - DISKANN_DLLEXPORT void prune_all_neighbors(const uint32_t max_degree, - const uint32_t max_occlusion, - const float alpha); - - DISKANN_DLLEXPORT bool is_index_saved(); - - // repositions frozen points to the end of _data - if they have been moved - // during deletion - DISKANN_DLLEXPORT void reposition_frozen_point_to_end(); - DISKANN_DLLEXPORT void reposition_points(uint32_t old_location_start, - uint32_t new_location_start, - uint32_t num_locations); - - // DISKANN_DLLEXPORT void save_index_as_one_file(bool flag); - - DISKANN_DLLEXPORT void get_active_tags(tsl::robin_set &active_tags); - - // memory should be allocated for vec before calling this function - DISKANN_DLLEXPORT int get_vector_by_tag(TagT &tag, T *vec); - - DISKANN_DLLEXPORT void print_status(); - - DISKANN_DLLEXPORT void count_nodes_at_bfs_levels(); - - // This variable MUST be updated if the number of entries in the metadata - // change. - DISKANN_DLLEXPORT static const int METADATA_ROWS = 5; - - // ******************************** - // - // Internals of the library - // - // ******************************** - -protected: - // overload of abstract index virtual methods - virtual void _build(const DataType &data, const size_t num_points_to_load, - TagVector &tags) override; - - virtual std::pair - _search(const DataType &query, const size_t K, const uint32_t L, - std::any &indices, float *distances = nullptr) override; - virtual std::pair - _search_with_filters(const DataType &query, - const std::string &filter_label_raw, const size_t K, - const uint32_t L, std::any &indices, - float *distances) override; - - virtual int _insert_point(const DataType &data_point, - const TagType tag) override; - virtual int _insert_point(const DataType &data_point, const TagType tag, - Labelvector &labels) override; - - virtual int _lazy_delete(const TagType &tag) override; - - virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) override; - - virtual void _get_active_tags(TagRobinSet &active_tags) override; - - virtual void _set_start_points_at_random(DataType radius, - uint32_t random_seed = 0) override; - - virtual int _get_vector_by_tag(TagType &tag, DataType &vec) override; - - virtual void _search_with_optimized_layout(const DataType &query, size_t K, - size_t L, - uint32_t *indices) override; - - virtual size_t - _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, - const TagType &tags, float *distances, - DataVector &res_vectors, bool use_filters = false, - const std::string filter_label = "") override; - - virtual void _set_universal_label(const LabelType universal_label) override; - - // No copy/assign. - Index(const Index &) = delete; - Index &operator=(const Index &) = delete; - - // Use after _data and _nd have been populated - // Acquire exclusive _update_lock before calling - void build_with_data_populated(const std::vector &tags); - - // generates 1 frozen point that will never be deleted from the graph - // This is not visible to the user - void generate_frozen_point(); - - // determines navigating node of the graph by calculating medoid of datafopt - uint32_t calculate_entry_point(); - - void parse_label_file(const std::string &label_file, size_t &num_pts_labels); - - std::unordered_map - load_label_map(const std::string &map_file); - - // Returns the locations of start point and frozen points suitable for use - // with iterate_to_fixed_point. - std::vector get_init_ids(); - - // The query to use is placed in scratch->aligned_query - std::pair - iterate_to_fixed_point(InMemQueryScratch *scratch, const uint32_t Lindex, - const std::vector &init_ids, bool use_filter, - const std::vector &filters, - bool search_invocation); - - void search_for_point_and_prune(int location, uint32_t Lindex, - std::vector &pruned_list, - InMemQueryScratch *scratch, - bool use_filter = false, - uint32_t filteredLindex = 0); - - void prune_neighbors(const uint32_t location, std::vector &pool, - std::vector &pruned_list, - InMemQueryScratch *scratch); - - void prune_neighbors(const uint32_t location, std::vector &pool, - const uint32_t range, const uint32_t max_candidate_size, - const float alpha, std::vector &pruned_list, - InMemQueryScratch *scratch); - - // Prunes candidates in @pool to a shorter list @result - // @pool must be sorted before calling - void - occlude_list(const uint32_t location, std::vector &pool, - const float alpha, const uint32_t degree, const uint32_t maxc, - std::vector &result, InMemQueryScratch *scratch, - const tsl::robin_set *const delete_set_ptr = nullptr); - - // add reverse links from all the visited nodes to node n. - void inter_insert(uint32_t n, std::vector &pruned_list, - const uint32_t range, InMemQueryScratch *scratch); - - void inter_insert(uint32_t n, std::vector &pruned_list, - InMemQueryScratch *scratch); - - // Acquire exclusive _update_lock before calling - void link(); - - // Acquire exclusive _tag_lock and _delete_lock before calling - int reserve_location(); - - // Acquire exclusive _tag_lock before calling - size_t release_location(int location); - size_t release_locations(const tsl::robin_set &locations); - - // Resize the index when no slots are left for insertion. - // Acquire exclusive _update_lock and _tag_lock before calling. - void resize(size_t new_max_points); - - // Acquire unique lock on _update_lock, _consolidate_lock, _tag_lock - // and _delete_lock before calling these functions. - // Renumber nodes, update tag and location maps and compact the - // graph, mode = _consolidated_order in case of lazy deletion and - // _compacted_order in case of eager deletion - DISKANN_DLLEXPORT void compact_data(); - DISKANN_DLLEXPORT void compact_frozen_point(); - - // Remove deleted nodes from adjacency list of node loc - // Replace removed neighbors with second order neighbors. - // Also acquires _locks[i] for i = loc and out-neighbors of loc. - void process_delete(const tsl::robin_set &old_delete_set, - size_t loc, const uint32_t range, const uint32_t maxc, - const float alpha, InMemQueryScratch *scratch); - - void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, - uint32_t indexing_l, uint32_t r, uint32_t maxc, - size_t dim); - - // Do not call without acquiring appropriate locks - // call public member functions save and load to invoke these. - DISKANN_DLLEXPORT size_t save_graph(std::string filename); - DISKANN_DLLEXPORT size_t save_data(std::string filename); - DISKANN_DLLEXPORT size_t save_tags(std::string filename); - DISKANN_DLLEXPORT size_t save_delete_list(const std::string &filename); + // Do not call without acquiring appropriate locks + // call public member functions save and load to invoke these. + DISKANN_DLLEXPORT size_t save_graph(std::string filename); + DISKANN_DLLEXPORT size_t save_data(std::string filename); + DISKANN_DLLEXPORT size_t save_tags(std::string filename); + DISKANN_DLLEXPORT size_t save_delete_list(const std::string &filename); #ifdef EXEC_ENV_OLS - DISKANN_DLLEXPORT size_t load_graph(AlignedFileReader &reader, - size_t expected_num_points); - DISKANN_DLLEXPORT size_t load_data(AlignedFileReader &reader); - DISKANN_DLLEXPORT size_t load_tags(AlignedFileReader &reader); - DISKANN_DLLEXPORT size_t load_delete_set(AlignedFileReader &reader); + DISKANN_DLLEXPORT size_t load_graph(AlignedFileReader &reader, size_t expected_num_points); + DISKANN_DLLEXPORT size_t load_data(AlignedFileReader &reader); + DISKANN_DLLEXPORT size_t load_tags(AlignedFileReader &reader); + DISKANN_DLLEXPORT size_t load_delete_set(AlignedFileReader &reader); #else - DISKANN_DLLEXPORT size_t load_graph(const std::string filename, - size_t expected_num_points); - DISKANN_DLLEXPORT size_t load_data(std::string filename0); - DISKANN_DLLEXPORT size_t load_tags(const std::string tag_file_name); - DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename); + DISKANN_DLLEXPORT size_t load_graph(const std::string filename, size_t expected_num_points); + DISKANN_DLLEXPORT size_t load_data(std::string filename0); + DISKANN_DLLEXPORT size_t load_tags(const std::string tag_file_name); + DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename); #endif -private: - // Distance functions - Metric _dist_metric = diskann::L2; - - // Data - std::shared_ptr> _data_store; - - // Graph related data structures - std::unique_ptr _graph_store; - - char *_opt_graph = nullptr; - - // Dimensions - size_t _dim = 0; - size_t _nd = 0; // number of active points i.e. existing in the graph - size_t _max_points = 0; // total number of points in given data set - - // _num_frozen_pts is the number of points which are used as initial - // candidates when iterating to closest point(s). These are not visible - // externally and won't be returned by search. At least 1 frozen point is - // needed for a dynamic index. The frozen points have consecutive locations. - // See also _start below. - size_t _num_frozen_pts = 0; - size_t _frozen_pts_used = 0; - size_t _node_size; - size_t _data_len; - size_t _neighbor_len; - - // Start point of the search. When _num_frozen_pts is greater than zero, - // this is the location of the first frozen point. Otherwise, this is a - // location of one of the points in index. - uint32_t _start = 0; - - bool _has_built = false; - bool _saturate_graph = false; - bool _save_as_one_file = false; // plan to support in next version - bool _dynamic_index = false; - bool _enable_tags = false; - bool _normalize_vecs = false; // Using normalied L2 for cosine. - bool _deletes_enabled = false; - - // Filter Support - - bool _filtered_index = false; - // Location to label is only updated during insert_point(), all other reads - // are protected by default as a location can only be released at end of - // consolidate deletes - std::vector> _location_to_labels; - tsl::robin_set _labels; - std::string _labels_file; - std::unordered_map _label_to_start_id; - std::unordered_map _medoid_counts; - - bool _use_universal_label = false; - LabelT _universal_label = 0; - uint32_t _filterIndexingQueueSize; - std::unordered_map _label_map; - - // Indexing parameters - uint32_t _indexingQueueSize; - uint32_t _indexingRange; - uint32_t _indexingMaxC; - float _indexingAlpha; - uint32_t _indexingThreads; - - // Query scratch data structures - ConcurrentQueue *> _query_scratch; - - // Flags for PQ based distance calculation - bool _pq_dist = false; - bool _use_opq = false; - size_t _num_pq_chunks = 0; - // REFACTOR - // uint8_t *_pq_data = nullptr; - std::shared_ptr> _pq_distance_fn = nullptr; - std::shared_ptr> _pq_data_store = nullptr; - bool _pq_generated = false; - FixedChunkPQTable _pq_table; - - // - // Data structures, locks and flags for dynamic indexing and tags - // - - // lazy_delete removes entry from _location_to_tag and _tag_to_location. If - // _location_to_tag does not resolve a location, infer that it was deleted. - tsl::sparse_map _tag_to_location; - natural_number_map _location_to_tag; - - // _empty_slots has unallocated slots and those freed by consolidate_delete. - // _delete_set has locations marked deleted by lazy_delete. Will not be - // immediately available for insert. consolidate_delete will release these - // slots to _empty_slots. - natural_number_set _empty_slots; - std::unique_ptr> _delete_set; - - bool _data_compacted = true; // true if data has been compacted - bool _is_saved = false; // Checking if the index is already saved. - bool _conc_consolidate = false; // use _lock while searching - - // Acquire locks in the order below when acquiring multiple locks - std::shared_timed_mutex // RW mutex between save/load (exclusive lock) and - _update_lock; // search/inserts/deletes/consolidate (shared lock) - std::shared_timed_mutex // Ensure only one consolidate or compact_data is - _consolidate_lock; // ever active - std::shared_timed_mutex // RW lock for _tag_to_location, - _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, - // _label_to_start_id - std::shared_timed_mutex // RW Lock on _delete_set and _data_compacted - _delete_lock; // variable - - // Per node lock, cardinality=_max_points + _num_frozen_points - std::vector _locks; - - static const float INDEX_GROWTH_FACTOR; + private: + // Distance functions + Metric _dist_metric = diskann::L2; + + // Data + std::shared_ptr> _data_store; + + // Graph related data structures + std::unique_ptr _graph_store; + + char *_opt_graph = nullptr; + + // Dimensions + size_t _dim = 0; + size_t _nd = 0; // number of active points i.e. existing in the graph + size_t _max_points = 0; // total number of points in given data set + + // _num_frozen_pts is the number of points which are used as initial + // candidates when iterating to closest point(s). These are not visible + // externally and won't be returned by search. At least 1 frozen point is + // needed for a dynamic index. The frozen points have consecutive locations. + // See also _start below. + size_t _num_frozen_pts = 0; + size_t _frozen_pts_used = 0; + size_t _node_size; + size_t _data_len; + size_t _neighbor_len; + + // Start point of the search. When _num_frozen_pts is greater than zero, + // this is the location of the first frozen point. Otherwise, this is a + // location of one of the points in index. + uint32_t _start = 0; + + bool _has_built = false; + bool _saturate_graph = false; + bool _save_as_one_file = false; // plan to support in next version + bool _dynamic_index = false; + bool _enable_tags = false; + bool _normalize_vecs = false; // Using normalied L2 for cosine. + bool _deletes_enabled = false; + + // Filter Support + + bool _filtered_index = false; + // Location to label is only updated during insert_point(), all other reads + // are protected by default as a location can only be released at end of + // consolidate deletes + std::vector> _location_to_labels; + tsl::robin_set _labels; + std::string _labels_file; + std::unordered_map _label_to_start_id; + std::unordered_map _medoid_counts; + + bool _use_universal_label = false; + LabelT _universal_label = 0; + uint32_t _filterIndexingQueueSize; + std::unordered_map _label_map; + + // Indexing parameters + uint32_t _indexingQueueSize; + uint32_t _indexingRange; + uint32_t _indexingMaxC; + float _indexingAlpha; + uint32_t _indexingThreads; + + // Query scratch data structures + ConcurrentQueue *> _query_scratch; + + // Flags for PQ based distance calculation + bool _pq_dist = false; + bool _use_opq = false; + size_t _num_pq_chunks = 0; + // REFACTOR + // uint8_t *_pq_data = nullptr; + std::shared_ptr> _pq_distance_fn = nullptr; + std::shared_ptr> _pq_data_store = nullptr; + bool _pq_generated = false; + FixedChunkPQTable _pq_table; + + // + // Data structures, locks and flags for dynamic indexing and tags + // + + // lazy_delete removes entry from _location_to_tag and _tag_to_location. If + // _location_to_tag does not resolve a location, infer that it was deleted. + tsl::sparse_map _tag_to_location; + natural_number_map _location_to_tag; + + // _empty_slots has unallocated slots and those freed by consolidate_delete. + // _delete_set has locations marked deleted by lazy_delete. Will not be + // immediately available for insert. consolidate_delete will release these + // slots to _empty_slots. + natural_number_set _empty_slots; + std::unique_ptr> _delete_set; + + bool _data_compacted = true; // true if data has been compacted + bool _is_saved = false; // Checking if the index is already saved. + bool _conc_consolidate = false; // use _lock while searching + + // Acquire locks in the order below when acquiring multiple locks + std::shared_timed_mutex // RW mutex between save/load (exclusive lock) and + _update_lock; // search/inserts/deletes/consolidate (shared lock) + std::shared_timed_mutex // Ensure only one consolidate or compact_data is + _consolidate_lock; // ever active + std::shared_timed_mutex // RW lock for _tag_to_location, + _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, + // _label_to_start_id + std::shared_timed_mutex // RW Lock on _delete_set and _data_compacted + _delete_lock; // variable + + // Per node lock, cardinality=_max_points + _num_frozen_points + std::vector _locks; + + static const float INDEX_GROWTH_FACTOR; }; } // namespace diskann diff --git a/include/index_build_params.h b/include/index_build_params.h index d0ce9f4b1..38434e204 100644 --- a/include/index_build_params.h +++ b/include/index_build_params.h @@ -4,68 +4,71 @@ #include "common_includes.h" #include "parameters.h" -namespace diskann { -struct IndexFilterParams { -public: - std::string save_path_prefix; - std::string label_file; - std::string tags_file; - std::string universal_label; - uint32_t filter_threshold = 0; +namespace diskann +{ +struct IndexFilterParams +{ + public: + std::string save_path_prefix; + std::string label_file; + std::string tags_file; + std::string universal_label; + uint32_t filter_threshold = 0; -private: - IndexFilterParams(const std::string &save_path_prefix, - const std::string &label_file, - const std::string &universal_label, - uint32_t filter_threshold) - : save_path_prefix(save_path_prefix), label_file(label_file), - universal_label(universal_label), filter_threshold(filter_threshold) {} + private: + IndexFilterParams(const std::string &save_path_prefix, const std::string &label_file, + const std::string &universal_label, uint32_t filter_threshold) + : save_path_prefix(save_path_prefix), label_file(label_file), universal_label(universal_label), + filter_threshold(filter_threshold) + { + } - friend class IndexFilterParamsBuilder; + friend class IndexFilterParamsBuilder; }; -class IndexFilterParamsBuilder { -public: - IndexFilterParamsBuilder() = default; +class IndexFilterParamsBuilder +{ + public: + IndexFilterParamsBuilder() = default; - IndexFilterParamsBuilder & - with_save_path_prefix(const std::string &save_path_prefix) { - if (save_path_prefix.empty() || save_path_prefix == "") - throw diskann::ANNException("Error: save_path_prefix can't be empty", -1); - this->_save_path_prefix = save_path_prefix; - return *this; - } + IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) + { + if (save_path_prefix.empty() || save_path_prefix == "") + throw diskann::ANNException("Error: save_path_prefix can't be empty", -1); + this->_save_path_prefix = save_path_prefix; + return *this; + } - IndexFilterParamsBuilder &with_label_file(const std::string &label_file) { - this->_label_file = label_file; - return *this; - } + IndexFilterParamsBuilder &with_label_file(const std::string &label_file) + { + this->_label_file = label_file; + return *this; + } - IndexFilterParamsBuilder & - with_universal_label(const std::string &univeral_label) { - this->_universal_label = univeral_label; - return *this; - } + IndexFilterParamsBuilder &with_universal_label(const std::string &univeral_label) + { + this->_universal_label = univeral_label; + return *this; + } - IndexFilterParamsBuilder & - with_filter_threshold(const std::uint32_t &filter_threshold) { - this->_filter_threshold = filter_threshold; - return *this; - } + IndexFilterParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold) + { + this->_filter_threshold = filter_threshold; + return *this; + } - IndexFilterParams build() { - return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, - _filter_threshold); - } + IndexFilterParams build() + { + return IndexFilterParams(_save_path_prefix, _label_file, _universal_label, _filter_threshold); + } - IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete; - IndexFilterParamsBuilder & - operator=(const IndexFilterParamsBuilder &) = delete; + IndexFilterParamsBuilder(const IndexFilterParamsBuilder &) = delete; + IndexFilterParamsBuilder &operator=(const IndexFilterParamsBuilder &) = delete; -private: - std::string _save_path_prefix; - std::string _label_file; - std::string _tags_file; - std::string _universal_label; - uint32_t _filter_threshold = 0; + private: + std::string _save_path_prefix; + std::string _label_file; + std::string _tags_file; + std::string _universal_label; + uint32_t _filter_threshold = 0; }; } // namespace diskann diff --git a/include/index_config.h b/include/index_config.h index 351d9aeba..d1709cc1e 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -9,242 +9,258 @@ #include "parameters.h" #include -namespace diskann { -enum class DataStoreStrategy { MEMORY }; - -enum class GraphStoreStrategy { MEMORY }; - -struct IndexConfig { - DataStoreStrategy data_strategy; - GraphStoreStrategy graph_strategy; - - Metric metric; - size_t dimension; - size_t max_points; - - bool dynamic_index; - bool enable_tags; - bool pq_dist_build; - bool concurrent_consolidate; - bool use_opq; - bool filtered_index; - - size_t num_pq_chunks; - size_t num_frozen_pts; - - std::string label_type; - std::string tag_type; - std::string data_type; - - // Params for building index - std::shared_ptr index_write_params; - // Params for searching index - std::shared_ptr index_search_params; - -private: - IndexConfig(DataStoreStrategy data_strategy, - GraphStoreStrategy graph_strategy, Metric metric, - size_t dimension, size_t max_points, size_t num_pq_chunks, - size_t num_frozen_points, bool dynamic_index, bool enable_tags, - bool pq_dist_build, bool concurrent_consolidate, bool use_opq, - bool filtered_index, std::string &data_type, - const std::string &tag_type, const std::string &label_type, - std::shared_ptr index_write_params, - std::shared_ptr index_search_params) - : data_strategy(data_strategy), graph_strategy(graph_strategy), - metric(metric), dimension(dimension), max_points(max_points), - dynamic_index(dynamic_index), enable_tags(enable_tags), - pq_dist_build(pq_dist_build), - concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), - filtered_index(filtered_index), num_pq_chunks(num_pq_chunks), - num_frozen_pts(num_frozen_points), label_type(label_type), - tag_type(tag_type), data_type(data_type), - index_write_params(index_write_params), - index_search_params(index_search_params) {} - - friend class IndexConfigBuilder; +namespace diskann +{ +enum class DataStoreStrategy +{ + MEMORY }; -class IndexConfigBuilder { -public: - IndexConfigBuilder() = default; - - IndexConfigBuilder &with_metric(Metric m) { - this->_metric = m; - return *this; - } - - IndexConfigBuilder & - with_graph_load_store_strategy(GraphStoreStrategy graph_strategy) { - this->_graph_strategy = graph_strategy; - return *this; - } - - IndexConfigBuilder & - with_data_load_store_strategy(DataStoreStrategy data_strategy) { - this->_data_strategy = data_strategy; - return *this; - } - - IndexConfigBuilder &with_dimension(size_t dimension) { - this->_dimension = dimension; - return *this; - } - - IndexConfigBuilder &with_max_points(size_t max_points) { - this->_max_points = max_points; - return *this; - } - - IndexConfigBuilder &is_dynamic_index(bool dynamic_index) { - this->_dynamic_index = dynamic_index; - return *this; - } - - IndexConfigBuilder &is_enable_tags(bool enable_tags) { - this->_enable_tags = enable_tags; - return *this; - } - - IndexConfigBuilder &is_pq_dist_build(bool pq_dist_build) { - this->_pq_dist_build = pq_dist_build; - return *this; - } - - IndexConfigBuilder &is_concurrent_consolidate(bool concurrent_consolidate) { - this->_concurrent_consolidate = concurrent_consolidate; - return *this; - } - - IndexConfigBuilder &is_use_opq(bool use_opq) { - this->_use_opq = use_opq; - return *this; - } - - IndexConfigBuilder &is_filtered(bool is_filtered) { - this->_filtered_index = is_filtered; - return *this; - } - - IndexConfigBuilder &with_num_pq_chunks(size_t num_pq_chunks) { - this->_num_pq_chunks = num_pq_chunks; - return *this; - } - - IndexConfigBuilder &with_num_frozen_pts(size_t num_frozen_pts) { - this->_num_frozen_pts = num_frozen_pts; - return *this; - } - - IndexConfigBuilder &with_label_type(const std::string &label_type) { - this->_label_type = label_type; - return *this; - } - - IndexConfigBuilder &with_tag_type(const std::string &tag_type) { - this->_tag_type = tag_type; - return *this; - } - - IndexConfigBuilder &with_data_type(const std::string &data_type) { - this->_data_type = data_type; - return *this; - } - - IndexConfigBuilder & - with_index_write_params(IndexWriteParameters &index_write_params) { - this->_index_write_params = - std::make_shared(index_write_params); - return *this; - } - - IndexConfigBuilder &with_index_write_params( - std::shared_ptr index_write_params_ptr) { - if (index_write_params_ptr == nullptr) { - diskann::cout << "Passed, empty build_params while creating index config" - << std::endl; - return *this; - } - this->_index_write_params = index_write_params_ptr; - return *this; - } - - IndexConfigBuilder & - with_index_search_params(IndexSearchParams &search_params) { - this->_index_search_params = - std::make_shared(search_params); - return *this; - } - - IndexConfigBuilder &with_index_search_params( - std::shared_ptr search_params_ptr) { - if (search_params_ptr == nullptr) { - diskann::cout << "Passed, empty search_params while creating index config" - << std::endl; - return *this; - } - this->_index_search_params = search_params_ptr; - return *this; - } - - IndexConfig build() { - if (_data_type == "" || _data_type.empty()) - throw ANNException("Error: data_type can not be empty", -1); - - if (_dynamic_index && _num_frozen_pts == 0) { - _num_frozen_pts = 1; - } - - if (_dynamic_index) { - if (_index_search_params != nullptr && - _index_search_params->initial_search_list_size == 0) - throw ANNException("Error: please pass initial_search_list_size for " - "building dynamic index.", - -1); - } - - // sanity check - if (_dynamic_index && _num_frozen_pts == 0) { - diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting " - "it to 1 for safety." - << std::endl; - _num_frozen_pts = 1; - } - - return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, - _max_points, _num_pq_chunks, _num_frozen_pts, - _dynamic_index, _enable_tags, _pq_dist_build, - _concurrent_consolidate, _use_opq, _filtered_index, - _data_type, _tag_type, _label_type, _index_write_params, - _index_search_params); - } - - IndexConfigBuilder(const IndexConfigBuilder &) = delete; - IndexConfigBuilder &operator=(const IndexConfigBuilder &) = delete; - -private: - DataStoreStrategy _data_strategy; - GraphStoreStrategy _graph_strategy; - - Metric _metric; - size_t _dimension; - size_t _max_points; - - bool _dynamic_index = false; - bool _enable_tags = false; - bool _pq_dist_build = false; - bool _concurrent_consolidate = false; - bool _use_opq = false; - bool _filtered_index{defaults::HAS_LABELS}; - - size_t _num_pq_chunks = 0; - size_t _num_frozen_pts{defaults::NUM_FROZEN_POINTS_STATIC}; - - std::string _label_type{"uint32"}; - std::string _tag_type{"uint32"}; - std::string _data_type; - - std::shared_ptr _index_write_params; - std::shared_ptr _index_search_params; +enum class GraphStoreStrategy +{ + MEMORY +}; + +struct IndexConfig +{ + DataStoreStrategy data_strategy; + GraphStoreStrategy graph_strategy; + + Metric metric; + size_t dimension; + size_t max_points; + + bool dynamic_index; + bool enable_tags; + bool pq_dist_build; + bool concurrent_consolidate; + bool use_opq; + bool filtered_index; + + size_t num_pq_chunks; + size_t num_frozen_pts; + + std::string label_type; + std::string tag_type; + std::string data_type; + + // Params for building index + std::shared_ptr index_write_params; + // Params for searching index + std::shared_ptr index_search_params; + + private: + IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension, + size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags, + bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index, + std::string &data_type, const std::string &tag_type, const std::string &label_type, + std::shared_ptr index_write_params, + std::shared_ptr index_search_params) + : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension), + max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build), + concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index), + num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), + data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params) + { + } + + friend class IndexConfigBuilder; +}; + +class IndexConfigBuilder +{ + public: + IndexConfigBuilder() = default; + + IndexConfigBuilder &with_metric(Metric m) + { + this->_metric = m; + return *this; + } + + IndexConfigBuilder &with_graph_load_store_strategy(GraphStoreStrategy graph_strategy) + { + this->_graph_strategy = graph_strategy; + return *this; + } + + IndexConfigBuilder &with_data_load_store_strategy(DataStoreStrategy data_strategy) + { + this->_data_strategy = data_strategy; + return *this; + } + + IndexConfigBuilder &with_dimension(size_t dimension) + { + this->_dimension = dimension; + return *this; + } + + IndexConfigBuilder &with_max_points(size_t max_points) + { + this->_max_points = max_points; + return *this; + } + + IndexConfigBuilder &is_dynamic_index(bool dynamic_index) + { + this->_dynamic_index = dynamic_index; + return *this; + } + + IndexConfigBuilder &is_enable_tags(bool enable_tags) + { + this->_enable_tags = enable_tags; + return *this; + } + + IndexConfigBuilder &is_pq_dist_build(bool pq_dist_build) + { + this->_pq_dist_build = pq_dist_build; + return *this; + } + + IndexConfigBuilder &is_concurrent_consolidate(bool concurrent_consolidate) + { + this->_concurrent_consolidate = concurrent_consolidate; + return *this; + } + + IndexConfigBuilder &is_use_opq(bool use_opq) + { + this->_use_opq = use_opq; + return *this; + } + + IndexConfigBuilder &is_filtered(bool is_filtered) + { + this->_filtered_index = is_filtered; + return *this; + } + + IndexConfigBuilder &with_num_pq_chunks(size_t num_pq_chunks) + { + this->_num_pq_chunks = num_pq_chunks; + return *this; + } + + IndexConfigBuilder &with_num_frozen_pts(size_t num_frozen_pts) + { + this->_num_frozen_pts = num_frozen_pts; + return *this; + } + + IndexConfigBuilder &with_label_type(const std::string &label_type) + { + this->_label_type = label_type; + return *this; + } + + IndexConfigBuilder &with_tag_type(const std::string &tag_type) + { + this->_tag_type = tag_type; + return *this; + } + + IndexConfigBuilder &with_data_type(const std::string &data_type) + { + this->_data_type = data_type; + return *this; + } + + IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params) + { + this->_index_write_params = std::make_shared(index_write_params); + return *this; + } + + IndexConfigBuilder &with_index_write_params(std::shared_ptr index_write_params_ptr) + { + if (index_write_params_ptr == nullptr) + { + diskann::cout << "Passed, empty build_params while creating index config" << std::endl; + return *this; + } + this->_index_write_params = index_write_params_ptr; + return *this; + } + + IndexConfigBuilder &with_index_search_params(IndexSearchParams &search_params) + { + this->_index_search_params = std::make_shared(search_params); + return *this; + } + + IndexConfigBuilder &with_index_search_params(std::shared_ptr search_params_ptr) + { + if (search_params_ptr == nullptr) + { + diskann::cout << "Passed, empty search_params while creating index config" << std::endl; + return *this; + } + this->_index_search_params = search_params_ptr; + return *this; + } + + IndexConfig build() + { + if (_data_type == "" || _data_type.empty()) + throw ANNException("Error: data_type can not be empty", -1); + + if (_dynamic_index && _num_frozen_pts == 0) + { + _num_frozen_pts = 1; + } + + if (_dynamic_index) + { + if (_index_search_params != nullptr && _index_search_params->initial_search_list_size == 0) + throw ANNException("Error: please pass initial_search_list_size for " + "building dynamic index.", + -1); + } + + // sanity check + if (_dynamic_index && _num_frozen_pts == 0) + { + diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting " + "it to 1 for safety." + << std::endl; + _num_frozen_pts = 1; + } + + return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks, + _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, + _use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params, + _index_search_params); + } + + IndexConfigBuilder(const IndexConfigBuilder &) = delete; + IndexConfigBuilder &operator=(const IndexConfigBuilder &) = delete; + + private: + DataStoreStrategy _data_strategy; + GraphStoreStrategy _graph_strategy; + + Metric _metric; + size_t _dimension; + size_t _max_points; + + bool _dynamic_index = false; + bool _enable_tags = false; + bool _pq_dist_build = false; + bool _concurrent_consolidate = false; + bool _use_opq = false; + bool _filtered_index{defaults::HAS_LABELS}; + + size_t _num_pq_chunks = 0; + size_t _num_frozen_pts{defaults::NUM_FROZEN_POINTS_STATIC}; + + std::string _label_type{"uint32"}; + std::string _tag_type{"uint32"}; + std::string _data_type; + + std::shared_ptr _index_write_params; + std::shared_ptr _index_search_params; }; } // namespace diskann diff --git a/include/index_factory.h b/include/index_factory.h index dabd8837e..a41c1f50f 100644 --- a/include/index_factory.h +++ b/include/index_factory.h @@ -5,49 +5,47 @@ #include "index.h" #include "pq_data_store.h" -namespace diskann { -class IndexFactory { -public: - DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config); - DISKANN_DLLEXPORT std::unique_ptr create_instance(); - - DISKANN_DLLEXPORT static std::unique_ptr - construct_graphstore(const GraphStoreStrategy stratagy, const size_t size, - const size_t reserve_graph_degree); - - template - DISKANN_DLLEXPORT static std::shared_ptr> - construct_datastore(DataStoreStrategy stratagy, size_t num_points, - size_t dimension, Metric m); - // For now PQDataStore incorporates within itself all variants of quantization - // that we support. In the future it may be necessary to introduce an - // AbstractPQDataStore class to spearate various quantization flavours. - template - DISKANN_DLLEXPORT static std::shared_ptr> - construct_pq_datastore(DataStoreStrategy strategy, size_t num_points, - size_t dimension, Metric m, size_t num_pq_chunks, - bool use_opq); - template - static Distance *construct_inmem_distance_fn(Metric m); - -private: - void check_config(); - - template - std::unique_ptr create_instance(); - - std::unique_ptr create_instance(const std::string &data_type, - const std::string &tag_type, - const std::string &label_type); - - template - std::unique_ptr create_instance(const std::string &tag_type, - const std::string &label_type); - - template - std::unique_ptr create_instance(const std::string &label_type); - - std::unique_ptr _config; +namespace diskann +{ +class IndexFactory +{ + public: + DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config); + DISKANN_DLLEXPORT std::unique_ptr create_instance(); + + DISKANN_DLLEXPORT static std::unique_ptr construct_graphstore( + const GraphStoreStrategy stratagy, const size_t size, const size_t reserve_graph_degree); + + template + DISKANN_DLLEXPORT static std::shared_ptr> construct_datastore(DataStoreStrategy stratagy, + size_t num_points, + size_t dimension, Metric m); + // For now PQDataStore incorporates within itself all variants of quantization + // that we support. In the future it may be necessary to introduce an + // AbstractPQDataStore class to spearate various quantization flavours. + template + DISKANN_DLLEXPORT static std::shared_ptr> construct_pq_datastore(DataStoreStrategy strategy, + size_t num_points, size_t dimension, + Metric m, size_t num_pq_chunks, + bool use_opq); + template static Distance *construct_inmem_distance_fn(Metric m); + + private: + void check_config(); + + template + std::unique_ptr create_instance(); + + std::unique_ptr create_instance(const std::string &data_type, const std::string &tag_type, + const std::string &label_type); + + template + std::unique_ptr create_instance(const std::string &tag_type, const std::string &label_type); + + template + std::unique_ptr create_instance(const std::string &label_type); + + std::unique_ptr _config; }; } // namespace diskann diff --git a/include/linux_aligned_file_reader.h b/include/linux_aligned_file_reader.h index 5907ff750..7620e3194 100644 --- a/include/linux_aligned_file_reader.h +++ b/include/linux_aligned_file_reader.h @@ -6,34 +6,34 @@ #include "aligned_file_reader.h" -class LinuxAlignedFileReader : public AlignedFileReader { -private: - uint64_t file_sz; - FileHandle file_desc; - io_context_t bad_ctx = (io_context_t)-1; - -public: - LinuxAlignedFileReader(); - ~LinuxAlignedFileReader(); - - IOContext &get_ctx(); - - // register thread-id for a context - void register_thread(); - - // de-register thread-id for a context - void deregister_thread(); - void deregister_all_threads(); - - // Open & close ops - // Blocking calls - void open(const std::string &fname); - void close(); - - // process batch of aligned requests in parallel - // NOTE :: blocking call - void read(std::vector &read_reqs, IOContext &ctx, - bool async = false); +class LinuxAlignedFileReader : public AlignedFileReader +{ + private: + uint64_t file_sz; + FileHandle file_desc; + io_context_t bad_ctx = (io_context_t)-1; + + public: + LinuxAlignedFileReader(); + ~LinuxAlignedFileReader(); + + IOContext &get_ctx(); + + // register thread-id for a context + void register_thread(); + + // de-register thread-id for a context + void deregister_thread(); + void deregister_all_threads(); + + // Open & close ops + // Blocking calls + void open(const std::string &fname); + void close(); + + // process batch of aligned requests in parallel + // NOTE :: blocking call + void read(std::vector &read_reqs, IOContext &ctx, bool async = false); }; #endif diff --git a/include/locking.h b/include/locking.h index e810ad65b..890c24a2b 100644 --- a/include/locking.h +++ b/include/locking.h @@ -8,7 +8,8 @@ #include "windows_slim_lock.h" #endif -namespace diskann { +namespace diskann +{ #ifdef _WINDOWS using non_recursive_mutex = windows_exclusive_slim_lock; using LockGuard = windows_exclusive_slim_lock_guard; diff --git a/include/logger.h b/include/logger.h index 7eccdc312..f1c6ee7f3 100644 --- a/include/logger.h +++ b/include/logger.h @@ -12,7 +12,8 @@ #endif // !ENABLE_CUSTOM_LOGGER #endif // EXEC_ENV_OLS -namespace diskann { +namespace diskann +{ #ifdef ENABLE_CUSTOM_LOGGER DISKANN_DLLEXPORT extern std::basic_ostream cout; DISKANN_DLLEXPORT extern std::basic_ostream cerr; @@ -21,10 +22,14 @@ using std::cerr; using std::cout; #endif -enum class DISKANN_DLLEXPORT LogLevel { LL_Info = 0, LL_Error, LL_Count }; +enum class DISKANN_DLLEXPORT LogLevel +{ + LL_Info = 0, + LL_Error, + LL_Count +}; #ifdef ENABLE_CUSTOM_LOGGER -DISKANN_DLLEXPORT void -SetCustomLogger(std::function logger); +DISKANN_DLLEXPORT void SetCustomLogger(std::function logger); #endif } // namespace diskann diff --git a/include/logger_impl.h b/include/logger_impl.h index 49cb5b830..d2dfaf573 100644 --- a/include/logger_impl.h +++ b/include/logger_impl.h @@ -9,50 +9,53 @@ #include "ann_exception.h" #include "logger.h" -namespace diskann { +namespace diskann +{ #ifdef ENABLE_CUSTOM_LOGGER -class ANNStreamBuf : public std::basic_streambuf { -public: - DISKANN_DLLEXPORT explicit ANNStreamBuf(FILE *fp); - DISKANN_DLLEXPORT ~ANNStreamBuf(); - - DISKANN_DLLEXPORT bool is_open() const { - return true; // because stdout and stderr are always open. - } - DISKANN_DLLEXPORT void close(); - DISKANN_DLLEXPORT virtual int underflow(); - DISKANN_DLLEXPORT virtual int overflow(int c); - DISKANN_DLLEXPORT virtual int sync(); - -private: - FILE *_fp; - char *_buf; - int _bufIndex; - std::mutex _mutex; - LogLevel _logLevel; - - int flush(); - void logImpl(char *str, int numchars); - - // Why the two buffer-sizes? If we are running normally, we are basically - // interacting with a character output system, so we short-circuit the - // output process by keeping an empty buffer and writing each character - // to stdout/stderr. But if we are running in OLS, we have to take all - // the text that is written to diskann::cout/diskann:cerr, consolidate it - // and push it out in one-shot, because the OLS infra does not give us - // character based output. Therefore, we use a larger buffer that is large - // enough to store the longest message, and continuously add characters - // to it. When the calling code outputs a std::endl or std::flush, sync() - // will be called and will output a log level, component name, and the text - // that has been collected. (sync() is also called if the buffer is full, so - // overflows/missing text are not a concern). - // This implies calling code _must_ either print std::endl or std::flush - // to ensure that the message is written immediately. - - static const int BUFFER_SIZE = 1024; - - ANNStreamBuf(const ANNStreamBuf &); - ANNStreamBuf &operator=(const ANNStreamBuf &); +class ANNStreamBuf : public std::basic_streambuf +{ + public: + DISKANN_DLLEXPORT explicit ANNStreamBuf(FILE *fp); + DISKANN_DLLEXPORT ~ANNStreamBuf(); + + DISKANN_DLLEXPORT bool is_open() const + { + return true; // because stdout and stderr are always open. + } + DISKANN_DLLEXPORT void close(); + DISKANN_DLLEXPORT virtual int underflow(); + DISKANN_DLLEXPORT virtual int overflow(int c); + DISKANN_DLLEXPORT virtual int sync(); + + private: + FILE *_fp; + char *_buf; + int _bufIndex; + std::mutex _mutex; + LogLevel _logLevel; + + int flush(); + void logImpl(char *str, int numchars); + + // Why the two buffer-sizes? If we are running normally, we are basically + // interacting with a character output system, so we short-circuit the + // output process by keeping an empty buffer and writing each character + // to stdout/stderr. But if we are running in OLS, we have to take all + // the text that is written to diskann::cout/diskann:cerr, consolidate it + // and push it out in one-shot, because the OLS infra does not give us + // character based output. Therefore, we use a larger buffer that is large + // enough to store the longest message, and continuously add characters + // to it. When the calling code outputs a std::endl or std::flush, sync() + // will be called and will output a log level, component name, and the text + // that has been collected. (sync() is also called if the buffer is full, so + // overflows/missing text are not a concern). + // This implies calling code _must_ either print std::endl or std::flush + // to ensure that the message is written immediately. + + static const int BUFFER_SIZE = 1024; + + ANNStreamBuf(const ANNStreamBuf &); + ANNStreamBuf &operator=(const ANNStreamBuf &); }; #endif } // namespace diskann diff --git a/include/math_utils.h b/include/math_utils.h index 43da740fa..83d189f70 100644 --- a/include/math_utils.h +++ b/include/math_utils.h @@ -6,18 +6,17 @@ #include "common_includes.h" #include "utils.h" -namespace math_utils { +namespace math_utils +{ float calc_distance(float *vec_1, float *vec_2, size_t dim); // compute l2-squared norms of data stored in row major num_points * dim, // needs // to be pre-allocated -void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, - const size_t dim); +void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, const size_t dim); -void rotate_data_randomly(float *data, size_t num_points, size_t dim, - float *rot_mat, float *&new_mat, +void rotate_data_randomly(float *data, size_t num_points, size_t dim, float *rot_mat, float *&new_mat, bool transpose_rot = false); // calculate closest center to data of num_points * dim (row major) @@ -29,11 +28,10 @@ void rotate_data_randomly(float *data, size_t num_points, size_t dim, // squared distances // Ideally used only by compute_closest_centers -void compute_closest_centers_in_block( - const float *const data, const size_t num_points, const size_t dim, - const float *const centers, const size_t num_centers, - const float *const docs_l2sq, const float *const centers_l2sq, - uint32_t *center_index, float *const dist_matrix, size_t k = 1); +void compute_closest_centers_in_block(const float *const data, const size_t num_points, const size_t dim, + const float *const centers, const size_t num_centers, + const float *const docs_l2sq, const float *const centers_l2sq, + uint32_t *center_index, float *const dist_matrix, size_t k = 1); // Given data in num_points * new_dim row major // Pivots stored in full_pivot_data as k * new_dim row major @@ -45,23 +43,21 @@ void compute_closest_centers_in_block( // those // values -void compute_closest_centers(float *data, size_t num_points, size_t dim, - float *pivot_data, size_t num_centers, size_t k, - uint32_t *closest_centers_ivf, - std::vector *inverted_index = NULL, +void compute_closest_centers(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers, + size_t k, uint32_t *closest_centers_ivf, std::vector *inverted_index = NULL, float *pts_norms_squared = NULL); // if to_subtract is 1, will subtract nearest center from each row. Else will // add. Output will be in data_load iself. // Nearest centers need to be provided in closst_centers. -void process_residuals(float *data_load, size_t num_points, size_t dim, - float *cur_pivot_data, size_t num_centers, +void process_residuals(float *data_load, size_t num_points, size_t dim, float *cur_pivot_data, size_t num_centers, uint32_t *closest_centers, bool to_subtract); } // namespace math_utils -namespace kmeans { +namespace kmeans +{ // run Lloyds one iteration // Given data in row major num_points * dim, and centers in row major @@ -71,8 +67,7 @@ namespace kmeans { // If closest_centers == NULL, will allocate memory and return. // Similarly, if closest_docs == NULL, will allocate memory and return. -float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, - size_t num_centers, float *docs_l2sq, +float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, size_t num_centers, float *docs_l2sq, std::vector *closest_docs, uint32_t *&closest_center); // Run Lloyds until max_reps or stopping criterion @@ -81,15 +76,12 @@ float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, // vector [num_centers], and closest_center = new size_t[num_points] // Final centers are output in centers as row major num_centers * dim // -float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, - const size_t num_centers, const size_t max_reps, - std::vector *closest_docs, uint32_t *closest_center); +float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, const size_t num_centers, + const size_t max_reps, std::vector *closest_docs, uint32_t *closest_center); // assumes already memory allocated for pivot_data as new // float[num_centers*dim] and select randomly num_centers points as pivots -void selecting_pivots(float *data, size_t num_points, size_t dim, - float *pivot_data, size_t num_centers); +void selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers); -void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, - float *pivot_data, size_t num_centers); +void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers); } // namespace kmeans diff --git a/include/memory_mapper.h b/include/memory_mapper.h index a84c31b39..75faca1bb 100644 --- a/include/memory_mapper.h +++ b/include/memory_mapper.h @@ -15,27 +15,29 @@ #endif #include -namespace diskann { -class MemoryMapper { -private: +namespace diskann +{ +class MemoryMapper +{ + private: #ifndef _WINDOWS - int _fd; + int _fd; #else - HANDLE _bareFile; - HANDLE _fd; + HANDLE _bareFile; + HANDLE _fd; #endif - char *_buf; - size_t _fileSize; - const char *_fileName; + char *_buf; + size_t _fileSize; + const char *_fileName; -public: - MemoryMapper(const char *filename); - MemoryMapper(const std::string &filename); + public: + MemoryMapper(const char *filename); + MemoryMapper(const std::string &filename); - char *getBuf(); - size_t getFileSize(); + char *getBuf(); + size_t getFileSize(); - ~MemoryMapper(); + ~MemoryMapper(); }; } // namespace diskann diff --git a/include/natural_number_map.h b/include/natural_number_map.h index b11bd596f..e846882a8 100644 --- a/include/natural_number_map.h +++ b/include/natural_number_map.h @@ -9,7 +9,8 @@ #include -namespace diskann { +namespace diskann +{ // A map whose key is a natural number (from 0 onwards) and maps to a value. // Made as both memory and performance efficient map for scenario such as // DiskANN location-to-tag map. There, the pool of numbers is consecutive from @@ -21,63 +22,65 @@ namespace diskann { // Thread-safety: this class is not thread-safe in general. // Exception: multiple read-only operations are safe on the object only if // there are no writers to it in parallel. -template class natural_number_map { -public: - static_assert(std::is_trivial::value, "Key must be a trivial type"); - - // Represents a reference to a element in the map. Used while iterating - // over map entries. - struct position { - size_t _key; - // The number of keys that were enumerated when iterating through the - // map so far. Used to early-terminate enumeration when ithere are no - // more entries in the map. - size_t _keys_already_enumerated; - - // Returns whether it's valid to access the element at this position in - // the map. - bool is_valid() const; - }; - - natural_number_map(); - - void reserve(size_t count); - size_t size() const; - - void set(Key key, Value value); - void erase(Key key); - - bool contains(Key key) const; - bool try_get(Key key, Value &value) const; - - // Returns the value at the specified position. Prerequisite: position is - // valid. - Value get(const position &pos) const; - - // Finds the first element in the map, if any. Invalidated by changes in the - // map. - position find_first() const; - - // Finds the next element in the map after the specified position. - // Invalidated by changes in the map. - position find_next(const position &after_position) const; - - void clear(); - -private: - // Number of entries in the map. Not the same as size() of the - // _values_vector below. - size_t _size; - - // Array of values. The key is the index of the value. - std::vector _values_vector; - - // Values that are in the set have the corresponding bit index set - // to 1. - // - // Use a pointer here to allow for forward declaration of dynamic_bitset - // in public headers to avoid making boost a dependency for clients - // of DiskANN. - std::unique_ptr> _values_bitset; +template class natural_number_map +{ + public: + static_assert(std::is_trivial::value, "Key must be a trivial type"); + + // Represents a reference to a element in the map. Used while iterating + // over map entries. + struct position + { + size_t _key; + // The number of keys that were enumerated when iterating through the + // map so far. Used to early-terminate enumeration when ithere are no + // more entries in the map. + size_t _keys_already_enumerated; + + // Returns whether it's valid to access the element at this position in + // the map. + bool is_valid() const; + }; + + natural_number_map(); + + void reserve(size_t count); + size_t size() const; + + void set(Key key, Value value); + void erase(Key key); + + bool contains(Key key) const; + bool try_get(Key key, Value &value) const; + + // Returns the value at the specified position. Prerequisite: position is + // valid. + Value get(const position &pos) const; + + // Finds the first element in the map, if any. Invalidated by changes in the + // map. + position find_first() const; + + // Finds the next element in the map after the specified position. + // Invalidated by changes in the map. + position find_next(const position &after_position) const; + + void clear(); + + private: + // Number of entries in the map. Not the same as size() of the + // _values_vector below. + size_t _size; + + // Array of values. The key is the index of the value. + std::vector _values_vector; + + // Values that are in the set have the corresponding bit index set + // to 1. + // + // Use a pointer here to allow for forward declaration of dynamic_bitset + // in public headers to avoid making boost a dependency for clients + // of DiskANN. + std::unique_ptr> _values_bitset; }; } // namespace diskann diff --git a/include/natural_number_set.h b/include/natural_number_set.h index 1720b5a39..ec5b827e6 100644 --- a/include/natural_number_set.h +++ b/include/natural_number_set.h @@ -8,7 +8,8 @@ #include "boost_dynamic_bitset_fwd.h" -namespace diskann { +namespace diskann +{ // A set of natural numbers (from 0 onwards). Made for scenario where the // pool of numbers is consecutive from zero to some max value and very // efficient methods for "add to set", "get any value from set", "is in set" @@ -19,30 +20,31 @@ namespace diskann { // Thread-safety: this class is not thread-safe in general. // Exception: multiple read-only operations (e.g. is_in_set, empty, size) are // safe on the object only if there are no writers to it in parallel. -template class natural_number_set { -public: - static_assert(std::is_trivial::value, "Identifier must be a trivial type"); - - natural_number_set(); - - bool is_empty() const; - void reserve(size_t count); - void insert(T id); - T pop_any(); - void clear(); - size_t size() const; - bool is_in_set(T id) const; - -private: - // Values that are currently in set. - std::vector _values_vector; - - // Values that are in the set have the corresponding bit index set - // to 1. - // - // Use a pointer here to allow for forward declaration of dynamic_bitset - // in public headers to avoid making boost a dependency for clients - // of DiskANN. - std::unique_ptr> _values_bitset; +template class natural_number_set +{ + public: + static_assert(std::is_trivial::value, "Identifier must be a trivial type"); + + natural_number_set(); + + bool is_empty() const; + void reserve(size_t count); + void insert(T id); + T pop_any(); + void clear(); + size_t size() const; + bool is_in_set(T id) const; + + private: + // Values that are currently in set. + std::vector _values_vector; + + // Values that are in the set have the corresponding bit index set + // to 1. + // + // Use a pointer here to allow for forward declaration of dynamic_bitset + // in public headers to avoid making boost a dependency for clients + // of DiskANN. + std::unique_ptr> _values_bitset; }; } // namespace diskann diff --git a/include/neighbor.h b/include/neighbor.h index 788c2d12b..61a6932c1 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -8,106 +8,145 @@ #include #include -namespace diskann { +namespace diskann +{ -struct Neighbor { - unsigned id; - float distance; - bool expanded; +struct Neighbor +{ + unsigned id; + float distance; + bool expanded; - Neighbor() = default; + Neighbor() = default; - Neighbor(unsigned id, float distance) - : id{id}, distance{distance}, expanded(false) {} + Neighbor(unsigned id, float distance) : id{id}, distance{distance}, expanded(false) + { + } - inline bool operator<(const Neighbor &other) const { - return distance < other.distance || - (distance == other.distance && id < other.id); - } + inline bool operator<(const Neighbor &other) const + { + return distance < other.distance || (distance == other.distance && id < other.id); + } - inline bool operator==(const Neighbor &other) const { - return (id == other.id); - } + inline bool operator==(const Neighbor &other) const + { + return (id == other.id); + } }; // Invariant: after every `insert` and `closest_unexpanded()`, `_cur` points to // the first Neighbor which is unexpanded. -class NeighborPriorityQueue { -public: - NeighborPriorityQueue() : _size(0), _capacity(0), _cur(0) {} - - explicit NeighborPriorityQueue(size_t capacity) - : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1) {} - - // Inserts the item ordered into the set up to the sets capacity. - // The item will be dropped if it is the same id as an exiting - // set item or it has a greated distance than the final - // item in the set. The set cursor that is used to pop() the - // next item will be set to the lowest index of an uncheck item - void insert(const Neighbor &nbr) { - if (_size == _capacity && _data[_size - 1] < nbr) { - return; +class NeighborPriorityQueue +{ + public: + NeighborPriorityQueue() : _size(0), _capacity(0), _cur(0) + { } - size_t lo = 0, hi = _size; - while (lo < hi) { - size_t mid = (lo + hi) >> 1; - if (nbr < _data[mid]) { - hi = mid; - // Make sure the same id isn't inserted into the set - } else if (_data[mid].id == nbr.id) { - return; - } else { - lo = mid + 1; - } + explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1) + { } - if (lo < _capacity) { - std::memmove(&_data[lo + 1], &_data[lo], (_size - lo) * sizeof(Neighbor)); + // Inserts the item ordered into the set up to the sets capacity. + // The item will be dropped if it is the same id as an exiting + // set item or it has a greated distance than the final + // item in the set. The set cursor that is used to pop() the + // next item will be set to the lowest index of an uncheck item + void insert(const Neighbor &nbr) + { + if (_size == _capacity && _data[_size - 1] < nbr) + { + return; + } + + size_t lo = 0, hi = _size; + while (lo < hi) + { + size_t mid = (lo + hi) >> 1; + if (nbr < _data[mid]) + { + hi = mid; + // Make sure the same id isn't inserted into the set + } + else if (_data[mid].id == nbr.id) + { + return; + } + else + { + lo = mid + 1; + } + } + + if (lo < _capacity) + { + std::memmove(&_data[lo + 1], &_data[lo], (_size - lo) * sizeof(Neighbor)); + } + _data[lo] = {nbr.id, nbr.distance}; + if (_size < _capacity) + { + _size++; + } + if (lo < _cur) + { + _cur = lo; + } } - _data[lo] = {nbr.id, nbr.distance}; - if (_size < _capacity) { - _size++; - } - if (lo < _cur) { - _cur = lo; - } - } - Neighbor closest_unexpanded() { - _data[_cur].expanded = true; - size_t pre = _cur; - while (_cur < _size && _data[_cur].expanded) { - _cur++; + Neighbor closest_unexpanded() + { + _data[_cur].expanded = true; + size_t pre = _cur; + while (_cur < _size && _data[_cur].expanded) + { + _cur++; + } + return _data[pre]; } - return _data[pre]; - } - bool has_unexpanded_node() const { return _cur < _size; } + bool has_unexpanded_node() const + { + return _cur < _size; + } - size_t size() const { return _size; } + size_t size() const + { + return _size; + } - size_t capacity() const { return _capacity; } + size_t capacity() const + { + return _capacity; + } - void reserve(size_t capacity) { - if (capacity + 1 > _data.size()) { - _data.resize(capacity + 1); + void reserve(size_t capacity) + { + if (capacity + 1 > _data.size()) + { + _data.resize(capacity + 1); + } + _capacity = capacity; } - _capacity = capacity; - } - Neighbor &operator[](size_t i) { return _data[i]; } + Neighbor &operator[](size_t i) + { + return _data[i]; + } - Neighbor operator[](size_t i) const { return _data[i]; } + Neighbor operator[](size_t i) const + { + return _data[i]; + } - void clear() { - _size = 0; - _cur = 0; - } + void clear() + { + _size = 0; + _cur = 0; + } -private: - size_t _size, _capacity, _cur; - std::vector _data; + private: + size_t _size, _capacity, _cur; + std::vector _data; }; } // namespace diskann diff --git a/include/parameters.h b/include/parameters.h index 3eb444e35..50e7e4a1a 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -9,105 +9,111 @@ #include "defaults.h" #include "omp.h" -namespace diskann { +namespace diskann +{ class IndexWriteParameters { -public: - const uint32_t search_list_size; // L - const uint32_t max_degree; // R - const bool saturate_graph; - const uint32_t max_occlusion_size; // C - const float alpha; - const uint32_t num_threads; - const uint32_t filter_list_size; // Lf - - IndexWriteParameters(const uint32_t search_list_size, - const uint32_t max_degree, const bool saturate_graph, - const uint32_t max_occlusion_size, const float alpha, - const uint32_t num_threads, - const uint32_t filter_list_size) - : search_list_size(search_list_size), max_degree(max_degree), - saturate_graph(saturate_graph), max_occlusion_size(max_occlusion_size), - alpha(alpha), num_threads(num_threads), - filter_list_size(filter_list_size) {} - - friend class IndexWriteParametersBuilder; + public: + const uint32_t search_list_size; // L + const uint32_t max_degree; // R + const bool saturate_graph; + const uint32_t max_occlusion_size; // C + const float alpha; + const uint32_t num_threads; + const uint32_t filter_list_size; // Lf + + IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph, + const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads, + const uint32_t filter_list_size) + : search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph), + max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads), + filter_list_size(filter_list_size) + { + } + + friend class IndexWriteParametersBuilder; }; -class IndexSearchParams { -public: - IndexSearchParams(const uint32_t initial_search_list_size, - const uint32_t num_search_threads) - : initial_search_list_size(initial_search_list_size), - num_search_threads(num_search_threads) {} - const uint32_t initial_search_list_size; // search L - const uint32_t num_search_threads; // search threads +class IndexSearchParams +{ + public: + IndexSearchParams(const uint32_t initial_search_list_size, const uint32_t num_search_threads) + : initial_search_list_size(initial_search_list_size), num_search_threads(num_search_threads) + { + } + const uint32_t initial_search_list_size; // search L + const uint32_t num_search_threads; // search threads }; -class IndexWriteParametersBuilder { - /** - * Fluent builder pattern to keep track of the 7 non-default properties - * and their order. The basic ctor was getting unwieldy. - */ -public: - IndexWriteParametersBuilder(const uint32_t search_list_size, // L - const uint32_t max_degree // R - ) - : _search_list_size(search_list_size), _max_degree(max_degree) {} - - IndexWriteParametersBuilder & - with_max_occlusion_size(const uint32_t max_occlusion_size) { - _max_occlusion_size = max_occlusion_size; - return *this; - } - - IndexWriteParametersBuilder &with_saturate_graph(const bool saturate_graph) { - _saturate_graph = saturate_graph; - return *this; - } - - IndexWriteParametersBuilder &with_alpha(const float alpha) { - _alpha = alpha; - return *this; - } - - IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads) { - _num_threads = num_threads == 0 ? omp_get_num_procs() : num_threads; - return *this; - } - - IndexWriteParametersBuilder & - with_filter_list_size(const uint32_t filter_list_size) { - _filter_list_size = - filter_list_size == 0 ? _search_list_size : filter_list_size; - return *this; - } - - IndexWriteParameters build() const { - return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, - _max_occlusion_size, _alpha, _num_threads, - _filter_list_size); - } - - IndexWriteParametersBuilder(const IndexWriteParameters &wp) - : _search_list_size(wp.search_list_size), _max_degree(wp.max_degree), - _max_occlusion_size(wp.max_occlusion_size), - _saturate_graph(wp.saturate_graph), _alpha(wp.alpha), - _filter_list_size(wp.filter_list_size) {} - IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete; - IndexWriteParametersBuilder & - operator=(const IndexWriteParametersBuilder &) = delete; - -private: - uint32_t _search_list_size{}; - uint32_t _max_degree{}; - uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE}; - bool _saturate_graph{defaults::SATURATE_GRAPH}; - float _alpha{defaults::ALPHA}; - uint32_t _num_threads{defaults::NUM_THREADS}; - uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; +class IndexWriteParametersBuilder +{ + /** + * Fluent builder pattern to keep track of the 7 non-default properties + * and their order. The basic ctor was getting unwieldy. + */ + public: + IndexWriteParametersBuilder(const uint32_t search_list_size, // L + const uint32_t max_degree // R + ) + : _search_list_size(search_list_size), _max_degree(max_degree) + { + } + + IndexWriteParametersBuilder &with_max_occlusion_size(const uint32_t max_occlusion_size) + { + _max_occlusion_size = max_occlusion_size; + return *this; + } + + IndexWriteParametersBuilder &with_saturate_graph(const bool saturate_graph) + { + _saturate_graph = saturate_graph; + return *this; + } + + IndexWriteParametersBuilder &with_alpha(const float alpha) + { + _alpha = alpha; + return *this; + } + + IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads) + { + _num_threads = num_threads == 0 ? omp_get_num_procs() : num_threads; + return *this; + } + + IndexWriteParametersBuilder &with_filter_list_size(const uint32_t filter_list_size) + { + _filter_list_size = filter_list_size == 0 ? _search_list_size : filter_list_size; + return *this; + } + + IndexWriteParameters build() const + { + return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha, + _num_threads, _filter_list_size); + } + + IndexWriteParametersBuilder(const IndexWriteParameters &wp) + : _search_list_size(wp.search_list_size), _max_degree(wp.max_degree), + _max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha), + _filter_list_size(wp.filter_list_size) + { + } + IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete; + IndexWriteParametersBuilder &operator=(const IndexWriteParametersBuilder &) = delete; + + private: + uint32_t _search_list_size{}; + uint32_t _max_degree{}; + uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE}; + bool _saturate_graph{defaults::SATURATE_GRAPH}; + float _alpha{defaults::ALPHA}; + uint32_t _num_threads{defaults::NUM_THREADS}; + uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; }; } // namespace diskann diff --git a/include/partition.h b/include/partition.h index a5af890be..c2c4c76ad 100644 --- a/include/partition.h +++ b/include/partition.h @@ -16,45 +16,34 @@ #include "windows_customizations.h" template -void gen_random_slice(const std::string base_file, - const std::string output_prefix, double sampling_rate); +void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate); template -void gen_random_slice(const std::string data_file, double p_val, - float *&sampled_data, size_t &slice_size, size_t &ndims); +void gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, size_t &slice_size, + size_t &ndims); template -void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, - double p_val, float *&sampled_data, size_t &slice_size); +void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data, + size_t &slice_size); -int estimate_cluster_sizes(float *test_data_float, size_t num_test, - float *pivots, const size_t num_centers, - const size_t dim, const size_t k_base, - std::vector &cluster_sizes); +int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivots, const size_t num_centers, + const size_t dim, const size_t k_base, std::vector &cluster_sizes); template -int shard_data_into_clusters(const std::string data_file, float *pivots, - const size_t num_centers, const size_t dim, +int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim, const size_t k_base, std::string prefix_path); template -int shard_data_into_clusters_only_ids(const std::string data_file, - float *pivots, const size_t num_centers, - const size_t dim, const size_t k_base, - std::string prefix_path); +int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots, const size_t num_centers, + const size_t dim, const size_t k_base, std::string prefix_path); template -int retrieve_shard_data_from_ids(const std::string data_file, - std::string idmap_filename, - std::string data_filename); +int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, std::string data_filename); template -int partition(const std::string data_file, const float sampling_rate, - size_t num_centers, size_t max_k_means_reps, +int partition(const std::string data_file, const float sampling_rate, size_t num_centers, size_t max_k_means_reps, const std::string prefix_path, size_t k_base); template -int partition_with_ram_budget(const std::string data_file, - const double sampling_rate, double ram_budget, - size_t graph_degree, - const std::string prefix_path, size_t k_base); +int partition_with_ram_budget(const std::string data_file, const double sampling_rate, double ram_budget, + size_t graph_degree, const std::string prefix_path, size_t k_base); diff --git a/include/percentile_stats.h b/include/percentile_stats.h index 2a6d2e9a2..793257577 100644 --- a/include/percentile_stats.h +++ b/include/percentile_stats.h @@ -16,48 +16,50 @@ #include "distance.h" #include "parameters.h" -namespace diskann { -struct QueryStats { - float total_us = 0; // total time to process query in micros - float io_us = 0; // total time spent in IO - float cpu_us = 0; // total time spent in CPU - - unsigned n_4k = 0; // # of 4kB reads - unsigned n_8k = 0; // # of 8kB reads - unsigned n_12k = 0; // # of 12kB reads - unsigned n_ios = 0; // total # of IOs issued - unsigned read_size = 0; // total # of bytes read - unsigned n_cmps_saved = 0; // # cmps saved - unsigned n_cmps = 0; // # cmps - unsigned n_cache_hits = 0; // # cache_hits - unsigned n_hops = 0; // # search hops +namespace diskann +{ +struct QueryStats +{ + float total_us = 0; // total time to process query in micros + float io_us = 0; // total time spent in IO + float cpu_us = 0; // total time spent in CPU + + unsigned n_4k = 0; // # of 4kB reads + unsigned n_8k = 0; // # of 8kB reads + unsigned n_12k = 0; // # of 12kB reads + unsigned n_ios = 0; // total # of IOs issued + unsigned read_size = 0; // total # of bytes read + unsigned n_cmps_saved = 0; // # cmps saved + unsigned n_cmps = 0; // # cmps + unsigned n_cache_hits = 0; // # cache_hits + unsigned n_hops = 0; // # search hops }; template -inline T -get_percentile_stats(QueryStats *stats, uint64_t len, float percentile, - const std::function &member_fn) { - std::vector vals(len); - for (uint64_t i = 0; i < len; i++) { - vals[i] = member_fn(stats[i]); - } - - std::sort(vals.begin(), vals.end(), - [](const T &left, const T &right) { return left < right; }); - - auto retval = vals[(uint64_t)(percentile * len)]; - vals.clear(); - return retval; +inline T get_percentile_stats(QueryStats *stats, uint64_t len, float percentile, + const std::function &member_fn) +{ + std::vector vals(len); + for (uint64_t i = 0; i < len; i++) + { + vals[i] = member_fn(stats[i]); + } + + std::sort(vals.begin(), vals.end(), [](const T &left, const T &right) { return left < right; }); + + auto retval = vals[(uint64_t)(percentile * len)]; + vals.clear(); + return retval; } template -inline double -get_mean_stats(QueryStats *stats, uint64_t len, - const std::function &member_fn) { - double avg = 0; - for (uint64_t i = 0; i < len; i++) { - avg += (double)member_fn(stats[i]); - } - return avg / len; +inline double get_mean_stats(QueryStats *stats, uint64_t len, const std::function &member_fn) +{ + double avg = 0; + for (uint64_t i = 0; i < len; i++) + { + avg += (double)member_fn(stats[i]); + } + return avg / len; } } // namespace diskann diff --git a/include/pq.h b/include/pq.h index 464271717..db9226d8b 100644 --- a/include/pq.h +++ b/include/pq.h @@ -6,104 +6,88 @@ #include "pq_common.h" #include "utils.h" -namespace diskann { -class FixedChunkPQTable { - float *tables = nullptr; // pq_tables = float array of size [256 * ndims] - uint64_t ndims = 0; // ndims = true dimension of vectors - uint64_t n_chunks = 0; - bool use_rotation = false; - uint32_t *chunk_offsets = nullptr; - float *centroid = nullptr; - float *tables_tr = nullptr; // same as pq_tables, but col-major - float *rotmat_tr = nullptr; - -public: - FixedChunkPQTable(); - - virtual ~FixedChunkPQTable(); +namespace diskann +{ +class FixedChunkPQTable +{ + float *tables = nullptr; // pq_tables = float array of size [256 * ndims] + uint64_t ndims = 0; // ndims = true dimension of vectors + uint64_t n_chunks = 0; + bool use_rotation = false; + uint32_t *chunk_offsets = nullptr; + float *centroid = nullptr; + float *tables_tr = nullptr; // same as pq_tables, but col-major + float *rotmat_tr = nullptr; + + public: + FixedChunkPQTable(); + + virtual ~FixedChunkPQTable(); #ifdef EXEC_ENV_OLS - void load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, - size_t num_chunks); + void load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, size_t num_chunks); #else - void load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks); + void load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks); #endif - uint32_t get_num_chunks(); + uint32_t get_num_chunks(); - void preprocess_query(float *query_vec); + void preprocess_query(float *query_vec); - // assumes pre-processed query - void populate_chunk_distances(const float *query_vec, float *dist_vec); + // assumes pre-processed query + void populate_chunk_distances(const float *query_vec, float *dist_vec); - float l2_distance(const float *query_vec, uint8_t *base_vec); + float l2_distance(const float *query_vec, uint8_t *base_vec); - float inner_product(const float *query_vec, uint8_t *base_vec); + float inner_product(const float *query_vec, uint8_t *base_vec); - // assumes no rotation is involved - void inflate_vector(uint8_t *base_vec, float *out_vec); + // assumes no rotation is involved + void inflate_vector(uint8_t *base_vec, float *out_vec); - void populate_chunk_inner_products(const float *query_vec, float *dist_vec); + void populate_chunk_inner_products(const float *query_vec, float *dist_vec); }; -void aggregate_coords(const std::vector &ids, - const uint8_t *all_coords, const uint64_t ndims, - uint8_t *out); +void aggregate_coords(const std::vector &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out); -void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, - const size_t pq_nchunks, const float *pq_dists, +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, std::vector &dists_out); // Need to replace calls to these with calls to vector& based functions above -void aggregate_coords(const unsigned *ids, const uint64_t n_ids, - const uint8_t *all_coords, const uint64_t ndims, +void aggregate_coords(const unsigned *ids, const uint64_t n_ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out); -void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, - const size_t pq_nchunks, const float *pq_dists, +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, float *dists_out); -DISKANN_DLLEXPORT int -generate_pq_pivots(const float *const train_data, size_t num_train, - unsigned dim, unsigned num_centers, unsigned num_pq_chunks, - unsigned max_k_means_reps, std::string pq_pivots_path, - bool make_zero_mean = false); +DISKANN_DLLEXPORT int generate_pq_pivots(const float *const train_data, size_t num_train, unsigned dim, + unsigned num_centers, unsigned num_pq_chunks, unsigned max_k_means_reps, + std::string pq_pivots_path, bool make_zero_mean = false); -DISKANN_DLLEXPORT int -generate_opq_pivots(const float *train_data, size_t num_train, unsigned dim, - unsigned num_centers, unsigned num_pq_chunks, - std::string opq_pivots_path, bool make_zero_mean = false); +DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_train, unsigned dim, unsigned num_centers, + unsigned num_pq_chunks, std::string opq_pivots_path, + bool make_zero_mean = false); -DISKANN_DLLEXPORT int -generate_pq_pivots_simplified(const float *train_data, size_t num_train, - size_t dim, size_t num_pq_chunks, - std::vector &pivot_data_vector); +DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, + size_t num_pq_chunks, std::vector &pivot_data_vector); template -int generate_pq_data_from_pivots(const std::string &data_file, - unsigned num_centers, unsigned num_pq_chunks, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, +int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks, + const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, bool use_opq = false); -DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified( - const float *data, const size_t num, const float *pivot_data, - const size_t pivots_num, const size_t dim, const size_t num_pq_chunks, - std::vector &pq); +DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, + const float *pivot_data, const size_t pivots_num, + const size_t dim, const size_t num_pq_chunks, + std::vector &pq); template -void generate_disk_quantized_data( - const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, - const std::string &disk_pq_compressed_vectors_path, - const diskann::Metric compareMetric, const double p_val, - size_t &disk_pq_dims); +void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, + const diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); template -void generate_quantized_data(const std::string &data_file_to_use, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, - const diskann::Metric compareMetric, - const double p_val, const uint64_t num_pq_chunks, - const bool use_opq, +void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, const diskann::Metric compareMetric, + const double p_val, const uint64_t num_pq_chunks, const bool use_opq, const std::string &codebook_prefix = ""); } // namespace diskann diff --git a/include/pq_common.h b/include/pq_common.h index 2d7dc28a9..d7a4b60f4 100644 --- a/include/pq_common.h +++ b/include/pq_common.h @@ -10,23 +10,21 @@ #define MAX_PQ_TRAINING_SET_SIZE 256000 #define MAX_PQ_CHUNKS 512 -namespace diskann { -inline std::string get_quantized_vectors_filename(const std::string &prefix, - bool use_opq, - uint32_t num_chunks) { - return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + - "_compressed.bin"; +namespace diskann +{ +inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks) +{ + return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin"; } -inline std::string get_pivot_data_filename(const std::string &prefix, - bool use_opq, uint32_t num_chunks) { - return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + - "_pivots.bin"; +inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks) +{ + return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin"; } -inline std::string -get_rotation_matrix_suffix(const std::string &pivot_data_filename) { - return pivot_data_filename + "_rotation_matrix.bin"; +inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename) +{ + return pivot_data_filename + "_rotation_matrix.bin"; } } // namespace diskann diff --git a/include/pq_data_store.h b/include/pq_data_store.h index 98f0bc3a0..4e223e785 100644 --- a/include/pq_data_store.h +++ b/include/pq_data_store.h @@ -5,109 +5,94 @@ #include "quantized_distance.h" #include -namespace diskann { +namespace diskann +{ // REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because // both Vamana and DiskANN treat it the same way. But with DiskPQ, that may need // to change. -template -class PQDataStore : public AbstractDataStore { - -public: - PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, - std::unique_ptr> distance_fn, - std::unique_ptr> pq_distance_fn); - PQDataStore(const PQDataStore &) = delete; - PQDataStore &operator=(const PQDataStore &) = delete; - ~PQDataStore(); - - // Load quantized vectors from a set of files. Here filename is treated - // as a prefix and the files are assumed to be named with DiskANN - // conventions. - virtual location_t load(const std::string &file_prefix) override; - - // Save quantized vectors to a set of files whose names start with - // file_prefix. - // Currently, the plan is to save the quantized vectors to the quantized - // vectors file. - virtual size_t save(const std::string &file_prefix, - const location_t num_points) override; - - // Since base class function is pure virtual, we need to declare it here, even - // though alignent concept is not needed for Quantized data stores. - virtual size_t get_aligned_dim() const override; - - // Populate quantized data from unaligned data using PQ functionality - virtual void populate_data(const data_t *vectors, - const location_t num_pts) override; - virtual void populate_data(const std::string &filename, - const size_t offset) override; - - virtual void extract_data_to_bin(const std::string &filename, - const location_t num_pts) override; - - virtual void get_vector(const location_t i, data_t *target) const override; - virtual void set_vector(const location_t i, - const data_t *const vector) override; - virtual void prefetch_vector(const location_t loc) override; - - virtual void move_vectors(const location_t old_location_start, - const location_t new_location_start, - const location_t num_points) override; - virtual void copy_vectors(const location_t from_loc, const location_t to_loc, - const location_t num_points) override; - - virtual void - preprocess_query(const data_t *query, - AbstractScratch *scratch) const override; - - virtual float get_distance(const data_t *query, - const location_t loc) const override; - virtual float get_distance(const location_t loc1, - const location_t loc2) const override; - - // NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling - // this function. - virtual void - get_distance(const data_t *preprocessed_query, const location_t *locations, - const uint32_t location_count, float *distances, - AbstractScratch *scratch_space) const override; - - // NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling - // this function. - virtual void - get_distance(const data_t *preprocessed_query, - const std::vector &ids, - std::vector &distances, - AbstractScratch *scratch_space) const override; - - // We are returning the distance function that is used for full precision - // vectors here, not the PQ distance function. This is because the callers - // all are expecting a Distance not QuantizedDistance. - virtual Distance *get_dist_fn() const override; - - virtual location_t calculate_medoid() const override; - - virtual size_t get_alignment_factor() const override; - -protected: - virtual location_t expand(const location_t new_size) override; - virtual location_t shrink(const location_t new_size) override; - - virtual location_t load_impl(const std::string &filename); +template class PQDataStore : public AbstractDataStore +{ + + public: + PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::unique_ptr> distance_fn, + std::unique_ptr> pq_distance_fn); + PQDataStore(const PQDataStore &) = delete; + PQDataStore &operator=(const PQDataStore &) = delete; + ~PQDataStore(); + + // Load quantized vectors from a set of files. Here filename is treated + // as a prefix and the files are assumed to be named with DiskANN + // conventions. + virtual location_t load(const std::string &file_prefix) override; + + // Save quantized vectors to a set of files whose names start with + // file_prefix. + // Currently, the plan is to save the quantized vectors to the quantized + // vectors file. + virtual size_t save(const std::string &file_prefix, const location_t num_points) override; + + // Since base class function is pure virtual, we need to declare it here, even + // though alignent concept is not needed for Quantized data stores. + virtual size_t get_aligned_dim() const override; + + // Populate quantized data from unaligned data using PQ functionality + virtual void populate_data(const data_t *vectors, const location_t num_pts) override; + virtual void populate_data(const std::string &filename, const size_t offset) override; + + virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override; + + virtual void get_vector(const location_t i, data_t *target) const override; + virtual void set_vector(const location_t i, const data_t *const vector) override; + virtual void prefetch_vector(const location_t loc) override; + + virtual void move_vectors(const location_t old_location_start, const location_t new_location_start, + const location_t num_points) override; + virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override; + + virtual void preprocess_query(const data_t *query, AbstractScratch *scratch) const override; + + virtual float get_distance(const data_t *query, const location_t loc) const override; + virtual float get_distance(const location_t loc1, const location_t loc2) const override; + + // NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling + // this function. + virtual void get_distance(const data_t *preprocessed_query, const location_t *locations, + const uint32_t location_count, float *distances, + AbstractScratch *scratch_space) const override; + + // NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling + // this function. + virtual void get_distance(const data_t *preprocessed_query, const std::vector &ids, + std::vector &distances, AbstractScratch *scratch_space) const override; + + // We are returning the distance function that is used for full precision + // vectors here, not the PQ distance function. This is because the callers + // all are expecting a Distance not QuantizedDistance. + virtual Distance *get_dist_fn() const override; + + virtual location_t calculate_medoid() const override; + + virtual size_t get_alignment_factor() const override; + + protected: + virtual location_t expand(const location_t new_size) override; + virtual location_t shrink(const location_t new_size) override; + + virtual location_t load_impl(const std::string &filename); #ifdef EXEC_ENV_OLS - virtual location_t load_impl(AlignedFileReader &reader); + virtual location_t load_impl(AlignedFileReader &reader); #endif -private: - uint8_t *_quantized_data = nullptr; - size_t _num_chunks = 0; + private: + uint8_t *_quantized_data = nullptr; + size_t _num_chunks = 0; - // REFACTOR TODO: Doing this temporarily before refactoring OPQ into - // its own class. Remove later. - bool _use_opq = false; + // REFACTOR TODO: Doing this temporarily before refactoring OPQ into + // its own class. Remove later. + bool _use_opq = false; - Metric _distance_metric; - std::unique_ptr> _distance_fn = nullptr; - std::unique_ptr> _pq_distance_fn = nullptr; + Metric _distance_metric; + std::unique_ptr> _distance_fn = nullptr; + std::unique_ptr> _pq_distance_fn = nullptr; }; } // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 7cd82a685..9e43debfc 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -20,225 +20,211 @@ #define FULL_PRECISION_REORDER_MULTIPLIER 3 -namespace diskann { +namespace diskann +{ -template class PQFlashIndex { -public: - DISKANN_DLLEXPORT PQFlashIndex(std::shared_ptr &fileReader, - diskann::Metric metric = diskann::Metric::L2); - DISKANN_DLLEXPORT ~PQFlashIndex(); +template class PQFlashIndex +{ + public: + DISKANN_DLLEXPORT PQFlashIndex(std::shared_ptr &fileReader, + diskann::Metric metric = diskann::Metric::L2); + DISKANN_DLLEXPORT ~PQFlashIndex(); #ifdef EXEC_ENV_OLS - DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, - uint32_t num_threads, const char *index_prefix); + DISKANN_DLLEXPORT int load(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix); #else - // load compressed data, and obtains the handle to the disk-resident index - DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix); + // load compressed data, and obtains the handle to the disk-resident index + DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix); #endif #ifdef EXEC_ENV_OLS - DISKANN_DLLEXPORT int - load_from_separate_paths(diskann::MemoryMappedFiles &files, - uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, - const char *compressed_filepath); + DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, + const char *index_filepath, const char *pivots_filepath, + const char *compressed_filepath); #else - DISKANN_DLLEXPORT int - load_from_separate_paths(uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, - const char *compressed_filepath); + DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath, + const char *pivots_filepath, const char *compressed_filepath); #endif - DISKANN_DLLEXPORT void load_cache_list(std::vector &node_list); + DISKANN_DLLEXPORT void load_cache_list(std::vector &node_list); #ifdef EXEC_ENV_OLS - DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries( - MemoryMappedFiles &files, std::string sample_bin, uint64_t l_search, - uint64_t beamwidth, uint64_t num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list); + DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin, + uint64_t l_search, uint64_t beamwidth, + uint64_t num_nodes_to_cache, uint32_t nthreads, + std::vector &node_list); #else - DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries( - std::string sample_bin, uint64_t l_search, uint64_t beamwidth, - uint64_t num_nodes_to_cache, uint32_t num_threads, - std::vector &node_list); + DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search, + uint64_t beamwidth, uint64_t num_nodes_to_cache, + uint32_t num_threads, + std::vector &node_list); #endif - DISKANN_DLLEXPORT void cache_bfs_levels(uint64_t num_nodes_to_cache, - std::vector &node_list, - const bool shuffle = false); - - DISKANN_DLLEXPORT void cached_beam_search( - const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_reorder_data = false, QueryStats *stats = nullptr); - - DISKANN_DLLEXPORT void cached_beam_search( - const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, - const bool use_reorder_data = false, QueryStats *stats = nullptr); - - DISKANN_DLLEXPORT void cached_beam_search( - const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data = false, - QueryStats *stats = nullptr); - - DISKANN_DLLEXPORT void cached_beam_search( - const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data = false, - QueryStats *stats = nullptr); - - DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, - const uint64_t min_l_search, - const uint64_t max_l_search, - std::vector &indices, - std::vector &distances, - const uint64_t min_beam_width, - QueryStats *stats = nullptr); - - DISKANN_DLLEXPORT uint64_t get_data_dim(); - - DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); - - std::shared_ptr &reader; - - DISKANN_DLLEXPORT diskann::Metric get_metric(); - - // - // node_ids: input list of node_ids to be read - // coord_buffers: pointers to pre-allocated buffers that coords need to copied - // to. If null, dont copy. nbr_buffers: pre-allocated buffers to copy - // neighbors into - // - // returns a vector of bool one for each node_id: true if read is success, - // else false - // - DISKANN_DLLEXPORT std::vector - read_nodes(const std::vector &node_ids, - std::vector &coord_buffers, - std::vector> &nbr_buffers); - - DISKANN_DLLEXPORT std::vector get_pq_vector(std::uint64_t vid); - DISKANN_DLLEXPORT uint64_t get_num_points(); - -protected: - DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); - DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, - uint64_t visited_reserve = 4096); - -private: - // sector # on disk where node_id is present with in the graph part - DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id); - - // ptr to start of the node - DISKANN_DLLEXPORT char *offset_to_node(char *sector_buf, uint64_t node_id); - - // returns region of `node_buf` containing [NNBRS][NBR_ID(uint32_t)] - DISKANN_DLLEXPORT uint32_t *offset_to_node_nhood(char *node_buf); - - // returns region of `node_buf` containing [COORD(T)] - DISKANN_DLLEXPORT T *offset_to_node_coords(char *node_buf); - - // index info for multi-node sectors - // nhood of node `i` is in sector: [i / nnodes_per_sector] - // offset in sector: [(i % nnodes_per_sector) * max_node_len] - // - // index info for multi-sector nodes - // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, - // SECTOR_LEN)] offset in sector: [0] - // - // Common info - // coords start at ofsset - // #nbrs of node `i`: *(unsigned*) (offset + disk_bytes_per_point) - // nbrs of node `i` : (unsigned*) (offset + disk_bytes_per_point + 1) - - uint64_t _max_node_len = 0; - uint64_t _nnodes_per_sector = - 0; // 0 for multi-sector nodes, >0 for multi-node sectors - uint64_t _max_degree = 0; - - // Data used for searching with re-order vectors - uint64_t _ndims_reorder_vecs = 0; - uint64_t _reorder_data_start_sector = 0; - uint64_t _nvecs_per_sector = 0; - - diskann::Metric metric = diskann::Metric::L2; - - // used only for inner product search to re-scale the result value - // (due to the pre-processing of base during index build) - float _max_base_norm = 0.0f; - - // data info - uint64_t _num_points = 0; - uint64_t _num_frozen_points = 0; - uint64_t _frozen_location = 0; - uint64_t _data_dim = 0; - uint64_t _aligned_dim = 0; - uint64_t _disk_bytes_per_point = 0; // Number of bytes - - std::string _disk_index_file; - std::vector> _node_visit_counter; - - // PQ data - // _n_chunks = # of chunks ndims is split into - // data: char * _n_chunks - // chunk_size = chunk size of each dimension chunk - // pq_tables = float* [[2^8 * [chunk_size]] * _n_chunks] - uint8_t *data = nullptr; - uint64_t _n_chunks; - FixedChunkPQTable _pq_table; - - // distance comparator - std::shared_ptr> _dist_cmp; - std::shared_ptr> _dist_cmp_float; - - // for very large datasets: we use PQ even for the disk resident index - bool _use_disk_index_pq = false; - uint64_t _disk_pq_n_chunks = 0; - FixedChunkPQTable _disk_pq_table; - - // medoid/start info - - // graph has one entry point by default, - // we can optionally have multiple starting points - uint32_t *_medoids = nullptr; - // defaults to 1 - size_t _num_medoids; - // by default, it is empty. If there are multiple - // centroids, we pick the medoid corresponding to the - // closest centroid as the starting point of search - float *_centroid_data = nullptr; - - // nhood_cache; the uint32_t in nhood_Cache are offsets into nhood_cache_buf - unsigned *_nhood_cache_buf = nullptr; - tsl::robin_map> _nhood_cache; - - // coord_cache; The T* in coord_cache are offsets into coord_cache_buf - T *_coord_cache_buf = nullptr; - tsl::robin_map _coord_cache; - - // thread-specific scratch - ConcurrentQueue *> _thread_data; - uint64_t _max_nthreads; - bool _load_flag = false; - bool _count_visited_nodes = false; - bool _reorder_data_exists = false; - uint64_t _reoreder_data_offset = 0; - - // Moved filter-specific data structures to in_mem_filter_store. - // TODO: Make this a unique pointer - bool _filter_index = false; - std::unique_ptr> _filter_store; + DISKANN_DLLEXPORT void cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector &node_list, + const bool shuffle = false); + + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const bool use_reorder_data = false, QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const bool use_reorder_data = false, QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const uint32_t io_limit, const bool use_reorder_data = false, + QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const uint32_t io_limit, const bool use_reorder_data = false, + QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search, + const uint64_t max_l_search, std::vector &indices, + std::vector &distances, const uint64_t min_beam_width, + QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT uint64_t get_data_dim(); + + DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); + + std::shared_ptr &reader; + + DISKANN_DLLEXPORT diskann::Metric get_metric(); + + // + // node_ids: input list of node_ids to be read + // coord_buffers: pointers to pre-allocated buffers that coords need to copied + // to. If null, dont copy. nbr_buffers: pre-allocated buffers to copy + // neighbors into + // + // returns a vector of bool one for each node_id: true if read is success, + // else false + // + DISKANN_DLLEXPORT std::vector read_nodes(const std::vector &node_ids, + std::vector &coord_buffers, + std::vector> &nbr_buffers); + + DISKANN_DLLEXPORT std::vector get_pq_vector(std::uint64_t vid); + DISKANN_DLLEXPORT uint64_t get_num_points(); + + protected: + DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); + DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); + + private: + // sector # on disk where node_id is present with in the graph part + DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id); + + // ptr to start of the node + DISKANN_DLLEXPORT char *offset_to_node(char *sector_buf, uint64_t node_id); + + // returns region of `node_buf` containing [NNBRS][NBR_ID(uint32_t)] + DISKANN_DLLEXPORT uint32_t *offset_to_node_nhood(char *node_buf); + + // returns region of `node_buf` containing [COORD(T)] + DISKANN_DLLEXPORT T *offset_to_node_coords(char *node_buf); + + // index info for multi-node sectors + // nhood of node `i` is in sector: [i / nnodes_per_sector] + // offset in sector: [(i % nnodes_per_sector) * max_node_len] + // + // index info for multi-sector nodes + // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, + // SECTOR_LEN)] offset in sector: [0] + // + // Common info + // coords start at ofsset + // #nbrs of node `i`: *(unsigned*) (offset + disk_bytes_per_point) + // nbrs of node `i` : (unsigned*) (offset + disk_bytes_per_point + 1) + + uint64_t _max_node_len = 0; + uint64_t _nnodes_per_sector = 0; // 0 for multi-sector nodes, >0 for multi-node sectors + uint64_t _max_degree = 0; + + // Data used for searching with re-order vectors + uint64_t _ndims_reorder_vecs = 0; + uint64_t _reorder_data_start_sector = 0; + uint64_t _nvecs_per_sector = 0; + + diskann::Metric metric = diskann::Metric::L2; + + // used only for inner product search to re-scale the result value + // (due to the pre-processing of base during index build) + float _max_base_norm = 0.0f; + + // data info + uint64_t _num_points = 0; + uint64_t _num_frozen_points = 0; + uint64_t _frozen_location = 0; + uint64_t _data_dim = 0; + uint64_t _aligned_dim = 0; + uint64_t _disk_bytes_per_point = 0; // Number of bytes + + std::string _disk_index_file; + std::vector> _node_visit_counter; + + // PQ data + // _n_chunks = # of chunks ndims is split into + // data: char * _n_chunks + // chunk_size = chunk size of each dimension chunk + // pq_tables = float* [[2^8 * [chunk_size]] * _n_chunks] + uint8_t *data = nullptr; + uint64_t _n_chunks; + FixedChunkPQTable _pq_table; + + // distance comparator + std::shared_ptr> _dist_cmp; + std::shared_ptr> _dist_cmp_float; + + // for very large datasets: we use PQ even for the disk resident index + bool _use_disk_index_pq = false; + uint64_t _disk_pq_n_chunks = 0; + FixedChunkPQTable _disk_pq_table; + + // medoid/start info + + // graph has one entry point by default, + // we can optionally have multiple starting points + uint32_t *_medoids = nullptr; + // defaults to 1 + size_t _num_medoids; + // by default, it is empty. If there are multiple + // centroids, we pick the medoid corresponding to the + // closest centroid as the starting point of search + float *_centroid_data = nullptr; + + // nhood_cache; the uint32_t in nhood_Cache are offsets into nhood_cache_buf + unsigned *_nhood_cache_buf = nullptr; + tsl::robin_map> _nhood_cache; + + // coord_cache; The T* in coord_cache are offsets into coord_cache_buf + T *_coord_cache_buf = nullptr; + tsl::robin_map _coord_cache; + + // thread-specific scratch + ConcurrentQueue *> _thread_data; + uint64_t _max_nthreads; + bool _load_flag = false; + bool _count_visited_nodes = false; + bool _reorder_data_exists = false; + uint64_t _reoreder_data_offset = 0; + + // Moved filter-specific data structures to in_mem_filter_store. + // TODO: Make this a unique pointer + bool _filter_index = false; + std::unique_ptr> _filter_store; #ifdef EXEC_ENV_OLS - // Set to a larger value than the actual header to accommodate - // any additions we make to the header. This is an outer limit - // on how big the header can be. - static const int HEADER_SIZE = defaults::SECTOR_LEN; - char *getHeaderBytes(); + // Set to a larger value than the actual header to accommodate + // any additions we make to the header. This is an outer limit + // on how big the header can be. + static const int HEADER_SIZE = defaults::SECTOR_LEN; + char *getHeaderBytes(); #endif }; } // namespace diskann diff --git a/include/pq_l2_distance.h b/include/pq_l2_distance.h index 302c57d3c..e6fc6e41b 100644 --- a/include/pq_l2_distance.h +++ b/include/pq_l2_distance.h @@ -1,96 +1,87 @@ #pragma once #include "quantized_distance.h" -namespace diskann { -template -class PQL2Distance : public QuantizedDistance { -public: - // REFACTOR TODO: We could take a file prefix here and load the - // PQ pivots file, so that the distance object is initialized - // immediately after construction. But this would not work well - // with our data store concept where the store is created first - // and data populated after. - // REFACTOR TODO: Ideally, we should only read the num_chunks from - // the pivots file. However, we read the pivots file only later, but - // clients can call functions like get__filename without calling - // load_pivot_data. Hence this. The TODO is whether we should check - // that the num_chunks from the file is the same as this one. +namespace diskann +{ +template class PQL2Distance : public QuantizedDistance +{ + public: + // REFACTOR TODO: We could take a file prefix here and load the + // PQ pivots file, so that the distance object is initialized + // immediately after construction. But this would not work well + // with our data store concept where the store is created first + // and data populated after. + // REFACTOR TODO: Ideally, we should only read the num_chunks from + // the pivots file. However, we read the pivots file only later, but + // clients can call functions like get__filename without calling + // load_pivot_data. Hence this. The TODO is whether we should check + // that the num_chunks from the file is the same as this one. - PQL2Distance(uint32_t num_chunks, bool use_opq = false); + PQL2Distance(uint32_t num_chunks, bool use_opq = false); - virtual ~PQL2Distance() override; + virtual ~PQL2Distance() override; - virtual bool is_opq() const override; + virtual bool is_opq() const override; - virtual std::string - get_quantized_vectors_filename(const std::string &prefix) const override; - virtual std::string - get_pivot_data_filename(const std::string &prefix) const override; - virtual std::string get_rotation_matrix_suffix( - const std::string &pq_pivots_filename) const override; + virtual std::string get_quantized_vectors_filename(const std::string &prefix) const override; + virtual std::string get_pivot_data_filename(const std::string &prefix) const override; + virtual std::string get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const override; #ifdef EXEC_ENV_OLS - virtual void load_pivot_data(MemoryMappedFiles &files, - const std::string &pq_table_file, - size_t num_chunks) override; + virtual void load_pivot_data(MemoryMappedFiles &files, const std::string &pq_table_file, + size_t num_chunks) override; #else - virtual void load_pivot_data(const std::string &pq_table_file, - size_t num_chunks) override; + virtual void load_pivot_data(const std::string &pq_table_file, size_t num_chunks) override; #endif - // Number of chunks in the PQ table. Depends on the compression level used. - // Has to be < ndim - virtual uint32_t get_num_chunks() const override; + // Number of chunks in the PQ table. Depends on the compression level used. + // Has to be < ndim + virtual uint32_t get_num_chunks() const override; - // Preprocess the query by computing chunk distances from the query vector to - // various centroids. Since we don't want this class to do scratch management, - // we will take a PQScratch object which can come either from Index class or - // PQFlashIndex class. - virtual void preprocess_query(const data_t *aligned_query, - uint32_t original_dim, - PQScratch &pq_scratch) override; + // Preprocess the query by computing chunk distances from the query vector to + // various centroids. Since we don't want this class to do scratch management, + // we will take a PQScratch object which can come either from Index class or + // PQFlashIndex class. + virtual void preprocess_query(const data_t *aligned_query, uint32_t original_dim, + PQScratch &pq_scratch) override; - // Distance function used for graph traversal. This function must be called - // after - // preprocess_query. The reason we do not call preprocess ourselves is because - // that function has to be called once per query, while this function is - // called at each iteration of the graph walk. NOTE: This function expects - // 1. the query to be preprocessed using preprocess_query() - // 2. the scratch object to contain the quantized vectors corresponding to ids - // in aligned_pq_coord_scratch. Done by calling aggregate_coords() - // - virtual void preprocessed_distance(PQScratch &pq_scratch, - const uint32_t id_count, - float *dists_out) override; + // Distance function used for graph traversal. This function must be called + // after + // preprocess_query. The reason we do not call preprocess ourselves is because + // that function has to be called once per query, while this function is + // called at each iteration of the graph walk. NOTE: This function expects + // 1. the query to be preprocessed using preprocess_query() + // 2. the scratch object to contain the quantized vectors corresponding to ids + // in aligned_pq_coord_scratch. Done by calling aggregate_coords() + // + virtual void preprocessed_distance(PQScratch &pq_scratch, const uint32_t id_count, + float *dists_out) override; - // Same as above, but returns the distances in a vector instead of an array. - // Convenience function for index.cpp. - virtual void preprocessed_distance(PQScratch &pq_scratch, - const uint32_t n_ids, - std::vector &dists_out) override; + // Same as above, but returns the distances in a vector instead of an array. + // Convenience function for index.cpp. + virtual void preprocessed_distance(PQScratch &pq_scratch, const uint32_t n_ids, + std::vector &dists_out) override; - // Currently this function is required for DiskPQ. However, it too can be - // subsumed under preprocessed_distance if we add the appropriate scratch - // variables to PQScratch and initialize them in - // pq_flash_index.cpp::disk_iterate_to_fixed_point() - virtual float brute_force_distance(const float *query_vec, - uint8_t *base_vec) override; + // Currently this function is required for DiskPQ. However, it too can be + // subsumed under preprocessed_distance if we add the appropriate scratch + // variables to PQScratch and initialize them in + // pq_flash_index.cpp::disk_iterate_to_fixed_point() + virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) override; -protected: - // assumes pre-processed query - virtual void prepopulate_chunkwise_distances(const float *query_vec, - float *dist_vec); + protected: + // assumes pre-processed query + virtual void prepopulate_chunkwise_distances(const float *query_vec, float *dist_vec); - // assumes no rotation is involved - // virtual void inflate_vector(uint8_t *base_vec, float *out_vec); + // assumes no rotation is involved + // virtual void inflate_vector(uint8_t *base_vec, float *out_vec); - float *_tables = nullptr; // pq_tables = float array of size [256 * ndims] - uint64_t _ndims = 0; // ndims = true dimension of vectors - uint64_t _num_chunks = 0; - bool _is_opq = false; - uint32_t *_chunk_offsets = nullptr; - float *_centroid = nullptr; - float *_tables_tr = nullptr; // same as pq_tables, but col-major - float *_rotmat_tr = nullptr; + float *_tables = nullptr; // pq_tables = float array of size [256 * ndims] + uint64_t _ndims = 0; // ndims = true dimension of vectors + uint64_t _num_chunks = 0; + bool _is_opq = false; + uint32_t *_chunk_offsets = nullptr; + float *_centroid = nullptr; + float *_tables_tr = nullptr; // same as pq_tables, but col-major + float *_rotmat_tr = nullptr; }; } // namespace diskann diff --git a/include/pq_scratch.h b/include/pq_scratch.h index bdbf7de99..6b52463eb 100644 --- a/include/pq_scratch.h +++ b/include/pq_scratch.h @@ -3,21 +3,21 @@ #include "utils.h" #include -namespace diskann { +namespace diskann +{ -template class PQScratch { -public: - float *aligned_pqtable_dist_scratch = - nullptr; // MUST BE AT LEAST [256 * NCHUNKS] - float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE - uint8_t *aligned_pq_coord_scratch = - nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE] - float *rotated_query = nullptr; - float *aligned_query_float = nullptr; +template class PQScratch +{ + public: + float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS] + float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE + uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE] + float *rotated_query = nullptr; + float *aligned_query_float = nullptr; - PQScratch(size_t graph_degree, size_t aligned_dim); - void initialize(size_t dim, const T *query, const float norm = 1.0f); - virtual ~PQScratch(); + PQScratch(size_t graph_degree, size_t aligned_dim); + void initialize(size_t dim, const T *query, const float norm = 1.0f); + virtual ~PQScratch(); }; } // namespace diskann \ No newline at end of file diff --git a/include/quantized_distance.h b/include/quantized_distance.h index cc70c2989..44798ac96 100644 --- a/include/quantized_distance.h +++ b/include/quantized_distance.h @@ -4,63 +4,54 @@ #include #include -namespace diskann { +namespace diskann +{ template class PQScratch; -template class QuantizedDistance { -public: - QuantizedDistance() = default; - QuantizedDistance(const QuantizedDistance &) = delete; - QuantizedDistance &operator=(const QuantizedDistance &) = delete; - virtual ~QuantizedDistance() = default; - - virtual bool is_opq() const = 0; - virtual std::string - get_quantized_vectors_filename(const std::string &prefix) const = 0; - virtual std::string - get_pivot_data_filename(const std::string &prefix) const = 0; - virtual std::string - get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const = 0; - - // Loading the PQ centroid table need not be part of the abstract class. - // However, we want to indicate that this function will change once we have a - // file reader hierarchy, so leave it here as-is. +template class QuantizedDistance +{ + public: + QuantizedDistance() = default; + QuantizedDistance(const QuantizedDistance &) = delete; + QuantizedDistance &operator=(const QuantizedDistance &) = delete; + virtual ~QuantizedDistance() = default; + + virtual bool is_opq() const = 0; + virtual std::string get_quantized_vectors_filename(const std::string &prefix) const = 0; + virtual std::string get_pivot_data_filename(const std::string &prefix) const = 0; + virtual std::string get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const = 0; + + // Loading the PQ centroid table need not be part of the abstract class. + // However, we want to indicate that this function will change once we have a + // file reader hierarchy, so leave it here as-is. #ifdef EXEC_ENV_OLS - virtual void load_pivot_data(MemoryMappedFiles &files, - const std::String &pq_table_file, - size_t num_chunks) = 0; + virtual void load_pivot_data(MemoryMappedFiles &files, const std::String &pq_table_file, size_t num_chunks) = 0; #else - virtual void load_pivot_data(const std::string &pq_table_file, - size_t num_chunks) = 0; + virtual void load_pivot_data(const std::string &pq_table_file, size_t num_chunks) = 0; #endif - // Number of chunks in the PQ table. Depends on the compression level used. - // Has to be < ndim - virtual uint32_t get_num_chunks() const = 0; - - // Preprocess the query by computing chunk distances from the query vector to - // various centroids. Since we don't want this class to do scratch management, - // we will take a PQScratch object which can come either from Index class or - // PQFlashIndex class. - virtual void preprocess_query(const data_t *query_vec, uint32_t query_dim, - PQScratch &pq_scratch) = 0; - - // Workhorse - // This function must be called after preprocess_query - virtual void preprocessed_distance(PQScratch &pq_scratch, - const uint32_t id_count, - float *dists_out) = 0; - - // Same as above, but convenience function for index.cpp. - virtual void preprocessed_distance(PQScratch &pq_scratch, - const uint32_t n_ids, - std::vector &dists_out) = 0; - - // Currently this function is required for DiskPQ. However, it too can be - // subsumed under preprocessed_distance if we add the appropriate scratch - // variables to PQScratch and initialize them in - // pq_flash_index.cpp::disk_iterate_to_fixed_point() - virtual float brute_force_distance(const float *query_vec, - uint8_t *base_vec) = 0; + // Number of chunks in the PQ table. Depends on the compression level used. + // Has to be < ndim + virtual uint32_t get_num_chunks() const = 0; + + // Preprocess the query by computing chunk distances from the query vector to + // various centroids. Since we don't want this class to do scratch management, + // we will take a PQScratch object which can come either from Index class or + // PQFlashIndex class. + virtual void preprocess_query(const data_t *query_vec, uint32_t query_dim, PQScratch &pq_scratch) = 0; + + // Workhorse + // This function must be called after preprocess_query + virtual void preprocessed_distance(PQScratch &pq_scratch, const uint32_t id_count, float *dists_out) = 0; + + // Same as above, but convenience function for index.cpp. + virtual void preprocessed_distance(PQScratch &pq_scratch, const uint32_t n_ids, + std::vector &dists_out) = 0; + + // Currently this function is required for DiskPQ. However, it too can be + // subsumed under preprocessed_distance if we add the appropriate scratch + // variables to PQScratch and initialize them in + // pq_flash_index.cpp::disk_iterate_to_fixed_point() + virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) = 0; }; } // namespace diskann diff --git a/include/scratch.h b/include/scratch.h index 9775a7386..79bb027de 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -17,154 +17,200 @@ #include "defaults.h" #include "neighbor.h" -namespace diskann { +namespace diskann +{ template class PQScratch; // // AbstractScratch space for in-memory index based search // -template class InMemQueryScratch : public AbstractScratch { -public: - ~InMemQueryScratch(); - InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, - uint32_t maxc, size_t dim, size_t aligned_dim, - size_t alignment_factor, bool init_pq_scratch = false); - void resize_for_new_L(uint32_t new_search_l); - void clear(); - - inline uint32_t get_L() { return _L; } - inline uint32_t get_R() { return _R; } - inline uint32_t get_maxc() { return _maxc; } - inline T *aligned_query() { return this->_aligned_query_T; } - inline PQScratch *pq_scratch() { return this->_pq_scratch; } - inline std::vector &pool() { return _pool; } - inline NeighborPriorityQueue &best_l_nodes() { return _best_l_nodes; } - inline std::vector &occlude_factor() { return _occlude_factor; } - inline tsl::robin_set &inserted_into_pool_rs() { - return _inserted_into_pool_rs; - } - inline boost::dynamic_bitset<> &inserted_into_pool_bs() { - return *_inserted_into_pool_bs; - } - inline std::vector &id_scratch() { return _id_scratch; } - inline std::vector &dist_scratch() { return _dist_scratch; } - inline tsl::robin_set &expanded_nodes_set() { - return _expanded_nodes_set; - } - inline std::vector &expanded_nodes_vec() { - return _expanded_nghrs_vec; - } - inline std::vector &occlude_list_output() { - return _occlude_list_output; - } - -private: - uint32_t _L; - uint32_t _R; - uint32_t _maxc; - - // _pool stores all neighbors explored from best_L_nodes. - // Usually around L+R, but could be higher. - // Initialized to 3L+R for some slack, expands as needed. - std::vector _pool; - - // _best_l_nodes is reserved for storing best L entries - // Underlying storage is L+1 to support inserts - NeighborPriorityQueue _best_l_nodes; - - // _occlude_factor.size() >= pool.size() in occlude_list function - // _pool is clipped to maxc in occlude_list before affecting _occlude_factor - // _occlude_factor is initialized to maxc size - std::vector _occlude_factor; - - // Capacity initialized to 20L - tsl::robin_set _inserted_into_pool_rs; - - // Use a pointer here to allow for forward declaration of dynamic_bitset - // in public headers to avoid making boost a dependency for clients - // of DiskANN. - boost::dynamic_bitset<> *_inserted_into_pool_bs; - - // _id_scratch.size() must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp - std::vector _id_scratch; - - // _dist_scratch must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp - // _dist_scratch should be at least the size of id_scratch - std::vector _dist_scratch; - - // Buffers used in process delete, capacity increases as needed - tsl::robin_set _expanded_nodes_set; - std::vector _expanded_nghrs_vec; - std::vector _occlude_list_output; +template class InMemQueryScratch : public AbstractScratch +{ + public: + ~InMemQueryScratch(); + InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim, + size_t alignment_factor, bool init_pq_scratch = false); + void resize_for_new_L(uint32_t new_search_l); + void clear(); + + inline uint32_t get_L() + { + return _L; + } + inline uint32_t get_R() + { + return _R; + } + inline uint32_t get_maxc() + { + return _maxc; + } + inline T *aligned_query() + { + return this->_aligned_query_T; + } + inline PQScratch *pq_scratch() + { + return this->_pq_scratch; + } + inline std::vector &pool() + { + return _pool; + } + inline NeighborPriorityQueue &best_l_nodes() + { + return _best_l_nodes; + } + inline std::vector &occlude_factor() + { + return _occlude_factor; + } + inline tsl::robin_set &inserted_into_pool_rs() + { + return _inserted_into_pool_rs; + } + inline boost::dynamic_bitset<> &inserted_into_pool_bs() + { + return *_inserted_into_pool_bs; + } + inline std::vector &id_scratch() + { + return _id_scratch; + } + inline std::vector &dist_scratch() + { + return _dist_scratch; + } + inline tsl::robin_set &expanded_nodes_set() + { + return _expanded_nodes_set; + } + inline std::vector &expanded_nodes_vec() + { + return _expanded_nghrs_vec; + } + inline std::vector &occlude_list_output() + { + return _occlude_list_output; + } + + private: + uint32_t _L; + uint32_t _R; + uint32_t _maxc; + + // _pool stores all neighbors explored from best_L_nodes. + // Usually around L+R, but could be higher. + // Initialized to 3L+R for some slack, expands as needed. + std::vector _pool; + + // _best_l_nodes is reserved for storing best L entries + // Underlying storage is L+1 to support inserts + NeighborPriorityQueue _best_l_nodes; + + // _occlude_factor.size() >= pool.size() in occlude_list function + // _pool is clipped to maxc in occlude_list before affecting _occlude_factor + // _occlude_factor is initialized to maxc size + std::vector _occlude_factor; + + // Capacity initialized to 20L + tsl::robin_set _inserted_into_pool_rs; + + // Use a pointer here to allow for forward declaration of dynamic_bitset + // in public headers to avoid making boost a dependency for clients + // of DiskANN. + boost::dynamic_bitset<> *_inserted_into_pool_bs; + + // _id_scratch.size() must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp + std::vector _id_scratch; + + // _dist_scratch must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp + // _dist_scratch should be at least the size of id_scratch + std::vector _dist_scratch; + + // Buffers used in process delete, capacity increases as needed + tsl::robin_set _expanded_nodes_set; + std::vector _expanded_nghrs_vec; + std::vector _occlude_list_output; }; // // AbstractScratch space for SSD index based search // -template class SSDQueryScratch : public AbstractScratch { -public: - T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim] +template class SSDQueryScratch : public AbstractScratch +{ + public: + T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim] - char *sector_scratch = - nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN] - size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use + char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN] + size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use - tsl::robin_set visited; - NeighborPriorityQueue retset; - std::vector full_retset; + tsl::robin_set visited; + NeighborPriorityQueue retset; + std::vector full_retset; - SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); - ~SSDQueryScratch(); + SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); + ~SSDQueryScratch(); - void reset(); + void reset(); }; -template class SSDThreadData { -public: - SSDQueryScratch scratch; - IOContext ctx; +template class SSDThreadData +{ + public: + SSDQueryScratch scratch; + IOContext ctx; - SSDThreadData(size_t aligned_dim, size_t visited_reserve); - void clear(); + SSDThreadData(size_t aligned_dim, size_t visited_reserve); + void clear(); }; // // Class to avoid the hassle of pushing and popping the query scratch. // -template class ScratchStoreManager { -public: - ScratchStoreManager(ConcurrentQueue &query_scratch) - : _scratch_pool(query_scratch) { - _scratch = query_scratch.pop(); - while (_scratch == nullptr) { - query_scratch.wait_for_push_notify(); - _scratch = query_scratch.pop(); - } - } - T *scratch_space() { return _scratch; } - - ~ScratchStoreManager() { - _scratch->clear(); - _scratch_pool.push(_scratch); - _scratch_pool.push_notify_all(); - } - - void destroy() { - while (!_scratch_pool.empty()) { - auto scratch = _scratch_pool.pop(); - while (scratch == nullptr) { - _scratch_pool.wait_for_push_notify(); - scratch = _scratch_pool.pop(); - } - delete scratch; - } - } - -private: - T *_scratch; - ConcurrentQueue &_scratch_pool; - ScratchStoreManager(const ScratchStoreManager &); - ScratchStoreManager &operator=(const ScratchStoreManager &); +template class ScratchStoreManager +{ + public: + ScratchStoreManager(ConcurrentQueue &query_scratch) : _scratch_pool(query_scratch) + { + _scratch = query_scratch.pop(); + while (_scratch == nullptr) + { + query_scratch.wait_for_push_notify(); + _scratch = query_scratch.pop(); + } + } + T *scratch_space() + { + return _scratch; + } + + ~ScratchStoreManager() + { + _scratch->clear(); + _scratch_pool.push(_scratch); + _scratch_pool.push_notify_all(); + } + + void destroy() + { + while (!_scratch_pool.empty()) + { + auto scratch = _scratch_pool.pop(); + while (scratch == nullptr) + { + _scratch_pool.wait_for_push_notify(); + scratch = _scratch_pool.pop(); + } + delete scratch; + } + } + + private: + T *_scratch; + ConcurrentQueue &_scratch_pool; + ScratchStoreManager(const ScratchStoreManager &); + ScratchStoreManager &operator=(const ScratchStoreManager &); }; } // namespace diskann diff --git a/include/simd_utils.h b/include/simd_utils.h index 405c48896..da59c0cde 100644 --- a/include/simd_utils.h +++ b/include/simd_utils.h @@ -9,97 +9,98 @@ #include #endif -namespace diskann { -static inline __m256 _mm256_mul_epi8(__m256i X) { - __m256i zero = _mm256_setzero_si256(); +namespace diskann +{ +static inline __m256 _mm256_mul_epi8(__m256i X) +{ + __m256i zero = _mm256_setzero_si256(); - __m256i sign_x = _mm256_cmpgt_epi8(zero, X); + __m256i sign_x = _mm256_cmpgt_epi8(zero, X); - __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); - __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); + __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); + __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); - return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, xlo), - _mm256_madd_epi16(xhi, xhi))); + return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, xlo), _mm256_madd_epi16(xhi, xhi))); } -static inline __m128 _mm_mulhi_epi8(__m128i X) { - __m128i zero = _mm_setzero_si128(); - __m128i sign_x = _mm_cmplt_epi8(X, zero); - __m128i xhi = _mm_unpackhi_epi8(X, sign_x); +static inline __m128 _mm_mulhi_epi8(__m128i X) +{ + __m128i zero = _mm_setzero_si128(); + __m128i sign_x = _mm_cmplt_epi8(X, zero); + __m128i xhi = _mm_unpackhi_epi8(X, sign_x); - return _mm_cvtepi32_ps( - _mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi))); + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi))); } -static inline __m128 _mm_mulhi_epi8_shift32(__m128i X) { - __m128i zero = _mm_setzero_si128(); - X = _mm_srli_epi64(X, 32); - __m128i sign_x = _mm_cmplt_epi8(X, zero); - __m128i xhi = _mm_unpackhi_epi8(X, sign_x); +static inline __m128 _mm_mulhi_epi8_shift32(__m128i X) +{ + __m128i zero = _mm_setzero_si128(); + X = _mm_srli_epi64(X, 32); + __m128i sign_x = _mm_cmplt_epi8(X, zero); + __m128i xhi = _mm_unpackhi_epi8(X, sign_x); - return _mm_cvtepi32_ps( - _mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi))); + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_setzero_si128(), _mm_madd_epi16(xhi, xhi))); } -static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y) { - __m128i zero = _mm_setzero_si128(); +static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y) +{ + __m128i zero = _mm_setzero_si128(); - __m128i sign_x = _mm_cmplt_epi8(X, zero); - __m128i sign_y = _mm_cmplt_epi8(Y, zero); + __m128i sign_x = _mm_cmplt_epi8(X, zero); + __m128i sign_y = _mm_cmplt_epi8(Y, zero); - __m128i xlo = _mm_unpacklo_epi8(X, sign_x); - __m128i xhi = _mm_unpackhi_epi8(X, sign_x); - __m128i ylo = _mm_unpacklo_epi8(Y, sign_y); - __m128i yhi = _mm_unpackhi_epi8(Y, sign_y); + __m128i xlo = _mm_unpacklo_epi8(X, sign_x); + __m128i xhi = _mm_unpackhi_epi8(X, sign_x); + __m128i ylo = _mm_unpacklo_epi8(Y, sign_y); + __m128i yhi = _mm_unpackhi_epi8(Y, sign_y); - return _mm_cvtepi32_ps( - _mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi))); + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi))); } -static inline __m128 _mm_mul_epi8(__m128i X) { - __m128i zero = _mm_setzero_si128(); - __m128i sign_x = _mm_cmplt_epi8(X, zero); - __m128i xlo = _mm_unpacklo_epi8(X, sign_x); - __m128i xhi = _mm_unpackhi_epi8(X, sign_x); - - return _mm_cvtepi32_ps( - _mm_add_epi32(_mm_madd_epi16(xlo, xlo), _mm_madd_epi16(xhi, xhi))); +static inline __m128 _mm_mul_epi8(__m128i X) +{ + __m128i zero = _mm_setzero_si128(); + __m128i sign_x = _mm_cmplt_epi8(X, zero); + __m128i xlo = _mm_unpacklo_epi8(X, sign_x); + __m128i xhi = _mm_unpackhi_epi8(X, sign_x); + + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, xlo), _mm_madd_epi16(xhi, xhi))); } -static inline __m128 _mm_mul32_pi8(__m128i X, __m128i Y) { - __m128i xlo = _mm_cvtepi8_epi16(X), ylo = _mm_cvtepi8_epi16(Y); - return _mm_cvtepi32_ps( - _mm_unpacklo_epi32(_mm_madd_epi16(xlo, ylo), _mm_setzero_si128())); +static inline __m128 _mm_mul32_pi8(__m128i X, __m128i Y) +{ + __m128i xlo = _mm_cvtepi8_epi16(X), ylo = _mm_cvtepi8_epi16(Y); + return _mm_cvtepi32_ps(_mm_unpacklo_epi32(_mm_madd_epi16(xlo, ylo), _mm_setzero_si128())); } -static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y) { - __m256i zero = _mm256_setzero_si256(); +static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y) +{ + __m256i zero = _mm256_setzero_si256(); - __m256i sign_x = _mm256_cmpgt_epi8(zero, X); - __m256i sign_y = _mm256_cmpgt_epi8(zero, Y); + __m256i sign_x = _mm256_cmpgt_epi8(zero, X); + __m256i sign_y = _mm256_cmpgt_epi8(zero, Y); - __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); - __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); - __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y); - __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y); + __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); + __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); + __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y); + __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y); - return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), - _mm256_madd_epi16(xhi, yhi))); + return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi))); } -static inline __m256 _mm256_mul32_pi8(__m128i X, __m128i Y) { - __m256i xlo = _mm256_cvtepi8_epi16(X), ylo = _mm256_cvtepi8_epi16(Y); - return _mm256_blend_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(xlo, ylo)), - _mm256_setzero_ps(), 252); +static inline __m256 _mm256_mul32_pi8(__m128i X, __m128i Y) +{ + __m256i xlo = _mm256_cvtepi8_epi16(X), ylo = _mm256_cvtepi8_epi16(Y); + return _mm256_blend_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(xlo, ylo)), _mm256_setzero_ps(), 252); } -static inline float _mm256_reduce_add_ps(__m256 x) { - /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ - const __m128 x128 = - _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); - /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ - const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); - /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ - const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); - /* Conversion to float is a no-op on x86-64 */ - return _mm_cvtss_f32(x32); +static inline float _mm256_reduce_add_ps(__m256 x) +{ + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); } } // namespace diskann diff --git a/include/tag_uint128.h b/include/tag_uint128.h index a589fadfc..642de3159 100644 --- a/include/tag_uint128.h +++ b/include/tag_uint128.h @@ -2,58 +2,67 @@ #include #include -namespace diskann { +namespace diskann +{ #pragma pack(push, 1) -struct tag_uint128 { - std::uint64_t _data1 = 0; - std::uint64_t _data2 = 0; +struct tag_uint128 +{ + std::uint64_t _data1 = 0; + std::uint64_t _data2 = 0; - bool operator==(const tag_uint128 &other) const { - return _data1 == other._data1 && _data2 == other._data2; - } + bool operator==(const tag_uint128 &other) const + { + return _data1 == other._data1 && _data2 == other._data2; + } - bool operator==(std::uint64_t other) const { - return _data1 == other && _data2 == 0; - } + bool operator==(std::uint64_t other) const + { + return _data1 == other && _data2 == 0; + } - tag_uint128 &operator=(const tag_uint128 &other) { - _data1 = other._data1; - _data2 = other._data2; + tag_uint128 &operator=(const tag_uint128 &other) + { + _data1 = other._data1; + _data2 = other._data2; - return *this; - } + return *this; + } - tag_uint128 &operator=(std::uint64_t other) { - _data1 = other; - _data2 = 0; + tag_uint128 &operator=(std::uint64_t other) + { + _data1 = other; + _data2 = 0; - return *this; - } + return *this; + } }; #pragma pack(pop) } // namespace diskann -namespace std { +namespace std +{ // Hash 128 input bits down to 64 bits of output. // This is intended to be a reasonably good hash function. -inline std::uint64_t Hash128to64(const std::uint64_t &low, - const std::uint64_t &high) { - // Murmur-inspired hashing. - const std::uint64_t kMul = 0x9ddfea08eb382d69ULL; - std::uint64_t a = (low ^ high) * kMul; - a ^= (a >> 47); - std::uint64_t b = (high ^ a) * kMul; - b ^= (b >> 47); - b *= kMul; - return b; +inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high) +{ + // Murmur-inspired hashing. + const std::uint64_t kMul = 0x9ddfea08eb382d69ULL; + std::uint64_t a = (low ^ high) * kMul; + a ^= (a >> 47); + std::uint64_t b = (high ^ a) * kMul; + b ^= (b >> 47); + b *= kMul; + return b; } -template <> struct hash { - size_t operator()(const diskann::tag_uint128 &key) const noexcept { - return Hash128to64(key._data1, key._data2); // map -0 to 0 - } +template <> struct hash +{ + size_t operator()(const diskann::tag_uint128 &key) const noexcept + { + return Hash128to64(key._data1, key._data2); // map -0 to 0 + } }; } // namespace std \ No newline at end of file diff --git a/include/timer.h b/include/timer.h index 32765f45f..325edf302 100644 --- a/include/timer.h +++ b/include/timer.h @@ -4,27 +4,37 @@ #include -namespace diskann { -class Timer { - typedef std::chrono::high_resolution_clock _clock; - std::chrono::time_point<_clock> check_point; +namespace diskann +{ +class Timer +{ + typedef std::chrono::high_resolution_clock _clock; + std::chrono::time_point<_clock> check_point; -public: - Timer() : check_point(_clock::now()) {} + public: + Timer() : check_point(_clock::now()) + { + } - void reset() { check_point = _clock::now(); } + void reset() + { + check_point = _clock::now(); + } - long long elapsed() const { - return std::chrono::duration_cast(_clock::now() - - check_point) - .count(); - } + long long elapsed() const + { + return std::chrono::duration_cast(_clock::now() - check_point).count(); + } - float elapsed_seconds() const { return (float)elapsed() / 1000000.0f; } + float elapsed_seconds() const + { + return (float)elapsed() / 1000000.0f; + } - std::string elapsed_seconds_for_step(const std::string &step) const { - return std::string("Time for ") + step + std::string(": ") + - std::to_string(elapsed_seconds()) + std::string(" seconds"); - } + std::string elapsed_seconds_for_step(const std::string &step) const + { + return std::string("Time for ") + step + std::string(": ") + std::to_string(elapsed_seconds()) + + std::string(" seconds"); + } }; } // namespace diskann diff --git a/include/types.h b/include/types.h index 2a7d5c9e2..58d8d40a4 100644 --- a/include/types.h +++ b/include/types.h @@ -8,7 +8,8 @@ #include #include -namespace diskann { +namespace diskann +{ typedef uint32_t location_t; using DataType = std::any; diff --git a/include/utils.h b/include/utils.h index 43e6b25da..6672c6fc4 100644 --- a/include/utils.h +++ b/include/utils.h @@ -38,8 +38,7 @@ typedef int FileHandle; // taken from // https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h // round up X to the nearest multiple of Y -#define ROUND_UP(X, Y) \ - ((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y)) +#define ROUND_UP(X, Y) ((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y)) #define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) @@ -50,776 +49,794 @@ typedef int FileHandle; #define IS_ALIGNED(X, Y) ((uint64_t)(X) % (uint64_t)(Y) == 0) #define IS_512_ALIGNED(X) IS_ALIGNED(X, 512) #define IS_4096_ALIGNED(X) IS_ALIGNED(X, 4096) -#define METADATA_SIZE \ - 4096 // all metadata of individual sub-component files is written in first - // 4KB for unified files +#define METADATA_SIZE \ + 4096 // all metadata of individual sub-component files is written in first + // 4KB for unified files #define BUFFER_SIZE_FOR_CACHED_IO (size_t)1024 * (size_t)1048576 #define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||" #define PBWIDTH 60 -inline bool file_exists_impl(const std::string &name, bool dirCheck = false) { - int val; +inline bool file_exists_impl(const std::string &name, bool dirCheck = false) +{ + int val; #ifndef _WINDOWS - struct stat buffer; - val = stat(name.c_str(), &buffer); + struct stat buffer; + val = stat(name.c_str(), &buffer); #else - // It is the 21st century but Windows API still thinks in 32-bit terms. - // Turns out calling stat() on a file > 4GB results in errno = 132 - // (OVERFLOW). How silly is this!? So calling _stat64() - struct _stat64 buffer; - val = _stat64(name.c_str(), &buffer); + // It is the 21st century but Windows API still thinks in 32-bit terms. + // Turns out calling stat() on a file > 4GB results in errno = 132 + // (OVERFLOW). How silly is this!? So calling _stat64() + struct _stat64 buffer; + val = _stat64(name.c_str(), &buffer); #endif - if (val != 0) { - switch (errno) { - case EINVAL: - diskann::cout << "Invalid argument passed to stat()" << std::endl; - break; - case ENOENT: - // file is not existing, not an issue, so we won't cout anything. - break; - default: - diskann::cout << "Unexpected error in stat():" << errno << std::endl; - break; + if (val != 0) + { + switch (errno) + { + case EINVAL: + diskann::cout << "Invalid argument passed to stat()" << std::endl; + break; + case ENOENT: + // file is not existing, not an issue, so we won't cout anything. + break; + default: + diskann::cout << "Unexpected error in stat():" << errno << std::endl; + break; + } + return false; + } + else + { + // the file entry exists. If reqd, check if this is a directory. + return dirCheck ? buffer.st_mode & S_IFDIR : true; } - return false; - } else { - // the file entry exists. If reqd, check if this is a directory. - return dirCheck ? buffer.st_mode & S_IFDIR : true; - } } -inline bool file_exists(const std::string &name, bool dirCheck = false) { +inline bool file_exists(const std::string &name, bool dirCheck = false) +{ #ifdef EXEC_ENV_OLS - bool exists = file_exists_impl(name, dirCheck); - if (exists) { - return true; - } - if (!dirCheck) { - // try with .enc extension - std::string enc_name = name + ENCRYPTED_EXTENSION; - return file_exists_impl(enc_name, dirCheck); - } else { - return exists; - } + bool exists = file_exists_impl(name, dirCheck); + if (exists) + { + return true; + } + if (!dirCheck) + { + // try with .enc extension + std::string enc_name = name + ENCRYPTED_EXTENSION; + return file_exists_impl(enc_name, dirCheck); + } + else + { + return exists; + } #else - return file_exists_impl(name, dirCheck); + return file_exists_impl(name, dirCheck); #endif } -inline void open_file_to_write(std::ofstream &writer, - const std::string &filename) { - writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); - if (!file_exists(filename)) - writer.open(filename, std::ios::binary | std::ios::out); - else - writer.open(filename, std::ios::binary | std::ios::in | std::ios::out); - - if (writer.fail()) { - char buff[1024]; +inline void open_file_to_write(std::ofstream &writer, const std::string &filename) +{ + writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); + if (!file_exists(filename)) + writer.open(filename, std::ios::binary | std::ios::out); + else + writer.open(filename, std::ios::binary | std::ios::in | std::ios::out); + + if (writer.fail()) + { + char buff[1024]; #ifdef _WINDOWS - auto ret = std::to_string(strerror_s(buff, 1024, errno)); + auto ret = std::to_string(strerror_s(buff, 1024, errno)); #else - auto ret = std::string(strerror_r(errno, buff, 1024)); + auto ret = std::string(strerror_r(errno, buff, 1024)); #endif - auto message = std::string("Failed to open file") + filename + - " for write because " + buff + ", ret=" + ret; - diskann::cerr << message << std::endl; - throw diskann::ANNException(message, -1); - } -} - -inline size_t get_file_size(const std::string &fname) { - std::ifstream reader(fname, std::ios::binary | std::ios::ate); - if (!reader.fail() && reader.is_open()) { - size_t end_pos = reader.tellg(); - reader.close(); - return end_pos; - } else { - diskann::cerr << "Could not open file: " << fname << std::endl; - return 0; - } -} - -inline int delete_file(const std::string &fileName) { - if (file_exists(fileName)) { - auto rc = ::remove(fileName.c_str()); - if (rc != 0) { - diskann::cerr - << "Could not delete file: " << fileName - << " even though it exists. This might indicate a permissions " - "issue. " - "If you see this message, please contact the diskann team." - << std::endl; + auto message = std::string("Failed to open file") + filename + " for write because " + buff + ", ret=" + ret; + diskann::cerr << message << std::endl; + throw diskann::ANNException(message, -1); } - return rc; - } else { - return 0; - } } -// generates formatted_label and _labels_map file. -inline void convert_labels_string_to_int(const std::string &inFileName, - const std::string &outFileName, - const std::string &mapFileName, - const std::string &unv_label) { - std::unordered_map string_int_map; - std::ofstream label_writer(outFileName); - std::ifstream label_reader(inFileName); - if (unv_label != "") - string_int_map[unv_label] = - 0; // if universal label is provided map it to 0 always - std::string line, token; - while (std::getline(label_reader, line)) { - std::istringstream new_iss(line); - std::vector lbls; - while (getline(new_iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - if (string_int_map.find(token) == string_int_map.end()) { - uint32_t nextId = (uint32_t)string_int_map.size() + 1; - string_int_map[token] = nextId; // nextId can never be 0 - } - lbls.push_back(string_int_map[token]); +inline size_t get_file_size(const std::string &fname) +{ + std::ifstream reader(fname, std::ios::binary | std::ios::ate); + if (!reader.fail() && reader.is_open()) + { + size_t end_pos = reader.tellg(); + reader.close(); + return end_pos; + } + else + { + diskann::cerr << "Could not open file: " << fname << std::endl; + return 0; } - if (lbls.size() <= 0) { - std::cout << "No label found"; - exit(-1); +} + +inline int delete_file(const std::string &fileName) +{ + if (file_exists(fileName)) + { + auto rc = ::remove(fileName.c_str()); + if (rc != 0) + { + diskann::cerr << "Could not delete file: " << fileName + << " even though it exists. This might indicate a permissions " + "issue. " + "If you see this message, please contact the diskann team." + << std::endl; + } + return rc; } - for (size_t j = 0; j < lbls.size(); j++) { - if (j != lbls.size() - 1) - label_writer << lbls[j] << ","; - else - label_writer << lbls[j] << std::endl; + else + { + return 0; } - } - label_writer.close(); +} - std::ofstream map_writer(mapFileName); - for (auto mp : string_int_map) { - map_writer << mp.first << "\t" << mp.second << std::endl; - } - map_writer.close(); +// generates formatted_label and _labels_map file. +inline void convert_labels_string_to_int(const std::string &inFileName, const std::string &outFileName, + const std::string &mapFileName, const std::string &unv_label) +{ + std::unordered_map string_int_map; + std::ofstream label_writer(outFileName); + std::ifstream label_reader(inFileName); + if (unv_label != "") + string_int_map[unv_label] = 0; // if universal label is provided map it to 0 always + std::string line, token; + while (std::getline(label_reader, line)) + { + std::istringstream new_iss(line); + std::vector lbls; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + if (string_int_map.find(token) == string_int_map.end()) + { + uint32_t nextId = (uint32_t)string_int_map.size() + 1; + string_int_map[token] = nextId; // nextId can never be 0 + } + lbls.push_back(string_int_map[token]); + } + if (lbls.size() <= 0) + { + std::cout << "No label found"; + exit(-1); + } + for (size_t j = 0; j < lbls.size(); j++) + { + if (j != lbls.size() - 1) + label_writer << lbls[j] << ","; + else + label_writer << lbls[j] << std::endl; + } + } + label_writer.close(); + + std::ofstream map_writer(mapFileName); + for (auto mp : string_int_map) + { + map_writer << mp.first << "\t" << mp.second << std::endl; + } + map_writer.close(); } #ifdef EXEC_ENV_OLS class AlignedFileReader; #endif -namespace diskann { +namespace diskann +{ static const size_t MAX_SIZE_OF_STREAMBUF = 2LL * 1024 * 1024 * 1024; -inline void print_error_and_terminate(std::stringstream &error_stream) { - diskann::cerr << error_stream.str() << std::endl; - throw diskann::ANNException(error_stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); +inline void print_error_and_terminate(std::stringstream &error_stream) +{ + diskann::cerr << error_stream.str() << std::endl; + throw diskann::ANNException(error_stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } -inline void report_memory_allocation_failure() { - std::stringstream stream; - stream << "Memory Allocation Failed."; - print_error_and_terminate(stream); +inline void report_memory_allocation_failure() +{ + std::stringstream stream; + stream << "Memory Allocation Failed."; + print_error_and_terminate(stream); } -inline void report_misalignment_of_requested_size(size_t align) { - std::stringstream stream; - stream << "Requested memory size is not a multiple of " << align - << ". Can not be allocated."; - print_error_and_terminate(stream); +inline void report_misalignment_of_requested_size(size_t align) +{ + std::stringstream stream; + stream << "Requested memory size is not a multiple of " << align << ". Can not be allocated."; + print_error_and_terminate(stream); } -inline void alloc_aligned(void **ptr, size_t size, size_t align) { - *ptr = nullptr; - if (IS_ALIGNED(size, align) == 0) - report_misalignment_of_requested_size(align); +inline void alloc_aligned(void **ptr, size_t size, size_t align) +{ + *ptr = nullptr; + if (IS_ALIGNED(size, align) == 0) + report_misalignment_of_requested_size(align); #ifndef _WINDOWS - *ptr = ::aligned_alloc(align, size); + *ptr = ::aligned_alloc(align, size); #else - *ptr = ::_aligned_malloc(size, align); // note the swapped arguments! + *ptr = ::_aligned_malloc(size, align); // note the swapped arguments! #endif - if (*ptr == nullptr) - report_memory_allocation_failure(); + if (*ptr == nullptr) + report_memory_allocation_failure(); } -inline void realloc_aligned(void **ptr, size_t size, size_t align) { - if (IS_ALIGNED(size, align) == 0) - report_misalignment_of_requested_size(align); +inline void realloc_aligned(void **ptr, size_t size, size_t align) +{ + if (IS_ALIGNED(size, align) == 0) + report_misalignment_of_requested_size(align); #ifdef _WINDOWS - *ptr = ::_aligned_realloc(*ptr, size, align); + *ptr = ::_aligned_realloc(*ptr, size, align); #else - diskann::cerr << "No aligned realloc on GCC. Must malloc and mem_align, " - "left it out for now." - << std::endl; + diskann::cerr << "No aligned realloc on GCC. Must malloc and mem_align, " + "left it out for now." + << std::endl; #endif - if (*ptr == nullptr) - report_memory_allocation_failure(); + if (*ptr == nullptr) + report_memory_allocation_failure(); } -inline void check_stop(std::string arnd) { - int brnd; - diskann::cout << arnd << std::endl; - std::cin >> brnd; +inline void check_stop(std::string arnd) +{ + int brnd; + diskann::cout << arnd << std::endl; + std::cin >> brnd; } -inline void aligned_free(void *ptr) { - // Gopal. Must have a check here if the pointer was actually allocated by - // _alloc_aligned - if (ptr == nullptr) { - return; - } +inline void aligned_free(void *ptr) +{ + // Gopal. Must have a check here if the pointer was actually allocated by + // _alloc_aligned + if (ptr == nullptr) + { + return; + } #ifndef _WINDOWS - free(ptr); + free(ptr); #else - ::_aligned_free(ptr); + ::_aligned_free(ptr); #endif } -inline void GenRandom(std::mt19937 &rng, unsigned *addr, unsigned size, - unsigned N) { - for (unsigned i = 0; i < size; ++i) { - addr[i] = rng() % (N - size); - } +inline void GenRandom(std::mt19937 &rng, unsigned *addr, unsigned size, unsigned N) +{ + for (unsigned i = 0; i < size; ++i) + { + addr[i] = rng() % (N - size); + } - std::sort(addr, addr + size); - for (unsigned i = 1; i < size; ++i) { - if (addr[i] <= addr[i - 1]) { - addr[i] = addr[i - 1] + 1; + std::sort(addr, addr + size); + for (unsigned i = 1; i < size; ++i) + { + if (addr[i] <= addr[i - 1]) + { + addr[i] = addr[i - 1] + 1; + } + } + unsigned off = rng() % N; + for (unsigned i = 0; i < size; ++i) + { + addr[i] = (addr[i] + off) % N; } - } - unsigned off = rng() % N; - for (unsigned i = 0; i < size; ++i) { - addr[i] = (addr[i] + off) % N; - } } // get_bin_metadata functions START -inline void get_bin_metadata_impl(std::basic_istream &reader, - size_t &nrows, size_t &ncols, - size_t offset = 0) { - int nrows_32, ncols_32; - reader.seekg(offset, reader.beg); - reader.read((char *)&nrows_32, sizeof(int)); - reader.read((char *)&ncols_32, sizeof(int)); - nrows = nrows_32; - ncols = ncols_32; +inline void get_bin_metadata_impl(std::basic_istream &reader, size_t &nrows, size_t &ncols, size_t offset = 0) +{ + int nrows_32, ncols_32; + reader.seekg(offset, reader.beg); + reader.read((char *)&nrows_32, sizeof(int)); + reader.read((char *)&ncols_32, sizeof(int)); + nrows = nrows_32; + ncols = ncols_32; } #ifdef EXEC_ENV_OLS -inline void get_bin_metadata(MemoryMappedFiles &files, - const std::string &bin_file, size_t &nrows, - size_t &ncols, size_t offset = 0) { - diskann::cout << "Getting metadata for file: " << bin_file << std::endl; - auto fc = files.getContent(bin_file); - // auto cb = ContentBuf((char*) fc._content, fc._size); - // std::basic_istream reader(&cb); - // get_bin_metadata_impl(reader, nrows, ncols, offset); - - int nrows_32, ncols_32; - int32_t *metadata_ptr = (int32_t *)((char *)fc._content + offset); - nrows_32 = *metadata_ptr; - ncols_32 = *(metadata_ptr + 1); - nrows = nrows_32; - ncols = ncols_32; +inline void get_bin_metadata(MemoryMappedFiles &files, const std::string &bin_file, size_t &nrows, size_t &ncols, + size_t offset = 0) +{ + diskann::cout << "Getting metadata for file: " << bin_file << std::endl; + auto fc = files.getContent(bin_file); + // auto cb = ContentBuf((char*) fc._content, fc._size); + // std::basic_istream reader(&cb); + // get_bin_metadata_impl(reader, nrows, ncols, offset); + + int nrows_32, ncols_32; + int32_t *metadata_ptr = (int32_t *)((char *)fc._content + offset); + nrows_32 = *metadata_ptr; + ncols_32 = *(metadata_ptr + 1); + nrows = nrows_32; + ncols = ncols_32; } #endif -inline void get_bin_metadata(const std::string &bin_file, size_t &nrows, - size_t &ncols, size_t offset = 0) { - std::ifstream reader(bin_file.c_str(), std::ios::binary); - get_bin_metadata_impl(reader, nrows, ncols, offset); +inline void get_bin_metadata(const std::string &bin_file, size_t &nrows, size_t &ncols, size_t offset = 0) +{ + std::ifstream reader(bin_file.c_str(), std::ios::binary); + get_bin_metadata_impl(reader, nrows, ncols, offset); } // get_bin_metadata functions END #ifndef EXEC_ENV_OLS -inline size_t get_graph_num_frozen_points(const std::string &graph_file) { - size_t expected_file_size; - uint32_t max_observed_degree, start; - size_t file_frozen_pts; - - std::ifstream in; - in.exceptions(std::ios::badbit | std::ios::failbit); - - in.open(graph_file, std::ios::binary); - in.read((char *)&expected_file_size, sizeof(size_t)); - in.read((char *)&max_observed_degree, sizeof(uint32_t)); - in.read((char *)&start, sizeof(uint32_t)); - in.read((char *)&file_frozen_pts, sizeof(size_t)); - - return file_frozen_pts; +inline size_t get_graph_num_frozen_points(const std::string &graph_file) +{ + size_t expected_file_size; + uint32_t max_observed_degree, start; + size_t file_frozen_pts; + + std::ifstream in; + in.exceptions(std::ios::badbit | std::ios::failbit); + + in.open(graph_file, std::ios::binary); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&max_observed_degree, sizeof(uint32_t)); + in.read((char *)&start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); + + return file_frozen_pts; } #endif -template inline std::string getValues(T *data, size_t num) { - std::stringstream stream; - stream << "["; - for (size_t i = 0; i < num; i++) { - stream << std::to_string(data[i]) << ","; - } - stream << "]" << std::endl; +template inline std::string getValues(T *data, size_t num) +{ + std::stringstream stream; + stream << "["; + for (size_t i = 0; i < num; i++) + { + stream << std::to_string(data[i]) << ","; + } + stream << "]" << std::endl; - return stream.str(); + return stream.str(); } // load_bin functions START template -inline void load_bin_impl(std::basic_istream &reader, T *&data, - size_t &npts, size_t &dim, size_t file_offset = 0) { - int npts_i32, dim_i32; +inline void load_bin_impl(std::basic_istream &reader, T *&data, size_t &npts, size_t &dim, size_t file_offset = 0) +{ + int npts_i32, dim_i32; - reader.seekg(file_offset, reader.beg); - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (unsigned)npts_i32; - dim = (unsigned)dim_i32; + reader.seekg(file_offset, reader.beg); + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (unsigned)npts_i32; + dim = (unsigned)dim_i32; - std::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..." - << std::endl; + std::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "..." << std::endl; - data = new T[npts * dim]; - reader.read((char *)data, npts * dim * sizeof(T)); + data = new T[npts * dim]; + reader.read((char *)data, npts * dim * sizeof(T)); } #ifdef EXEC_ENV_OLS template -inline void load_bin(MemoryMappedFiles &files, const std::string &bin_file, - T *&data, size_t &npts, size_t &dim, size_t offset = 0) { - diskann::cout << "Reading bin file " << bin_file.c_str() - << " at offset: " << offset << "..." << std::endl; - auto fc = files.getContent(bin_file); +inline void load_bin(MemoryMappedFiles &files, const std::string &bin_file, T *&data, size_t &npts, size_t &dim, + size_t offset = 0) +{ + diskann::cout << "Reading bin file " << bin_file.c_str() << " at offset: " << offset << "..." << std::endl; + auto fc = files.getContent(bin_file); - uint32_t t_npts, t_dim; - uint32_t *contentAsIntPtr = (uint32_t *)((char *)fc._content + offset); - t_npts = *(contentAsIntPtr); - t_dim = *(contentAsIntPtr + 1); + uint32_t t_npts, t_dim; + uint32_t *contentAsIntPtr = (uint32_t *)((char *)fc._content + offset); + t_npts = *(contentAsIntPtr); + t_dim = *(contentAsIntPtr + 1); - npts = t_npts; - dim = t_dim; + npts = t_npts; + dim = t_dim; - data = (T *)((char *)fc._content + offset + - 2 * sizeof(uint32_t)); // No need to copy! + data = (T *)((char *)fc._content + offset + 2 * sizeof(uint32_t)); // No need to copy! } -DISKANN_DLLEXPORT void get_bin_metadata(AlignedFileReader &reader, size_t &npts, - size_t &ndim, size_t offset = 0); +DISKANN_DLLEXPORT void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, size_t offset = 0); template -DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, T *&data, - size_t &npts, size_t &ndim, size_t offset = 0); +DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim, size_t offset = 0); template -DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - std::unique_ptr &data, size_t &npts, - size_t &ndim, size_t offset = 0); +DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, size_t &npts, size_t &ndim, + size_t offset = 0); template -DISKANN_DLLEXPORT void -copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, size_t &npts, - size_t &dim, const size_t &rounded_dim, - size_t offset = 0); +DISKANN_DLLEXPORT void copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, size_t &npts, size_t &dim, + const size_t &rounded_dim, size_t offset = 0); // Unlike load_bin, assumes that data is already allocated 'size' entries template -DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, T *data, - size_t size, size_t offset = 0); +DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, T *data, size_t size, size_t offset = 0); -template -DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, T &value, - size_t offset = 0); +template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, T &value, size_t offset = 0); #endif template -inline void load_bin(const std::string &bin_file, T *&data, size_t &npts, - size_t &dim, size_t offset = 0) { - diskann::cout << "Reading bin file " << bin_file.c_str() << " ..." - << std::endl; - std::ifstream reader; - reader.exceptions(std::ifstream::failbit | std::ifstream::badbit); - - try { - diskann::cout << "Opening bin file " << bin_file.c_str() << "... " - << std::endl; - reader.open(bin_file, std::ios::binary | std::ios::ate); - reader.seekg(0); - load_bin_impl(reader, data, npts, dim, offset); - } catch (std::system_error &e) { - throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__); - } - diskann::cout << "done." << std::endl; +inline void load_bin(const std::string &bin_file, T *&data, size_t &npts, size_t &dim, size_t offset = 0) +{ + diskann::cout << "Reading bin file " << bin_file.c_str() << " ..." << std::endl; + std::ifstream reader; + reader.exceptions(std::ifstream::failbit | std::ifstream::badbit); + + try + { + diskann::cout << "Opening bin file " << bin_file.c_str() << "... " << std::endl; + reader.open(bin_file, std::ios::binary | std::ios::ate); + reader.seekg(0); + load_bin_impl(reader, data, npts, dim, offset); + } + catch (std::system_error &e) + { + throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__); + } + diskann::cout << "done." << std::endl; } -inline void wait_for_keystroke() { - int a; - std::cout << "Press any number to continue.." << std::endl; - std::cin >> a; +inline void wait_for_keystroke() +{ + int a; + std::cout << "Press any number to continue.." << std::endl; + std::cin >> a; } // load_bin functions END -inline void load_truthset(const std::string &bin_file, uint32_t *&ids, - float *&dists, size_t &npts, size_t &dim) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." - << std::endl; - size_t actual_file_size = reader.get_file_size(); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (unsigned)npts_i32; - dim = (unsigned)dim_i32; - - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " - << std::endl; +inline void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (unsigned)npts_i32; + dim = (unsigned)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - int truthset_type = -1; // 1 means truthset has ids and distances, 2 means - // only ids, -1 is error - size_t expected_file_size_with_dists = - 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - if (actual_file_size == expected_file_size_with_dists) - truthset_type = 1; + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } +} - size_t expected_file_size_just_ids = - npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); +inline void prune_truthset_for_range(const std::string &bin_file, float range, + std::vector> &groundtruth, size_t &npts) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (unsigned)npts_i32; + uint64_t dim = (unsigned)dim_i32; + uint32_t *ids; + float *dists; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - if (actual_file_size == expected_file_size_just_ids) - truthset_type = 2; + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - if (truthset_type == -1) { - std::stringstream stream; - stream << "Error. File size mismatch. File should have bin format, with " - "npts followed by ngt followed by npts*ngt ids and optionally " - "followed by npts*ngt distance values; actual size: " - << actual_file_size - << ", expected: " << expected_file_size_with_dists << " or " - << expected_file_size_just_ids; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - ids = new uint32_t[npts * dim]; - reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - - if (truthset_type == 1) { - dists = new float[npts * dim]; - reader.read((char *)dists, npts * dim * sizeof(float)); - } -} - -inline void -prune_truthset_for_range(const std::string &bin_file, float range, - std::vector> &groundtruth, - size_t &npts) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " - << std::endl; - size_t actual_file_size = reader.get_file_size(); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (unsigned)npts_i32; - uint64_t dim = (unsigned)dim_i32; - uint32_t *ids; - float *dists; - - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " - << std::endl; - - int truthset_type = -1; // 1 means truthset has ids and distances, 2 means - // only ids, -1 is error - size_t expected_file_size_with_dists = - 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); - - if (actual_file_size == expected_file_size_with_dists) - truthset_type = 1; - - if (truthset_type == -1) { - std::stringstream stream; - stream << "Error. File size mismatch. File should have bin format, with " - "npts followed by ngt followed by npts*ngt ids and optionally " - "followed by npts*ngt distance values; actual size: " - << actual_file_size - << ", expected: " << expected_file_size_with_dists; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - ids = new uint32_t[npts * dim]; - reader.read((char *)ids, npts * dim * sizeof(uint32_t)); - - if (truthset_type == 1) { - dists = new float[npts * dim]; - reader.read((char *)dists, npts * dim * sizeof(float)); - } - float min_dist = std::numeric_limits::max(); - float max_dist = 0; - groundtruth.resize(npts); - for (uint32_t i = 0; i < npts; i++) { - groundtruth[i].clear(); - for (uint32_t j = 0; j < dim; j++) { - if (dists[i * dim + j] <= range) { - groundtruth[i].emplace_back(ids[i * dim + j]); - } - min_dist = min_dist > dists[i * dim + j] ? dists[i * dim + j] : min_dist; - max_dist = max_dist < dists[i * dim + j] ? dists[i * dim + j] : max_dist; + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); } - // std::cout<::max(); + float max_dist = 0; + groundtruth.resize(npts); + for (uint32_t i = 0; i < npts; i++) + { + groundtruth[i].clear(); + for (uint32_t j = 0; j < dim; j++) + { + if (dists[i * dim + j] <= range) + { + groundtruth[i].emplace_back(ids[i * dim + j]); + } + min_dist = min_dist > dists[i * dim + j] ? dists[i * dim + j] : min_dist; + max_dist = max_dist < dists[i * dim + j] ? dists[i * dim + j] : max_dist; + } + // std::cout<> &groundtruth, - uint64_t >_num) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream reader(bin_file, read_blk_size); - diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " - << std::flush; - size_t actual_file_size = reader.get_file_size(); +inline void load_range_truthset(const std::string &bin_file, std::vector> &groundtruth, + uint64_t >_num) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " << std::flush; + size_t actual_file_size = reader.get_file_size(); - int nptsuint32_t, totaluint32_t; - reader.read((char *)&nptsuint32_t, sizeof(int)); - reader.read((char *)&totaluint32_t, sizeof(int)); + int nptsuint32_t, totaluint32_t; + reader.read((char *)&nptsuint32_t, sizeof(int)); + reader.read((char *)&totaluint32_t, sizeof(int)); - gt_num = (uint64_t)nptsuint32_t; - uint64_t total_res = (uint64_t)totaluint32_t; + gt_num = (uint64_t)nptsuint32_t; + uint64_t total_res = (uint64_t)totaluint32_t; - diskann::cout << "Metadata: #pts = " << gt_num - << ", #total_results = " << total_res << "..." << std::endl; + diskann::cout << "Metadata: #pts = " << gt_num << ", #total_results = " << total_res << "..." << std::endl; - size_t expected_file_size = 2 * sizeof(uint32_t) + gt_num * sizeof(uint32_t) + - total_res * sizeof(uint32_t); + size_t expected_file_size = 2 * sizeof(uint32_t) + gt_num * sizeof(uint32_t) + total_res * sizeof(uint32_t); - if (actual_file_size != expected_file_size) { - std::stringstream stream; - stream << "Error. File size mismatch in range truthset. actual size: " - << actual_file_size << ", expected: " << expected_file_size; - diskann::cout << stream.str(); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - groundtruth.clear(); - groundtruth.resize(gt_num); - std::vector gt_count(gt_num); - - reader.read((char *)gt_count.data(), sizeof(uint32_t) * gt_num); - - std::vector gt_stats(gt_count); - std::sort(gt_stats.begin(), gt_stats.end()); - - std::cout << "GT count percentiles:" << std::endl; - for (uint32_t p = 0; p < 100; p += 5) - std::cout << "percentile " << p << ": " - << gt_stats[static_cast(std::floor((p / 100.0) * gt_num))] - << std::endl; - std::cout << "percentile 100" - << ": " << gt_stats[gt_num - 1] << std::endl; - - for (uint32_t i = 0; i < gt_num; i++) { - groundtruth[i].clear(); - groundtruth[i].resize(gt_count[i]); - if (gt_count[i] != 0) - reader.read((char *)groundtruth[i].data(), - sizeof(uint32_t) * gt_count[i]); - } + if (actual_file_size != expected_file_size) + { + std::stringstream stream; + stream << "Error. File size mismatch in range truthset. actual size: " << actual_file_size + << ", expected: " << expected_file_size; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + groundtruth.clear(); + groundtruth.resize(gt_num); + std::vector gt_count(gt_num); + + reader.read((char *)gt_count.data(), sizeof(uint32_t) * gt_num); + + std::vector gt_stats(gt_count); + std::sort(gt_stats.begin(), gt_stats.end()); + + std::cout << "GT count percentiles:" << std::endl; + for (uint32_t p = 0; p < 100; p += 5) + std::cout << "percentile " << p << ": " << gt_stats[static_cast(std::floor((p / 100.0) * gt_num))] + << std::endl; + std::cout << "percentile 100" + << ": " << gt_stats[gt_num - 1] << std::endl; + + for (uint32_t i = 0; i < gt_num; i++) + { + groundtruth[i].clear(); + groundtruth[i].resize(gt_count[i]); + if (gt_count[i] != 0) + reader.read((char *)groundtruth[i].data(), sizeof(uint32_t) * gt_count[i]); + } } #ifdef EXEC_ENV_OLS template -inline void load_bin(MemoryMappedFiles &files, const std::string &bin_file, - std::unique_ptr &data, size_t &npts, size_t &dim, - size_t offset = 0) { - T *ptr; - load_bin(files, bin_file, ptr, npts, dim, offset); - data.reset(ptr); +inline void load_bin(MemoryMappedFiles &files, const std::string &bin_file, std::unique_ptr &data, size_t &npts, + size_t &dim, size_t offset = 0) +{ + T *ptr; + load_bin(files, bin_file, ptr, npts, dim, offset); + data.reset(ptr); } #endif -inline void copy_file(std::string in_file, std::string out_file) { - std::ifstream source(in_file, std::ios::binary); - std::ofstream dest(out_file, std::ios::binary); +inline void copy_file(std::string in_file, std::string out_file) +{ + std::ifstream source(in_file, std::ios::binary); + std::ofstream dest(out_file, std::ios::binary); - std::istreambuf_iterator begin_source(source); - std::istreambuf_iterator end_source; - std::ostreambuf_iterator begin_dest(dest); - std::copy(begin_source, end_source, begin_dest); + std::istreambuf_iterator begin_source(source); + std::istreambuf_iterator end_source; + std::ostreambuf_iterator begin_dest(dest); + std::copy(begin_source, end_source, begin_dest); - source.close(); - dest.close(); + source.close(); + dest.close(); } -DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, - unsigned *gold_std, float *gs_dist, - unsigned dim_gs, - unsigned *our_results, - unsigned dim_or, unsigned recall_at); +DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, + unsigned *our_results, unsigned dim_or, unsigned recall_at); -DISKANN_DLLEXPORT double -calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, - unsigned dim_gs, unsigned *our_results, unsigned dim_or, - unsigned recall_at, - const tsl::robin_set &active_tags); +DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, + unsigned *our_results, unsigned dim_or, unsigned recall_at, + const tsl::robin_set &active_tags); -DISKANN_DLLEXPORT double -calculate_range_search_recall(unsigned num_queries, - std::vector> &groundtruth, - std::vector> &our_results); +DISKANN_DLLEXPORT double calculate_range_search_recall(unsigned num_queries, + std::vector> &groundtruth, + std::vector> &our_results); template -inline void load_bin(const std::string &bin_file, std::unique_ptr &data, - size_t &npts, size_t &dim, size_t offset = 0) { - T *ptr; - load_bin(bin_file, ptr, npts, dim, offset); - data.reset(ptr); -} - -inline void open_file_to_write(std::ofstream &writer, - const std::string &filename) { - writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); - if (!file_exists(filename)) - writer.open(filename, std::ios::binary | std::ios::out); - else - writer.open(filename, std::ios::binary | std::ios::in | std::ios::out); - - if (writer.fail()) { - char buff[1024]; +inline void load_bin(const std::string &bin_file, std::unique_ptr &data, size_t &npts, size_t &dim, + size_t offset = 0) +{ + T *ptr; + load_bin(bin_file, ptr, npts, dim, offset); + data.reset(ptr); +} + +inline void open_file_to_write(std::ofstream &writer, const std::string &filename) +{ + writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); + if (!file_exists(filename)) + writer.open(filename, std::ios::binary | std::ios::out); + else + writer.open(filename, std::ios::binary | std::ios::in | std::ios::out); + + if (writer.fail()) + { + char buff[1024]; #ifdef _WINDOWS - auto ret = std::to_string(strerror_s(buff, 1024, errno)); + auto ret = std::to_string(strerror_s(buff, 1024, errno)); #else - auto ret = std::string(strerror_r(errno, buff, 1024)); + auto ret = std::string(strerror_r(errno, buff, 1024)); #endif - std::string error_message = std::string("Failed to open file") + filename + - " for write because " + buff + ", ret=" + ret; - diskann::cerr << error_message << std::endl; - throw diskann::ANNException(error_message, -1); - } + std::string error_message = + std::string("Failed to open file") + filename + " for write because " + buff + ", ret=" + ret; + diskann::cerr << error_message << std::endl; + throw diskann::ANNException(error_message, -1); + } } template -inline size_t save_bin(const std::string &filename, T *data, size_t npts, - size_t ndims, size_t offset = 0) { - std::ofstream writer; - open_file_to_write(writer, filename); - - diskann::cout << "Writing bin: " << filename.c_str() << std::endl; - writer.seekp(offset, writer.beg); - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - size_t bytes_written = npts * ndims * sizeof(T) + 2 * sizeof(uint32_t); - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - diskann::cout << "bin: #pts = " << npts << ", #dims = " << ndims - << ", size = " << bytes_written << "B" << std::endl; - - writer.write((char *)data, npts * ndims * sizeof(T)); - writer.close(); - diskann::cout << "Finished writing bin." << std::endl; - return bytes_written; -} - -inline void print_progress(double percentage) { - int val = (int)(percentage * 100); - int lpad = (int)(percentage * PBWIDTH); - int rpad = PBWIDTH - lpad; - printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); - fflush(stdout); +inline size_t save_bin(const std::string &filename, T *data, size_t npts, size_t ndims, size_t offset = 0) +{ + std::ofstream writer; + open_file_to_write(writer, filename); + + diskann::cout << "Writing bin: " << filename.c_str() << std::endl; + writer.seekp(offset, writer.beg); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + size_t bytes_written = npts * ndims * sizeof(T) + 2 * sizeof(uint32_t); + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + diskann::cout << "bin: #pts = " << npts << ", #dims = " << ndims << ", size = " << bytes_written << "B" + << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + diskann::cout << "Finished writing bin." << std::endl; + return bytes_written; +} + +inline void print_progress(double percentage) +{ + int val = (int)(percentage * 100); + int lpad = (int)(percentage * PBWIDTH); + int rpad = PBWIDTH - lpad; + printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); + fflush(stdout); } // load_aligned_bin functions START template -inline void load_aligned_bin_impl(std::basic_istream &reader, - size_t actual_file_size, T *&data, - size_t &npts, size_t &dim, - size_t &rounded_dim) { - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (unsigned)npts_i32; - dim = (unsigned)dim_i32; - - size_t expected_actual_file_size = - npts * dim * sizeof(T) + 2 * sizeof(uint32_t); - if (actual_file_size != expected_actual_file_size) { - std::stringstream stream; - stream << "Error. File size mismatch. Actual size is " << actual_file_size - << " while expected size is " << expected_actual_file_size - << " npts = " << npts << " dim = " << dim - << " size of = " << sizeof(T) << std::endl; - diskann::cout << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - rounded_dim = ROUND_UP(dim, 8); - diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim - << ", aligned_dim = " << rounded_dim << "... " << std::flush; - size_t allocSize = npts * rounded_dim * sizeof(T); - diskann::cout << "allocating aligned memory of " << allocSize << " bytes... " - << std::flush; - alloc_aligned(((void **)&data), allocSize, 8 * sizeof(T)); - diskann::cout << "done. Copying data to mem_aligned buffer..." << std::flush; - - for (size_t i = 0; i < npts; i++) { - reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); - memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); - } - diskann::cout << " done." << std::endl; +inline void load_aligned_bin_impl(std::basic_istream &reader, size_t actual_file_size, T *&data, size_t &npts, + size_t &dim, size_t &rounded_dim) +{ + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (unsigned)npts_i32; + dim = (unsigned)dim_i32; + + size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); + if (actual_file_size != expected_actual_file_size) + { + std::stringstream stream; + stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is " + << expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of = " << sizeof(T) + << std::endl; + diskann::cout << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + rounded_dim = ROUND_UP(dim, 8); + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << ", aligned_dim = " << rounded_dim << "... " + << std::flush; + size_t allocSize = npts * rounded_dim * sizeof(T); + diskann::cout << "allocating aligned memory of " << allocSize << " bytes... " << std::flush; + alloc_aligned(((void **)&data), allocSize, 8 * sizeof(T)); + diskann::cout << "done. Copying data to mem_aligned buffer..." << std::flush; + + for (size_t i = 0; i < npts; i++) + { + reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); + memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); + } + diskann::cout << " done." << std::endl; } #ifdef EXEC_ENV_OLS template -inline void load_aligned_bin(MemoryMappedFiles &files, - const std::string &bin_file, T *&data, - size_t &npts, size_t &dim, size_t &rounded_dim) { - try { - diskann::cout << "Opening bin file " << bin_file << " ..." << std::flush; - FileContent fc = files.getContent(bin_file); - ContentBuf buf((char *)fc._content, fc._size); - std::basic_istream reader(&buf); - - size_t actual_file_size = fc._size; - load_aligned_bin_impl(reader, actual_file_size, data, npts, dim, - rounded_dim); - } catch (std::system_error &e) { - throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__); - } +inline void load_aligned_bin(MemoryMappedFiles &files, const std::string &bin_file, T *&data, size_t &npts, size_t &dim, + size_t &rounded_dim) +{ + try + { + diskann::cout << "Opening bin file " << bin_file << " ..." << std::flush; + FileContent fc = files.getContent(bin_file); + ContentBuf buf((char *)fc._content, fc._size); + std::basic_istream reader(&buf); + + size_t actual_file_size = fc._size; + load_aligned_bin_impl(reader, actual_file_size, data, npts, dim, rounded_dim); + } + catch (std::system_error &e) + { + throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__); + } } #endif template -inline void load_aligned_bin(const std::string &bin_file, T *&data, - size_t &npts, size_t &dim, size_t &rounded_dim) { - std::ifstream reader; - reader.exceptions(std::ifstream::failbit | std::ifstream::badbit); - - try { - diskann::cout << "Reading (with alignment) bin file " << bin_file << " ..." - << std::flush; - reader.open(bin_file, std::ios::binary | std::ios::ate); - - uint64_t fsize = reader.tellg(); - reader.seekg(0); - load_aligned_bin_impl(reader, fsize, data, npts, dim, rounded_dim); - } catch (std::system_error &e) { - throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__); - } +inline void load_aligned_bin(const std::string &bin_file, T *&data, size_t &npts, size_t &dim, size_t &rounded_dim) +{ + std::ifstream reader; + reader.exceptions(std::ifstream::failbit | std::ifstream::badbit); + + try + { + diskann::cout << "Reading (with alignment) bin file " << bin_file << " ..." << std::flush; + reader.open(bin_file, std::ios::binary | std::ios::ate); + + uint64_t fsize = reader.tellg(); + reader.seekg(0); + load_aligned_bin_impl(reader, fsize, data, npts, dim, rounded_dim); + } + catch (std::system_error &e) + { + throw FileException(bin_file, e, __FUNCSIG__, __FILE__, __LINE__); + } } template -void convert_types(const InType *srcmat, OutType *destmat, size_t npts, - size_t dim) { +void convert_types(const InType *srcmat, OutType *destmat, size_t npts, size_t dim) +{ #pragma omp parallel for schedule(static, 65536) - for (int64_t i = 0; i < (int64_t)npts; i++) { - for (uint64_t j = 0; j < dim; j++) { - destmat[i * dim + j] = (OutType)srcmat[i * dim + j]; + for (int64_t i = 0; i < (int64_t)npts; i++) + { + for (uint64_t j = 0; j < dim; j++) + { + destmat[i * dim + j] = (OutType)srcmat[i * dim + j]; + } } - } } // this function will take in_file of n*d dimensions and save the output as a @@ -830,317 +847,344 @@ void convert_types(const InType *srcmat, OutType *destmat, size_t npts, // from MIPS to L2 search from "On Symmetric and Asymmetric LSHs for Inner // Product Search" by Neyshabur and Srebro -template -float prepare_base_for_inner_products(const std::string in_file, - const std::string out_file) { - std::cout << "Pre-processing base file by adding extra coordinate" - << std::endl; - std::ifstream in_reader(in_file.c_str(), std::ios::binary); - std::ofstream out_writer(out_file.c_str(), std::ios::binary); - uint64_t npts, in_dims, out_dims; - float max_norm = 0; - - uint32_t npts32, dims32; - in_reader.read((char *)&npts32, sizeof(uint32_t)); - in_reader.read((char *)&dims32, sizeof(uint32_t)); - - npts = npts32; - in_dims = dims32; - out_dims = in_dims + 1; - uint32_t outdims32 = (uint32_t)out_dims; - - out_writer.write((char *)&npts32, sizeof(uint32_t)); - out_writer.write((char *)&outdims32, sizeof(uint32_t)); - - size_t BLOCK_SIZE = 100000; - size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE; - std::unique_ptr in_block_data = - std::make_unique(block_size * in_dims); - std::unique_ptr out_block_data = - std::make_unique(block_size * out_dims); - - std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims); - uint64_t num_blocks = DIV_ROUND_UP(npts, block_size); - - std::vector norms(npts, 0); - - for (uint64_t b = 0; b < num_blocks; b++) { - uint64_t start_id = b * block_size; - uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; - uint64_t block_pts = end_id - start_id; - in_reader.read((char *)in_block_data.get(), - block_pts * in_dims * sizeof(T)); - for (uint64_t p = 0; p < block_pts; p++) { - for (uint64_t j = 0; j < in_dims; j++) { - norms[start_id + p] += - in_block_data[p * in_dims + j] * in_block_data[p * in_dims + j]; - } - max_norm = - max_norm > norms[start_id + p] ? max_norm : norms[start_id + p]; +template float prepare_base_for_inner_products(const std::string in_file, const std::string out_file) +{ + std::cout << "Pre-processing base file by adding extra coordinate" << std::endl; + std::ifstream in_reader(in_file.c_str(), std::ios::binary); + std::ofstream out_writer(out_file.c_str(), std::ios::binary); + uint64_t npts, in_dims, out_dims; + float max_norm = 0; + + uint32_t npts32, dims32; + in_reader.read((char *)&npts32, sizeof(uint32_t)); + in_reader.read((char *)&dims32, sizeof(uint32_t)); + + npts = npts32; + in_dims = dims32; + out_dims = in_dims + 1; + uint32_t outdims32 = (uint32_t)out_dims; + + out_writer.write((char *)&npts32, sizeof(uint32_t)); + out_writer.write((char *)&outdims32, sizeof(uint32_t)); + + size_t BLOCK_SIZE = 100000; + size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE; + std::unique_ptr in_block_data = std::make_unique(block_size * in_dims); + std::unique_ptr out_block_data = std::make_unique(block_size * out_dims); + + std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims); + uint64_t num_blocks = DIV_ROUND_UP(npts, block_size); + + std::vector norms(npts, 0); + + for (uint64_t b = 0; b < num_blocks; b++) + { + uint64_t start_id = b * block_size; + uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; + uint64_t block_pts = end_id - start_id; + in_reader.read((char *)in_block_data.get(), block_pts * in_dims * sizeof(T)); + for (uint64_t p = 0; p < block_pts; p++) + { + for (uint64_t j = 0; j < in_dims; j++) + { + norms[start_id + p] += in_block_data[p * in_dims + j] * in_block_data[p * in_dims + j]; + } + max_norm = max_norm > norms[start_id + p] ? max_norm : norms[start_id + p]; + } } - } - - max_norm = std::sqrt(max_norm); - - in_reader.seekg(2 * sizeof(uint32_t), std::ios::beg); - for (uint64_t b = 0; b < num_blocks; b++) { - uint64_t start_id = b * block_size; - uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; - uint64_t block_pts = end_id - start_id; - in_reader.read((char *)in_block_data.get(), - block_pts * in_dims * sizeof(T)); - for (uint64_t p = 0; p < block_pts; p++) { - for (uint64_t j = 0; j < in_dims; j++) { - out_block_data[p * out_dims + j] = - in_block_data[p * in_dims + j] / max_norm; - } - float res = 1 - (norms[start_id + p] / (max_norm * max_norm)); - res = res <= 0 ? 0 : std::sqrt(res); - out_block_data[p * out_dims + out_dims - 1] = res; + + max_norm = std::sqrt(max_norm); + + in_reader.seekg(2 * sizeof(uint32_t), std::ios::beg); + for (uint64_t b = 0; b < num_blocks; b++) + { + uint64_t start_id = b * block_size; + uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; + uint64_t block_pts = end_id - start_id; + in_reader.read((char *)in_block_data.get(), block_pts * in_dims * sizeof(T)); + for (uint64_t p = 0; p < block_pts; p++) + { + for (uint64_t j = 0; j < in_dims; j++) + { + out_block_data[p * out_dims + j] = in_block_data[p * in_dims + j] / max_norm; + } + float res = 1 - (norms[start_id + p] / (max_norm * max_norm)); + res = res <= 0 ? 0 : std::sqrt(res); + out_block_data[p * out_dims + out_dims - 1] = res; + } + out_writer.write((char *)out_block_data.get(), block_pts * out_dims * sizeof(float)); } - out_writer.write((char *)out_block_data.get(), - block_pts * out_dims * sizeof(float)); - } - out_writer.close(); - return max_norm; + out_writer.close(); + return max_norm; } // plain saves data as npts X ndims array into filename -template -void save_Tvecs(const char *filename, T *data, size_t npts, size_t ndims) { - std::string fname(filename); +template void save_Tvecs(const char *filename, T *data, size_t npts, size_t ndims) +{ + std::string fname(filename); - // create cached ofstream with 64MB cache - cached_ofstream writer(fname, 64 * 1048576); + // create cached ofstream with 64MB cache + cached_ofstream writer(fname, 64 * 1048576); - unsigned dims_u32 = (unsigned)ndims; + unsigned dims_u32 = (unsigned)ndims; - // start writing - for (size_t i = 0; i < npts; i++) { - // write dims in u32 - writer.write((char *)&dims_u32, sizeof(unsigned)); + // start writing + for (size_t i = 0; i < npts; i++) + { + // write dims in u32 + writer.write((char *)&dims_u32, sizeof(unsigned)); - // get cur point in data - T *cur_pt = data + i * ndims; - writer.write((char *)cur_pt, ndims * sizeof(T)); - } + // get cur point in data + T *cur_pt = data + i * ndims; + writer.write((char *)cur_pt, ndims * sizeof(T)); + } } template -inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, - size_t npts, size_t ndims, - size_t aligned_dim, - size_t offset = 0) { - std::ofstream writer; //(filename, std::ios::binary | std::ios::out); - open_file_to_write(writer, filename); - int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - size_t bytes_written = 2 * sizeof(uint32_t) + npts * ndims * sizeof(T); - writer.seekp(offset, writer.beg); - writer.write((char *)&npts_i32, sizeof(int)); - writer.write((char *)&ndims_i32, sizeof(int)); - for (size_t i = 0; i < npts; i++) { - writer.write((char *)(data + i * aligned_dim), ndims * sizeof(T)); - } - writer.close(); - return bytes_written; +inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, size_t npts, size_t ndims, + size_t aligned_dim, size_t offset = 0) +{ + std::ofstream writer; //(filename, std::ios::binary | std::ios::out); + open_file_to_write(writer, filename); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + size_t bytes_written = 2 * sizeof(uint32_t) + npts * ndims * sizeof(T); + writer.seekp(offset, writer.beg); + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + for (size_t i = 0; i < npts; i++) + { + writer.write((char *)(data + i * aligned_dim), ndims * sizeof(T)); + } + writer.close(); + return bytes_written; } template -inline void copy_aligned_data_from_file(const char *bin_file, T *&data, - size_t &npts, size_t &dim, - const size_t &rounded_dim, - size_t offset = 0) { - if (data == nullptr) { - diskann::cerr << "Memory was not allocated for " << data - << " before calling the load function. Exiting..." - << std::endl; - throw diskann::ANNException( - "Null pointer passed to copy_aligned_data_from_file function", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - std::ifstream reader; - reader.exceptions(std::ios::badbit | std::ios::failbit); - reader.open(bin_file, std::ios::binary); - reader.seekg(offset, reader.beg); - - int npts_i32, dim_i32; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&dim_i32, sizeof(int)); - npts = (unsigned)npts_i32; - dim = (unsigned)dim_i32; - - for (size_t i = 0; i < npts; i++) { - reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); - memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); - } +inline void copy_aligned_data_from_file(const char *bin_file, T *&data, size_t &npts, size_t &dim, + const size_t &rounded_dim, size_t offset = 0) +{ + if (data == nullptr) + { + diskann::cerr << "Memory was not allocated for " << data << " before calling the load function. Exiting..." + << std::endl; + throw diskann::ANNException("Null pointer passed to copy_aligned_data_from_file function", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + std::ifstream reader; + reader.exceptions(std::ios::badbit | std::ios::failbit); + reader.open(bin_file, std::ios::binary); + reader.seekg(offset, reader.beg); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (unsigned)npts_i32; + dim = (unsigned)dim_i32; + + for (size_t i = 0; i < npts; i++) + { + reader.read((char *)(data + i * rounded_dim), dim * sizeof(T)); + memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T)); + } } // NOTE :: good efficiency when total_vec_size is integral multiple of 64 -inline void prefetch_vector(const char *vec, size_t vecsize) { - size_t max_prefetch_size = (vecsize / 64) * 64; - for (size_t d = 0; d < max_prefetch_size; d += 64) - _mm_prefetch((const char *)vec + d, _MM_HINT_T0); +inline void prefetch_vector(const char *vec, size_t vecsize) +{ + size_t max_prefetch_size = (vecsize / 64) * 64; + for (size_t d = 0; d < max_prefetch_size; d += 64) + _mm_prefetch((const char *)vec + d, _MM_HINT_T0); } // NOTE :: good efficiency when total_vec_size is integral multiple of 64 -inline void prefetch_vector_l2(const char *vec, size_t vecsize) { - size_t max_prefetch_size = (vecsize / 64) * 64; - for (size_t d = 0; d < max_prefetch_size; d += 64) - _mm_prefetch((const char *)vec + d, _MM_HINT_T1); +inline void prefetch_vector_l2(const char *vec, size_t vecsize) +{ + size_t max_prefetch_size = (vecsize / 64) * 64; + for (size_t d = 0; d < max_prefetch_size; d += 64) + _mm_prefetch((const char *)vec + d, _MM_HINT_T1); } // NOTE: Implementation in utils.cpp. -void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, - uint64_t npts, uint64_t ndims); +void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, uint64_t npts, uint64_t ndims); -DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, - const std::string &outFileName); +DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName); -inline std::string get_tag_string(std::uint64_t tag) { - return std::to_string(tag); +inline std::string get_tag_string(std::uint64_t tag) +{ + return std::to_string(tag); } -inline std::string get_tag_string(const tag_uint128 &tag) { - std::string str = - std::to_string(tag._data2) + "_" + std::to_string(tag._data1); - return str; +inline std::string get_tag_string(const tag_uint128 &tag) +{ + std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1); + return str; } }; // namespace diskann -struct PivotContainer { - PivotContainer() = default; +struct PivotContainer +{ + PivotContainer() = default; - PivotContainer(size_t pivo_id, float pivo_dist) - : piv_id{pivo_id}, piv_dist{pivo_dist} {} + PivotContainer(size_t pivo_id, float pivo_dist) : piv_id{pivo_id}, piv_dist{pivo_dist} + { + } - bool operator<(const PivotContainer &p) const { - return p.piv_dist < piv_dist; - } + bool operator<(const PivotContainer &p) const + { + return p.piv_dist < piv_dist; + } - bool operator>(const PivotContainer &p) const { - return p.piv_dist > piv_dist; - } + bool operator>(const PivotContainer &p) const + { + return p.piv_dist > piv_dist; + } - size_t piv_id; - float piv_dist; + size_t piv_id; + float piv_dist; }; -inline bool validate_index_file_size(std::ifstream &in) { - if (!in.is_open()) - throw diskann::ANNException( - "Index file size check called on unopened file stream", -1, __FUNCSIG__, - __FILE__, __LINE__); - in.seekg(0, in.end); - size_t actual_file_size = in.tellg(); - in.seekg(0, in.beg); - size_t expected_file_size; - in.read((char *)&expected_file_size, sizeof(uint64_t)); - in.seekg(0, in.beg); - if (actual_file_size != expected_file_size) { - diskann::cerr << "Index file size error. Expected size (metadata): " - << expected_file_size - << ", actual file size : " << actual_file_size << "." - << std::endl; - return false; - } - return true; +inline bool validate_index_file_size(std::ifstream &in) +{ + if (!in.is_open()) + throw diskann::ANNException("Index file size check called on unopened file stream", -1, __FUNCSIG__, __FILE__, + __LINE__); + in.seekg(0, in.end); + size_t actual_file_size = in.tellg(); + in.seekg(0, in.beg); + size_t expected_file_size; + in.read((char *)&expected_file_size, sizeof(uint64_t)); + in.seekg(0, in.beg); + if (actual_file_size != expected_file_size) + { + diskann::cerr << "Index file size error. Expected size (metadata): " << expected_file_size + << ", actual file size : " << actual_file_size << "." << std::endl; + return false; + } + return true; } -template inline float get_norm(T *arr, const size_t dim) { - float sum = 0.0f; - for (uint32_t i = 0; i < dim; i++) { - sum += arr[i] * arr[i]; - } - return sqrt(sum); +template inline float get_norm(T *arr, const size_t dim) +{ + float sum = 0.0f; + for (uint32_t i = 0; i < dim; i++) + { + sum += arr[i] * arr[i]; + } + return sqrt(sum); } // This function is valid only for float data type. -template inline void normalize(T *arr, const size_t dim) { - float norm = get_norm(arr, dim); - for (uint32_t i = 0; i < dim; i++) { - arr[i] = (T)(arr[i] / norm); - } -} - -inline std::vector -read_file_to_vector_of_strings(const std::string &filename, - bool unique = false) { - std::vector result; - std::set elementSet; - if (filename != "") { - std::ifstream file(filename); - if (file.fail()) { - throw diskann::ANNException( - std::string("Failed to open file ") + filename, -1); +template inline void normalize(T *arr, const size_t dim) +{ + float norm = get_norm(arr, dim); + for (uint32_t i = 0; i < dim; i++) + { + arr[i] = (T)(arr[i] / norm); } - std::string line; - while (std::getline(file, line)) { - if (line.empty()) { - break; - } - if (line.find(',') != std::string::npos) { - std::cerr << "Every query must have exactly one filter" << std::endl; - exit(-1); - } - if (!line.empty() && (line.back() == '\r' || line.back() == '\n')) { - line.erase(line.size() - 1); - } - if (!elementSet.count(line)) { - result.push_back(line); - } - if (unique) { - elementSet.insert(line); - } +} + +inline std::vector read_file_to_vector_of_strings(const std::string &filename, bool unique = false) +{ + std::vector result; + std::set elementSet; + if (filename != "") + { + std::ifstream file(filename); + if (file.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + std::string line; + while (std::getline(file, line)) + { + if (line.empty()) + { + break; + } + if (line.find(',') != std::string::npos) + { + std::cerr << "Every query must have exactly one filter" << std::endl; + exit(-1); + } + if (!line.empty() && (line.back() == '\r' || line.back() == '\n')) + { + line.erase(line.size() - 1); + } + if (!elementSet.count(line)) + { + result.push_back(line); + } + if (unique) + { + elementSet.insert(line); + } + } + file.close(); } - file.close(); - } else { - throw diskann::ANNException( - std::string("Failed to open file. filename can not be blank"), -1); - } - return result; -} - -inline void clean_up_artifacts(tsl::robin_set paths_to_clean, - tsl::robin_set path_suffixes) { - try { - for (const auto &path : paths_to_clean) { - for (const auto &suffix : path_suffixes) { - std::string curr_path_to_clean(path + "_" + suffix); - if (std::remove(curr_path_to_clean.c_str()) != 0) - diskann::cout << "Warning: Unable to remove file :" - << curr_path_to_clean << std::endl; - } + else + { + throw diskann::ANNException(std::string("Failed to open file. filename can not be blank"), -1); + } + return result; +} + +inline void clean_up_artifacts(tsl::robin_set paths_to_clean, tsl::robin_set path_suffixes) +{ + try + { + for (const auto &path : paths_to_clean) + { + for (const auto &suffix : path_suffixes) + { + std::string curr_path_to_clean(path + "_" + suffix); + if (std::remove(curr_path_to_clean.c_str()) != 0) + diskann::cout << "Warning: Unable to remove file :" << curr_path_to_clean << std::endl; + } + } + diskann::cout << "Cleaned all artifacts" << std::endl; + } + catch (const std::exception &e) + { + diskann::cout << "Warning: Unable to clean all artifacts " << e.what() << std::endl; } - diskann::cout << "Cleaned all artifacts" << std::endl; - } catch (const std::exception &e) { - diskann::cout << "Warning: Unable to clean all artifacts " << e.what() - << std::endl; - } } template inline const char *diskann_type_to_name() = delete; -template <> inline const char *diskann_type_to_name() { return "float"; } -template <> inline const char *diskann_type_to_name() { - return "uint8"; +template <> inline const char *diskann_type_to_name() +{ + return "float"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint8"; } -template <> inline const char *diskann_type_to_name() { return "int8"; } -template <> inline const char *diskann_type_to_name() { - return "uint16"; +template <> inline const char *diskann_type_to_name() +{ + return "int8"; } -template <> inline const char *diskann_type_to_name() { - return "int16"; +template <> inline const char *diskann_type_to_name() +{ + return "uint16"; } -template <> inline const char *diskann_type_to_name() { - return "uint32"; +template <> inline const char *diskann_type_to_name() +{ + return "int16"; } -template <> inline const char *diskann_type_to_name() { - return "int32"; +template <> inline const char *diskann_type_to_name() +{ + return "uint32"; } -template <> inline const char *diskann_type_to_name() { - return "uint64"; +template <> inline const char *diskann_type_to_name() +{ + return "int32"; } -template <> inline const char *diskann_type_to_name() { - return "int64"; +template <> inline const char *diskann_type_to_name() +{ + return "uint64"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int64"; } #ifdef _WINDOWS @@ -1150,53 +1194,57 @@ template <> inline const char *diskann_type_to_name() { extern bool AvxSupportedCPU; extern bool Avx2SupportedCPU; -inline size_t getMemoryUsage() { - PROCESS_MEMORY_COUNTERS_EX pmc; - GetProcessMemoryInfo(GetCurrentProcess(), (PROCESS_MEMORY_COUNTERS *)&pmc, - sizeof(pmc)); - return pmc.PrivateUsage; -} - -inline std::string getWindowsErrorMessage(DWORD lastError) { - char *errorText; - FormatMessageA( - // use system message tables to retrieve error text - FORMAT_MESSAGE_FROM_SYSTEM - // allocate buffer on local heap for error text - | FORMAT_MESSAGE_ALLOCATE_BUFFER - // Important! will fail otherwise, since we're not - // (and CANNOT) pass insertion parameters - | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, // unused with FORMAT_MESSAGE_FROM_SYSTEM - lastError, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&errorText, // output - 0, // minimum size for output buffer - NULL); // arguments - see note - - return errorText != nullptr ? std::string(errorText) : std::string(); -} - -inline void printProcessMemory(const char *message) { - PROCESS_MEMORY_COUNTERS counters; - HANDLE h = GetCurrentProcess(); - GetProcessMemoryInfo(h, &counters, sizeof(counters)); - diskann::cout << message << " [Peaking Working Set size: " - << counters.PeakWorkingSetSize * 1.0 / (1024.0 * 1024 * 1024) - << "GB Working set size: " - << counters.WorkingSetSize * 1.0 / (1024.0 * 1024 * 1024) - << "GB Private bytes " - << counters.PagefileUsage * 1.0 / (1024 * 1024 * 1024) << "GB]" - << std::endl; +inline size_t getMemoryUsage() +{ + PROCESS_MEMORY_COUNTERS_EX pmc; + GetProcessMemoryInfo(GetCurrentProcess(), (PROCESS_MEMORY_COUNTERS *)&pmc, sizeof(pmc)); + return pmc.PrivateUsage; +} + +inline std::string getWindowsErrorMessage(DWORD lastError) +{ + char *errorText; + FormatMessageA( + // use system message tables to retrieve error text + FORMAT_MESSAGE_FROM_SYSTEM + // allocate buffer on local heap for error text + | FORMAT_MESSAGE_ALLOCATE_BUFFER + // Important! will fail otherwise, since we're not + // (and CANNOT) pass insertion parameters + | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, // unused with FORMAT_MESSAGE_FROM_SYSTEM + lastError, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&errorText, // output + 0, // minimum size for output buffer + NULL); // arguments - see note + + return errorText != nullptr ? std::string(errorText) : std::string(); +} + +inline void printProcessMemory(const char *message) +{ + PROCESS_MEMORY_COUNTERS counters; + HANDLE h = GetCurrentProcess(); + GetProcessMemoryInfo(h, &counters, sizeof(counters)); + diskann::cout << message + << " [Peaking Working Set size: " << counters.PeakWorkingSetSize * 1.0 / (1024.0 * 1024 * 1024) + << "GB Working set size: " << counters.WorkingSetSize * 1.0 / (1024.0 * 1024 * 1024) + << "GB Private bytes " << counters.PagefileUsage * 1.0 / (1024 * 1024 * 1024) << "GB]" << std::endl; } #else // need to check and change this -inline bool avx2Supported() { return true; } -inline void printProcessMemory(const char *) {} +inline bool avx2Supported() +{ + return true; +} +inline void printProcessMemory(const char *) +{ +} -inline size_t -getMemoryUsage() { // for non-windows, we have not implemented this function - return 0; +inline size_t getMemoryUsage() +{ // for non-windows, we have not implemented this function + return 0; } #endif diff --git a/include/windows_aligned_file_reader.h b/include/windows_aligned_file_reader.h index 0a3ab137d..e3a898b9a 100644 --- a/include/windows_aligned_file_reader.h +++ b/include/windows_aligned_file_reader.h @@ -17,39 +17,41 @@ #include #include -class WindowsAlignedFileReader : public AlignedFileReader { -private: +class WindowsAlignedFileReader : public AlignedFileReader +{ + private: #ifdef UNICODE - std::wstring m_filename; + std::wstring m_filename; #else - std::string m_filename; + std::string m_filename; #endif -protected: - // virtual IOContext createContext(); - -public: - DISKANN_DLLEXPORT WindowsAlignedFileReader(){}; - DISKANN_DLLEXPORT virtual ~WindowsAlignedFileReader(){}; - - // Open & close ops - // Blocking calls - DISKANN_DLLEXPORT virtual void open(const std::string &fname) override; - DISKANN_DLLEXPORT virtual void close() override; - - DISKANN_DLLEXPORT virtual void register_thread() override; - DISKANN_DLLEXPORT virtual void deregister_thread() override { - // TODO: Needs implementation. - } - DISKANN_DLLEXPORT virtual void deregister_all_threads() override { - // TODO: Needs implementation. - } - DISKANN_DLLEXPORT virtual IOContext &get_ctx() override; - - // process batch of aligned requests in parallel - // NOTE :: blocking call for the calling thread, but can thread-safe - DISKANN_DLLEXPORT virtual void read(std::vector &read_reqs, - IOContext &ctx, bool async) override; + protected: + // virtual IOContext createContext(); + + public: + DISKANN_DLLEXPORT WindowsAlignedFileReader(){}; + DISKANN_DLLEXPORT virtual ~WindowsAlignedFileReader(){}; + + // Open & close ops + // Blocking calls + DISKANN_DLLEXPORT virtual void open(const std::string &fname) override; + DISKANN_DLLEXPORT virtual void close() override; + + DISKANN_DLLEXPORT virtual void register_thread() override; + DISKANN_DLLEXPORT virtual void deregister_thread() override + { + // TODO: Needs implementation. + } + DISKANN_DLLEXPORT virtual void deregister_all_threads() override + { + // TODO: Needs implementation. + } + DISKANN_DLLEXPORT virtual IOContext &get_ctx() override; + + // process batch of aligned requests in parallel + // NOTE :: blocking call for the calling thread, but can thread-safe + DISKANN_DLLEXPORT virtual void read(std::vector &read_reqs, IOContext &ctx, bool async) override; }; #endif // USE_BING_INFRA #endif //_WINDOWS diff --git a/include/windows_slim_lock.h b/include/windows_slim_lock.h index 67ac98d14..7fc09b8f9 100644 --- a/include/windows_slim_lock.h +++ b/include/windows_slim_lock.h @@ -7,7 +7,8 @@ #endif #include "Windows.h" -namespace diskann { +namespace diskann +{ // A thin C++ wrapper around Windows exclusive functionality of Windows // SlimReaderWriterLock. // @@ -18,42 +19,55 @@ namespace diskann { // // Full documentation can be found at. // https://msdn.microsoft.com/en-us/library/windows/desktop/aa904937(v=vs.85).aspx -class windows_exclusive_slim_lock { -public: - windows_exclusive_slim_lock() : _lock(SRWLOCK_INIT) {} +class windows_exclusive_slim_lock +{ + public: + windows_exclusive_slim_lock() : _lock(SRWLOCK_INIT) + { + } - // The lock is non-copyable. This also disables move constructor/operator=. - windows_exclusive_slim_lock(const windows_exclusive_slim_lock &) = delete; - windows_exclusive_slim_lock & - operator=(const windows_exclusive_slim_lock &) = delete; + // The lock is non-copyable. This also disables move constructor/operator=. + windows_exclusive_slim_lock(const windows_exclusive_slim_lock &) = delete; + windows_exclusive_slim_lock &operator=(const windows_exclusive_slim_lock &) = delete; - void lock() { return AcquireSRWLockExclusive(&_lock); } + void lock() + { + return AcquireSRWLockExclusive(&_lock); + } - bool try_lock() { return TryAcquireSRWLockExclusive(&_lock) != FALSE; } + bool try_lock() + { + return TryAcquireSRWLockExclusive(&_lock) != FALSE; + } - void unlock() { return ReleaseSRWLockExclusive(&_lock); } + void unlock() + { + return ReleaseSRWLockExclusive(&_lock); + } -private: - SRWLOCK _lock; + private: + SRWLOCK _lock; }; // An exclusive lock over a SlimReaderWriterLock. -class windows_exclusive_slim_lock_guard { -public: - windows_exclusive_slim_lock_guard(windows_exclusive_slim_lock &p_lock) - : _lock(p_lock) { - _lock.lock(); - } - - // The lock is non-copyable. This also disables move constructor/operator=. - windows_exclusive_slim_lock_guard(const windows_exclusive_slim_lock_guard &) = - delete; - windows_exclusive_slim_lock_guard & - operator=(const windows_exclusive_slim_lock_guard &) = delete; - - ~windows_exclusive_slim_lock_guard() { _lock.unlock(); } - -private: - windows_exclusive_slim_lock &_lock; +class windows_exclusive_slim_lock_guard +{ + public: + windows_exclusive_slim_lock_guard(windows_exclusive_slim_lock &p_lock) : _lock(p_lock) + { + _lock.lock(); + } + + // The lock is non-copyable. This also disables move constructor/operator=. + windows_exclusive_slim_lock_guard(const windows_exclusive_slim_lock_guard &) = delete; + windows_exclusive_slim_lock_guard &operator=(const windows_exclusive_slim_lock_guard &) = delete; + + ~windows_exclusive_slim_lock_guard() + { + _lock.unlock(); + } + + private: + windows_exclusive_slim_lock &_lock; }; } // namespace diskann diff --git a/src/abstract_data_store.cpp b/src/abstract_data_store.cpp index c40fe2bd3..79efaca45 100644 --- a/src/abstract_data_store.cpp +++ b/src/abstract_data_store.cpp @@ -4,31 +4,39 @@ #include "abstract_data_store.h" #include -namespace diskann { +namespace diskann +{ template -AbstractDataStore::AbstractDataStore(const location_t capacity, - const size_t dim) - : _capacity(capacity), _dim(dim) {} +AbstractDataStore::AbstractDataStore(const location_t capacity, const size_t dim) + : _capacity(capacity), _dim(dim) +{ +} -template -location_t AbstractDataStore::capacity() const { - return _capacity; +template location_t AbstractDataStore::capacity() const +{ + return _capacity; } -template size_t AbstractDataStore::get_dims() const { - return _dim; +template size_t AbstractDataStore::get_dims() const +{ + return _dim; } -template -location_t AbstractDataStore::resize(const location_t new_num_points) { - if (new_num_points > _capacity) { - return expand(new_num_points); - } else if (new_num_points < _capacity) { - return shrink(new_num_points); - } else { - return _capacity; - } +template location_t AbstractDataStore::resize(const location_t new_num_points) +{ + if (new_num_points > _capacity) + { + return expand(new_num_points); + } + else if (new_num_points < _capacity) + { + return shrink(new_num_points); + } + else + { + return _capacity; + } } template DISKANN_DLLEXPORT class AbstractDataStore; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index fef9c368f..c3c257ba2 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -2,498 +2,333 @@ #include "common_includes.h" #include "windows_customizations.h" -namespace diskann { +namespace diskann +{ template -void AbstractIndex::build(const data_type *data, - const size_t num_points_to_load, - const std::vector &tags) { - auto any_data = std::any(data); - auto any_tags_vec = TagVector(tags); - this->_build(any_data, num_points_to_load, any_tags_vec); +void AbstractIndex::build(const data_type *data, const size_t num_points_to_load, const std::vector &tags) +{ + auto any_data = std::any(data); + auto any_tags_vec = TagVector(tags); + this->_build(any_data, num_points_to_load, any_tags_vec); } template -std::pair -AbstractIndex::search(const data_type *query, const size_t K, const uint32_t L, - IDType *indices, float *distances) { - auto any_indices = std::any(indices); - auto any_query = std::any(query); - return _search(any_query, K, L, any_indices, distances); +std::pair AbstractIndex::search(const data_type *query, const size_t K, const uint32_t L, + IDType *indices, float *distances) +{ + auto any_indices = std::any(indices); + auto any_query = std::any(query); + return _search(any_query, K, L, any_indices, distances); } template -size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, - const uint32_t L, tag_type *tags, - float *distances, - std::vector &res_vectors, - bool use_filters, - const std::string filter_label) { - auto any_query = std::any(query); - auto any_tags = std::any(tags); - auto any_res_vectors = DataVector(res_vectors); - return this->_search_with_tags(any_query, K, L, any_tags, distances, - any_res_vectors, use_filters, filter_label); +size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors, bool use_filters, + const std::string filter_label) +{ + auto any_query = std::any(query); + auto any_tags = std::any(tags); + auto any_res_vectors = DataVector(res_vectors); + return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label); } template -std::pair AbstractIndex::search_with_filters( - const DataType &query, const std::string &raw_label, const size_t K, - const uint32_t L, IndexType *indices, float *distances) { - auto any_indices = std::any(indices); - return _search_with_filters(query, raw_label, K, L, any_indices, distances); +std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances) +{ + auto any_indices = std::any(indices); + return _search_with_filters(query, raw_label, K, L, any_indices, distances); } template -void AbstractIndex::search_with_optimized_layout(const data_type *query, - size_t K, size_t L, - uint32_t *indices) { - auto any_query = std::any(query); - this->_search_with_optimized_layout(any_query, K, L, indices); +void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices) +{ + auto any_query = std::any(query); + this->_search_with_optimized_layout(any_query, K, L, indices); } template -int AbstractIndex::insert_point(const data_type *point, const tag_type tag) { - auto any_point = std::any(point); - auto any_tag = std::any(tag); - return this->_insert_point(any_point, any_tag); +int AbstractIndex::insert_point(const data_type *point, const tag_type tag) +{ + auto any_point = std::any(point); + auto any_tag = std::any(tag); + return this->_insert_point(any_point, any_tag); } template -int AbstractIndex::insert_point(const data_type *point, const tag_type tag, - const std::vector &labels) { - auto any_point = std::any(point); - auto any_tag = std::any(tag); - auto any_labels = Labelvector(labels); - return this->_insert_point(any_point, any_tag, any_labels); +int AbstractIndex::insert_point(const data_type *point, const tag_type tag, const std::vector &labels) +{ + auto any_point = std::any(point); + auto any_tag = std::any(tag); + auto any_labels = Labelvector(labels); + return this->_insert_point(any_point, any_tag, any_labels); } -template -int AbstractIndex::lazy_delete(const tag_type &tag) { - auto any_tag = std::any(tag); - return this->_lazy_delete(any_tag); +template int AbstractIndex::lazy_delete(const tag_type &tag) +{ + auto any_tag = std::any(tag); + return this->_lazy_delete(any_tag); } template -void AbstractIndex::lazy_delete(const std::vector &tags, - std::vector &failed_tags) { - auto any_tags = TagVector(tags); - auto any_failed_tags = TagVector(failed_tags); - this->_lazy_delete(any_tags, any_failed_tags); +void AbstractIndex::lazy_delete(const std::vector &tags, std::vector &failed_tags) +{ + auto any_tags = TagVector(tags); + auto any_failed_tags = TagVector(failed_tags); + this->_lazy_delete(any_tags, any_failed_tags); } -template -void AbstractIndex::get_active_tags(tsl::robin_set &active_tags) { - auto any_active_tags = TagRobinSet(active_tags); - this->_get_active_tags(any_active_tags); +template void AbstractIndex::get_active_tags(tsl::robin_set &active_tags) +{ + auto any_active_tags = TagRobinSet(active_tags); + this->_get_active_tags(any_active_tags); } -template -void AbstractIndex::set_start_points_at_random(data_type radius, - uint32_t random_seed) { - auto any_radius = std::any(radius); - this->_set_start_points_at_random(any_radius, random_seed); +template void AbstractIndex::set_start_points_at_random(data_type radius, uint32_t random_seed) +{ + auto any_radius = std::any(radius); + this->_set_start_points_at_random(any_radius, random_seed); } -template -int AbstractIndex::get_vector_by_tag(tag_type &tag, data_type *vec) { - auto any_tag = std::any(tag); - auto any_data_ptr = std::any(vec); - return this->_get_vector_by_tag(any_tag, any_data_ptr); +template int AbstractIndex::get_vector_by_tag(tag_type &tag, data_type *vec) +{ + auto any_tag = std::any(tag); + auto any_data_ptr = std::any(vec); + return this->_get_vector_by_tag(any_tag, any_data_ptr); } -template -void AbstractIndex::set_universal_label(const label_type universal_label) { - auto any_label = std::any(universal_label); - this->_set_universal_label(any_label); +template void AbstractIndex::set_universal_label(const label_type universal_label) +{ + auto any_label = std::any(universal_label); + this->_set_universal_label(any_label); } // exports -template DISKANN_DLLEXPORT void -AbstractIndex::build(const float *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const int8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const uint8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const float *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const int8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const uint8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const float *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const int8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const uint8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const float *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const int8_t *data, - const size_t num_points_to_load, - const std::vector &tags); -template DISKANN_DLLEXPORT void -AbstractIndex::build(const uint8_t *data, - const size_t num_points_to_load, - const std::vector &tags); - -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search(const float *query, const size_t K, - const uint32_t L, uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search(const uint8_t *query, const size_t K, - const uint32_t L, uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search(const int8_t *query, const size_t K, - const uint32_t L, uint32_t *indices, - float *distances); - -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search(const float *query, const size_t K, - const uint32_t L, uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search(const uint8_t *query, const size_t K, - const uint32_t L, uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search(const int8_t *query, const size_t K, - const uint32_t L, uint64_t *indices, - float *distances); - -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search_with_filters(const DataType &query, - const std::string &raw_label, - const size_t K, const uint32_t L, - uint32_t *indices, - float *distances); - -template DISKANN_DLLEXPORT std::pair -AbstractIndex::search_with_filters(const DataType &query, - const std::string &raw_label, - const size_t K, const uint32_t L, - uint64_t *indices, - float *distances); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const float *query, const uint64_t K, const uint32_t L, int32_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const float *query, const uint64_t K, const uint32_t L, int64_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags( - const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label); - -template DISKANN_DLLEXPORT void -AbstractIndex::search_with_optimized_layout(const float *query, size_t K, - size_t L, uint32_t *indices); -template DISKANN_DLLEXPORT void -AbstractIndex::search_with_optimized_layout(const uint8_t *query, - size_t K, size_t L, - uint32_t *indices); -template DISKANN_DLLEXPORT void -AbstractIndex::search_with_optimized_layout(const int8_t *query, - size_t K, size_t L, - uint32_t *indices); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const float *point, - const int32_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const uint8_t *point, - const int32_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const int8_t *point, - const int32_t tag); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const float *point, - const uint32_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const uint8_t *point, - const uint32_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const int8_t *point, - const uint32_t tag); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const float *point, - const int64_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const uint8_t *point, - const int64_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const int8_t *point, - const int64_t tag); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const float *point, - const uint64_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const uint8_t *point, - const uint64_t tag); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point(const int8_t *point, - const uint64_t tag); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const std::vector &tags); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); + +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, + size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const uint8_t *query, size_t K, + size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const int8_t *query, size_t K, + size_t L, uint32_t *indices); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const int32_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const uint32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const uint32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const uint32_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const int64_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const uint64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const uint64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const uint64_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( const float *point, const int32_t tag, const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const int32_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const int32_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const float *point, const uint32_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const uint32_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const uint32_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const int32_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const int32_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const float *point, const uint32_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const uint32_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const uint32_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( const float *point, const int64_t tag, const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const int64_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const int64_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const float *point, const uint64_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const uint64_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const uint64_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const int64_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const int64_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const float *point, const uint64_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const uint64_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const uint64_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( const float *point, const int32_t tag, const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const int32_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const int32_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const float *point, const uint32_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const uint32_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const uint32_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const int32_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const int32_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const float *point, const uint32_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const uint32_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const uint32_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( const float *point, const int64_t tag, const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const int64_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const int64_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const float *point, const uint64_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const uint8_t *point, const uint64_t tag, - const std::vector &labels); -template DISKANN_DLLEXPORT int -AbstractIndex::insert_point( - const int8_t *point, const uint64_t tag, - const std::vector &labels); - -template DISKANN_DLLEXPORT int -AbstractIndex::lazy_delete(const int32_t &tag); -template DISKANN_DLLEXPORT int -AbstractIndex::lazy_delete(const uint32_t &tag); -template DISKANN_DLLEXPORT int -AbstractIndex::lazy_delete(const int64_t &tag); -template DISKANN_DLLEXPORT int -AbstractIndex::lazy_delete(const uint64_t &tag); - -template DISKANN_DLLEXPORT void -AbstractIndex::lazy_delete(const std::vector &tags, - std::vector &failed_tags); -template DISKANN_DLLEXPORT void -AbstractIndex::lazy_delete(const std::vector &tags, - std::vector &failed_tags); -template DISKANN_DLLEXPORT void -AbstractIndex::lazy_delete(const std::vector &tags, - std::vector &failed_tags); -template DISKANN_DLLEXPORT void -AbstractIndex::lazy_delete(const std::vector &tags, - std::vector &failed_tags); - -template DISKANN_DLLEXPORT void -AbstractIndex::get_active_tags(tsl::robin_set &active_tags); -template DISKANN_DLLEXPORT void -AbstractIndex::get_active_tags(tsl::robin_set &active_tags); -template DISKANN_DLLEXPORT void -AbstractIndex::get_active_tags(tsl::robin_set &active_tags); -template DISKANN_DLLEXPORT void -AbstractIndex::get_active_tags(tsl::robin_set &active_tags); - -template DISKANN_DLLEXPORT void -AbstractIndex::set_start_points_at_random(float radius, - uint32_t random_seed); -template DISKANN_DLLEXPORT void -AbstractIndex::set_start_points_at_random(uint8_t radius, - uint32_t random_seed); -template DISKANN_DLLEXPORT void -AbstractIndex::set_start_points_at_random(int8_t radius, - uint32_t random_seed); - -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(int32_t &tag, float *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(int32_t &tag, uint8_t *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(int32_t &tag, int8_t *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(uint32_t &tag, float *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(uint32_t &tag, - uint8_t *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(uint32_t &tag, int8_t *vec); - -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(int64_t &tag, float *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(int64_t &tag, uint8_t *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(int64_t &tag, int8_t *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(uint64_t &tag, float *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(uint64_t &tag, - uint8_t *vec); -template DISKANN_DLLEXPORT int -AbstractIndex::get_vector_by_tag(uint64_t &tag, int8_t *vec); - -template DISKANN_DLLEXPORT void -AbstractIndex::set_universal_label(const uint16_t label); -template DISKANN_DLLEXPORT void -AbstractIndex::set_universal_label(const uint32_t label); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const int64_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const int64_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const float *point, const uint64_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const uint8_t *point, const uint64_t tag, const std::vector &labels); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point( + const int8_t *point, const uint64_t tag, const std::vector &labels); + +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const int32_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const uint32_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const int64_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const uint64_t &tag); + +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); + +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); + +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(float radius, uint32_t random_seed); +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(uint8_t radius, + uint32_t random_seed); +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(int8_t radius, uint32_t random_seed); + +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, int8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, int8_t *vec); + +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, int8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, int8_t *vec); + +template DISKANN_DLLEXPORT void AbstractIndex::set_universal_label(const uint16_t label); +template DISKANN_DLLEXPORT void AbstractIndex::set_universal_label(const uint32_t label); } // namespace diskann diff --git a/src/ann_exception.cpp b/src/ann_exception.cpp index 1039527eb..ba55e3655 100644 --- a/src/ann_exception.cpp +++ b/src/ann_exception.cpp @@ -5,33 +5,32 @@ #include #include -namespace diskann { +namespace diskann +{ ANNException::ANNException(const std::string &message, int errorCode) - : std::runtime_error(message), _errorCode(errorCode) {} + : std::runtime_error(message), _errorCode(errorCode) +{ +} -std::string package_string(const std::string &item_name, - const std::string &item_val) { - return std::string("[") + item_name + ": " + std::string(item_val) + - std::string("]"); +std::string package_string(const std::string &item_name, const std::string &item_val) +{ + return std::string("[") + item_name + ": " + std::string(item_val) + std::string("]"); } -ANNException::ANNException(const std::string &message, int errorCode, - const std::string &funcSig, +ANNException::ANNException(const std::string &message, int errorCode, const std::string &funcSig, const std::string &fileName, uint32_t lineNum) - : ANNException( - package_string(std::string("FUNC"), funcSig) + - package_string(std::string("FILE"), fileName) + - package_string(std::string("LINE"), std::to_string(lineNum)) + - " " + message, - errorCode) {} + : ANNException(package_string(std::string("FUNC"), funcSig) + package_string(std::string("FILE"), fileName) + + package_string(std::string("LINE"), std::to_string(lineNum)) + " " + message, + errorCode) +{ +} -FileException::FileException(const std::string &filename, std::system_error &e, - const std::string &funcSig, +FileException::FileException(const std::string &filename, std::system_error &e, const std::string &funcSig, const std::string &fileName, uint32_t lineNum) - : ANNException(std::string(" While opening file \'") + filename + - std::string("\', error code: ") + - std::to_string(e.code().value()) + " " + - e.code().message(), - e.code().value(), funcSig, fileName, lineNum) {} + : ANNException(std::string(" While opening file \'") + filename + std::string("\', error code: ") + + std::to_string(e.code().value()) + " " + e.code().message(), + e.code().value(), funcSig, fileName, lineNum) +{ +} } // namespace diskann \ No newline at end of file diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index b9380b3cc..9b29cb542 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -3,8 +3,7 @@ #include "common_includes.h" -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -20,459 +19,464 @@ #include "timer.h" #include "tsl/robin_set.h" -namespace diskann { - -void add_new_file_to_single_index(std::string index_file, - std::string new_file) { - std::unique_ptr metadata; - uint64_t nr, nc; - diskann::load_bin(index_file, metadata, nr, nc); - if (nc != 1) { - std::stringstream stream; - stream << "Error, index file specified does not have correct metadata. " - << std::endl; - throw diskann::ANNException(stream.str(), -1); - } - size_t index_ending_offset = metadata[nr - 1]; - size_t read_blk_size = 64 * 1024 * 1024; - cached_ofstream writer(index_file, read_blk_size); - size_t check_file_size = get_file_size(index_file); - if (check_file_size != index_ending_offset) { - std::stringstream stream; - stream << "Error, index file specified does not have correct metadata " - "(last entry must match the filesize). " - << std::endl; - throw diskann::ANNException(stream.str(), -1); - } - - cached_ifstream reader(new_file, read_blk_size); - size_t fsize = reader.get_file_size(); - if (fsize == 0) { - std::stringstream stream; - stream << "Error, new file specified is empty. Not appending. " - << std::endl; - throw diskann::ANNException(stream.str(), -1); - } - - size_t num_blocks = DIV_ROUND_UP(fsize, read_blk_size); - char *dump = new char[read_blk_size]; - for (uint64_t i = 0; i < num_blocks; i++) { - size_t cur_block_size = read_blk_size > fsize - (i * read_blk_size) - ? fsize - (i * read_blk_size) - : read_blk_size; - reader.read(dump, cur_block_size); - writer.write(dump, cur_block_size); - } - // reader.close(); - // writer.close(); - - delete[] dump; - std::vector new_meta; - for (uint64_t i = 0; i < nr; i++) - new_meta.push_back(metadata[i]); - new_meta.push_back(metadata[nr - 1] + fsize); - - diskann::save_bin(index_file, new_meta.data(), new_meta.size(), 1); +namespace diskann +{ + +void add_new_file_to_single_index(std::string index_file, std::string new_file) +{ + std::unique_ptr metadata; + uint64_t nr, nc; + diskann::load_bin(index_file, metadata, nr, nc); + if (nc != 1) + { + std::stringstream stream; + stream << "Error, index file specified does not have correct metadata. " << std::endl; + throw diskann::ANNException(stream.str(), -1); + } + size_t index_ending_offset = metadata[nr - 1]; + size_t read_blk_size = 64 * 1024 * 1024; + cached_ofstream writer(index_file, read_blk_size); + size_t check_file_size = get_file_size(index_file); + if (check_file_size != index_ending_offset) + { + std::stringstream stream; + stream << "Error, index file specified does not have correct metadata " + "(last entry must match the filesize). " + << std::endl; + throw diskann::ANNException(stream.str(), -1); + } + + cached_ifstream reader(new_file, read_blk_size); + size_t fsize = reader.get_file_size(); + if (fsize == 0) + { + std::stringstream stream; + stream << "Error, new file specified is empty. Not appending. " << std::endl; + throw diskann::ANNException(stream.str(), -1); + } + + size_t num_blocks = DIV_ROUND_UP(fsize, read_blk_size); + char *dump = new char[read_blk_size]; + for (uint64_t i = 0; i < num_blocks; i++) + { + size_t cur_block_size = + read_blk_size > fsize - (i * read_blk_size) ? fsize - (i * read_blk_size) : read_blk_size; + reader.read(dump, cur_block_size); + writer.write(dump, cur_block_size); + } + // reader.close(); + // writer.close(); + + delete[] dump; + std::vector new_meta; + for (uint64_t i = 0; i < nr; i++) + new_meta.push_back(metadata[i]); + new_meta.push_back(metadata[nr - 1] + fsize); + + diskann::save_bin(index_file, new_meta.data(), new_meta.size(), 1); } -double get_memory_budget(double search_ram_budget) { - double final_index_ram_limit = search_ram_budget; - if (search_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB > - THRESHOLD_FOR_CACHING_IN_GB) { // slack for space used by cached - // nodes - final_index_ram_limit = search_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB; - } - return final_index_ram_limit * 1024 * 1024 * 1024; +double get_memory_budget(double search_ram_budget) +{ + double final_index_ram_limit = search_ram_budget; + if (search_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB > THRESHOLD_FOR_CACHING_IN_GB) + { // slack for space used by cached + // nodes + final_index_ram_limit = search_ram_budget - SPACE_FOR_CACHED_NODES_IN_GB; + } + return final_index_ram_limit * 1024 * 1024 * 1024; } -double get_memory_budget(const std::string &mem_budget_str) { - double search_ram_budget = atof(mem_budget_str.c_str()); - return get_memory_budget(search_ram_budget); +double get_memory_budget(const std::string &mem_budget_str) +{ + double search_ram_budget = atof(mem_budget_str.c_str()); + return get_memory_budget(search_ram_budget); } -size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, - uint32_t dim, - const std::vector ¶m_list) { - size_t num_pq_chunks = (size_t)(std::floor)( - uint64_t(final_index_ram_limit / (double)points_num)); - diskann::cout << "Calculated num_pq_chunks :" << num_pq_chunks << std::endl; - if (param_list.size() >= 6) { - float compress_ratio = (float)atof(param_list[5].c_str()); - if (compress_ratio > 0 && compress_ratio <= 1) { - size_t chunks_by_cr = (size_t)(std::floor)(compress_ratio * dim); - - if (chunks_by_cr > 0 && chunks_by_cr < num_pq_chunks) { - diskann::cout << "Compress ratio:" << compress_ratio - << " new #pq_chunks:" << chunks_by_cr << std::endl; - num_pq_chunks = chunks_by_cr; - } else { - diskann::cout << "Compress ratio: " << compress_ratio - << " #new pq_chunks: " << chunks_by_cr - << " is either zero or greater than num_pq_chunks: " - << num_pq_chunks << ". num_pq_chunks is unchanged. " - << std::endl; - } - } else { - diskann::cerr << "Compression ratio: " << compress_ratio - << " should be in (0,1]" << std::endl; +size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim, + const std::vector ¶m_list) +{ + size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / (double)points_num)); + diskann::cout << "Calculated num_pq_chunks :" << num_pq_chunks << std::endl; + if (param_list.size() >= 6) + { + float compress_ratio = (float)atof(param_list[5].c_str()); + if (compress_ratio > 0 && compress_ratio <= 1) + { + size_t chunks_by_cr = (size_t)(std::floor)(compress_ratio * dim); + + if (chunks_by_cr > 0 && chunks_by_cr < num_pq_chunks) + { + diskann::cout << "Compress ratio:" << compress_ratio << " new #pq_chunks:" << chunks_by_cr << std::endl; + num_pq_chunks = chunks_by_cr; + } + else + { + diskann::cout << "Compress ratio: " << compress_ratio << " #new pq_chunks: " << chunks_by_cr + << " is either zero or greater than num_pq_chunks: " << num_pq_chunks + << ". num_pq_chunks is unchanged. " << std::endl; + } + } + else + { + diskann::cerr << "Compression ratio: " << compress_ratio << " should be in (0,1]" << std::endl; + } } - } - num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks; - num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks; - num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks; + num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks; + num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks; + num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks; - diskann::cout << "Compressing " << dim << "-dimensional data into " - << num_pq_chunks << " bytes per vector." << std::endl; - return num_pq_chunks; + diskann::cout << "Compressing " << dim << "-dimensional data into " << num_pq_chunks << " bytes per vector." + << std::endl; + return num_pq_chunks; } -template -T *generateRandomWarmup(uint64_t warmup_num, uint64_t warmup_dim, - uint64_t warmup_aligned_dim) { - T *warmup = nullptr; - warmup_num = 100000; - diskann::cout << "Generating random warmup file with dim " << warmup_dim - << " and aligned dim " << warmup_aligned_dim << std::flush; - diskann::alloc_aligned(((void **)&warmup), - warmup_num * warmup_aligned_dim * sizeof(T), - 8 * sizeof(T)); - std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(-128, 127); - for (uint32_t i = 0; i < warmup_num; i++) { - for (uint32_t d = 0; d < warmup_dim; d++) { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); +template T *generateRandomWarmup(uint64_t warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim) +{ + T *warmup = nullptr; + warmup_num = 100000; + diskann::cout << "Generating random warmup file with dim " << warmup_dim << " and aligned dim " + << warmup_aligned_dim << std::flush; + diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) + { + for (uint32_t d = 0; d < warmup_dim; d++) + { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + } } - } - diskann::cout << "..done" << std::endl; - return warmup; + diskann::cout << "..done" << std::endl; + return warmup; } #ifdef EXEC_ENV_OLS template -T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, - uint64_t &warmup_num, uint64_t warmup_dim, - uint64_t warmup_aligned_dim) { - T *warmup = nullptr; - uint64_t file_dim, file_aligned_dim; - - if (files.fileExists(cache_warmup_file)) { - diskann::load_aligned_bin(files, cache_warmup_file, warmup, warmup_num, - file_dim, file_aligned_dim); - diskann::cout << "In the warmup file: " << cache_warmup_file - << " File dim: " << file_dim - << " File aligned dim: " << file_aligned_dim - << " Expected dim: " << warmup_dim - << " Expected aligned dim: " << warmup_aligned_dim - << std::endl; - - if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim) { - std::stringstream stream; - stream << "Mismatched dimensions in sample file. file_dim = " << file_dim - << " file_aligned_dim: " << file_aligned_dim - << " index_dim: " << warmup_dim - << " index_aligned_dim: " << warmup_aligned_dim << std::endl; - diskann::cerr << stream.str(); - throw diskann::ANNException(stream.str(), -1); +T *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, uint64_t &warmup_num, + uint64_t warmup_dim, uint64_t warmup_aligned_dim) +{ + T *warmup = nullptr; + uint64_t file_dim, file_aligned_dim; + + if (files.fileExists(cache_warmup_file)) + { + diskann::load_aligned_bin(files, cache_warmup_file, warmup, warmup_num, file_dim, file_aligned_dim); + diskann::cout << "In the warmup file: " << cache_warmup_file << " File dim: " << file_dim + << " File aligned dim: " << file_aligned_dim << " Expected dim: " << warmup_dim + << " Expected aligned dim: " << warmup_aligned_dim << std::endl; + + if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim) + { + std::stringstream stream; + stream << "Mismatched dimensions in sample file. file_dim = " << file_dim + << " file_aligned_dim: " << file_aligned_dim << " index_dim: " << warmup_dim + << " index_aligned_dim: " << warmup_aligned_dim << std::endl; + diskann::cerr << stream.str(); + throw diskann::ANNException(stream.str(), -1); + } } - } else { - warmup = - generateRandomWarmup(warmup_num, warmup_dim, warmup_aligned_dim); - } - return warmup; + else + { + warmup = generateRandomWarmup(warmup_num, warmup_dim, warmup_aligned_dim); + } + return warmup; } #endif template -T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim) { - T *warmup = nullptr; - uint64_t file_dim, file_aligned_dim; - - if (file_exists(cache_warmup_file)) { - diskann::load_aligned_bin(cache_warmup_file, warmup, warmup_num, - file_dim, file_aligned_dim); - if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim) { - std::stringstream stream; - stream << "Mismatched dimensions in sample file. file_dim = " << file_dim - << " file_aligned_dim: " << file_aligned_dim - << " index_dim: " << warmup_dim - << " index_aligned_dim: " << warmup_aligned_dim << std::endl; - throw diskann::ANNException(stream.str(), -1); +T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim, + uint64_t warmup_aligned_dim) +{ + T *warmup = nullptr; + uint64_t file_dim, file_aligned_dim; + + if (file_exists(cache_warmup_file)) + { + diskann::load_aligned_bin(cache_warmup_file, warmup, warmup_num, file_dim, file_aligned_dim); + if (file_dim != warmup_dim || file_aligned_dim != warmup_aligned_dim) + { + std::stringstream stream; + stream << "Mismatched dimensions in sample file. file_dim = " << file_dim + << " file_aligned_dim: " << file_aligned_dim << " index_dim: " << warmup_dim + << " index_aligned_dim: " << warmup_aligned_dim << std::endl; + throw diskann::ANNException(stream.str(), -1); + } + } + else + { + warmup = generateRandomWarmup(warmup_num, warmup_dim, warmup_aligned_dim); } - } else { - warmup = - generateRandomWarmup(warmup_num, warmup_dim, warmup_aligned_dim); - } - return warmup; + return warmup; } /*************************************************** Support for Merging Many Vamana Indices ***************************************************/ -void read_idmap(const std::string &fname, std::vector &ivecs) { - uint32_t npts32, dim; - size_t actual_file_size = get_file_size(fname); - std::ifstream reader(fname.c_str(), std::ios::binary); - reader.read((char *)&npts32, sizeof(uint32_t)); - reader.read((char *)&dim, sizeof(uint32_t)); - if (dim != 1 || actual_file_size != ((size_t)npts32) * sizeof(uint32_t) + - 2 * sizeof(uint32_t)) { - std::stringstream stream; - stream << "Error reading idmap file. Check if the file is bin file with " - "1 dimensional data. Actual: " - << actual_file_size - << ", expected: " << (size_t)npts32 + 2 * sizeof(uint32_t) - << std::endl; - - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - ivecs.resize(npts32); - reader.read((char *)ivecs.data(), ((size_t)npts32) * sizeof(uint32_t)); - reader.close(); +void read_idmap(const std::string &fname, std::vector &ivecs) +{ + uint32_t npts32, dim; + size_t actual_file_size = get_file_size(fname); + std::ifstream reader(fname.c_str(), std::ios::binary); + reader.read((char *)&npts32, sizeof(uint32_t)); + reader.read((char *)&dim, sizeof(uint32_t)); + if (dim != 1 || actual_file_size != ((size_t)npts32) * sizeof(uint32_t) + 2 * sizeof(uint32_t)) + { + std::stringstream stream; + stream << "Error reading idmap file. Check if the file is bin file with " + "1 dimensional data. Actual: " + << actual_file_size << ", expected: " << (size_t)npts32 + 2 * sizeof(uint32_t) << std::endl; + + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + ivecs.resize(npts32); + reader.read((char *)ivecs.data(), ((size_t)npts32) * sizeof(uint32_t)); + reader.close(); } -int merge_shards(const std::string &vamana_prefix, - const std::string &vamana_suffix, - const std::string &idmaps_prefix, - const std::string &idmaps_suffix, const uint64_t nshards, - uint32_t max_degree, const std::string &output_vamana, - const std::string &medoids_file, bool use_filters, - const std::string &labels_to_medoids_file) { - // Read ID maps - std::vector vamana_names(nshards); - std::vector> idmaps(nshards); - for (uint64_t shard = 0; shard < nshards; shard++) { - vamana_names[shard] = vamana_prefix + std::to_string(shard) + vamana_suffix; - read_idmap(idmaps_prefix + std::to_string(shard) + idmaps_suffix, - idmaps[shard]); - } - - // find max node id - size_t nnodes = 0; - size_t nelems = 0; - for (auto &idmap : idmaps) { - for (auto &id : idmap) { - nnodes = std::max(nnodes, (size_t)id); +int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix, const std::string &idmaps_prefix, + const std::string &idmaps_suffix, const uint64_t nshards, uint32_t max_degree, + const std::string &output_vamana, const std::string &medoids_file, bool use_filters, + const std::string &labels_to_medoids_file) +{ + // Read ID maps + std::vector vamana_names(nshards); + std::vector> idmaps(nshards); + for (uint64_t shard = 0; shard < nshards; shard++) + { + vamana_names[shard] = vamana_prefix + std::to_string(shard) + vamana_suffix; + read_idmap(idmaps_prefix + std::to_string(shard) + idmaps_suffix, idmaps[shard]); + } + + // find max node id + size_t nnodes = 0; + size_t nelems = 0; + for (auto &idmap : idmaps) + { + for (auto &id : idmap) + { + nnodes = std::max(nnodes, (size_t)id); + } + nelems += idmap.size(); } - nelems += idmap.size(); - } - nnodes++; - diskann::cout << "# nodes: " << nnodes << ", max. degree: " << max_degree - << std::endl; - - // compute inverse map: node -> shards - std::vector> node_shard; - node_shard.reserve(nelems); - for (size_t shard = 0; shard < nshards; shard++) { - diskann::cout << "Creating inverse map -- shard #" << shard << std::endl; - for (size_t idx = 0; idx < idmaps[shard].size(); idx++) { - size_t node_id = idmaps[shard][idx]; - node_shard.push_back(std::make_pair((uint32_t)node_id, (uint32_t)shard)); + nnodes++; + diskann::cout << "# nodes: " << nnodes << ", max. degree: " << max_degree << std::endl; + + // compute inverse map: node -> shards + std::vector> node_shard; + node_shard.reserve(nelems); + for (size_t shard = 0; shard < nshards; shard++) + { + diskann::cout << "Creating inverse map -- shard #" << shard << std::endl; + for (size_t idx = 0; idx < idmaps[shard].size(); idx++) + { + size_t node_id = idmaps[shard][idx]; + node_shard.push_back(std::make_pair((uint32_t)node_id, (uint32_t)shard)); + } } - } - std::sort(node_shard.begin(), node_shard.end(), - [](const auto &left, const auto &right) { - return left.first < right.first || - (left.first == right.first && left.second < right.second); - }); - diskann::cout << "Finished computing node -> shards map" << std::endl; - - // will merge all the labels to medoids files of each shard into one - // combined file - if (use_filters) { - std::unordered_map> global_label_to_medoids; - - for (size_t i = 0; i < nshards; i++) { - std::ifstream mapping_reader; - std::string map_file = vamana_names[i] + "_labels_to_medoids.txt"; - mapping_reader.open(map_file); - - std::string line, token; - uint32_t line_cnt = 0; - - while (std::getline(mapping_reader, line)) { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t medoid = 0; - uint32_t label = 0; - while (std::getline(iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), - token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), - token.end()); - - uint32_t token_as_num = std::stoul(token); - - if (cnt == 0) - label = token_as_num; - else - medoid = token_as_num; - cnt++; + std::sort(node_shard.begin(), node_shard.end(), [](const auto &left, const auto &right) { + return left.first < right.first || (left.first == right.first && left.second < right.second); + }); + diskann::cout << "Finished computing node -> shards map" << std::endl; + + // will merge all the labels to medoids files of each shard into one + // combined file + if (use_filters) + { + std::unordered_map> global_label_to_medoids; + + for (size_t i = 0; i < nshards; i++) + { + std::ifstream mapping_reader; + std::string map_file = vamana_names[i] + "_labels_to_medoids.txt"; + mapping_reader.open(map_file); + + std::string line, token; + uint32_t line_cnt = 0; + + while (std::getline(mapping_reader, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t medoid = 0; + uint32_t label = 0; + while (std::getline(iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + + uint32_t token_as_num = std::stoul(token); + + if (cnt == 0) + label = token_as_num; + else + medoid = token_as_num; + cnt++; + } + global_label_to_medoids[label].push_back(idmaps[i][medoid]); + line_cnt++; + } + mapping_reader.close(); + } + + std::ofstream mapping_writer(labels_to_medoids_file); + assert(mapping_writer.is_open()); + for (auto iter : global_label_to_medoids) + { + mapping_writer << iter.first << ", "; + auto &vec = iter.second; + for (uint32_t idx = 0; idx < vec.size() - 1; idx++) + { + mapping_writer << vec[idx] << ", "; + } + mapping_writer << vec[vec.size() - 1] << std::endl; } - global_label_to_medoids[label].push_back(idmaps[i][medoid]); - line_cnt++; - } - mapping_reader.close(); + mapping_writer.close(); } - std::ofstream mapping_writer(labels_to_medoids_file); - assert(mapping_writer.is_open()); - for (auto iter : global_label_to_medoids) { - mapping_writer << iter.first << ", "; - auto &vec = iter.second; - for (uint32_t idx = 0; idx < vec.size() - 1; idx++) { - mapping_writer << vec[idx] << ", "; - } - mapping_writer << vec[vec.size() - 1] << std::endl; + // create cached vamana readers + std::vector vamana_readers(nshards); + for (size_t i = 0; i < nshards; i++) + { + vamana_readers[i].open(vamana_names[i], BUFFER_SIZE_FOR_CACHED_IO); + size_t expected_file_size; + vamana_readers[i].read((char *)&expected_file_size, sizeof(uint64_t)); } - mapping_writer.close(); - } - - // create cached vamana readers - std::vector vamana_readers(nshards); - for (size_t i = 0; i < nshards; i++) { - vamana_readers[i].open(vamana_names[i], BUFFER_SIZE_FOR_CACHED_IO); - size_t expected_file_size; - vamana_readers[i].read((char *)&expected_file_size, sizeof(uint64_t)); - } - - size_t vamana_metadata_size = - sizeof(uint64_t) + sizeof(uint32_t) + sizeof(uint32_t) + - sizeof(uint64_t); // expected file size + max degree + - // medoid_id + frozen_point info - - // create cached vamana writers - cached_ofstream merged_vamana_writer(output_vamana, - BUFFER_SIZE_FOR_CACHED_IO); - - size_t merged_index_size = - vamana_metadata_size; // we initialize the size of the merged index to - // the metadata size - size_t merged_index_frozen = 0; - merged_vamana_writer.write( - (char *)&merged_index_size, - sizeof(uint64_t)); // we will overwrite the index size at the end - - uint32_t output_width = max_degree; - uint32_t max_input_width = 0; - // read width from each vamana to advance buffer by sizeof(uint32_t) bytes - for (auto &reader : vamana_readers) { - uint32_t input_width; - reader.read((char *)&input_width, sizeof(uint32_t)); - max_input_width = - input_width > max_input_width ? input_width : max_input_width; - } - - diskann::cout << "Max input width: " << max_input_width - << ", output width: " << output_width << std::endl; - - merged_vamana_writer.write((char *)&output_width, sizeof(uint32_t)); - std::ofstream medoid_writer(medoids_file.c_str(), std::ios::binary); - uint32_t nshards_u32 = (uint32_t)nshards; - uint32_t one_val = 1; - medoid_writer.write((char *)&nshards_u32, sizeof(uint32_t)); - medoid_writer.write((char *)&one_val, sizeof(uint32_t)); - - uint64_t vamana_index_frozen = - 0; // as of now the functionality to merge many overlapping vamana - // indices is supported only for bulk indices without frozen point. - // Hence the final index will also not have any frozen points. - for (uint64_t shard = 0; shard < nshards; shard++) { - uint32_t medoid; - // read medoid - vamana_readers[shard].read((char *)&medoid, sizeof(uint32_t)); - vamana_readers[shard].read((char *)&vamana_index_frozen, sizeof(uint64_t)); - assert(vamana_index_frozen == false); - // rename medoid - medoid = idmaps[shard][medoid]; - - medoid_writer.write((char *)&medoid, sizeof(uint32_t)); - // write renamed medoid - if (shard == (nshards - 1)) //--> uncomment if running hierarchical - merged_vamana_writer.write((char *)&medoid, sizeof(uint32_t)); - } - merged_vamana_writer.write((char *)&merged_index_frozen, sizeof(uint64_t)); - medoid_writer.close(); - - diskann::cout << "Starting merge" << std::endl; - - // Gopal. random_shuffle() is deprecated. - std::random_device rng; - std::mt19937 urng(rng()); - - std::vector nhood_set(nnodes, 0); - std::vector final_nhood; - - uint32_t nnbrs = 0, shard_nnbrs = 0; - uint32_t cur_id = 0; - for (const auto &id_shard : node_shard) { - uint32_t node_id = id_shard.first; - uint32_t shard_id = id_shard.second; - if (cur_id < node_id) { - // Gopal. random_shuffle() is deprecated. - std::shuffle(final_nhood.begin(), final_nhood.end(), urng); - nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree); - // write into merged ofstream - merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t)); - merged_vamana_writer.write((char *)final_nhood.data(), - nnbrs * sizeof(uint32_t)); - merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t)); - if (cur_id % 499999 == 1) { - diskann::cout << "." << std::flush; - } - cur_id = node_id; - nnbrs = 0; - for (auto &p : final_nhood) - nhood_set[p] = 0; - final_nhood.clear(); + + size_t vamana_metadata_size = + sizeof(uint64_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint64_t); // expected file size + max degree + + // medoid_id + frozen_point info + + // create cached vamana writers + cached_ofstream merged_vamana_writer(output_vamana, BUFFER_SIZE_FOR_CACHED_IO); + + size_t merged_index_size = vamana_metadata_size; // we initialize the size of the merged index to + // the metadata size + size_t merged_index_frozen = 0; + merged_vamana_writer.write((char *)&merged_index_size, + sizeof(uint64_t)); // we will overwrite the index size at the end + + uint32_t output_width = max_degree; + uint32_t max_input_width = 0; + // read width from each vamana to advance buffer by sizeof(uint32_t) bytes + for (auto &reader : vamana_readers) + { + uint32_t input_width; + reader.read((char *)&input_width, sizeof(uint32_t)); + max_input_width = input_width > max_input_width ? input_width : max_input_width; } - // read from shard_id ifstream - vamana_readers[shard_id].read((char *)&shard_nnbrs, sizeof(uint32_t)); - if (shard_nnbrs == 0) { - diskann::cout << "WARNING: shard #" << shard_id << ", node_id " << node_id - << " has 0 nbrs" << std::endl; + diskann::cout << "Max input width: " << max_input_width << ", output width: " << output_width << std::endl; + + merged_vamana_writer.write((char *)&output_width, sizeof(uint32_t)); + std::ofstream medoid_writer(medoids_file.c_str(), std::ios::binary); + uint32_t nshards_u32 = (uint32_t)nshards; + uint32_t one_val = 1; + medoid_writer.write((char *)&nshards_u32, sizeof(uint32_t)); + medoid_writer.write((char *)&one_val, sizeof(uint32_t)); + + uint64_t vamana_index_frozen = 0; // as of now the functionality to merge many overlapping vamana + // indices is supported only for bulk indices without frozen point. + // Hence the final index will also not have any frozen points. + for (uint64_t shard = 0; shard < nshards; shard++) + { + uint32_t medoid; + // read medoid + vamana_readers[shard].read((char *)&medoid, sizeof(uint32_t)); + vamana_readers[shard].read((char *)&vamana_index_frozen, sizeof(uint64_t)); + assert(vamana_index_frozen == false); + // rename medoid + medoid = idmaps[shard][medoid]; + + medoid_writer.write((char *)&medoid, sizeof(uint32_t)); + // write renamed medoid + if (shard == (nshards - 1)) //--> uncomment if running hierarchical + merged_vamana_writer.write((char *)&medoid, sizeof(uint32_t)); + } + merged_vamana_writer.write((char *)&merged_index_frozen, sizeof(uint64_t)); + medoid_writer.close(); + + diskann::cout << "Starting merge" << std::endl; + + // Gopal. random_shuffle() is deprecated. + std::random_device rng; + std::mt19937 urng(rng()); + + std::vector nhood_set(nnodes, 0); + std::vector final_nhood; + + uint32_t nnbrs = 0, shard_nnbrs = 0; + uint32_t cur_id = 0; + for (const auto &id_shard : node_shard) + { + uint32_t node_id = id_shard.first; + uint32_t shard_id = id_shard.second; + if (cur_id < node_id) + { + // Gopal. random_shuffle() is deprecated. + std::shuffle(final_nhood.begin(), final_nhood.end(), urng); + nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree); + // write into merged ofstream + merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t)); + merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(uint32_t)); + merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t)); + if (cur_id % 499999 == 1) + { + diskann::cout << "." << std::flush; + } + cur_id = node_id; + nnbrs = 0; + for (auto &p : final_nhood) + nhood_set[p] = 0; + final_nhood.clear(); + } + // read from shard_id ifstream + vamana_readers[shard_id].read((char *)&shard_nnbrs, sizeof(uint32_t)); + + if (shard_nnbrs == 0) + { + diskann::cout << "WARNING: shard #" << shard_id << ", node_id " << node_id << " has 0 nbrs" << std::endl; + } + + std::vector shard_nhood(shard_nnbrs); + if (shard_nnbrs > 0) + vamana_readers[shard_id].read((char *)shard_nhood.data(), shard_nnbrs * sizeof(uint32_t)); + // rename nodes + for (uint64_t j = 0; j < shard_nnbrs; j++) + { + if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) + { + nhood_set[idmaps[shard_id][shard_nhood[j]]] = 1; + final_nhood.emplace_back(idmaps[shard_id][shard_nhood[j]]); + } + } } - std::vector shard_nhood(shard_nnbrs); - if (shard_nnbrs > 0) - vamana_readers[shard_id].read((char *)shard_nhood.data(), - shard_nnbrs * sizeof(uint32_t)); - // rename nodes - for (uint64_t j = 0; j < shard_nnbrs; j++) { - if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) { - nhood_set[idmaps[shard_id][shard_nhood[j]]] = 1; - final_nhood.emplace_back(idmaps[shard_id][shard_nhood[j]]); - } + // Gopal. random_shuffle() is deprecated. + std::shuffle(final_nhood.begin(), final_nhood.end(), urng); + nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree); + // write into merged ofstream + merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t)); + if (nnbrs > 0) + { + merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(uint32_t)); } - } - - // Gopal. random_shuffle() is deprecated. - std::shuffle(final_nhood.begin(), final_nhood.end(), urng); - nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree); - // write into merged ofstream - merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t)); - if (nnbrs > 0) { - merged_vamana_writer.write((char *)final_nhood.data(), - nnbrs * sizeof(uint32_t)); - } - merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t)); - for (auto &p : final_nhood) - nhood_set[p] = 0; - final_nhood.clear(); - - diskann::cout << "Expected size: " << merged_index_size << std::endl; - - merged_vamana_writer.reset(); - merged_vamana_writer.write((char *)&merged_index_size, sizeof(uint64_t)); - - diskann::cout << "Finished merge" << std::endl; - return 0; + merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t)); + for (auto &p : final_nhood) + nhood_set[p] = 0; + final_nhood.clear(); + + diskann::cout << "Expected size: " << merged_index_size << std::endl; + + merged_vamana_writer.reset(); + merged_vamana_writer.write((char *)&merged_index_size, sizeof(uint64_t)); + + diskann::cout << "Finished merge" << std::endl; + return 0; } // TODO: Make this a streaming implementation to avoid exceeding the memory @@ -484,307 +488,303 @@ int merge_shards(const std::string &vamana_prefix, the new nodes at the end. The dummy map contains the real graph id of the new nodes added to the graph */ template -void breakup_dense_points(const std::string data_file, - const std::string labels_file, uint32_t density, - const std::string out_data_file, - const std::string out_labels_file, - const std::string out_metadata_file) { - std::string token, line; - std::ifstream labels_stream(labels_file); - T *data; - uint64_t npts, ndims; - diskann::load_bin(data_file, data, npts, ndims); - - std::unordered_map dummy_pt_ids; - uint32_t next_dummy_id = (uint32_t)npts; - - uint32_t point_cnt = 0; - - std::vector> labels_per_point; - labels_per_point.resize(npts); - - uint32_t dense_pts = 0; - if (labels_stream.is_open()) { - while (getline(labels_stream, line)) { - std::stringstream iss(line); - uint32_t lbl_cnt = 0; - uint32_t label_host = point_cnt; - while (getline(iss, token, ',')) { - if (lbl_cnt == density) { - if (label_host == point_cnt) - dense_pts++; - label_host = next_dummy_id; - labels_per_point.resize(next_dummy_id + 1); - dummy_pt_ids[next_dummy_id] = (uint32_t)point_cnt; - next_dummy_id++; - lbl_cnt = 0; +void breakup_dense_points(const std::string data_file, const std::string labels_file, uint32_t density, + const std::string out_data_file, const std::string out_labels_file, + const std::string out_metadata_file) +{ + std::string token, line; + std::ifstream labels_stream(labels_file); + T *data; + uint64_t npts, ndims; + diskann::load_bin(data_file, data, npts, ndims); + + std::unordered_map dummy_pt_ids; + uint32_t next_dummy_id = (uint32_t)npts; + + uint32_t point_cnt = 0; + + std::vector> labels_per_point; + labels_per_point.resize(npts); + + uint32_t dense_pts = 0; + if (labels_stream.is_open()) + { + while (getline(labels_stream, line)) + { + std::stringstream iss(line); + uint32_t lbl_cnt = 0; + uint32_t label_host = point_cnt; + while (getline(iss, token, ',')) + { + if (lbl_cnt == density) + { + if (label_host == point_cnt) + dense_pts++; + label_host = next_dummy_id; + labels_per_point.resize(next_dummy_id + 1); + dummy_pt_ids[next_dummy_id] = (uint32_t)point_cnt; + next_dummy_id++; + lbl_cnt = 0; + } + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = std::stoul(token); + labels_per_point[label_host].push_back(token_as_num); + lbl_cnt++; + } + point_cnt++; } - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - uint32_t token_as_num = std::stoul(token); - labels_per_point[label_host].push_back(token_as_num); - lbl_cnt++; - } - point_cnt++; } - } - diskann::cout << "fraction of dense points with >= " << density - << " labels = " << (float)dense_pts / (float)npts << std::endl; - - if (labels_per_point.size() != 0) { - diskann::cout << labels_per_point.size() << " is the new number of points" + diskann::cout << "fraction of dense points with >= " << density << " labels = " << (float)dense_pts / (float)npts << std::endl; - std::ofstream label_writer(out_labels_file); - assert(label_writer.is_open()); - for (uint32_t i = 0; i < labels_per_point.size(); i++) { - for (uint32_t j = 0; j < (labels_per_point[i].size() - 1); j++) { - label_writer << labels_per_point[i][j] << ","; - } - if (labels_per_point[i].size() != 0) - label_writer << labels_per_point[i][labels_per_point[i].size() - 1]; - label_writer << std::endl; - } - label_writer.close(); - } - - if (dummy_pt_ids.size() != 0) { - diskann::cout << dummy_pt_ids.size() - << " is the number of dummy points created" << std::endl; - - T *ptr = (T *)std::realloc((void *)data, - labels_per_point.size() * ndims * sizeof(T)); - if (ptr == nullptr) { - diskann::cerr << "Realloc failed while creating dummy points" - << std::endl; - free(data); - data = nullptr; - throw new diskann::ANNException("Realloc failed while expanding data.", - -1, __FUNCTION__, __FILE__, __LINE__); - } else { - data = ptr; + + if (labels_per_point.size() != 0) + { + diskann::cout << labels_per_point.size() << " is the new number of points" << std::endl; + std::ofstream label_writer(out_labels_file); + assert(label_writer.is_open()); + for (uint32_t i = 0; i < labels_per_point.size(); i++) + { + for (uint32_t j = 0; j < (labels_per_point[i].size() - 1); j++) + { + label_writer << labels_per_point[i][j] << ","; + } + if (labels_per_point[i].size() != 0) + label_writer << labels_per_point[i][labels_per_point[i].size() - 1]; + label_writer << std::endl; + } + label_writer.close(); } - std::ofstream dummy_writer(out_metadata_file); - assert(dummy_writer.is_open()); - for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) { - dummy_writer << i->first << "," << i->second << std::endl; - std::memcpy(data + i->first * ndims, data + i->second * ndims, - ndims * sizeof(T)); + if (dummy_pt_ids.size() != 0) + { + diskann::cout << dummy_pt_ids.size() << " is the number of dummy points created" << std::endl; + + T *ptr = (T *)std::realloc((void *)data, labels_per_point.size() * ndims * sizeof(T)); + if (ptr == nullptr) + { + diskann::cerr << "Realloc failed while creating dummy points" << std::endl; + free(data); + data = nullptr; + throw new diskann::ANNException("Realloc failed while expanding data.", -1, __FUNCTION__, __FILE__, + __LINE__); + } + else + { + data = ptr; + } + + std::ofstream dummy_writer(out_metadata_file); + assert(dummy_writer.is_open()); + for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) + { + dummy_writer << i->first << "," << i->second << std::endl; + std::memcpy(data + i->first * ndims, data + i->second * ndims, ndims * sizeof(T)); + } + dummy_writer.close(); } - dummy_writer.close(); - } - diskann::save_bin(out_data_file, data, labels_per_point.size(), ndims); + diskann::save_bin(out_data_file, data, labels_per_point.size(), ndims); } -void extract_shard_labels( - const std::string &in_label_file, const std::string &shard_ids_bin, - const std::string &shard_label_file) { // assumes ith row is for ith - // point in labels file - diskann::cout << "Extracting labels for shard" << std::endl; - - uint32_t *ids = nullptr; - uint64_t num_ids, tmp_dim; - diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim); - - uint32_t counter = 0, shard_counter = 0; - std::string cur_line; - - std::ifstream label_reader(in_label_file); - std::ofstream label_writer(shard_label_file); - assert(label_reader.is_open()); - assert(label_reader.is_open()); - if (label_reader && label_writer) { - while (std::getline(label_reader, cur_line)) { - if (shard_counter >= num_ids) { - break; - } - if (counter == ids[shard_counter]) { - label_writer << cur_line << "\n"; - shard_counter++; - } - counter++; +void extract_shard_labels(const std::string &in_label_file, const std::string &shard_ids_bin, + const std::string &shard_label_file) +{ // assumes ith row is for ith + // point in labels file + diskann::cout << "Extracting labels for shard" << std::endl; + + uint32_t *ids = nullptr; + uint64_t num_ids, tmp_dim; + diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim); + + uint32_t counter = 0, shard_counter = 0; + std::string cur_line; + + std::ifstream label_reader(in_label_file); + std::ofstream label_writer(shard_label_file); + assert(label_reader.is_open()); + assert(label_reader.is_open()); + if (label_reader && label_writer) + { + while (std::getline(label_reader, cur_line)) + { + if (shard_counter >= num_ids) + { + break; + } + if (counter == ids[shard_counter]) + { + label_writer << cur_line << "\n"; + shard_counter++; + } + counter++; + } } - } - if (ids != nullptr) - delete[] ids; + if (ids != nullptr) + delete[] ids; } template -int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_file, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf) { - size_t base_num, base_dim; - diskann::get_bin_metadata(base_file, base_num, base_dim); - - double full_index_ram = - estimate_ram_usage(base_num, (uint32_t)base_dim, sizeof(T), R); - - // TODO: Make this honest when there is filter support - if (full_index_ram < ram_budget * 1024 * 1024 * 1024) { - diskann::cout << "Full index fits in RAM budget, should consume at most " - << full_index_ram / (1024 * 1024 * 1024) - << "GiBs, so building in one shot" << std::endl; - - diskann::IndexWriteParameters paras = - diskann::IndexWriteParametersBuilder(L, R) - .with_filter_list_size(Lf) - .with_saturate_graph(!use_filters) - .with_num_threads(num_threads) - .build(); - using TagT = uint32_t; - diskann::Index _index( - compareMetric, base_dim, base_num, - std::make_shared(paras), nullptr, - defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, - build_pq_bytes > 0, build_pq_bytes, use_opq, use_filters); - if (!use_filters) - _index.build(base_file.c_str(), base_num); - else { - if (universal_label != "") { // indicates no universal label - LabelT unv_label_as_num = 0; - _index.set_universal_label(unv_label_as_num); - } - _index.build_filtered_index(base_file.c_str(), label_file, base_num); - } - _index.save(mem_index_path.c_str()); - - if (use_filters) { - // need to copy the labels_to_medoids file to the specified input - // file - std::remove(labels_to_medoids_file.c_str()); - std::string mem_labels_to_medoid_file = - mem_index_path + "_labels_to_medoids.txt"; - copy_file(mem_labels_to_medoid_file, labels_to_medoids_file); - std::remove(mem_labels_to_medoid_file.c_str()); +int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, + double sampling_rate, double ram_budget, std::string mem_index_path, + std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, + uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, + const uint32_t Lf) +{ + size_t base_num, base_dim; + diskann::get_bin_metadata(base_file, base_num, base_dim); + + double full_index_ram = estimate_ram_usage(base_num, (uint32_t)base_dim, sizeof(T), R); + + // TODO: Make this honest when there is filter support + if (full_index_ram < ram_budget * 1024 * 1024 * 1024) + { + diskann::cout << "Full index fits in RAM budget, should consume at most " + << full_index_ram / (1024 * 1024 * 1024) << "GiBs, so building in one shot" << std::endl; + + diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_saturate_graph(!use_filters) + .with_num_threads(num_threads) + .build(); + using TagT = uint32_t; + diskann::Index _index(compareMetric, base_dim, base_num, + std::make_shared(paras), nullptr, + defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, + build_pq_bytes > 0, build_pq_bytes, use_opq, use_filters); + if (!use_filters) + _index.build(base_file.c_str(), base_num); + else + { + if (universal_label != "") + { // indicates no universal label + LabelT unv_label_as_num = 0; + _index.set_universal_label(unv_label_as_num); + } + _index.build_filtered_index(base_file.c_str(), label_file, base_num); + } + _index.save(mem_index_path.c_str()); + + if (use_filters) + { + // need to copy the labels_to_medoids file to the specified input + // file + std::remove(labels_to_medoids_file.c_str()); + std::string mem_labels_to_medoid_file = mem_index_path + "_labels_to_medoids.txt"; + copy_file(mem_labels_to_medoid_file, labels_to_medoids_file); + std::remove(mem_labels_to_medoid_file.c_str()); + } + + std::remove(medoids_file.c_str()); + std::remove(centroids_file.c_str()); + return 0; } - std::remove(medoids_file.c_str()); - std::remove(centroids_file.c_str()); - return 0; - } + // where the universal label is to be saved in the final graph + std::string final_index_universal_label_file = mem_index_path + "_universal_label.txt"; - // where the universal label is to be saved in the final graph - std::string final_index_universal_label_file = - mem_index_path + "_universal_label.txt"; + std::string merged_index_prefix = mem_index_path + "_tempFiles"; - std::string merged_index_prefix = mem_index_path + "_tempFiles"; + Timer timer; + int num_parts = + partition_with_ram_budget(base_file, sampling_rate, ram_budget, 2 * R / 3, merged_index_prefix, 2); + diskann::cout << timer.elapsed_seconds_for_step("partitioning data ") << std::endl; + + std::string cur_centroid_filepath = merged_index_prefix + "_centroids.bin"; + std::rename(cur_centroid_filepath.c_str(), centroids_file.c_str()); + + timer.reset(); + for (int p = 0; p < num_parts; p++) + { +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) + MallocExtension::instance()->ReleaseFreeMemory(); +#endif - Timer timer; - int num_parts = partition_with_ram_budget( - base_file, sampling_rate, ram_budget, 2 * R / 3, merged_index_prefix, 2); - diskann::cout << timer.elapsed_seconds_for_step("partitioning data ") - << std::endl; + std::string shard_base_file = merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin"; - std::string cur_centroid_filepath = merged_index_prefix + "_centroids.bin"; - std::rename(cur_centroid_filepath.c_str(), centroids_file.c_str()); + std::string shard_ids_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin"; - timer.reset(); - for (int p = 0; p < num_parts; p++) { -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) - MallocExtension::instance()->ReleaseFreeMemory(); -#endif + std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; - std::string shard_base_file = - merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin"; - - std::string shard_ids_file = merged_index_prefix + "_subshard-" + - std::to_string(p) + "_ids_uint32.bin"; - - std::string shard_labels_file = - merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; - - retrieve_shard_data_from_ids(base_file, shard_ids_file, shard_base_file); - - std::string shard_index_file = - merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; - - diskann::IndexWriteParameters low_degree_params = - diskann::IndexWriteParametersBuilder(L, 2 * R / 3) - .with_filter_list_size(Lf) - .with_saturate_graph(false) - .with_num_threads(num_threads) - .build(); - - uint64_t shard_base_dim, shard_base_pts; - get_bin_metadata(shard_base_file, shard_base_pts, shard_base_dim); - - diskann::Index _index( - compareMetric, shard_base_dim, shard_base_pts, - std::make_shared(low_degree_params), - nullptr, defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, - build_pq_bytes > 0, build_pq_bytes, use_opq); - if (!use_filters) { - _index.build(shard_base_file.c_str(), shard_base_pts); - } else { - diskann::extract_shard_labels(label_file, shard_ids_file, - shard_labels_file); - if (universal_label != "") { // indicates no universal label - LabelT unv_label_as_num = 0; - _index.set_universal_label(unv_label_as_num); - } - _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, - shard_base_pts); - } - _index.save(shard_index_file.c_str()); - // copy universal label file from first shard to the final destination - // index, since all shards anyway share the universal label - if (p == 0) { - std::string shard_universal_label_file = - shard_index_file + "_universal_label.txt"; - if (universal_label != "") { - copy_file(shard_universal_label_file, final_index_universal_label_file); - } - } + retrieve_shard_data_from_ids(base_file, shard_ids_file, shard_base_file); - std::remove(shard_base_file.c_str()); - } - diskann::cout << timer.elapsed_seconds_for_step("building indices on shards") - << std::endl; - - timer.reset(); - diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index", - merged_index_prefix + "_subshard-", "_ids_uint32.bin", - num_parts, R, mem_index_path, medoids_file, use_filters, - labels_to_medoids_file); - diskann::cout << timer.elapsed_seconds_for_step("merging indices") - << std::endl; - - // delete tempFiles - for (int p = 0; p < num_parts; p++) { - std::string shard_base_file = - merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin"; - std::string shard_id_file = merged_index_prefix + "_subshard-" + - std::to_string(p) + "_ids_uint32.bin"; - std::string shard_labels_file = - merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; - std::string shard_index_file = - merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; - std::string shard_index_file_data = shard_index_file + ".data"; - - std::remove(shard_base_file.c_str()); - std::remove(shard_id_file.c_str()); - std::remove(shard_index_file.c_str()); - std::remove(shard_index_file_data.c_str()); - if (use_filters) { - std::string shard_index_label_file = shard_index_file + "_labels.txt"; - std::string shard_index_univ_label_file = - shard_index_file + "_universal_label.txt"; - std::string shard_index_label_map_file = - shard_index_file + "_labels_to_medoids.txt"; - std::remove(shard_labels_file.c_str()); - std::remove(shard_index_label_file.c_str()); - std::remove(shard_index_label_map_file.c_str()); - std::remove(shard_index_univ_label_file.c_str()); + std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; + + diskann::IndexWriteParameters low_degree_params = diskann::IndexWriteParametersBuilder(L, 2 * R / 3) + .with_filter_list_size(Lf) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + uint64_t shard_base_dim, shard_base_pts; + get_bin_metadata(shard_base_file, shard_base_pts, shard_base_dim); + + diskann::Index _index(compareMetric, shard_base_dim, shard_base_pts, + std::make_shared(low_degree_params), nullptr, + defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, build_pq_bytes > 0, + build_pq_bytes, use_opq); + if (!use_filters) + { + _index.build(shard_base_file.c_str(), shard_base_pts); + } + else + { + diskann::extract_shard_labels(label_file, shard_ids_file, shard_labels_file); + if (universal_label != "") + { // indicates no universal label + LabelT unv_label_as_num = 0; + _index.set_universal_label(unv_label_as_num); + } + _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts); + } + _index.save(shard_index_file.c_str()); + // copy universal label file from first shard to the final destination + // index, since all shards anyway share the universal label + if (p == 0) + { + std::string shard_universal_label_file = shard_index_file + "_universal_label.txt"; + if (universal_label != "") + { + copy_file(shard_universal_label_file, final_index_universal_label_file); + } + } + + std::remove(shard_base_file.c_str()); + } + diskann::cout << timer.elapsed_seconds_for_step("building indices on shards") << std::endl; + + timer.reset(); + diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index", merged_index_prefix + "_subshard-", + "_ids_uint32.bin", num_parts, R, mem_index_path, medoids_file, use_filters, + labels_to_medoids_file); + diskann::cout << timer.elapsed_seconds_for_step("merging indices") << std::endl; + + // delete tempFiles + for (int p = 0; p < num_parts; p++) + { + std::string shard_base_file = merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin"; + std::string shard_id_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin"; + std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; + std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; + std::string shard_index_file_data = shard_index_file + ".data"; + + std::remove(shard_base_file.c_str()); + std::remove(shard_id_file.c_str()); + std::remove(shard_index_file.c_str()); + std::remove(shard_index_file_data.c_str()); + if (use_filters) + { + std::string shard_index_label_file = shard_index_file + "_labels.txt"; + std::string shard_index_univ_label_file = shard_index_file + "_universal_label.txt"; + std::string shard_index_label_map_file = shard_index_file + "_labels_to_medoids.txt"; + std::remove(shard_labels_file.c_str()); + std::remove(shard_index_label_file.c_str()); + std::remove(shard_index_label_map_file.c_str()); + std::remove(shard_index_univ_label_file.c_str()); + } } - } - return 0; + return 0; } // General purpose support for DiskANN interface @@ -792,778 +792,718 @@ int build_merged_vamana_index( // optimizes the beamwidth to maximize QPS for a given L_search subject to // 99.9 latency not blowing up template -uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - T *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw) { - uint32_t cur_bw = start_bw; - double max_qps = 0; - uint32_t best_bw = start_bw; - bool stop_flag = false; - - while (!stop_flag) { - std::vector tuning_sample_result_ids_64(tuning_sample_num, 0); - std::vector tuning_sample_result_dists(tuning_sample_num, 0); - diskann::QueryStats *stats = new diskann::QueryStats[tuning_sample_num]; - - auto s = std::chrono::high_resolution_clock::now(); +uint32_t optimize_beamwidth(std::unique_ptr> &pFlashIndex, T *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, + uint32_t nthreads, uint32_t start_bw) +{ + uint32_t cur_bw = start_bw; + double max_qps = 0; + uint32_t best_bw = start_bw; + bool stop_flag = false; + + while (!stop_flag) + { + std::vector tuning_sample_result_ids_64(tuning_sample_num, 0); + std::vector tuning_sample_result_dists(tuning_sample_num, 0); + diskann::QueryStats *stats = new diskann::QueryStats[tuning_sample_num]; + + auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < (int64_t)tuning_sample_num; i++) { - pFlashIndex->cached_beam_search( - tuning_sample + (i * tuning_sample_aligned_dim), 1, L, - tuning_sample_result_ids_64.data() + (i * 1), - tuning_sample_result_dists.data() + (i * 1), cur_bw, false, - stats + i); - } - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - double qps = - (1.0f * (float)tuning_sample_num) / (1.0f * (float)diff.count()); - - double lat_999 = diskann::get_percentile_stats( - stats, tuning_sample_num, 0.999f, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - double mean_latency = diskann::get_mean_stats( - stats, tuning_sample_num, - [](const diskann::QueryStats &stats) { return stats.total_us; }); - - if (qps > max_qps && lat_999 < (15000) + mean_latency * 2) { - max_qps = qps; - best_bw = cur_bw; - cur_bw = (uint32_t)(std::ceil)((float)cur_bw * 1.1f); - } else { - stop_flag = true; - } - if (cur_bw > 64) - stop_flag = true; + for (int64_t i = 0; i < (int64_t)tuning_sample_num; i++) + { + pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L, + tuning_sample_result_ids_64.data() + (i * 1), + tuning_sample_result_dists.data() + (i * 1), cur_bw, false, stats + i); + } + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + double qps = (1.0f * (float)tuning_sample_num) / (1.0f * (float)diff.count()); + + double lat_999 = diskann::get_percentile_stats( + stats, tuning_sample_num, 0.999f, [](const diskann::QueryStats &stats) { return stats.total_us; }); - delete[] stats; - } - return best_bw; + double mean_latency = diskann::get_mean_stats( + stats, tuning_sample_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); + + if (qps > max_qps && lat_999 < (15000) + mean_latency * 2) + { + max_qps = qps; + best_bw = cur_bw; + cur_bw = (uint32_t)(std::ceil)((float)cur_bw * 1.1f); + } + else + { + stop_flag = true; + } + if (cur_bw > 64) + stop_flag = true; + + delete[] stats; + } + return best_bw; } template -void create_disk_layout(const std::string base_file, - const std::string mem_index_file, - const std::string output_file, - const std::string reorder_data_file) { - uint32_t npts, ndims; - - // amount to read or write in one shot - size_t read_blk_size = 64 * 1024 * 1024; - size_t write_blk_size = read_blk_size; - cached_ifstream base_reader(base_file, read_blk_size); - base_reader.read((char *)&npts, sizeof(uint32_t)); - base_reader.read((char *)&ndims, sizeof(uint32_t)); - - size_t npts_64, ndims_64; - npts_64 = npts; - ndims_64 = ndims; - - // Check if we need to append data for re-ordering - bool append_reorder_data = false; - std::ifstream reorder_data_reader; - - uint32_t npts_reorder_file = 0, ndims_reorder_file = 0; - if (reorder_data_file != std::string("")) { - append_reorder_data = true; - size_t reorder_data_file_size = get_file_size(reorder_data_file); - reorder_data_reader.exceptions(std::ofstream::failbit | - std::ofstream::badbit); - - try { - reorder_data_reader.open(reorder_data_file, std::ios::binary); - reorder_data_reader.read((char *)&npts_reorder_file, sizeof(uint32_t)); - reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(uint32_t)); - if (npts_reorder_file != npts) - throw ANNException("Mismatch in num_points between reorder " - "data file and base file", - -1, __FUNCSIG__, __FILE__, __LINE__); - if (reorder_data_file_size != 8 + sizeof(float) * - (size_t)npts_reorder_file * - (size_t)ndims_reorder_file) - throw ANNException("Discrepancy in reorder data file size ", -1, - __FUNCSIG__, __FILE__, __LINE__); - } catch (std::system_error &e) { - throw FileException(reorder_data_file, e, __FUNCSIG__, __FILE__, - __LINE__); +void create_disk_layout(const std::string base_file, const std::string mem_index_file, const std::string output_file, + const std::string reorder_data_file) +{ + uint32_t npts, ndims; + + // amount to read or write in one shot + size_t read_blk_size = 64 * 1024 * 1024; + size_t write_blk_size = read_blk_size; + cached_ifstream base_reader(base_file, read_blk_size); + base_reader.read((char *)&npts, sizeof(uint32_t)); + base_reader.read((char *)&ndims, sizeof(uint32_t)); + + size_t npts_64, ndims_64; + npts_64 = npts; + ndims_64 = ndims; + + // Check if we need to append data for re-ordering + bool append_reorder_data = false; + std::ifstream reorder_data_reader; + + uint32_t npts_reorder_file = 0, ndims_reorder_file = 0; + if (reorder_data_file != std::string("")) + { + append_reorder_data = true; + size_t reorder_data_file_size = get_file_size(reorder_data_file); + reorder_data_reader.exceptions(std::ofstream::failbit | std::ofstream::badbit); + + try + { + reorder_data_reader.open(reorder_data_file, std::ios::binary); + reorder_data_reader.read((char *)&npts_reorder_file, sizeof(uint32_t)); + reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(uint32_t)); + if (npts_reorder_file != npts) + throw ANNException("Mismatch in num_points between reorder " + "data file and base file", + -1, __FUNCSIG__, __FILE__, __LINE__); + if (reorder_data_file_size != 8 + sizeof(float) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file) + throw ANNException("Discrepancy in reorder data file size ", -1, __FUNCSIG__, __FILE__, __LINE__); + } + catch (std::system_error &e) + { + throw FileException(reorder_data_file, e, __FUNCSIG__, __FILE__, __LINE__); + } + } + + // create cached reader + writer + size_t actual_file_size = get_file_size(mem_index_file); + diskann::cout << "Vamana index file size=" << actual_file_size << std::endl; + std::ifstream vamana_reader(mem_index_file, std::ios::binary); + cached_ofstream diskann_writer(output_file, write_blk_size); + + // metadata: width, medoid + uint32_t width_u32, medoid_u32; + size_t index_file_size; + + vamana_reader.read((char *)&index_file_size, sizeof(uint64_t)); + if (index_file_size != actual_file_size) + { + std::stringstream stream; + stream << "Vamana Index file size does not match expected size per " + "meta-data." + << " file size from file: " << index_file_size << " actual file size: " << actual_file_size << std::endl; + + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - } - - // create cached reader + writer - size_t actual_file_size = get_file_size(mem_index_file); - diskann::cout << "Vamana index file size=" << actual_file_size << std::endl; - std::ifstream vamana_reader(mem_index_file, std::ios::binary); - cached_ofstream diskann_writer(output_file, write_blk_size); - - // metadata: width, medoid - uint32_t width_u32, medoid_u32; - size_t index_file_size; - - vamana_reader.read((char *)&index_file_size, sizeof(uint64_t)); - if (index_file_size != actual_file_size) { - std::stringstream stream; - stream << "Vamana Index file size does not match expected size per " - "meta-data." - << " file size from file: " << index_file_size - << " actual file size: " << actual_file_size << std::endl; - - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - uint64_t vamana_frozen_num = false, vamana_frozen_loc = 0; - - vamana_reader.read((char *)&width_u32, sizeof(uint32_t)); - vamana_reader.read((char *)&medoid_u32, sizeof(uint32_t)); - vamana_reader.read((char *)&vamana_frozen_num, sizeof(uint64_t)); - // compute - uint64_t medoid, max_node_len, nnodes_per_sector; - npts_64 = (uint64_t)npts; - medoid = (uint64_t)medoid_u32; - if (vamana_frozen_num == 1) - vamana_frozen_loc = medoid; - max_node_len = - (((uint64_t)width_u32 + 1) * sizeof(uint32_t)) + (ndims_64 * sizeof(T)); - nnodes_per_sector = - defaults::SECTOR_LEN / max_node_len; // 0 if max_node_len > SECTOR_LEN - - diskann::cout << "medoid: " << medoid << "B" << std::endl; - diskann::cout << "max_node_len: " << max_node_len << "B" << std::endl; - diskann::cout << "nnodes_per_sector: " << nnodes_per_sector << "B" - << std::endl; - - // defaults::SECTOR_LEN buffer for each sector - std::unique_ptr sector_buf = - std::make_unique(defaults::SECTOR_LEN); - std::unique_ptr multisector_buf = - std::make_unique(ROUND_UP(max_node_len, defaults::SECTOR_LEN)); - std::unique_ptr node_buf = std::make_unique(max_node_len); - uint32_t &nnbrs = *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)); - uint32_t *nhood_buf = - (uint32_t *)(node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(uint32_t)); - - // number of sectors (1 for meta data) - uint64_t n_sectors = - nnodes_per_sector > 0 - ? ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector - : npts_64 * DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN); - uint64_t n_reorder_sectors = 0; - uint64_t n_data_nodes_per_sector = 0; - - if (append_reorder_data) { - n_data_nodes_per_sector = - defaults::SECTOR_LEN / (ndims_reorder_file * sizeof(float)); - n_reorder_sectors = - ROUND_UP(npts_64, n_data_nodes_per_sector) / n_data_nodes_per_sector; - } - uint64_t disk_index_file_size = - (n_sectors + n_reorder_sectors + 1) * defaults::SECTOR_LEN; - - std::vector output_file_meta; - output_file_meta.push_back(npts_64); - output_file_meta.push_back(ndims_64); - output_file_meta.push_back(medoid); - output_file_meta.push_back(max_node_len); - output_file_meta.push_back(nnodes_per_sector); - output_file_meta.push_back(vamana_frozen_num); - output_file_meta.push_back(vamana_frozen_loc); - output_file_meta.push_back((uint64_t)append_reorder_data); - if (append_reorder_data) { - output_file_meta.push_back(n_sectors + 1); - output_file_meta.push_back(ndims_reorder_file); - output_file_meta.push_back(n_data_nodes_per_sector); - } - output_file_meta.push_back(disk_index_file_size); - - diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN); - - std::unique_ptr cur_node_coords = std::make_unique(ndims_64); - diskann::cout << "# sectors: " << n_sectors << std::endl; - uint64_t cur_node_id = 0; - - if (nnodes_per_sector > 0) { // Write multiple nodes per sector - for (uint64_t sector = 0; sector < n_sectors; sector++) { - if (sector % 100000 == 0) { - diskann::cout << "Sector #" << sector << "written" << std::endl; - } - memset(sector_buf.get(), 0, defaults::SECTOR_LEN); - for (uint64_t sector_node_id = 0; - sector_node_id < nnodes_per_sector && cur_node_id < npts_64; - sector_node_id++) { - memset(node_buf.get(), 0, max_node_len); - // read cur node's nnbrs - vamana_reader.read((char *)&nnbrs, sizeof(uint32_t)); - - // sanity checks on nnbrs - assert(nnbrs > 0); - assert(nnbrs <= width_u32); - - // read node's nhood - vamana_reader.read((char *)nhood_buf, - (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); - if (nnbrs > width_u32) { - vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), - vamana_reader.cur); + uint64_t vamana_frozen_num = false, vamana_frozen_loc = 0; + + vamana_reader.read((char *)&width_u32, sizeof(uint32_t)); + vamana_reader.read((char *)&medoid_u32, sizeof(uint32_t)); + vamana_reader.read((char *)&vamana_frozen_num, sizeof(uint64_t)); + // compute + uint64_t medoid, max_node_len, nnodes_per_sector; + npts_64 = (uint64_t)npts; + medoid = (uint64_t)medoid_u32; + if (vamana_frozen_num == 1) + vamana_frozen_loc = medoid; + max_node_len = (((uint64_t)width_u32 + 1) * sizeof(uint32_t)) + (ndims_64 * sizeof(T)); + nnodes_per_sector = defaults::SECTOR_LEN / max_node_len; // 0 if max_node_len > SECTOR_LEN + + diskann::cout << "medoid: " << medoid << "B" << std::endl; + diskann::cout << "max_node_len: " << max_node_len << "B" << std::endl; + diskann::cout << "nnodes_per_sector: " << nnodes_per_sector << "B" << std::endl; + + // defaults::SECTOR_LEN buffer for each sector + std::unique_ptr sector_buf = std::make_unique(defaults::SECTOR_LEN); + std::unique_ptr multisector_buf = std::make_unique(ROUND_UP(max_node_len, defaults::SECTOR_LEN)); + std::unique_ptr node_buf = std::make_unique(max_node_len); + uint32_t &nnbrs = *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)); + uint32_t *nhood_buf = (uint32_t *)(node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(uint32_t)); + + // number of sectors (1 for meta data) + uint64_t n_sectors = nnodes_per_sector > 0 ? ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector + : npts_64 * DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN); + uint64_t n_reorder_sectors = 0; + uint64_t n_data_nodes_per_sector = 0; + + if (append_reorder_data) + { + n_data_nodes_per_sector = defaults::SECTOR_LEN / (ndims_reorder_file * sizeof(float)); + n_reorder_sectors = ROUND_UP(npts_64, n_data_nodes_per_sector) / n_data_nodes_per_sector; + } + uint64_t disk_index_file_size = (n_sectors + n_reorder_sectors + 1) * defaults::SECTOR_LEN; + + std::vector output_file_meta; + output_file_meta.push_back(npts_64); + output_file_meta.push_back(ndims_64); + output_file_meta.push_back(medoid); + output_file_meta.push_back(max_node_len); + output_file_meta.push_back(nnodes_per_sector); + output_file_meta.push_back(vamana_frozen_num); + output_file_meta.push_back(vamana_frozen_loc); + output_file_meta.push_back((uint64_t)append_reorder_data); + if (append_reorder_data) + { + output_file_meta.push_back(n_sectors + 1); + output_file_meta.push_back(ndims_reorder_file); + output_file_meta.push_back(n_data_nodes_per_sector); + } + output_file_meta.push_back(disk_index_file_size); + + diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN); + + std::unique_ptr cur_node_coords = std::make_unique(ndims_64); + diskann::cout << "# sectors: " << n_sectors << std::endl; + uint64_t cur_node_id = 0; + + if (nnodes_per_sector > 0) + { // Write multiple nodes per sector + for (uint64_t sector = 0; sector < n_sectors; sector++) + { + if (sector % 100000 == 0) + { + diskann::cout << "Sector #" << sector << "written" << std::endl; + } + memset(sector_buf.get(), 0, defaults::SECTOR_LEN); + for (uint64_t sector_node_id = 0; sector_node_id < nnodes_per_sector && cur_node_id < npts_64; + sector_node_id++) + { + memset(node_buf.get(), 0, max_node_len); + // read cur node's nnbrs + vamana_reader.read((char *)&nnbrs, sizeof(uint32_t)); + + // sanity checks on nnbrs + assert(nnbrs > 0); + assert(nnbrs <= width_u32); + + // read node's nhood + vamana_reader.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); + if (nnbrs > width_u32) + { + vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_reader.cur); + } + + // write coords of node first + // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id); + base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64); + memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T)); + + // write nnbrs + *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32); + + // write nhood next + memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf, + (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); + + // get offset into sector_buf + char *sector_node_buf = sector_buf.get() + (sector_node_id * max_node_len); + + // copy node buf into sector_node_buf + memcpy(sector_node_buf, node_buf.get(), max_node_len); + cur_node_id++; + } + // flush sector to disk + diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN); + } + } + else + { // Write multi-sector nodes + uint64_t nsectors_per_node = DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN); + for (uint64_t i = 0; i < npts_64; i++) + { + if ((i * nsectors_per_node) % 100000 == 0) + { + diskann::cout << "Sector #" << i * nsectors_per_node << "written" << std::endl; + } + memset(multisector_buf.get(), 0, nsectors_per_node * defaults::SECTOR_LEN); + + memset(node_buf.get(), 0, max_node_len); + // read cur node's nnbrs + vamana_reader.read((char *)&nnbrs, sizeof(uint32_t)); + + // sanity checks on nnbrs + assert(nnbrs > 0); + assert(nnbrs <= width_u32); + + // read node's nhood + vamana_reader.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); + if (nnbrs > width_u32) + { + vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_reader.cur); + } + + // write coords of node first + // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id); + base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64); + memcpy(multisector_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T)); + + // write nnbrs + *(uint32_t *)(multisector_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32); + + // write nhood next + memcpy(multisector_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf, + (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); + + // flush sector to disk + diskann_writer.write(multisector_buf.get(), nsectors_per_node * defaults::SECTOR_LEN); } + } + + if (append_reorder_data) + { + diskann::cout << "Index written. Appending reorder data..." << std::endl; + + auto vec_len = ndims_reorder_file * sizeof(float); + std::unique_ptr vec_buf = std::make_unique(vec_len); + + for (uint64_t sector = 0; sector < n_reorder_sectors; sector++) + { + if (sector % 100000 == 0) + { + diskann::cout << "Reorder data Sector #" << sector << "written" << std::endl; + } + + memset(sector_buf.get(), 0, defaults::SECTOR_LEN); + + for (uint64_t sector_node_id = 0; sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64; + sector_node_id++) + { + memset(vec_buf.get(), 0, vec_len); + reorder_data_reader.read(vec_buf.get(), vec_len); + + // copy node buf into sector_node_buf + memcpy(sector_buf.get() + (sector_node_id * vec_len), vec_buf.get(), vec_len); + } + // flush sector to disk + diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN); + } + } + diskann_writer.close(); + diskann::save_bin(output_file, output_file_meta.data(), output_file_meta.size(), 1, 0); + diskann::cout << "Output disk index file written to " << output_file << std::endl; +} - // write coords of node first - // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id); - base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64); - memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T)); - - // write nnbrs - *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)) = - (std::min)(nnbrs, width_u32); - - // write nhood next - memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), - nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); - - // get offset into sector_buf - char *sector_node_buf = - sector_buf.get() + (sector_node_id * max_node_len); - - // copy node buf into sector_node_buf - memcpy(sector_node_buf, node_buf.get(), max_node_len); - cur_node_id++; - } - // flush sector to disk - diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN); +template +int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, + const uint32_t Lf) +{ + std::stringstream parser; + parser << std::string(indexBuildParameters); + std::string cur_param; + std::vector param_list; + while (parser >> cur_param) + { + param_list.push_back(cur_param); } - } else { // Write multi-sector nodes - uint64_t nsectors_per_node = - DIV_ROUND_UP(max_node_len, defaults::SECTOR_LEN); - for (uint64_t i = 0; i < npts_64; i++) { - if ((i * nsectors_per_node) % 100000 == 0) { - diskann::cout << "Sector #" << i * nsectors_per_node << "written" + if (param_list.size() < 5 || param_list.size() > 9) + { + diskann::cout << "Correct usage of parameters is R (max degree)\n" + "L (indexing list size, better if >= R)\n" + "B (RAM limit of final index in GB)\n" + "M (memory limit while indexing)\n" + "T (number of threads for indexing)\n" + "B' (PQ bytes for disk index: optional parameter for " + "very large dimensional data)\n" + "reorder (set true to include full precision in data file" + ": optional paramter, use only when using disk PQ\n" + "build_PQ_byte (number of PQ bytes for inde build; set 0 to use " + "full precision vectors)\n" + "QD Quantized Dimension to overwrite the derived dim from B " << std::endl; - } - memset(multisector_buf.get(), 0, - nsectors_per_node * defaults::SECTOR_LEN); - - memset(node_buf.get(), 0, max_node_len); - // read cur node's nnbrs - vamana_reader.read((char *)&nnbrs, sizeof(uint32_t)); - - // sanity checks on nnbrs - assert(nnbrs > 0); - assert(nnbrs <= width_u32); - - // read node's nhood - vamana_reader.read((char *)nhood_buf, - (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); - if (nnbrs > width_u32) { - vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), - vamana_reader.cur); - } - - // write coords of node first - // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id); - base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64); - memcpy(multisector_buf.get(), cur_node_coords.get(), - ndims_64 * sizeof(T)); - - // write nnbrs - *(uint32_t *)(multisector_buf.get() + ndims_64 * sizeof(T)) = - (std::min)(nnbrs, width_u32); - - // write nhood next - memcpy(multisector_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), - nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); - - // flush sector to disk - diskann_writer.write(multisector_buf.get(), - nsectors_per_node * defaults::SECTOR_LEN); + return -1; } - } - if (append_reorder_data) { - diskann::cout << "Index written. Appending reorder data..." << std::endl; + if (!std::is_same::value && + (compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE)) + { + std::stringstream stream; + stream << "Disk-index build currently only supports floating point data " + "for Max " + "Inner Product Search/ cosine similarity. " + << std::endl; + throw diskann::ANNException(stream.str(), -1); + } - auto vec_len = ndims_reorder_file * sizeof(float); - std::unique_ptr vec_buf = std::make_unique(vec_len); + size_t disk_pq_dims = 0; + bool use_disk_pq = false; + size_t build_pq_bytes = 0; + + // if there is a 6th parameter, it means we compress the disk index + // vectors also using PQ data (for very large dimensionality data). If the + // provided parameter is 0, it means we store full vectors. + if (param_list.size() > 5) + { + disk_pq_dims = atoi(param_list[5].c_str()); + use_disk_pq = true; + if (disk_pq_dims == 0) + use_disk_pq = false; + } - for (uint64_t sector = 0; sector < n_reorder_sectors; sector++) { - if (sector % 100000 == 0) { - diskann::cout << "Reorder data Sector #" << sector << "written" - << std::endl; - } - - memset(sector_buf.get(), 0, defaults::SECTOR_LEN); - - for (uint64_t sector_node_id = 0; - sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64; - sector_node_id++) { - memset(vec_buf.get(), 0, vec_len); - reorder_data_reader.read(vec_buf.get(), vec_len); - - // copy node buf into sector_node_buf - memcpy(sector_buf.get() + (sector_node_id * vec_len), vec_buf.get(), - vec_len); - } - // flush sector to disk - diskann_writer.write(sector_buf.get(), defaults::SECTOR_LEN); + bool reorder_data = false; + if (param_list.size() >= 7) + { + if (1 == atoi(param_list[6].c_str())) + { + reorder_data = true; + } } - } - diskann_writer.close(); - diskann::save_bin(output_file, output_file_meta.data(), - output_file_meta.size(), 1, 0); - diskann::cout << "Output disk index file written to " << output_file - << std::endl; -} -template -int build_disk_index(const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, - const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, - const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf) { - std::stringstream parser; - parser << std::string(indexBuildParameters); - std::string cur_param; - std::vector param_list; - while (parser >> cur_param) { - param_list.push_back(cur_param); - } - if (param_list.size() < 5 || param_list.size() > 9) { - diskann::cout - << "Correct usage of parameters is R (max degree)\n" - "L (indexing list size, better if >= R)\n" - "B (RAM limit of final index in GB)\n" - "M (memory limit while indexing)\n" - "T (number of threads for indexing)\n" - "B' (PQ bytes for disk index: optional parameter for " - "very large dimensional data)\n" - "reorder (set true to include full precision in data file" - ": optional paramter, use only when using disk PQ\n" - "build_PQ_byte (number of PQ bytes for inde build; set 0 to use " - "full precision vectors)\n" - "QD Quantized Dimension to overwrite the derived dim from B " - << std::endl; - return -1; - } - - if (!std::is_same::value && - (compareMetric == diskann::Metric::INNER_PRODUCT || - compareMetric == diskann::Metric::COSINE)) { - std::stringstream stream; - stream << "Disk-index build currently only supports floating point data " - "for Max " - "Inner Product Search/ cosine similarity. " - << std::endl; - throw diskann::ANNException(stream.str(), -1); - } - - size_t disk_pq_dims = 0; - bool use_disk_pq = false; - size_t build_pq_bytes = 0; - - // if there is a 6th parameter, it means we compress the disk index - // vectors also using PQ data (for very large dimensionality data). If the - // provided parameter is 0, it means we store full vectors. - if (param_list.size() > 5) { - disk_pq_dims = atoi(param_list[5].c_str()); - use_disk_pq = true; - if (disk_pq_dims == 0) - use_disk_pq = false; - } - - bool reorder_data = false; - if (param_list.size() >= 7) { - if (1 == atoi(param_list[6].c_str())) { - reorder_data = true; + if (param_list.size() >= 8) + { + build_pq_bytes = atoi(param_list[7].c_str()); } - } - - if (param_list.size() >= 8) { - build_pq_bytes = atoi(param_list[7].c_str()); - } - - std::string base_file(dataFilePath); - std::string data_file_to_use = base_file; - std::string labels_file_original = label_file; - std::string index_prefix_path(indexFilePath); - std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt"; - std::string pq_pivots_path_base = codebook_prefix; - std::string pq_pivots_path = file_exists(pq_pivots_path_base) - ? pq_pivots_path_base + "_pq_pivots.bin" - : index_prefix_path + "_pq_pivots.bin"; - std::string pq_compressed_vectors_path = - index_prefix_path + "_pq_compressed.bin"; - std::string mem_index_path = index_prefix_path + "_mem.index"; - std::string disk_index_path = index_prefix_path + "_disk.index"; - std::string medoids_path = disk_index_path + "_medoids.bin"; - std::string centroids_path = disk_index_path + "_centroids.bin"; - - std::string labels_to_medoids_path = - disk_index_path + "_labels_to_medoids.txt"; - std::string mem_labels_file = mem_index_path + "_labels.txt"; - std::string disk_labels_file = disk_index_path + "_labels.txt"; - std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; - std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; - std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; - std::string dummy_remap_file = - disk_index_path + - "_dummy_map.txt"; // remap will be used if we break-up points of - // high label-density to create copies - - std::string sample_base_prefix = index_prefix_path + "_sample"; - // optional, used if disk index file must store pq data - std::string disk_pq_pivots_path = - index_prefix_path + "_disk.index_pq_pivots.bin"; - // optional, used if disk index must store pq data - std::string disk_pq_compressed_vectors_path = - index_prefix_path + "_disk.index_pq_compressed.bin"; - std::string prepped_base = - index_prefix_path + - "_prepped_base.bin"; // temp file for storing pre-processed base file for - // cosine/ mips metrics - bool created_temp_file_for_processed_data = false; - - // output a new base file which contains extra dimension with sqrt(1 - - // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on - // disk needed! - if (compareMetric == diskann::Metric::INNER_PRODUCT) { - Timer timer; - std::cout << "Using Inner Product search, so need to pre-process base " - "data into temp file. Please ensure there is additional " - "(n*(d+1)*4) bytes for storing pre-processed base vectors, " - "apart from the interim indices created by DiskANN and the " - "final index." - << std::endl; - data_file_to_use = prepped_base; - float max_norm_of_base = - diskann::prepare_base_for_inner_products(base_file, prepped_base); - std::string norm_file = disk_index_path + "_max_base_norm.bin"; - diskann::save_bin(norm_file, &max_norm_of_base, 1, 1); - diskann::cout << timer.elapsed_seconds_for_step( - "preprocessing data for inner product") + + std::string base_file(dataFilePath); + std::string data_file_to_use = base_file; + std::string labels_file_original = label_file; + std::string index_prefix_path(indexFilePath); + std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt"; + std::string pq_pivots_path_base = codebook_prefix; + std::string pq_pivots_path = file_exists(pq_pivots_path_base) ? pq_pivots_path_base + "_pq_pivots.bin" + : index_prefix_path + "_pq_pivots.bin"; + std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin"; + std::string mem_index_path = index_prefix_path + "_mem.index"; + std::string disk_index_path = index_prefix_path + "_disk.index"; + std::string medoids_path = disk_index_path + "_medoids.bin"; + std::string centroids_path = disk_index_path + "_centroids.bin"; + + std::string labels_to_medoids_path = disk_index_path + "_labels_to_medoids.txt"; + std::string mem_labels_file = mem_index_path + "_labels.txt"; + std::string disk_labels_file = disk_index_path + "_labels.txt"; + std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; + std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; + std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; + std::string dummy_remap_file = disk_index_path + "_dummy_map.txt"; // remap will be used if we break-up points of + // high label-density to create copies + + std::string sample_base_prefix = index_prefix_path + "_sample"; + // optional, used if disk index file must store pq data + std::string disk_pq_pivots_path = index_prefix_path + "_disk.index_pq_pivots.bin"; + // optional, used if disk index must store pq data + std::string disk_pq_compressed_vectors_path = index_prefix_path + "_disk.index_pq_compressed.bin"; + std::string prepped_base = index_prefix_path + "_prepped_base.bin"; // temp file for storing pre-processed base file + // for cosine/ mips metrics + bool created_temp_file_for_processed_data = false; + + // output a new base file which contains extra dimension with sqrt(1 - + // ||x||^2/M^2) for every x, M is max norm of all points. Extra space on + // disk needed! + if (compareMetric == diskann::Metric::INNER_PRODUCT) + { + Timer timer; + std::cout << "Using Inner Product search, so need to pre-process base " + "data into temp file. Please ensure there is additional " + "(n*(d+1)*4) bytes for storing pre-processed base vectors, " + "apart from the interim indices created by DiskANN and the " + "final index." + << std::endl; + data_file_to_use = prepped_base; + float max_norm_of_base = diskann::prepare_base_for_inner_products(base_file, prepped_base); + std::string norm_file = disk_index_path + "_max_base_norm.bin"; + diskann::save_bin(norm_file, &max_norm_of_base, 1, 1); + diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for inner product") << std::endl; + created_temp_file_for_processed_data = true; + } + else if (compareMetric == diskann::Metric::COSINE) + { + Timer timer; + std::cout << "Normalizing data for cosine to temporary file, please ensure " + "there is additional " + "(n*d*4) bytes for storing normalized base vectors, " + "apart from the interim indices created by DiskANN and the " + "final index." << std::endl; - created_temp_file_for_processed_data = true; - } else if (compareMetric == diskann::Metric::COSINE) { + data_file_to_use = prepped_base; + diskann::normalize_data_file(base_file, prepped_base); + diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for cosine") << std::endl; + created_temp_file_for_processed_data = true; + } + + uint32_t R = (uint32_t)atoi(param_list[0].c_str()); + uint32_t L = (uint32_t)atoi(param_list[1].c_str()); + + double final_index_ram_limit = get_memory_budget(param_list[2]); + if (final_index_ram_limit <= 0) + { + std::cerr << "Insufficient memory budget (or string was not in right " + "format). Should be > 0." + << std::endl; + return -1; + } + double indexing_ram_budget = (float)atof(param_list[3].c_str()); + if (indexing_ram_budget <= 0) + { + std::cerr << "Not building index. Please provide more RAM budget" << std::endl; + return -1; + } + uint32_t num_threads = (uint32_t)atoi(param_list[4].c_str()); + + if (num_threads != 0) + { + omp_set_num_threads(num_threads); + mkl_set_num_threads(num_threads); + } + + diskann::cout << "Starting index build: R=" << R << " L=" << L << " Query RAM budget: " << final_index_ram_limit + << " Indexing ram budget: " << indexing_ram_budget << " T: " << num_threads << std::endl; + + auto s = std::chrono::high_resolution_clock::now(); + + // If there is filter support, we break-up points which have too many labels + // into replica dummy points which evenly distribute the filters. The rest + // of index build happens on the augmented base and labels + std::string augmented_data_file, augmented_labels_file; + if (use_filters) + { + convert_labels_string_to_int(labels_file_original, labels_file_to_use, disk_labels_int_map_file, + universal_label); + augmented_data_file = index_prefix_path + "_augmented_data.bin"; + augmented_labels_file = index_prefix_path + "_augmented_labels.txt"; + if (filter_threshold != 0) + { + breakup_dense_points(data_file_to_use, labels_file_to_use, filter_threshold, augmented_data_file, + augmented_labels_file, + dummy_remap_file); // RKNOTE: This has large memory footprint, + // need to make this streaming + data_file_to_use = augmented_data_file; + labels_file_to_use = augmented_labels_file; + } + } + + size_t points_num, dim; + Timer timer; - std::cout << "Normalizing data for cosine to temporary file, please ensure " - "there is additional " - "(n*d*4) bytes for storing normalized base vectors, " - "apart from the interim indices created by DiskANN and the " - "final index." - << std::endl; - data_file_to_use = prepped_base; - diskann::normalize_data_file(base_file, prepped_base); - diskann::cout << timer.elapsed_seconds_for_step( - "preprocessing data for cosine") + diskann::get_bin_metadata(data_file_to_use.c_str(), points_num, dim); + const double p_val = ((double)MAX_PQ_TRAINING_SET_SIZE / (double)points_num); + + if (use_disk_pq) + { + generate_disk_quantized_data(data_file_to_use, disk_pq_pivots_path, disk_pq_compressed_vectors_path, + compareMetric, p_val, disk_pq_dims); + } + size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / points_num)); + + num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks; + num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks; + num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks; + + if (param_list.size() >= 9 && atoi(param_list[8].c_str()) <= MAX_PQ_CHUNKS && atoi(param_list[8].c_str()) > 0) + { + std::cout << "Use quantized dimension (QD) to overwrite derived quantized " + "dimension from search_DRAM_budget (B)" << std::endl; - created_temp_file_for_processed_data = true; - } - - uint32_t R = (uint32_t)atoi(param_list[0].c_str()); - uint32_t L = (uint32_t)atoi(param_list[1].c_str()); - - double final_index_ram_limit = get_memory_budget(param_list[2]); - if (final_index_ram_limit <= 0) { - std::cerr << "Insufficient memory budget (or string was not in right " - "format). Should be > 0." - << std::endl; - return -1; - } - double indexing_ram_budget = (float)atof(param_list[3].c_str()); - if (indexing_ram_budget <= 0) { - std::cerr << "Not building index. Please provide more RAM budget" - << std::endl; - return -1; - } - uint32_t num_threads = (uint32_t)atoi(param_list[4].c_str()); - - if (num_threads != 0) { - omp_set_num_threads(num_threads); - mkl_set_num_threads(num_threads); - } - - diskann::cout << "Starting index build: R=" << R << " L=" << L - << " Query RAM budget: " << final_index_ram_limit - << " Indexing ram budget: " << indexing_ram_budget - << " T: " << num_threads << std::endl; - - auto s = std::chrono::high_resolution_clock::now(); - - // If there is filter support, we break-up points which have too many labels - // into replica dummy points which evenly distribute the filters. The rest - // of index build happens on the augmented base and labels - std::string augmented_data_file, augmented_labels_file; - if (use_filters) { - convert_labels_string_to_int(labels_file_original, labels_file_to_use, - disk_labels_int_map_file, universal_label); - augmented_data_file = index_prefix_path + "_augmented_data.bin"; - augmented_labels_file = index_prefix_path + "_augmented_labels.txt"; - if (filter_threshold != 0) { - breakup_dense_points( - data_file_to_use, labels_file_to_use, filter_threshold, - augmented_data_file, augmented_labels_file, - dummy_remap_file); // RKNOTE: This has large memory footprint, - // need to make this streaming - data_file_to_use = augmented_data_file; - labels_file_to_use = augmented_labels_file; + num_pq_chunks = atoi(param_list[8].c_str()); } - } - - size_t points_num, dim; - - Timer timer; - diskann::get_bin_metadata(data_file_to_use.c_str(), points_num, dim); - const double p_val = ((double)MAX_PQ_TRAINING_SET_SIZE / (double)points_num); - - if (use_disk_pq) { - generate_disk_quantized_data(data_file_to_use, disk_pq_pivots_path, - disk_pq_compressed_vectors_path, - compareMetric, p_val, disk_pq_dims); - } - size_t num_pq_chunks = - (size_t)(std::floor)(uint64_t(final_index_ram_limit / points_num)); - - num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks; - num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks; - num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks; - - if (param_list.size() >= 9 && atoi(param_list[8].c_str()) <= MAX_PQ_CHUNKS && - atoi(param_list[8].c_str()) > 0) { - std::cout << "Use quantized dimension (QD) to overwrite derived quantized " - "dimension from search_DRAM_budget (B)" - << std::endl; - num_pq_chunks = atoi(param_list[8].c_str()); - } - - diskann::cout << "Compressing " << dim << "-dimensional data into " - << num_pq_chunks << " bytes per vector." << std::endl; - - generate_quantized_data(data_file_to_use, pq_pivots_path, - pq_compressed_vectors_path, compareMetric, p_val, - num_pq_chunks, use_opq, codebook_prefix); - diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") - << std::endl; + + diskann::cout << "Compressing " << dim << "-dimensional data into " << num_pq_chunks << " bytes per vector." + << std::endl; + + generate_quantized_data(data_file_to_use, pq_pivots_path, pq_compressed_vectors_path, compareMetric, p_val, + num_pq_chunks, use_opq, codebook_prefix); + diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") << std::endl; // Gopal. Splitting diskann_dll into separate DLLs for search and build. // This code should only be available in the "build" DLL. -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) - MallocExtension::instance()->ReleaseFreeMemory(); +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) + MallocExtension::instance()->ReleaseFreeMemory(); #endif - // Whether it is cosine or inner product, we still L2 metric due to the - // pre-processing. - timer.reset(); - diskann::build_merged_vamana_index( - data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, - indexing_ram_budget, mem_index_path, medoids_path, centroids_path, - build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use, - labels_to_medoids_path, universal_label, Lf); - diskann::cout << timer.elapsed_seconds_for_step( - "building merged vamana index") - << std::endl; - - timer.reset(); - if (!use_disk_pq) { - diskann::create_disk_layout(data_file_to_use.c_str(), mem_index_path, - disk_index_path); - } else { - if (!reorder_data) - diskann::create_disk_layout(disk_pq_compressed_vectors_path, - mem_index_path, disk_index_path); + // Whether it is cosine or inner product, we still L2 metric due to the + // pre-processing. + timer.reset(); + diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, + indexing_ram_budget, mem_index_path, medoids_path, centroids_path, + build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use, + labels_to_medoids_path, universal_label, Lf); + diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl; + + timer.reset(); + if (!use_disk_pq) + { + diskann::create_disk_layout(data_file_to_use.c_str(), mem_index_path, disk_index_path); + } else - diskann::create_disk_layout(disk_pq_compressed_vectors_path, - mem_index_path, disk_index_path, - data_file_to_use.c_str()); - } - diskann::cout << timer.elapsed_seconds_for_step("generating disk layout") - << std::endl; - - double ten_percent_points = std::ceil(points_num * 0.1); - double num_sample_points = ten_percent_points > MAX_SAMPLE_POINTS_FOR_WARMUP - ? MAX_SAMPLE_POINTS_FOR_WARMUP - : ten_percent_points; - double sample_sampling_rate = num_sample_points / points_num; - gen_random_slice(data_file_to_use.c_str(), sample_base_prefix, - sample_sampling_rate); - if (use_filters) { - copy_file(labels_file_to_use, disk_labels_file); - std::remove(mem_labels_file.c_str()); - if (universal_label != "") { - copy_file(mem_univ_label_file, disk_univ_label_file); - std::remove(mem_univ_label_file.c_str()); + { + if (!reorder_data) + diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path); + else + diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path, + data_file_to_use.c_str()); + } + diskann::cout << timer.elapsed_seconds_for_step("generating disk layout") << std::endl; + + double ten_percent_points = std::ceil(points_num * 0.1); + double num_sample_points = + ten_percent_points > MAX_SAMPLE_POINTS_FOR_WARMUP ? MAX_SAMPLE_POINTS_FOR_WARMUP : ten_percent_points; + double sample_sampling_rate = num_sample_points / points_num; + gen_random_slice(data_file_to_use.c_str(), sample_base_prefix, sample_sampling_rate); + if (use_filters) + { + copy_file(labels_file_to_use, disk_labels_file); + std::remove(mem_labels_file.c_str()); + if (universal_label != "") + { + copy_file(mem_univ_label_file, disk_univ_label_file); + std::remove(mem_univ_label_file.c_str()); + } + std::remove(augmented_data_file.c_str()); + std::remove(augmented_labels_file.c_str()); + std::remove(labels_file_to_use.c_str()); } - std::remove(augmented_data_file.c_str()); - std::remove(augmented_labels_file.c_str()); - std::remove(labels_file_to_use.c_str()); - } - if (created_temp_file_for_processed_data) - std::remove(prepped_base.c_str()); - std::remove(mem_index_path.c_str()); - std::remove((mem_index_path + ".data").c_str()); - std::remove((mem_index_path + ".tags").c_str()); - if (use_disk_pq) - std::remove(disk_pq_compressed_vectors_path.c_str()); - - auto e = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = e - s; - diskann::cout << "Indexing time: " << diff.count() << std::endl; - - return 0; + if (created_temp_file_for_processed_data) + std::remove(prepped_base.c_str()); + std::remove(mem_index_path.c_str()); + std::remove((mem_index_path + ".data").c_str()); + std::remove((mem_index_path + ".tags").c_str()); + if (use_disk_pq) + std::remove(disk_pq_compressed_vectors_path.c_str()); + + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + diskann::cout << "Indexing time: " << diff.count() << std::endl; + + return 0; } -template DISKANN_DLLEXPORT void create_disk_layout( - const std::string base_file, const std::string mem_index_file, - const std::string output_file, const std::string reorder_data_file); -template DISKANN_DLLEXPORT void create_disk_layout( - const std::string base_file, const std::string mem_index_file, - const std::string output_file, const std::string reorder_data_file); -template DISKANN_DLLEXPORT void create_disk_layout( - const std::string base_file, const std::string mem_index_file, - const std::string output_file, const std::string reorder_data_file); - -template DISKANN_DLLEXPORT int8_t * -load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim); -template DISKANN_DLLEXPORT uint8_t * -load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim); -template DISKANN_DLLEXPORT float * -load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, + const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, + const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); + +template DISKANN_DLLEXPORT int8_t *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, + uint64_t warmup_dim, uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT uint8_t *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, + uint64_t warmup_dim, uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT float *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, + uint64_t warmup_dim, uint64_t warmup_aligned_dim); #ifdef EXEC_ENV_OLS -template DISKANN_DLLEXPORT int8_t * -load_warmup(MemoryMappedFiles &files, - const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim); -template DISKANN_DLLEXPORT uint8_t * -load_warmup(MemoryMappedFiles &files, - const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim); -template DISKANN_DLLEXPORT float * -load_warmup(MemoryMappedFiles &files, - const std::string &cache_warmup_file, uint64_t &warmup_num, - uint64_t warmup_dim, uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT int8_t *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, + uint64_t &warmup_num, uint64_t warmup_dim, + uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT uint8_t *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, + uint64_t &warmup_num, uint64_t warmup_dim, + uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT float *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, + uint64_t &warmup_num, uint64_t warmup_dim, + uint64_t warmup_aligned_dim); #endif template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - int8_t *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw); + std::unique_ptr> &pFlashIndex, int8_t *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - uint8_t *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw); + std::unique_ptr> &pFlashIndex, uint8_t *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - float *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw); + std::unique_ptr> &pFlashIndex, float *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - int8_t *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw); + std::unique_ptr> &pFlashIndex, int8_t *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - uint8_t *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw); + std::unique_ptr> &pFlashIndex, uint8_t *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, - float *tuning_sample, uint64_t tuning_sample_num, - uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw); - -template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + std::unique_ptr> &pFlashIndex, float *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + +template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf); // LabelT = uint16 -template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); -template DISKANN_DLLEXPORT int build_disk_index( - const char *dataFilePath, const char *indexFilePath, - const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, + diskann::Metric compareMetric, bool use_opq, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, + const std::string &universal_label, + const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf); + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, + double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf); + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, + double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf); + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, + double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); // Label=16_t template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf); + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, + double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf); + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, + double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, uint32_t L, - uint32_t R, double sampling_rate, double ram_budget, - std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, - const std::string &universal_label, const uint32_t Lf); + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, + double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); }; // namespace diskann diff --git a/src/distance.cpp b/src/distance.cpp index fc4e43a75..957453ab8 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -19,662 +19,699 @@ #include "logger.h" #include "utils.h" -namespace diskann { +namespace diskann +{ // // Base Class Implementatons // template -float Distance::compare(const T *a, const T *b, const float normA, - const float normB, uint32_t length) const { - throw std::logic_error("This function is not implemented."); +float Distance::compare(const T *a, const T *b, const float normA, const float normB, uint32_t length) const +{ + throw std::logic_error("This function is not implemented."); } -template -uint32_t -Distance::post_normalization_dimension(uint32_t orig_dimension) const { - return orig_dimension; +template uint32_t Distance::post_normalization_dimension(uint32_t orig_dimension) const +{ + return orig_dimension; } -template diskann::Metric Distance::get_metric() const { - return _distance_metric; +template diskann::Metric Distance::get_metric() const +{ + return _distance_metric; } -template bool Distance::preprocessing_required() const { - return false; +template bool Distance::preprocessing_required() const +{ + return false; } template -void Distance::preprocess_base_points(T *original_data, - const size_t orig_dim, - const size_t num_points) {} +void Distance::preprocess_base_points(T *original_data, const size_t orig_dim, const size_t num_points) +{ +} -template -void Distance::preprocess_query(const T *query_vec, const size_t query_dim, - T *scratch_query) { - std::memcpy(scratch_query, query_vec, query_dim * sizeof(T)); +template void Distance::preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query) +{ + std::memcpy(scratch_query, query_vec, query_dim * sizeof(T)); } -template size_t Distance::get_required_alignment() const { - return _alignment_factor; +template size_t Distance::get_required_alignment() const +{ + return _alignment_factor; } // // Cosine distance functions. // -float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, - uint32_t length) const { +float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t length) const +{ #ifdef _WINDOWS - return diskann::CosineSimilarity2(a, b, length); + return diskann::CosineSimilarity2(a, b, length); #else - int magA = 0, magB = 0, scalarProduct = 0; - for (uint32_t i = 0; i < length; i++) { - magA += ((int32_t)a[i]) * ((int32_t)a[i]); - magB += ((int32_t)b[i]) * ((int32_t)b[i]); - scalarProduct += ((int32_t)a[i]) * ((int32_t)b[i]); - } - // similarity == 1-cosine distance - return 1.0f - (float)(scalarProduct / (sqrt(magA) * sqrt(magB))); + int magA = 0, magB = 0, scalarProduct = 0; + for (uint32_t i = 0; i < length; i++) + { + magA += ((int32_t)a[i]) * ((int32_t)a[i]); + magB += ((int32_t)b[i]) * ((int32_t)b[i]); + scalarProduct += ((int32_t)a[i]) * ((int32_t)b[i]); + } + // similarity == 1-cosine distance + return 1.0f - (float)(scalarProduct / (sqrt(magA) * sqrt(magB))); #endif } -float DistanceCosineFloat::compare(const float *a, const float *b, - uint32_t length) const { +float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t length) const +{ #ifdef _WINDOWS - return diskann::CosineSimilarity2(a, b, length); + return diskann::CosineSimilarity2(a, b, length); #else - float magA = 0, magB = 0, scalarProduct = 0; - for (uint32_t i = 0; i < length; i++) { - magA += (a[i]) * (a[i]); - magB += (b[i]) * (b[i]); - scalarProduct += (a[i]) * (b[i]); - } - // similarity == 1-cosine distance - return 1.0f - (scalarProduct / (sqrt(magA) * sqrt(magB))); + float magA = 0, magB = 0, scalarProduct = 0; + for (uint32_t i = 0; i < length; i++) + { + magA += (a[i]) * (a[i]); + magB += (b[i]) * (b[i]); + scalarProduct += (a[i]) * (b[i]); + } + // similarity == 1-cosine distance + return 1.0f - (scalarProduct / (sqrt(magA) * sqrt(magB))); #endif } -float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, - uint32_t length) const { - int magA = 0, magB = 0, scalarProduct = 0; - for (uint32_t i = 0; i < length; i++) { - magA += ((uint32_t)a[i]) * ((uint32_t)a[i]); - magB += ((uint32_t)b[i]) * ((uint32_t)b[i]); - scalarProduct += ((uint32_t)a[i]) * ((uint32_t)b[i]); - } - // similarity == 1-cosine distance - return 1.0f - (float)(scalarProduct / (sqrt(magA) * sqrt(magB))); +float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t length) const +{ + int magA = 0, magB = 0, scalarProduct = 0; + for (uint32_t i = 0; i < length; i++) + { + magA += ((uint32_t)a[i]) * ((uint32_t)a[i]); + magB += ((uint32_t)b[i]) * ((uint32_t)b[i]); + scalarProduct += ((uint32_t)a[i]) * ((uint32_t)b[i]); + } + // similarity == 1-cosine distance + return 1.0f - (float)(scalarProduct / (sqrt(magA) * sqrt(magB))); } // // L2 distance functions. // -float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, - uint32_t size) const { +float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) const +{ #ifdef _WINDOWS #ifdef USE_AVX2 - __m256 r = _mm256_setzero_ps(); - char *pX = (char *)a, *pY = (char *)b; - while (size >= 32) { - __m256i r1 = _mm256_subs_epi8(_mm256_loadu_si256((__m256i *)pX), - _mm256_loadu_si256((__m256i *)pY)); - r = _mm256_add_ps(r, _mm256_mul_epi8(r1, r1)); - pX += 32; - pY += 32; - size -= 32; - } - while (size > 0) { - __m128i r2 = _mm_subs_epi8(_mm_loadu_si128((__m128i *)pX), - _mm_loadu_si128((__m128i *)pY)); - r = _mm256_add_ps(r, _mm256_mul32_pi8(r2, r2)); - pX += 4; - pY += 4; - size -= 4; - } - r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r); - return r.m256_f32[0] + r.m256_f32[4]; + __m256 r = _mm256_setzero_ps(); + char *pX = (char *)a, *pY = (char *)b; + while (size >= 32) + { + __m256i r1 = _mm256_subs_epi8(_mm256_loadu_si256((__m256i *)pX), _mm256_loadu_si256((__m256i *)pY)); + r = _mm256_add_ps(r, _mm256_mul_epi8(r1, r1)); + pX += 32; + pY += 32; + size -= 32; + } + while (size > 0) + { + __m128i r2 = _mm_subs_epi8(_mm_loadu_si128((__m128i *)pX), _mm_loadu_si128((__m128i *)pY)); + r = _mm256_add_ps(r, _mm256_mul32_pi8(r2, r2)); + pX += 4; + pY += 4; + size -= 4; + } + r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r); + return r.m256_f32[0] + r.m256_f32[4]; #else - int32_t result = 0; + int32_t result = 0; #pragma omp simd reduction(+ : result) aligned(a, b : 8) - for (int32_t i = 0; i < (int32_t)size; i++) { - result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * - ((int32_t)((int16_t)a[i] - (int16_t)b[i])); - } - return (float)result; + for (int32_t i = 0; i < (int32_t)size; i++) + { + result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i])); + } + return (float)result; #endif #else - int32_t result = 0; + int32_t result = 0; #pragma omp simd reduction(+ : result) aligned(a, b : 8) - for (int32_t i = 0; i < (int32_t)size; i++) { - result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * - ((int32_t)((int16_t)a[i] - (int16_t)b[i])); - } - return (float)result; + for (int32_t i = 0; i < (int32_t)size; i++) + { + result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i])); + } + return (float)result; #endif } -float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, - uint32_t size) const { - uint32_t result = 0; +float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size) const +{ + uint32_t result = 0; #ifndef _WINDOWS #pragma omp simd reduction(+ : result) aligned(a, b : 8) #endif - for (int32_t i = 0; i < (int32_t)size; i++) { - result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * - ((int32_t)((int16_t)a[i] - (int16_t)b[i])); - } - return (float)result; + for (int32_t i = 0; i < (int32_t)size; i++) + { + result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i])); + } + return (float)result; } #ifndef _WINDOWS -float DistanceL2Float::compare(const float *a, const float *b, - uint32_t size) const { - a = (const float *)__builtin_assume_aligned(a, 32); - b = (const float *)__builtin_assume_aligned(b, 32); +float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) const +{ + a = (const float *)__builtin_assume_aligned(a, 32); + b = (const float *)__builtin_assume_aligned(b, 32); #else -float DistanceL2Float::compare(const float *a, const float *b, - uint32_t size) const { +float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) const +{ #endif - float result = 0; + float result = 0; #ifdef USE_AVX2 - // assume size is divisible by 8 - uint16_t niters = (uint16_t)(size / 8); - __m256 sum = _mm256_setzero_ps(); - for (uint16_t j = 0; j < niters; j++) { - // scope is a[8j:8j+7], b[8j:8j+7] - // load a_vec - if (j < (niters - 1)) { - _mm_prefetch((char *)(a + 8 * (j + 1)), _MM_HINT_T0); - _mm_prefetch((char *)(b + 8 * (j + 1)), _MM_HINT_T0); - } - __m256 a_vec = _mm256_load_ps(a + 8 * j); - // load b_vec - __m256 b_vec = _mm256_load_ps(b + 8 * j); - // a_vec - b_vec - __m256 tmp_vec = _mm256_sub_ps(a_vec, b_vec); - - sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum); - } - - // horizontal add sum - result = _mm256_reduce_add_ps(sum); + // assume size is divisible by 8 + uint16_t niters = (uint16_t)(size / 8); + __m256 sum = _mm256_setzero_ps(); + for (uint16_t j = 0; j < niters; j++) + { + // scope is a[8j:8j+7], b[8j:8j+7] + // load a_vec + if (j < (niters - 1)) + { + _mm_prefetch((char *)(a + 8 * (j + 1)), _MM_HINT_T0); + _mm_prefetch((char *)(b + 8 * (j + 1)), _MM_HINT_T0); + } + __m256 a_vec = _mm256_load_ps(a + 8 * j); + // load b_vec + __m256 b_vec = _mm256_load_ps(b + 8 * j); + // a_vec - b_vec + __m256 tmp_vec = _mm256_sub_ps(a_vec, b_vec); + + sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum); + } + + // horizontal add sum + result = _mm256_reduce_add_ps(sum); #else #ifndef _WINDOWS #pragma omp simd reduction(+ : result) aligned(a, b : 32) #endif - for (int32_t i = 0; i < (int32_t)size; i++) { - result += (a[i] - b[i]) * (a[i] - b[i]); - } + for (int32_t i = 0; i < (int32_t)size; i++) + { + result += (a[i] - b[i]) * (a[i] - b[i]); + } #endif - return result; + return result; } -template -float SlowDistanceL2::compare(const T *a, const T *b, - uint32_t length) const { - float result = 0.0f; - for (uint32_t i = 0; i < length; i++) { - result += ((float)(a[i] - b[i])) * (a[i] - b[i]); - } - return result; +template float SlowDistanceL2::compare(const T *a, const T *b, uint32_t length) const +{ + float result = 0.0f; + for (uint32_t i = 0; i < length; i++) + { + result += ((float)(a[i] - b[i])) * (a[i] - b[i]); + } + return result; } #ifdef _WINDOWS -float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, - uint32_t length) const { - __m128 r = _mm_setzero_ps(); - __m128i r1; - while (length >= 16) { - r1 = _mm_subs_epi8(_mm_load_si128((__m128i *)a), - _mm_load_si128((__m128i *)b)); - r = _mm_add_ps(r, _mm_mul_epi8(r1)); - a += 16; - b += 16; - length -= 16; - } - r = _mm_hadd_ps(_mm_hadd_ps(r, r), r); - float res = r.m128_f32[0]; - - if (length >= 8) { - __m128 r2 = _mm_setzero_ps(); - __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *)(a - 8)), - _mm_load_si128((__m128i *)(b - 8))); - r2 = _mm_add_ps(r2, _mm_mulhi_epi8(r3)); - a += 8; - b += 8; - length -= 8; - r2 = _mm_hadd_ps(_mm_hadd_ps(r2, r2), r2); - res += r2.m128_f32[0]; - } - - if (length >= 4) { - __m128 r2 = _mm_setzero_ps(); - __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *)(a - 12)), - _mm_load_si128((__m128i *)(b - 12))); - r2 = _mm_add_ps(r2, _mm_mulhi_epi8_shift32(r3)); - res += r2.m128_f32[0] + r2.m128_f32[1]; - } - - return res; +float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t length) const +{ + __m128 r = _mm_setzero_ps(); + __m128i r1; + while (length >= 16) + { + r1 = _mm_subs_epi8(_mm_load_si128((__m128i *)a), _mm_load_si128((__m128i *)b)); + r = _mm_add_ps(r, _mm_mul_epi8(r1)); + a += 16; + b += 16; + length -= 16; + } + r = _mm_hadd_ps(_mm_hadd_ps(r, r), r); + float res = r.m128_f32[0]; + + if (length >= 8) + { + __m128 r2 = _mm_setzero_ps(); + __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *)(a - 8)), _mm_load_si128((__m128i *)(b - 8))); + r2 = _mm_add_ps(r2, _mm_mulhi_epi8(r3)); + a += 8; + b += 8; + length -= 8; + r2 = _mm_hadd_ps(_mm_hadd_ps(r2, r2), r2); + res += r2.m128_f32[0]; + } + + if (length >= 4) + { + __m128 r2 = _mm_setzero_ps(); + __m128i r3 = _mm_subs_epi8(_mm_load_si128((__m128i *)(a - 12)), _mm_load_si128((__m128i *)(b - 12))); + r2 = _mm_add_ps(r2, _mm_mulhi_epi8_shift32(r3)); + res += r2.m128_f32[0] + r2.m128_f32[1]; + } + + return res; } -float AVXDistanceL2Float::compare(const float *a, const float *b, - uint32_t length) const { - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); - - while (length >= 4) { - v1 = _mm_loadu_ps(a); - a += 4; - v2 = _mm_loadu_ps(b); - b += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - length -= 4; - } - - return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + sum.m128_f32[3]; +float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t length) const +{ + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (length >= 4) + { + v1 = _mm_loadu_ps(a); + a += 4; + v2 = _mm_loadu_ps(b); + b += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + length -= 4; + } + + return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + sum.m128_f32[3]; } #else -float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, - uint32_t) const { - return 0; +float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, uint32_t) const +{ + return 0; } -float AVXDistanceL2Float::compare(const float *, const float *, - uint32_t) const { - return 0; +float AVXDistanceL2Float::compare(const float *, const float *, uint32_t) const +{ + return 0; } #endif -template -float DistanceInnerProduct::inner_product(const T *a, const T *b, - uint32_t size) const { - if (!std::is_floating_point::value) { - diskann::cerr << "ERROR: Inner Product only defined for float currently." - << std::endl; - throw diskann::ANNException( - "ERROR: Inner Product only defined for float currently.", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - float result = 0; +template float DistanceInnerProduct::inner_product(const T *a, const T *b, uint32_t size) const +{ + if (!std::is_floating_point::value) + { + diskann::cerr << "ERROR: Inner Product only defined for float currently." << std::endl; + throw diskann::ANNException("ERROR: Inner Product only defined for float currently.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + float result = 0; #ifdef __GNUC__ #ifdef USE_AVX2 -#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm256_loadu_ps(addr1); \ - tmp2 = _mm256_loadu_ps(addr2); \ - tmp1 = _mm256_mul_ps(tmp1, tmp2); \ - dest = _mm256_add_ps(dest, tmp1); - - __m256 sum; - __m256 l0, l1; - __m256 r0, r1; - uint32_t D = (size + 7) & ~7U; - uint32_t DR = D % 16; - uint32_t DD = D - DR; - const float *l = (float *)a; - const float *r = (float *)b; - const float *e_l = l + DD; - const float *e_r = r + DD; - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; - - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_DOT(e_l, e_r, sum, l0, r0); - } - - for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) { - AVX_DOT(l, r, sum, l0, r0); - AVX_DOT(l + 8, r + 8, sum, l1, r1); - } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + - unpack[5] + unpack[6] + unpack[7]; +#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp2); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + uint32_t D = (size + 7) & ~7U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; + const float *l = (float *)a; + const float *r = (float *)b; + const float *e_l = l + DD; + const float *e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) + { + AVX_DOT(e_l, e_r, sum, l0, r0); + } + + for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) + { + AVX_DOT(l, r, sum, l0, r0); + AVX_DOT(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; #else #ifdef __SSE2__ -#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm128_loadu_ps(addr1); \ - tmp2 = _mm128_loadu_ps(addr2); \ - tmp1 = _mm128_mul_ps(tmp1, tmp2); \ - dest = _mm128_add_ps(dest, tmp1); - __m128 sum; - __m128 l0, l1, l2, l3; - __m128 r0, r1, r2, r3; - uint32_t D = (size + 3) & ~3U; - uint32_t DR = D % 16; - uint32_t DD = D - DR; - const float *l = a; - const float *r = b; - const float *e_l = l + DD; - const float *e_r = r + DD; - float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; - - sum = _mm_load_ps(unpack); - switch (DR) { - case 12: - SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2); - case 8: - SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1); - case 4: - SSE_DOT(e_l, e_r, sum, l0, r0); - default: - break; - } - for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) { - SSE_DOT(l, r, sum, l0, r0); - SSE_DOT(l + 4, r + 4, sum, l1, r1); - SSE_DOT(l + 8, r + 8, sum, l2, r2); - SSE_DOT(l + 12, r + 12, sum, l3, r3); - } - _mm_storeu_ps(unpack, sum); - result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; +#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm128_loadu_ps(addr1); \ + tmp2 = _mm128_loadu_ps(addr2); \ + tmp1 = _mm128_mul_ps(tmp1, tmp2); \ + dest = _mm128_add_ps(dest, tmp1); + __m128 sum; + __m128 l0, l1, l2, l3; + __m128 r0, r1, r2, r3; + uint32_t D = (size + 3) & ~3U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; + const float *l = a; + const float *r = b; + const float *e_l = l + DD; + const float *e_r = r + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; + + sum = _mm_load_ps(unpack); + switch (DR) + { + case 12: + SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2); + case 8: + SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1); + case 4: + SSE_DOT(e_l, e_r, sum, l0, r0); + default: + break; + } + for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) + { + SSE_DOT(l, r, sum, l0, r0); + SSE_DOT(l + 4, r + 4, sum, l1, r1); + SSE_DOT(l + 8, r + 8, sum, l2, r2); + SSE_DOT(l + 12, r + 12, sum, l3, r3); + } + _mm_storeu_ps(unpack, sum); + result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; #else - float dot0, dot1, dot2, dot3; - const float *last = a + size; - const float *unroll_group = last - 3; - - /* Process 4 items with each loop for efficiency. */ - while (a < unroll_group) { - dot0 = a[0] * b[0]; - dot1 = a[1] * b[1]; - dot2 = a[2] * b[2]; - dot3 = a[3] * b[3]; - result += dot0 + dot1 + dot2 + dot3; - a += 4; - b += 4; - } - /* Process last 0-3 pixels. Not needed for standard vector lengths. */ - while (a < last) { - result += *a++ * *b++; - } + float dot0, dot1, dot2, dot3; + const float *last = a + size; + const float *unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) + { + dot0 = a[0] * b[0]; + dot1 = a[1] * b[1]; + dot2 = a[2] * b[2]; + dot3 = a[3] * b[3]; + result += dot0 + dot1 + dot2 + dot3; + a += 4; + b += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) + { + result += *a++ * *b++; + } #endif #endif #endif - return result; + return result; } -template -float DistanceFastL2::compare(const T *a, const T *b, float norm, - uint32_t size) const { - float result = -2 * DistanceInnerProduct::inner_product(a, b, size); - result += norm; - return result; +template float DistanceFastL2::compare(const T *a, const T *b, float norm, uint32_t size) const +{ + float result = -2 * DistanceInnerProduct::inner_product(a, b, size); + result += norm; + return result; } -template -float DistanceFastL2::norm(const T *a, uint32_t size) const { - if (!std::is_floating_point::value) { - diskann::cerr << "ERROR: FastL2 only defined for float currently." - << std::endl; - throw diskann::ANNException( - "ERROR: FastL2 only defined for float currently.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - float result = 0; +template float DistanceFastL2::norm(const T *a, uint32_t size) const +{ + if (!std::is_floating_point::value) + { + diskann::cerr << "ERROR: FastL2 only defined for float currently." << std::endl; + throw diskann::ANNException("ERROR: FastL2 only defined for float currently.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + float result = 0; #ifdef __GNUC__ #ifdef __AVX__ -#define AVX_L2NORM(addr, dest, tmp) \ - tmp = _mm256_loadu_ps(addr); \ - tmp = _mm256_mul_ps(tmp, tmp); \ - dest = _mm256_add_ps(dest, tmp); - - __m256 sum; - __m256 l0, l1; - uint32_t D = (size + 7) & ~7U; - uint32_t DR = D % 16; - uint32_t DD = D - DR; - const float *l = (float *)a; - const float *e_l = l + DD; - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; - - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_L2NORM(e_l, sum, l0); - } - for (uint32_t i = 0; i < DD; i += 16, l += 16) { - AVX_L2NORM(l, sum, l0); - AVX_L2NORM(l + 8, sum, l1); - } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + - unpack[5] + unpack[6] + unpack[7]; +#define AVX_L2NORM(addr, dest, tmp) \ + tmp = _mm256_loadu_ps(addr); \ + tmp = _mm256_mul_ps(tmp, tmp); \ + dest = _mm256_add_ps(dest, tmp); + + __m256 sum; + __m256 l0, l1; + uint32_t D = (size + 7) & ~7U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; + const float *l = (float *)a; + const float *e_l = l + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) + { + AVX_L2NORM(e_l, sum, l0); + } + for (uint32_t i = 0; i < DD; i += 16, l += 16) + { + AVX_L2NORM(l, sum, l0); + AVX_L2NORM(l + 8, sum, l1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; #else #ifdef __SSE2__ -#define SSE_L2NORM(addr, dest, tmp) \ - tmp = _mm128_loadu_ps(addr); \ - tmp = _mm128_mul_ps(tmp, tmp); \ - dest = _mm128_add_ps(dest, tmp); - - __m128 sum; - __m128 l0, l1, l2, l3; - uint32_t D = (size + 3) & ~3U; - uint32_t DR = D % 16; - uint32_t DD = D - DR; - const float *l = a; - const float *e_l = l + DD; - float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; - - sum = _mm_load_ps(unpack); - switch (DR) { - case 12: - SSE_L2NORM(e_l + 8, sum, l2); - case 8: - SSE_L2NORM(e_l + 4, sum, l1); - case 4: - SSE_L2NORM(e_l, sum, l0); - default: - break; - } - for (uint32_t i = 0; i < DD; i += 16, l += 16) { - SSE_L2NORM(l, sum, l0); - SSE_L2NORM(l + 4, sum, l1); - SSE_L2NORM(l + 8, sum, l2); - SSE_L2NORM(l + 12, sum, l3); - } - _mm_storeu_ps(unpack, sum); - result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; +#define SSE_L2NORM(addr, dest, tmp) \ + tmp = _mm128_loadu_ps(addr); \ + tmp = _mm128_mul_ps(tmp, tmp); \ + dest = _mm128_add_ps(dest, tmp); + + __m128 sum; + __m128 l0, l1, l2, l3; + uint32_t D = (size + 3) & ~3U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; + const float *l = a; + const float *e_l = l + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; + + sum = _mm_load_ps(unpack); + switch (DR) + { + case 12: + SSE_L2NORM(e_l + 8, sum, l2); + case 8: + SSE_L2NORM(e_l + 4, sum, l1); + case 4: + SSE_L2NORM(e_l, sum, l0); + default: + break; + } + for (uint32_t i = 0; i < DD; i += 16, l += 16) + { + SSE_L2NORM(l, sum, l0); + SSE_L2NORM(l + 4, sum, l1); + SSE_L2NORM(l + 8, sum, l2); + SSE_L2NORM(l + 12, sum, l3); + } + _mm_storeu_ps(unpack, sum); + result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; #else - float dot0, dot1, dot2, dot3; - const float *last = a + size; - const float *unroll_group = last - 3; - - /* Process 4 items with each loop for efficiency. */ - while (a < unroll_group) { - dot0 = a[0] * a[0]; - dot1 = a[1] * a[1]; - dot2 = a[2] * a[2]; - dot3 = a[3] * a[3]; - result += dot0 + dot1 + dot2 + dot3; - a += 4; - } - /* Process last 0-3 pixels. Not needed for standard vector lengths. */ - while (a < last) { - result += (*a) * (*a); - a++; - } + float dot0, dot1, dot2, dot3; + const float *last = a + size; + const float *unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) + { + dot0 = a[0] * a[0]; + dot1 = a[1] * a[1]; + dot2 = a[2] * a[2]; + dot3 = a[3] * a[3]; + result += dot0 + dot1 + dot2 + dot3; + a += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) + { + result += (*a) * (*a); + a++; + } #endif #endif #endif - return result; + return result; } -float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, - uint32_t size) const { - float result = 0.0f; -#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm256_loadu_ps(addr1); \ - tmp2 = _mm256_loadu_ps(addr2); \ - tmp1 = _mm256_mul_ps(tmp1, tmp2); \ - dest = _mm256_add_ps(dest, tmp1); - - __m256 sum; - __m256 l0, l1; - __m256 r0, r1; - uint32_t D = (size + 7) & ~7U; - uint32_t DR = D % 16; - uint32_t DD = D - DR; - const float *l = (float *)a; - const float *r = (float *)b; - const float *e_l = l + DD; - const float *e_r = r + DD; +float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint32_t size) const +{ + float result = 0.0f; +#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp2); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + uint32_t D = (size + 7) & ~7U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; + const float *l = (float *)a; + const float *r = (float *)b; + const float *e_l = l + DD; + const float *e_r = r + DD; #ifndef _WINDOWS - float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; #else - __declspec(align(32)) float unpack[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + __declspec(align(32)) float unpack[8] = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - sum = _mm256_loadu_ps(unpack); - if (DR) { - AVX_DOT(e_l, e_r, sum, l0, r0); - } + sum = _mm256_loadu_ps(unpack); + if (DR) + { + AVX_DOT(e_l, e_r, sum, l0, r0); + } - for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) { - AVX_DOT(l, r, sum, l0, r0); - AVX_DOT(l + 8, r + 8, sum, l1, r1); - } - _mm256_storeu_ps(unpack, sum); - result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + - unpack[5] + unpack[6] + unpack[7]; + for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) + { + AVX_DOT(l, r, sum, l0, r0); + AVX_DOT(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; - return -result; + return -result; } -uint32_t AVXNormalizedCosineDistanceFloat::post_normalization_dimension( - uint32_t orig_dimension) const { - return orig_dimension; +uint32_t AVXNormalizedCosineDistanceFloat::post_normalization_dimension(uint32_t orig_dimension) const +{ + return orig_dimension; } -bool AVXNormalizedCosineDistanceFloat::preprocessing_required() const { - return true; +bool AVXNormalizedCosineDistanceFloat::preprocessing_required() const +{ + return true; } -void AVXNormalizedCosineDistanceFloat::preprocess_base_points( - float *original_data, const size_t orig_dim, const size_t num_points) { - for (uint32_t i = 0; i < num_points; i++) { - normalize((float *)(original_data + i * orig_dim), orig_dim); - } +void AVXNormalizedCosineDistanceFloat::preprocess_base_points(float *original_data, const size_t orig_dim, + const size_t num_points) +{ + for (uint32_t i = 0; i < num_points; i++) + { + normalize((float *)(original_data + i * orig_dim), orig_dim); + } } -void AVXNormalizedCosineDistanceFloat::preprocess_query(const float *query_vec, - const size_t query_dim, - float *query_scratch) { - normalize_and_copy(query_vec, (uint32_t)query_dim, query_scratch); +void AVXNormalizedCosineDistanceFloat::preprocess_query(const float *query_vec, const size_t query_dim, + float *query_scratch) +{ + normalize_and_copy(query_vec, (uint32_t)query_dim, query_scratch); } -void AVXNormalizedCosineDistanceFloat::normalize_and_copy( - const float *query_vec, const uint32_t query_dim, - float *query_target) const { - float norm = get_norm(query_vec, query_dim); +void AVXNormalizedCosineDistanceFloat::normalize_and_copy(const float *query_vec, const uint32_t query_dim, + float *query_target) const +{ + float norm = get_norm(query_vec, query_dim); - for (uint32_t i = 0; i < query_dim; i++) { - query_target[i] = query_vec[i] / norm; - } + for (uint32_t i = 0; i < query_dim; i++) + { + query_target[i] = query_vec[i] / norm; + } } // Get the right distance function for the given metric. -template <> diskann::Distance *get_distance_function(diskann::Metric m) { - if (m == diskann::Metric::L2) { - if (Avx2SupportedCPU) { - diskann::cout << "L2: Using AVX2 distance computation DistanceL2Float" - << std::endl; - return new diskann::DistanceL2Float(); - } else if (AvxSupportedCPU) { - diskann::cout << "L2: AVX2 not supported. Using AVX distance computation" - << std::endl; - return new diskann::AVXDistanceL2Float(); - } else { - diskann::cout << "L2: Older CPU. Using slow distance computation" - << std::endl; - return new diskann::SlowDistanceL2(); - } - } else if (m == diskann::Metric::COSINE) { - diskann::cout << "Cosine: Using either AVX or AVX2 implementation" - << std::endl; - return new diskann::DistanceCosineFloat(); - } else if (m == diskann::Metric::INNER_PRODUCT) { - diskann::cout << "Inner product: Using AVX2 implementation " - "AVXDistanceInnerProductFloat" - << std::endl; - return new diskann::AVXDistanceInnerProductFloat(); - } else if (m == diskann::Metric::FAST_L2) { - diskann::cout << "Fast_L2: Using AVX2 implementation with norm " - "memoization DistanceFastL2" - << std::endl; - return new diskann::DistanceFastL2(); - } else { - std::stringstream stream; - stream << "Only L2, cosine, and inner product supported for floating " - "point vectors as of now." - << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } +template <> diskann::Distance *get_distance_function(diskann::Metric m) +{ + if (m == diskann::Metric::L2) + { + if (Avx2SupportedCPU) + { + diskann::cout << "L2: Using AVX2 distance computation DistanceL2Float" << std::endl; + return new diskann::DistanceL2Float(); + } + else if (AvxSupportedCPU) + { + diskann::cout << "L2: AVX2 not supported. Using AVX distance computation" << std::endl; + return new diskann::AVXDistanceL2Float(); + } + else + { + diskann::cout << "L2: Older CPU. Using slow distance computation" << std::endl; + return new diskann::SlowDistanceL2(); + } + } + else if (m == diskann::Metric::COSINE) + { + diskann::cout << "Cosine: Using either AVX or AVX2 implementation" << std::endl; + return new diskann::DistanceCosineFloat(); + } + else if (m == diskann::Metric::INNER_PRODUCT) + { + diskann::cout << "Inner product: Using AVX2 implementation " + "AVXDistanceInnerProductFloat" + << std::endl; + return new diskann::AVXDistanceInnerProductFloat(); + } + else if (m == diskann::Metric::FAST_L2) + { + diskann::cout << "Fast_L2: Using AVX2 implementation with norm " + "memoization DistanceFastL2" + << std::endl; + return new diskann::DistanceFastL2(); + } + else + { + std::stringstream stream; + stream << "Only L2, cosine, and inner product supported for floating " + "point vectors as of now." + << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } } -template <> -diskann::Distance *get_distance_function(diskann::Metric m) { - if (m == diskann::Metric::L2) { - if (Avx2SupportedCPU) { - diskann::cout << "Using AVX2 distance computation DistanceL2Int8." - << std::endl; - return new diskann::DistanceL2Int8(); - } else if (AvxSupportedCPU) { - diskann::cout << "AVX2 not supported. Using AVX distance computation" - << std::endl; - return new diskann::AVXDistanceL2Int8(); - } else { - diskann::cout << "Older CPU. Using slow distance computation " - "SlowDistanceL2Int." - << std::endl; - return new diskann::SlowDistanceL2(); - } - } else if (m == diskann::Metric::COSINE) { - diskann::cout << "Using either AVX or AVX2 for Cosine similarity " - "DistanceCosineInt8." - << std::endl; - return new diskann::DistanceCosineInt8(); - } else { - std::stringstream stream; - stream << "Only L2 and cosine supported for signed byte vectors." - << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } +template <> diskann::Distance *get_distance_function(diskann::Metric m) +{ + if (m == diskann::Metric::L2) + { + if (Avx2SupportedCPU) + { + diskann::cout << "Using AVX2 distance computation DistanceL2Int8." << std::endl; + return new diskann::DistanceL2Int8(); + } + else if (AvxSupportedCPU) + { + diskann::cout << "AVX2 not supported. Using AVX distance computation" << std::endl; + return new diskann::AVXDistanceL2Int8(); + } + else + { + diskann::cout << "Older CPU. Using slow distance computation " + "SlowDistanceL2Int." + << std::endl; + return new diskann::SlowDistanceL2(); + } + } + else if (m == diskann::Metric::COSINE) + { + diskann::cout << "Using either AVX or AVX2 for Cosine similarity " + "DistanceCosineInt8." + << std::endl; + return new diskann::DistanceCosineInt8(); + } + else + { + std::stringstream stream; + stream << "Only L2 and cosine supported for signed byte vectors." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } } -template <> -diskann::Distance *get_distance_function(diskann::Metric m) { - if (m == diskann::Metric::L2) { +template <> diskann::Distance *get_distance_function(diskann::Metric m) +{ + if (m == diskann::Metric::L2) + { #ifdef _WINDOWS - diskann::cout - << "WARNING: AVX/AVX2 distance function not defined for Uint8. " - "Using " - "slow version. " - "Contact gopalsr@microsoft.com if you need AVX/AVX2 support." - << std::endl; + diskann::cout << "WARNING: AVX/AVX2 distance function not defined for Uint8. " + "Using " + "slow version. " + "Contact gopalsr@microsoft.com if you need AVX/AVX2 support." + << std::endl; #endif - return new diskann::DistanceL2UInt8(); - } else if (m == diskann::Metric::COSINE) { - diskann::cout - << "AVX/AVX2 distance function not defined for Uint8. Using " - "slow version SlowDistanceCosineUint8() " - "Contact gopalsr@microsoft.com if you need AVX/AVX2 support." - << std::endl; - return new diskann::SlowDistanceCosineUInt8(); - } else { - std::stringstream stream; - stream << "Only L2 and cosine supported for uint32_t byte vectors." - << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + return new diskann::DistanceL2UInt8(); + } + else if (m == diskann::Metric::COSINE) + { + diskann::cout << "AVX/AVX2 distance function not defined for Uint8. Using " + "slow version SlowDistanceCosineUint8() " + "Contact gopalsr@microsoft.com if you need AVX/AVX2 support." + << std::endl; + return new diskann::SlowDistanceCosineUInt8(); + } + else + { + std::stringstream stream; + stream << "Only L2 and cosine supported for uint32_t byte vectors." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } } template DISKANN_DLLEXPORT class DistanceInnerProduct; diff --git a/src/filter_utils.cpp b/src/filter_utils.cpp index 72f415df3..d5502d361 100644 --- a/src/filter_utils.cpp +++ b/src/filter_utils.cpp @@ -14,7 +14,8 @@ #include "utils.h" #include -namespace diskann { +namespace diskann +{ /* * Using passed in parameters and files generated from step 3, * builds a vanilla diskANN index for each label. @@ -23,127 +24,122 @@ namespace diskann { * final_index_path_prefix + "_" + label */ template -void generate_label_indices(path input_data_path, path final_index_path_prefix, - label_set all_labels, uint32_t R, uint32_t L, - float alpha, uint32_t num_threads) { - diskann::IndexWriteParameters label_index_build_parameters = - diskann::IndexWriteParametersBuilder(L, R) - .with_saturate_graph(false) - .with_alpha(alpha) - .with_num_threads(num_threads) - .build(); - - std::cout << "Generating indices per label..." << std::endl; - // for each label, build an index on resp. points - double total_indexing_time = 0.0, indexing_percentage = 0.0; - std::cout.setstate(std::ios_base::failbit); - diskann::cout.setstate(std::ios_base::failbit); - for (const auto &lbl : all_labels) { - path curr_label_input_data_path(input_data_path + "_" + lbl); - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - - size_t number_of_label_points, dimension; - diskann::get_bin_metadata(curr_label_input_data_path, - number_of_label_points, dimension); - - diskann::Index index(diskann::Metric::L2, dimension, - number_of_label_points, - std::make_shared( - label_index_build_parameters), - nullptr, 0, false, false, false, false, 0, false); - - auto index_build_timer = std::chrono::high_resolution_clock::now(); - index.build(curr_label_input_data_path.c_str(), number_of_label_points); - std::chrono::duration current_indexing_time = - std::chrono::high_resolution_clock::now() - index_build_timer; - - total_indexing_time += current_indexing_time.count(); - indexing_percentage += (1 / (double)all_labels.size()); - print_progress(indexing_percentage); - - index.save(curr_label_index_path.c_str()); - } - std::cout.clear(); - diskann::cout.clear(); - - std::cout << "\nDone. Generated per-label indices in " << total_indexing_time - << " seconds\n" - << std::endl; +void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, uint32_t R, + uint32_t L, float alpha, uint32_t num_threads) +{ + diskann::IndexWriteParameters label_index_build_parameters = diskann::IndexWriteParametersBuilder(L, R) + .with_saturate_graph(false) + .with_alpha(alpha) + .with_num_threads(num_threads) + .build(); + + std::cout << "Generating indices per label..." << std::endl; + // for each label, build an index on resp. points + double total_indexing_time = 0.0, indexing_percentage = 0.0; + std::cout.setstate(std::ios_base::failbit); + diskann::cout.setstate(std::ios_base::failbit); + for (const auto &lbl : all_labels) + { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + + size_t number_of_label_points, dimension; + diskann::get_bin_metadata(curr_label_input_data_path, number_of_label_points, dimension); + + diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, + std::make_shared(label_index_build_parameters), nullptr, + 0, false, false, false, false, 0, false); + + auto index_build_timer = std::chrono::high_resolution_clock::now(); + index.build(curr_label_input_data_path.c_str(), number_of_label_points); + std::chrono::duration current_indexing_time = + std::chrono::high_resolution_clock::now() - index_build_timer; + + total_indexing_time += current_indexing_time.count(); + indexing_percentage += (1 / (double)all_labels.size()); + print_progress(indexing_percentage); + + index.save(curr_label_index_path.c_str()); + } + std::cout.clear(); + diskann::cout.clear(); + + std::cout << "\nDone. Generated per-label indices in " << total_indexing_time << " seconds\n" << std::endl; } // for use on systems without writev (i.e. Windows) template -tsl::robin_map> -generate_label_specific_vector_files_compat( - path input_data_path, - tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels) { - auto file_writing_timer = std::chrono::high_resolution_clock::now(); - std::ifstream input_data_stream(input_data_path); - - uint32_t number_of_points, dimension; - input_data_stream.read((char *)&number_of_points, sizeof(uint32_t)); - input_data_stream.read((char *)&dimension, sizeof(uint32_t)); - const uint32_t VECTOR_SIZE = dimension * sizeof(T); - if (number_of_points != point_ids_to_labels.size()) { - std::cerr << "Error: number of points in labels file and data file differ." - << std::endl; - throw; - } - - tsl::robin_map labels_to_vectors; - tsl::robin_map labels_to_curr_vector; - tsl::robin_map> label_id_to_orig_id; - - for (const auto &lbl : all_labels) { - uint32_t number_of_label_pts = labels_to_number_of_points[lbl]; - char *vectors = (char *)malloc(number_of_label_pts * VECTOR_SIZE); - if (vectors == nullptr) { - throw; +tsl::robin_map> generate_label_specific_vector_files_compat( + path input_data_path, tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels) +{ + auto file_writing_timer = std::chrono::high_resolution_clock::now(); + std::ifstream input_data_stream(input_data_path); + + uint32_t number_of_points, dimension; + input_data_stream.read((char *)&number_of_points, sizeof(uint32_t)); + input_data_stream.read((char *)&dimension, sizeof(uint32_t)); + const uint32_t VECTOR_SIZE = dimension * sizeof(T); + if (number_of_points != point_ids_to_labels.size()) + { + std::cerr << "Error: number of points in labels file and data file differ." << std::endl; + throw; + } + + tsl::robin_map labels_to_vectors; + tsl::robin_map labels_to_curr_vector; + tsl::robin_map> label_id_to_orig_id; + + for (const auto &lbl : all_labels) + { + uint32_t number_of_label_pts = labels_to_number_of_points[lbl]; + char *vectors = (char *)malloc(number_of_label_pts * VECTOR_SIZE); + if (vectors == nullptr) + { + throw; + } + labels_to_vectors[lbl] = vectors; + labels_to_curr_vector[lbl] = 0; + label_id_to_orig_id[lbl].reserve(number_of_label_pts); + } + + for (uint32_t point_id = 0; point_id < number_of_points; point_id++) + { + char *curr_vector = (char *)malloc(VECTOR_SIZE); + input_data_stream.read(curr_vector, VECTOR_SIZE); + for (const auto &lbl : point_ids_to_labels[point_id]) + { + char *curr_label_vector_ptr = labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE); + memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE); + labels_to_curr_vector[lbl]++; + label_id_to_orig_id[lbl].push_back(point_id); + } + free(curr_vector); } - labels_to_vectors[lbl] = vectors; - labels_to_curr_vector[lbl] = 0; - label_id_to_orig_id[lbl].reserve(number_of_label_pts); - } - - for (uint32_t point_id = 0; point_id < number_of_points; point_id++) { - char *curr_vector = (char *)malloc(VECTOR_SIZE); - input_data_stream.read(curr_vector, VECTOR_SIZE); - for (const auto &lbl : point_ids_to_labels[point_id]) { - char *curr_label_vector_ptr = - labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE); - memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE); - labels_to_curr_vector[lbl]++; - label_id_to_orig_id[lbl].push_back(point_id); + + for (const auto &lbl : all_labels) + { + path curr_label_input_data_path(input_data_path + "_" + lbl); + uint32_t number_of_label_pts = labels_to_number_of_points[lbl]; + + std::ofstream label_file_stream; + label_file_stream.exceptions(std::ios::badbit | std::ios::failbit); + label_file_stream.open(curr_label_input_data_path, std::ios_base::binary); + label_file_stream.write((char *)&number_of_label_pts, sizeof(uint32_t)); + label_file_stream.write((char *)&dimension, sizeof(uint32_t)); + label_file_stream.write((char *)labels_to_vectors[lbl], number_of_label_pts * VECTOR_SIZE); + + label_file_stream.close(); + free(labels_to_vectors[lbl]); } - free(curr_vector); - } - - for (const auto &lbl : all_labels) { - path curr_label_input_data_path(input_data_path + "_" + lbl); - uint32_t number_of_label_pts = labels_to_number_of_points[lbl]; - - std::ofstream label_file_stream; - label_file_stream.exceptions(std::ios::badbit | std::ios::failbit); - label_file_stream.open(curr_label_input_data_path, std::ios_base::binary); - label_file_stream.write((char *)&number_of_label_pts, sizeof(uint32_t)); - label_file_stream.write((char *)&dimension, sizeof(uint32_t)); - label_file_stream.write((char *)labels_to_vectors[lbl], - number_of_label_pts * VECTOR_SIZE); - - label_file_stream.close(); - free(labels_to_vectors[lbl]); - } - input_data_stream.close(); - - std::chrono::duration file_writing_time = - std::chrono::high_resolution_clock::now() - file_writing_timer; - std::cout << "generated " << all_labels.size() - << " label-specific vector files for index building in time " - << file_writing_time.count() << "\n" - << std::endl; - - return label_id_to_orig_id; + input_data_stream.close(); + + std::chrono::duration file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer; + std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time " + << file_writing_time.count() << "\n" + << std::endl; + + return label_id_to_orig_id; } /* @@ -151,37 +147,36 @@ generate_label_specific_vector_files_compat( * * Returns both the graph index and the size of the file in bytes. */ -load_label_index_return_values -load_label_index(path label_index_path, uint32_t label_number_of_points) { - std::ifstream label_index_stream; - label_index_stream.exceptions(std::ios::badbit | std::ios::failbit); - label_index_stream.open(label_index_path, std::ios::binary); - - uint64_t index_file_size, index_num_frozen_points; - uint32_t index_max_observed_degree, index_entry_point; - const size_t INDEX_METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); - label_index_stream.read((char *)&index_file_size, sizeof(uint64_t)); - label_index_stream.read((char *)&index_max_observed_degree, sizeof(uint32_t)); - label_index_stream.read((char *)&index_entry_point, sizeof(uint32_t)); - label_index_stream.read((char *)&index_num_frozen_points, sizeof(uint64_t)); - size_t bytes_read = INDEX_METADATA; - - std::vector> label_index(label_number_of_points); - uint32_t nodes_read = 0; - while (bytes_read != index_file_size) { - uint32_t current_node_num_neighbors; - label_index_stream.read((char *)¤t_node_num_neighbors, - sizeof(uint32_t)); - nodes_read++; - - std::vector current_node_neighbors(current_node_num_neighbors); - label_index_stream.read((char *)current_node_neighbors.data(), - current_node_num_neighbors * sizeof(uint32_t)); - label_index[nodes_read - 1].swap(current_node_neighbors); - bytes_read += sizeof(uint32_t) * (current_node_num_neighbors + 1); - } - - return std::make_tuple(label_index, index_file_size); +load_label_index_return_values load_label_index(path label_index_path, uint32_t label_number_of_points) +{ + std::ifstream label_index_stream; + label_index_stream.exceptions(std::ios::badbit | std::ios::failbit); + label_index_stream.open(label_index_path, std::ios::binary); + + uint64_t index_file_size, index_num_frozen_points; + uint32_t index_max_observed_degree, index_entry_point; + const size_t INDEX_METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + label_index_stream.read((char *)&index_file_size, sizeof(uint64_t)); + label_index_stream.read((char *)&index_max_observed_degree, sizeof(uint32_t)); + label_index_stream.read((char *)&index_entry_point, sizeof(uint32_t)); + label_index_stream.read((char *)&index_num_frozen_points, sizeof(uint64_t)); + size_t bytes_read = INDEX_METADATA; + + std::vector> label_index(label_number_of_points); + uint32_t nodes_read = 0; + while (bytes_read != index_file_size) + { + uint32_t current_node_num_neighbors; + label_index_stream.read((char *)¤t_node_num_neighbors, sizeof(uint32_t)); + nodes_read++; + + std::vector current_node_neighbors(current_node_num_neighbors); + label_index_stream.read((char *)current_node_neighbors.data(), current_node_num_neighbors * sizeof(uint32_t)); + label_index[nodes_read - 1].swap(current_node_neighbors); + bytes_read += sizeof(uint32_t) * (current_node_num_neighbors + 1); + } + + return std::make_tuple(label_index, index_file_size); } /* @@ -193,71 +188,77 @@ load_label_index(path label_index_path, uint32_t label_number_of_points) { * 2. map: key is label, value is number of points with the label * 3. the label universe as a set */ -parse_label_file_return_values parse_label_file(path label_data_path, - std::string universal_label) { - std::ifstream label_data_stream(label_data_path); - std::string line, token; - uint32_t line_cnt = 0; - - // allows us to reserve space for the points_to_labels vector - while (std::getline(label_data_stream, line)) - line_cnt++; - label_data_stream.clear(); - label_data_stream.seekg(0, std::ios::beg); - - // values to return - std::vector point_ids_to_labels(line_cnt); - tsl::robin_map labels_to_number_of_points; - label_set all_labels; - - std::vector points_with_universal_label; - line_cnt = 0; - while (std::getline(label_data_stream, line)) { - std::istringstream current_labels_comma_separated(line); - label_set current_labels; - - // get point id - uint32_t point_id = line_cnt; - - // parse comma separated labels - bool current_universal_label_check = false; - while (getline(current_labels_comma_separated, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - - // if token is empty, there's no labels for the point - if (token == universal_label) { - points_with_universal_label.push_back(point_id); - current_universal_label_check = true; - } else { - all_labels.insert(token); - current_labels.insert(token); - labels_to_number_of_points[token]++; - } +parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label) +{ + std::ifstream label_data_stream(label_data_path); + std::string line, token; + uint32_t line_cnt = 0; + + // allows us to reserve space for the points_to_labels vector + while (std::getline(label_data_stream, line)) + line_cnt++; + label_data_stream.clear(); + label_data_stream.seekg(0, std::ios::beg); + + // values to return + std::vector point_ids_to_labels(line_cnt); + tsl::robin_map labels_to_number_of_points; + label_set all_labels; + + std::vector points_with_universal_label; + line_cnt = 0; + while (std::getline(label_data_stream, line)) + { + std::istringstream current_labels_comma_separated(line); + label_set current_labels; + + // get point id + uint32_t point_id = line_cnt; + + // parse comma separated labels + bool current_universal_label_check = false; + while (getline(current_labels_comma_separated, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + + // if token is empty, there's no labels for the point + if (token == universal_label) + { + points_with_universal_label.push_back(point_id); + current_universal_label_check = true; + } + else + { + all_labels.insert(token); + current_labels.insert(token); + labels_to_number_of_points[token]++; + } + } + + if (current_labels.size() <= 0 && !current_universal_label_check) + { + std::cerr << "Error: " << point_id << " has no labels." << std::endl; + exit(-1); + } + point_ids_to_labels[point_id] = current_labels; + line_cnt++; } - if (current_labels.size() <= 0 && !current_universal_label_check) { - std::cerr << "Error: " << point_id << " has no labels." << std::endl; - exit(-1); + // for every point with universal label, set its label set to all labels + // also, increment the count for number of points a label has + for (const auto &point_id : points_with_universal_label) + { + point_ids_to_labels[point_id] = all_labels; + for (const auto &lbl : all_labels) + labels_to_number_of_points[lbl]++; } - point_ids_to_labels[point_id] = current_labels; - line_cnt++; - } - - // for every point with universal label, set its label set to all labels - // also, increment the count for number of points a label has - for (const auto &point_id : points_with_universal_label) { - point_ids_to_labels[point_id] = all_labels; - for (const auto &lbl : all_labels) - labels_to_number_of_points[lbl]++; - } - std::cout << "Identified " << all_labels.size() << " distinct label(s) for " - << point_ids_to_labels.size() << " points\n" - << std::endl; + std::cout << "Identified " << all_labels.size() << " distinct label(s) for " << point_ids_to_labels.size() + << " points\n" + << std::endl; - return std::make_tuple(point_ids_to_labels, labels_to_number_of_points, - all_labels); + return std::make_tuple(point_ids_to_labels, labels_to_number_of_points, all_labels); } /* @@ -270,88 +271,86 @@ parse_label_file_return_values parse_label_file(path label_data_path, * 2. a set of all labels */ template -std::tuple>, tsl::robin_set> -parse_formatted_label_file(std::string label_file) { - std::vector> pts_to_labels; - tsl::robin_set labels; - - // Format of Label txt file: filters with comma separators - std::ifstream infile(label_file); - if (infile.fail()) { - throw diskann::ANNException( - std::string("Failed to open file ") + label_file, -1); - } - - std::string line, token; - uint32_t line_cnt = 0; - - while (std::getline(infile, line)) { - line_cnt++; - } - pts_to_labels.resize(line_cnt, std::vector()); - - infile.clear(); - infile.seekg(0, std::ios::beg); - line_cnt = 0; - - while (std::getline(infile, line)) { - std::istringstream iss(line); - std::vector lbls(0); - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = static_cast(std::stoul(token)); - lbls.push_back(token_as_num); - labels.insert(token_as_num); +std::tuple>, tsl::robin_set> parse_formatted_label_file(std::string label_file) +{ + std::vector> pts_to_labels; + tsl::robin_set labels; + + // Format of Label txt file: filters with comma separators + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); } - if (lbls.size() <= 0) { - diskann::cout << "No label found"; - exit(-1); + + std::string line, token; + uint32_t line_cnt = 0; + + while (std::getline(infile, line)) + { + line_cnt++; + } + pts_to_labels.resize(line_cnt, std::vector()); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + std::vector lbls(0); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = static_cast(std::stoul(token)); + lbls.push_back(token_as_num); + labels.insert(token_as_num); + } + if (lbls.size() <= 0) + { + diskann::cout << "No label found"; + exit(-1); + } + std::sort(lbls.begin(), lbls.end()); + pts_to_labels[line_cnt] = lbls; + line_cnt++; } - std::sort(lbls.begin(), lbls.end()); - pts_to_labels[line_cnt] = lbls; - line_cnt++; - } - diskann::cout << "Identified " << labels.size() << " distinct label(s)" - << std::endl; - - return std::make_tuple(pts_to_labels, labels); + diskann::cout << "Identified " << labels.size() << " distinct label(s)" << std::endl; + + return std::make_tuple(pts_to_labels, labels); } -template DISKANN_DLLEXPORT - std::tuple>, tsl::robin_set> - parse_formatted_label_file(path label_file); +template DISKANN_DLLEXPORT std::tuple>, tsl::robin_set> +parse_formatted_label_file(path label_file); -template DISKANN_DLLEXPORT - std::tuple>, tsl::robin_set> - parse_formatted_label_file(path label_file); +template DISKANN_DLLEXPORT std::tuple>, tsl::robin_set> +parse_formatted_label_file(path label_file); -template DISKANN_DLLEXPORT void generate_label_indices( - path input_data_path, path final_index_path_prefix, label_set all_labels, - uint32_t R, uint32_t L, float alpha, uint32_t num_threads); -template DISKANN_DLLEXPORT void generate_label_indices( - path input_data_path, path final_index_path_prefix, label_set all_labels, - uint32_t R, uint32_t L, float alpha, uint32_t num_threads); -template DISKANN_DLLEXPORT void generate_label_indices( - path input_data_path, path final_index_path_prefix, label_set all_labels, - uint32_t R, uint32_t L, float alpha, uint32_t num_threads); +template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, uint32_t R, uint32_t L, float alpha, + uint32_t num_threads); +template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, uint32_t R, uint32_t L, + float alpha, uint32_t num_threads); +template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, uint32_t R, uint32_t L, + float alpha, uint32_t num_threads); template DISKANN_DLLEXPORT tsl::robin_map> -generate_label_specific_vector_files_compat( - path input_data_path, - tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels); +generate_label_specific_vector_files_compat(path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); template DISKANN_DLLEXPORT tsl::robin_map> -generate_label_specific_vector_files_compat( - path input_data_path, - tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels); +generate_label_specific_vector_files_compat(path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); template DISKANN_DLLEXPORT tsl::robin_map> -generate_label_specific_vector_files_compat( - path input_data_path, - tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels); +generate_label_specific_vector_files_compat(path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); } // namespace diskann \ No newline at end of file diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index ab9276abe..46ddfc92b 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -7,393 +7,391 @@ #include "utils.h" -namespace diskann { +namespace diskann +{ template -InMemDataStore::InMemDataStore( - const location_t num_points, const size_t dim, - std::unique_ptr> distance_fn) - : AbstractDataStore(num_points, dim), - _distance_fn(std::move(distance_fn)) { - _aligned_dim = ROUND_UP(dim, _distance_fn->get_required_alignment()); - alloc_aligned(((void **)&_data), - this->_capacity * _aligned_dim * sizeof(data_t), - 8 * sizeof(data_t)); - std::memset(_data, 0, this->_capacity * _aligned_dim * sizeof(data_t)); +InMemDataStore::InMemDataStore(const location_t num_points, const size_t dim, + std::unique_ptr> distance_fn) + : AbstractDataStore(num_points, dim), _distance_fn(std::move(distance_fn)) +{ + _aligned_dim = ROUND_UP(dim, _distance_fn->get_required_alignment()); + alloc_aligned(((void **)&_data), this->_capacity * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); + std::memset(_data, 0, this->_capacity * _aligned_dim * sizeof(data_t)); } -template InMemDataStore::~InMemDataStore() { - if (_data != nullptr) { - aligned_free(this->_data); - } +template InMemDataStore::~InMemDataStore() +{ + if (_data != nullptr) + { + aligned_free(this->_data); + } } -template -size_t InMemDataStore::get_aligned_dim() const { - return _aligned_dim; +template size_t InMemDataStore::get_aligned_dim() const +{ + return _aligned_dim; } -template -size_t InMemDataStore::get_alignment_factor() const { - return _distance_fn->get_required_alignment(); +template size_t InMemDataStore::get_alignment_factor() const +{ + return _distance_fn->get_required_alignment(); } -template -location_t InMemDataStore::load(const std::string &filename) { - return load_impl(filename); +template location_t InMemDataStore::load(const std::string &filename) +{ + return load_impl(filename); } #ifdef EXEC_ENV_OLS -template -location_t InMemDataStore::load_impl(AlignedFileReader &reader) { - size_t file_dim, file_num_points; - - diskann::get_bin_metadata(reader, file_num_points, file_dim); - - if (file_dim != this->_dim) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << this->_dim << " dimension," - << "but file has " << file_dim << " dimension." << std::endl; - diskann::cerr << stream.str() << std::endl; - aligned_free(_data); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } +template location_t InMemDataStore::load_impl(AlignedFileReader &reader) +{ + size_t file_dim, file_num_points; + + diskann::get_bin_metadata(reader, file_num_points, file_dim); + + if (file_dim != this->_dim) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << this->_dim << " dimension," + << "but file has " << file_dim << " dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; + aligned_free(_data); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - if (file_num_points > this->capacity()) { - this->resize((location_t)file_num_points); - } - copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, - _aligned_dim); + if (file_num_points > this->capacity()) + { + this->resize((location_t)file_num_points); + } + copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _aligned_dim); - return (location_t)file_num_points; + return (location_t)file_num_points; } #endif -template -location_t InMemDataStore::load_impl(const std::string &filename) { - size_t file_dim, file_num_points; - if (!file_exists(filename)) { - std::stringstream stream; - stream << "ERROR: data file " << filename << " does not exist." - << std::endl; - diskann::cerr << stream.str() << std::endl; - aligned_free(_data); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - diskann::get_bin_metadata(filename, file_num_points, file_dim); - - if (file_dim != this->_dim) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << this->_dim << " dimension," - << "but file has " << file_dim << " dimension." << std::endl; - diskann::cerr << stream.str() << std::endl; - aligned_free(_data); - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } +template location_t InMemDataStore::load_impl(const std::string &filename) +{ + size_t file_dim, file_num_points; + if (!file_exists(filename)) + { + std::stringstream stream; + stream << "ERROR: data file " << filename << " does not exist." << std::endl; + diskann::cerr << stream.str() << std::endl; + aligned_free(_data); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + diskann::get_bin_metadata(filename, file_num_points, file_dim); + + if (file_dim != this->_dim) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << this->_dim << " dimension," + << "but file has " << file_dim << " dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; + aligned_free(_data); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - if (file_num_points > this->capacity()) { - this->resize((location_t)file_num_points); - } + if (file_num_points > this->capacity()) + { + this->resize((location_t)file_num_points); + } - copy_aligned_data_from_file(filename.c_str(), _data, file_num_points, - file_dim, _aligned_dim); + copy_aligned_data_from_file(filename.c_str(), _data, file_num_points, file_dim, _aligned_dim); - return (location_t)file_num_points; + return (location_t)file_num_points; } -template -size_t InMemDataStore::save(const std::string &filename, - const location_t num_points) { - return save_data_in_base_dimensions(filename, _data, num_points, - this->get_dims(), this->get_aligned_dim(), - 0U); +template size_t InMemDataStore::save(const std::string &filename, const location_t num_points) +{ + return save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); } -template -void InMemDataStore::populate_data(const data_t *vectors, - const location_t num_pts) { - memset(_data, 0, _aligned_dim * sizeof(data_t) * num_pts); - for (location_t i = 0; i < num_pts; i++) { - std::memmove(_data + i * _aligned_dim, vectors + i * this->_dim, - this->_dim * sizeof(data_t)); - } - - if (_distance_fn->preprocessing_required()) { - _distance_fn->preprocess_base_points(_data, this->_aligned_dim, num_pts); - } +template void InMemDataStore::populate_data(const data_t *vectors, const location_t num_pts) +{ + memset(_data, 0, _aligned_dim * sizeof(data_t) * num_pts); + for (location_t i = 0; i < num_pts; i++) + { + std::memmove(_data + i * _aligned_dim, vectors + i * this->_dim, this->_dim * sizeof(data_t)); + } + + if (_distance_fn->preprocessing_required()) + { + _distance_fn->preprocess_base_points(_data, this->_aligned_dim, num_pts); + } } -template -void InMemDataStore::populate_data(const std::string &filename, - const size_t offset) { - size_t npts, ndim; - copy_aligned_data_from_file(filename.c_str(), _data, npts, ndim, _aligned_dim, - offset); - - if ((location_t)npts > this->capacity()) { - std::stringstream ss; - ss << "Number of points in the file: " << filename - << " is greater than the capacity of data store: " << this->capacity() - << ". Must invoke resize before calling populate_data()" << std::endl; - throw diskann::ANNException(ss.str(), -1); - } - - if ((location_t)ndim != this->get_dims()) { - std::stringstream ss; - ss << "Number of dimensions of a point in the file: " << filename - << " is not equal to dimensions of data store: " << this->capacity() - << "." << std::endl; - throw diskann::ANNException(ss.str(), -1); - } - - if (_distance_fn->preprocessing_required()) { - _distance_fn->preprocess_base_points(_data, this->_aligned_dim, - this->capacity()); - } +template void InMemDataStore::populate_data(const std::string &filename, const size_t offset) +{ + size_t npts, ndim; + copy_aligned_data_from_file(filename.c_str(), _data, npts, ndim, _aligned_dim, offset); + + if ((location_t)npts > this->capacity()) + { + std::stringstream ss; + ss << "Number of points in the file: " << filename + << " is greater than the capacity of data store: " << this->capacity() + << ". Must invoke resize before calling populate_data()" << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + if ((location_t)ndim != this->get_dims()) + { + std::stringstream ss; + ss << "Number of dimensions of a point in the file: " << filename + << " is not equal to dimensions of data store: " << this->capacity() << "." << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + if (_distance_fn->preprocessing_required()) + { + _distance_fn->preprocess_base_points(_data, this->_aligned_dim, this->capacity()); + } } template -void InMemDataStore::extract_data_to_bin(const std::string &filename, - const location_t num_points) { - save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), - this->get_aligned_dim(), 0U); +void InMemDataStore::extract_data_to_bin(const std::string &filename, const location_t num_points) +{ + save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); } -template -void InMemDataStore::get_vector(const location_t i, - data_t *dest) const { - // REFACTOR TODO: Should we denormalize and return values? - memcpy(dest, _data + i * _aligned_dim, this->_dim * sizeof(data_t)); +template void InMemDataStore::get_vector(const location_t i, data_t *dest) const +{ + // REFACTOR TODO: Should we denormalize and return values? + memcpy(dest, _data + i * _aligned_dim, this->_dim * sizeof(data_t)); } -template -void InMemDataStore::set_vector(const location_t loc, - const data_t *const vector) { - size_t offset_in_data = loc * _aligned_dim; - memset(_data + offset_in_data, 0, _aligned_dim * sizeof(data_t)); - memcpy(_data + offset_in_data, vector, this->_dim * sizeof(data_t)); - if (_distance_fn->preprocessing_required()) { - _distance_fn->preprocess_base_points(_data + offset_in_data, _aligned_dim, - 1); - } +template void InMemDataStore::set_vector(const location_t loc, const data_t *const vector) +{ + size_t offset_in_data = loc * _aligned_dim; + memset(_data + offset_in_data, 0, _aligned_dim * sizeof(data_t)); + memcpy(_data + offset_in_data, vector, this->_dim * sizeof(data_t)); + if (_distance_fn->preprocessing_required()) + { + _distance_fn->preprocess_base_points(_data + offset_in_data, _aligned_dim, 1); + } } -template -void InMemDataStore::prefetch_vector(const location_t loc) { - diskann::prefetch_vector((const char *)_data + - _aligned_dim * (size_t)loc * sizeof(data_t), - sizeof(data_t) * _aligned_dim); +template void InMemDataStore::prefetch_vector(const location_t loc) +{ + diskann::prefetch_vector((const char *)_data + _aligned_dim * (size_t)loc * sizeof(data_t), + sizeof(data_t) * _aligned_dim); } template -void InMemDataStore::preprocess_query( - const data_t *query, AbstractScratch *query_scratch) const { - if (query_scratch != nullptr) { - memcpy(query_scratch->aligned_query_T(), query, - sizeof(data_t) * this->get_dims()); - } else { - std::stringstream ss; - ss << "In InMemDataStore::preprocess_query: Query scratch is null"; - diskann::cerr << ss.str() << std::endl; - throw diskann::ANNException(ss.str(), -1); - } +void InMemDataStore::preprocess_query(const data_t *query, AbstractScratch *query_scratch) const +{ + if (query_scratch != nullptr) + { + memcpy(query_scratch->aligned_query_T(), query, sizeof(data_t) * this->get_dims()); + } + else + { + std::stringstream ss; + ss << "In InMemDataStore::preprocess_query: Query scratch is null"; + diskann::cerr << ss.str() << std::endl; + throw diskann::ANNException(ss.str(), -1); + } } -template -float InMemDataStore::get_distance(const data_t *query, - const location_t loc) const { - return _distance_fn->compare(query, _data + _aligned_dim * loc, - (uint32_t)_aligned_dim); +template float InMemDataStore::get_distance(const data_t *query, const location_t loc) const +{ + return _distance_fn->compare(query, _data + _aligned_dim * loc, (uint32_t)_aligned_dim); } template -void InMemDataStore::get_distance( - const data_t *query, const location_t *locations, - const uint32_t location_count, float *distances, - AbstractScratch *scratch_space) const { - for (location_t i = 0; i < location_count; i++) { - distances[i] = - _distance_fn->compare(query, _data + locations[i] * _aligned_dim, - (uint32_t)this->_aligned_dim); - } +void InMemDataStore::get_distance(const data_t *query, const location_t *locations, + const uint32_t location_count, float *distances, + AbstractScratch *scratch_space) const +{ + for (location_t i = 0; i < location_count; i++) + { + distances[i] = _distance_fn->compare(query, _data + locations[i] * _aligned_dim, (uint32_t)this->_aligned_dim); + } } template -float InMemDataStore::get_distance(const location_t loc1, - const location_t loc2) const { - return _distance_fn->compare(_data + loc1 * _aligned_dim, - _data + loc2 * _aligned_dim, - (uint32_t)this->_aligned_dim); +float InMemDataStore::get_distance(const location_t loc1, const location_t loc2) const +{ + return _distance_fn->compare(_data + loc1 * _aligned_dim, _data + loc2 * _aligned_dim, + (uint32_t)this->_aligned_dim); } template -void InMemDataStore::get_distance( - const data_t *preprocessed_query, const std::vector &ids, - std::vector &distances, - AbstractScratch *scratch_space) const { - for (int i = 0; i < ids.size(); i++) { - distances[i] = - _distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, - (uint32_t)this->_aligned_dim); - } +void InMemDataStore::get_distance(const data_t *preprocessed_query, const std::vector &ids, + std::vector &distances, AbstractScratch *scratch_space) const +{ + for (int i = 0; i < ids.size(); i++) + { + distances[i] = + _distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, (uint32_t)this->_aligned_dim); + } } -template -location_t InMemDataStore::expand(const location_t new_size) { - if (new_size == this->capacity()) { - return this->capacity(); - } else if (new_size < this->capacity()) { - std::stringstream ss; - ss << "Cannot 'expand' datastore when new capacity (" << new_size - << ") < existing capacity(" << this->capacity() << ")" << std::endl; - throw diskann::ANNException(ss.str(), -1); - } +template location_t InMemDataStore::expand(const location_t new_size) +{ + if (new_size == this->capacity()) + { + return this->capacity(); + } + else if (new_size < this->capacity()) + { + std::stringstream ss; + ss << "Cannot 'expand' datastore when new capacity (" << new_size << ") < existing capacity(" + << this->capacity() << ")" << std::endl; + throw diskann::ANNException(ss.str(), -1); + } #ifndef _WINDOWS - data_t *new_data; - alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), - 8 * sizeof(data_t)); - memcpy(new_data, _data, this->capacity() * _aligned_dim * sizeof(data_t)); - aligned_free(_data); - _data = new_data; + data_t *new_data; + alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); + memcpy(new_data, _data, this->capacity() * _aligned_dim * sizeof(data_t)); + aligned_free(_data); + _data = new_data; #else - realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), - 8 * sizeof(data_t)); + realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); #endif - this->_capacity = new_size; - return this->_capacity; + this->_capacity = new_size; + return this->_capacity; } -template -location_t InMemDataStore::shrink(const location_t new_size) { - if (new_size == this->capacity()) { - return this->capacity(); - } else if (new_size > this->capacity()) { - std::stringstream ss; - ss << "Cannot 'shrink' datastore when new capacity (" << new_size - << ") > existing capacity(" << this->capacity() << ")" << std::endl; - throw diskann::ANNException(ss.str(), -1); - } +template location_t InMemDataStore::shrink(const location_t new_size) +{ + if (new_size == this->capacity()) + { + return this->capacity(); + } + else if (new_size > this->capacity()) + { + std::stringstream ss; + ss << "Cannot 'shrink' datastore when new capacity (" << new_size << ") > existing capacity(" + << this->capacity() << ")" << std::endl; + throw diskann::ANNException(ss.str(), -1); + } #ifndef _WINDOWS - data_t *new_data; - alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), - 8 * sizeof(data_t)); - memcpy(new_data, _data, new_size * _aligned_dim * sizeof(data_t)); - aligned_free(_data); - _data = new_data; + data_t *new_data; + alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); + memcpy(new_data, _data, new_size * _aligned_dim * sizeof(data_t)); + aligned_free(_data); + _data = new_data; #else - realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), - 8 * sizeof(data_t)); + realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); #endif - this->_capacity = new_size; - return this->_capacity; + this->_capacity = new_size; + return this->_capacity; } template -void InMemDataStore::move_vectors(const location_t old_location_start, - const location_t new_location_start, - const location_t num_locations) { - if (num_locations == 0 || old_location_start == new_location_start) { - return; - } - - /* // Update pointers to the moved nodes. Note: the computation is correct - even - // when new_location_start < old_location_start given the C++ uint32_t - // integer arithmetic rules. - const uint32_t location_delta = new_location_start - old_location_start; - */ - // The [start, end) interval which will contain obsolete points to be - // cleared. - uint32_t mem_clear_loc_start = old_location_start; - uint32_t mem_clear_loc_end_limit = old_location_start + num_locations; - - if (new_location_start < old_location_start) { - // If ranges are overlapping, make sure not to clear the newly copied - // data. - if (mem_clear_loc_start < new_location_start + num_locations) { - // Clear only after the end of the new range. - mem_clear_loc_start = new_location_start + num_locations; +void InMemDataStore::move_vectors(const location_t old_location_start, const location_t new_location_start, + const location_t num_locations) +{ + if (num_locations == 0 || old_location_start == new_location_start) + { + return; } - } else { - // If ranges are overlapping, make sure not to clear the newly copied - // data. - if (mem_clear_loc_end_limit > new_location_start) { - // Clear only up to the beginning of the new range. - mem_clear_loc_end_limit = new_location_start; + + /* // Update pointers to the moved nodes. Note: the computation is correct + even + // when new_location_start < old_location_start given the C++ uint32_t + // integer arithmetic rules. + const uint32_t location_delta = new_location_start - old_location_start; + */ + // The [start, end) interval which will contain obsolete points to be + // cleared. + uint32_t mem_clear_loc_start = old_location_start; + uint32_t mem_clear_loc_end_limit = old_location_start + num_locations; + + if (new_location_start < old_location_start) + { + // If ranges are overlapping, make sure not to clear the newly copied + // data. + if (mem_clear_loc_start < new_location_start + num_locations) + { + // Clear only after the end of the new range. + mem_clear_loc_start = new_location_start + num_locations; + } + } + else + { + // If ranges are overlapping, make sure not to clear the newly copied + // data. + if (mem_clear_loc_end_limit > new_location_start) + { + // Clear only up to the beginning of the new range. + mem_clear_loc_end_limit = new_location_start; + } } - } - // Use memmove to handle overlapping ranges. - copy_vectors(old_location_start, new_location_start, num_locations); - memset(_data + _aligned_dim * mem_clear_loc_start, 0, - sizeof(data_t) * _aligned_dim * - (mem_clear_loc_end_limit - mem_clear_loc_start)); + // Use memmove to handle overlapping ranges. + copy_vectors(old_location_start, new_location_start, num_locations); + memset(_data + _aligned_dim * mem_clear_loc_start, 0, + sizeof(data_t) * _aligned_dim * (mem_clear_loc_end_limit - mem_clear_loc_start)); } template -void InMemDataStore::copy_vectors(const location_t from_loc, - const location_t to_loc, - const location_t num_points) { - assert(from_loc < this->_capacity); - assert(to_loc < this->_capacity); - assert(num_points < this->_capacity); - memmove(_data + _aligned_dim * to_loc, _data + _aligned_dim * from_loc, - num_points * _aligned_dim * sizeof(data_t)); +void InMemDataStore::copy_vectors(const location_t from_loc, const location_t to_loc, + const location_t num_points) +{ + assert(from_loc < this->_capacity); + assert(to_loc < this->_capacity); + assert(num_points < this->_capacity); + memmove(_data + _aligned_dim * to_loc, _data + _aligned_dim * from_loc, num_points * _aligned_dim * sizeof(data_t)); } -template -location_t InMemDataStore::calculate_medoid() const { - // allocate and init centroid - float *center = new float[_aligned_dim]; - for (size_t j = 0; j < _aligned_dim; j++) - center[j] = 0; +template location_t InMemDataStore::calculate_medoid() const +{ + // allocate and init centroid + float *center = new float[_aligned_dim]; + for (size_t j = 0; j < _aligned_dim; j++) + center[j] = 0; + + for (size_t i = 0; i < this->capacity(); i++) + for (size_t j = 0; j < _aligned_dim; j++) + center[j] += (float)_data[i * _aligned_dim + j]; - for (size_t i = 0; i < this->capacity(); i++) for (size_t j = 0; j < _aligned_dim; j++) - center[j] += (float)_data[i * _aligned_dim + j]; - - for (size_t j = 0; j < _aligned_dim; j++) - center[j] /= (float)this->capacity(); - - // compute all to one distance - float *distances = new float[this->capacity()]; - - // TODO: REFACTOR. Removing pragma might make this slow. Must revisit. - // Problem is that we need to pass num_threads here, it is not clear - // if data store must be aware of threads! - // #pragma omp parallel for schedule(static, 65536) - for (int64_t i = 0; i < (int64_t)this->capacity(); i++) { - // extract point and distance reference - float &dist = distances[i]; - const data_t *cur_vec = _data + (i * (size_t)_aligned_dim); - dist = 0; - float diff = 0; - for (size_t j = 0; j < _aligned_dim; j++) { - diff = (center[j] - (float)cur_vec[j]) * (center[j] - (float)cur_vec[j]); - dist += diff; + center[j] /= (float)this->capacity(); + + // compute all to one distance + float *distances = new float[this->capacity()]; + + // TODO: REFACTOR. Removing pragma might make this slow. Must revisit. + // Problem is that we need to pass num_threads here, it is not clear + // if data store must be aware of threads! + // #pragma omp parallel for schedule(static, 65536) + for (int64_t i = 0; i < (int64_t)this->capacity(); i++) + { + // extract point and distance reference + float &dist = distances[i]; + const data_t *cur_vec = _data + (i * (size_t)_aligned_dim); + dist = 0; + float diff = 0; + for (size_t j = 0; j < _aligned_dim; j++) + { + diff = (center[j] - (float)cur_vec[j]) * (center[j] - (float)cur_vec[j]); + dist += diff; + } } - } - // find imin - uint32_t min_idx = 0; - float min_dist = distances[0]; - for (uint32_t i = 1; i < this->capacity(); i++) { - if (distances[i] < min_dist) { - min_idx = i; - min_dist = distances[i]; + // find imin + uint32_t min_idx = 0; + float min_dist = distances[0]; + for (uint32_t i = 1; i < this->capacity(); i++) + { + if (distances[i] < min_dist) + { + min_idx = i; + min_dist = distances[i]; + } } - } - delete[] distances; - delete[] center; - return min_idx; + delete[] distances; + delete[] center; + return min_idx; } -template -Distance *InMemDataStore::get_dist_fn() const { - return this->_distance_fn.get(); +template Distance *InMemDataStore::get_dist_fn() const +{ + return this->_distance_fn.get(); } template DISKANN_DLLEXPORT class InMemDataStore; diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp index d5f299063..de581d99d 100644 --- a/src/in_mem_filter_store.cpp +++ b/src/in_mem_filter_store.cpp @@ -12,382 +12,395 @@ #include #include -namespace diskann { +namespace diskann +{ // TODO: Move to utils.h -DISKANN_DLLEXPORT std::unique_ptr -get_file_content(const std::string &filename, uint64_t &file_size); - -template InMemFilterStore::~InMemFilterStore() { - if (_pts_to_label_offsets != nullptr) { - delete[] _pts_to_label_offsets; - _pts_to_label_offsets = nullptr; - } - if (_pts_to_label_counts != nullptr) { - delete[] _pts_to_label_counts; - _pts_to_label_counts = nullptr; - } - if (_pts_to_labels != nullptr) { - delete[] _pts_to_labels; - _pts_to_labels = nullptr; - } +DISKANN_DLLEXPORT std::unique_ptr get_file_content(const std::string &filename, uint64_t &file_size); + +template InMemFilterStore::~InMemFilterStore() +{ + if (_pts_to_label_offsets != nullptr) + { + delete[] _pts_to_label_offsets; + _pts_to_label_offsets = nullptr; + } + if (_pts_to_label_counts != nullptr) + { + delete[] _pts_to_label_counts; + _pts_to_label_counts = nullptr; + } + if (_pts_to_labels != nullptr) + { + delete[] _pts_to_labels; + _pts_to_labels = nullptr; + } } template -const std::unordered_map> & -InMemFilterStore::get_label_to_medoids() const { - return this->_filter_to_medoid_ids; +const std::unordered_map> &InMemFilterStore::get_label_to_medoids() const +{ + return this->_filter_to_medoid_ids; } template -const std::vector & -InMemFilterStore::get_medoids_of_label(const LabelT label) { - if (_filter_to_medoid_ids.find(label) != _filter_to_medoid_ids.end()) { - return this->_filter_to_medoid_ids[label]; - } else { - std::stringstream ss; - ss << "Could not find " << label << " in filters_to_medoid_ids map." - << std::endl; - diskann::cerr << ss.str(); - throw ANNException(ss.str(), -1); - } +const std::vector &InMemFilterStore::get_medoids_of_label(const LabelT label) +{ + if (_filter_to_medoid_ids.find(label) != _filter_to_medoid_ids.end()) + { + return this->_filter_to_medoid_ids[label]; + } + else + { + std::stringstream ss; + ss << "Could not find " << label << " in filters_to_medoid_ids map." << std::endl; + diskann::cerr << ss.str(); + throw ANNException(ss.str(), -1); + } } -template -void InMemFilterStore::set_universal_label(const LabelT univ_label) { - _universal_filter_label = univ_label; - _use_universal_label = true; +template void InMemFilterStore::set_universal_label(const LabelT univ_label) +{ + _universal_filter_label = univ_label; + _use_universal_label = true; } // Load functions for SEARCH START -template -bool InMemFilterStore::load(const std::string &disk_index_file) { - std::string labels_file = disk_index_file + "_labels.txt"; - std::string labels_to_medoids = disk_index_file + "_labels_to_medoids.txt"; - std::string dummy_map_file = disk_index_file + "_dummy_map.txt"; - std::string labels_map_file = disk_index_file + "_labels_map.txt"; - std::string univ_label_file = disk_index_file + "_universal_label.txt"; - - size_t num_pts_in_label_file = 0; - - // TODO: Check for encoding issues here. We are opening files as binary and - // reading them as bytes, not sure if that can cause an issue with UTF - // encodings. - bool has_filters = true; - if (false == load_file_and_parse( - labels_file, &InMemFilterStore::load_label_file)) { - diskann::cout << "Index does not have filter data. " << std::endl; - return false; - } - if (false == parse_stream(labels_map_file, - &InMemFilterStore::load_label_map)) { - diskann::cerr << "Failed to find file: " << labels_map_file - << " while labels_file exists." << std::endl; - return false; - } - - if (false == - parse_stream(labels_to_medoids, - &InMemFilterStore::load_labels_to_medoids)) { - diskann::cerr << "Failed to find file: " << labels_to_medoids - << " while labels file exists." << std::endl; - return false; - } - // missing universal label file is NOT an error. - load_file_and_parse(univ_label_file, - &InMemFilterStore::parse_universal_label); - - // missing dummy map file is also NOT an error. - parse_stream(dummy_map_file, &InMemFilterStore::load_dummy_map); - _is_valid = true; - return _is_valid; -} - -template -bool InMemFilterStore::has_filter_support() const { - return _is_valid; -} - -// TODO: Improve this to not load the entire file in memory -template -void InMemFilterStore::load_label_file( - const std::string_view &label_file_content) { - std::string line; - uint32_t line_cnt = 0; - - uint32_t num_pts_in_label_file; - uint32_t num_total_labels; - get_label_file_metadata(label_file_content, num_pts_in_label_file, - num_total_labels); - - _num_points = num_pts_in_label_file; - - _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; - _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; - _pts_to_labels = new LabelT[num_total_labels]; - uint32_t labels_seen_so_far = 0; - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - size_t file_size = label_file_content.size(); - - while (cur_pos < file_size && cur_pos != std::string_view::npos) { - next_pos = label_file_content.find('\n', cur_pos); - if (next_pos == std::string_view::npos) { - break; +template bool InMemFilterStore::load(const std::string &disk_index_file) +{ + std::string labels_file = disk_index_file + "_labels.txt"; + std::string labels_to_medoids = disk_index_file + "_labels_to_medoids.txt"; + std::string dummy_map_file = disk_index_file + "_dummy_map.txt"; + std::string labels_map_file = disk_index_file + "_labels_map.txt"; + std::string univ_label_file = disk_index_file + "_universal_label.txt"; + + size_t num_pts_in_label_file = 0; + + // TODO: Check for encoding issues here. We are opening files as binary and + // reading them as bytes, not sure if that can cause an issue with UTF + // encodings. + bool has_filters = true; + if (false == load_file_and_parse(labels_file, &InMemFilterStore::load_label_file)) + { + diskann::cout << "Index does not have filter data. " << std::endl; + return false; } - - _pts_to_label_offsets[line_cnt] = labels_seen_so_far; - uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; - num_lbls_in_cur_pt = 0; - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string_view::npos) { - next_lbl_pos = label_file_content.find(',', lbl_pos); - if (next_lbl_pos == - std::string_view::npos) // the last label in the whole file - { - next_lbl_pos = next_pos; - } - - if (next_lbl_pos > - next_pos) // the last label in one line, just read to the end - { - next_lbl_pos = next_pos; - } - - // TODO: SHOULD NOT EXPECT label_file_content TO BE NULL_TERMINATED - label_str.assign(label_file_content.data() + lbl_pos, - next_lbl_pos - lbl_pos); - if (label_str[label_str.length() - 1] == - '\t') // '\t' won't exist in label file? - { - label_str.erase(label_str.length() - 1); - } - - LabelT token_as_num = (LabelT)std::stoul(label_str); - _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; - num_lbls_in_cur_pt++; - - // move to next label - lbl_pos = next_lbl_pos + 1; + if (false == parse_stream(labels_map_file, &InMemFilterStore::load_label_map)) + { + diskann::cerr << "Failed to find file: " << labels_map_file << " while labels_file exists." << std::endl; + return false; } - // move to next line - cur_pos = next_pos + 1; - - if (num_lbls_in_cur_pt == 0) { - diskann::cout << "No label found for point " << line_cnt << std::endl; - exit(-1); + if (false == parse_stream(labels_to_medoids, &InMemFilterStore::load_labels_to_medoids)) + { + diskann::cerr << "Failed to find file: " << labels_to_medoids << " while labels file exists." << std::endl; + return false; } + // missing universal label file is NOT an error. + load_file_and_parse(univ_label_file, &InMemFilterStore::parse_universal_label); - line_cnt++; - } + // missing dummy map file is also NOT an error. + parse_stream(dummy_map_file, &InMemFilterStore::load_dummy_map); + _is_valid = true; + return _is_valid; +} - // TODO: We need to check if the number of labels and the number of points - // is as expected. Maybe add the check in PQFlashIndex? - // num_points_labels = line_cnt; +template bool InMemFilterStore::has_filter_support() const +{ + return _is_valid; } -template -void InMemFilterStore::load_labels_to_medoids( - std::basic_istream &medoid_stream) { - std::string line, token; - - _filter_to_medoid_ids.clear(); - while (std::getline(medoid_stream, line)) { - std::istringstream iss(line); - uint32_t cnt = 0; - std::vector medoids; - LabelT label; - while (std::getline(iss, token, ',')) { - if (cnt == 0) - label = (LabelT)std::stoul(token); - else - medoids.push_back((uint32_t)stoul(token)); - cnt++; +// TODO: Improve this to not load the entire file in memory +template void InMemFilterStore::load_label_file(const std::string_view &label_file_content) +{ + std::string line; + uint32_t line_cnt = 0; + + uint32_t num_pts_in_label_file; + uint32_t num_total_labels; + get_label_file_metadata(label_file_content, num_pts_in_label_file, num_total_labels); + + _num_points = num_pts_in_label_file; + + _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; + _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; + _pts_to_labels = new LabelT[num_total_labels]; + uint32_t labels_seen_so_far = 0; + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + size_t file_size = label_file_content.size(); + + while (cur_pos < file_size && cur_pos != std::string_view::npos) + { + next_pos = label_file_content.find('\n', cur_pos); + if (next_pos == std::string_view::npos) + { + break; + } + + _pts_to_label_offsets[line_cnt] = labels_seen_so_far; + uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; + num_lbls_in_cur_pt = 0; + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string_view::npos) + { + next_lbl_pos = label_file_content.find(',', lbl_pos); + if (next_lbl_pos == std::string_view::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + // TODO: SHOULD NOT EXPECT label_file_content TO BE NULL_TERMINATED + label_str.assign(label_file_content.data() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); + _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; + num_lbls_in_cur_pt++; + + // move to next label + lbl_pos = next_lbl_pos + 1; + } + + // move to next line + cur_pos = next_pos + 1; + + if (num_lbls_in_cur_pt == 0) + { + diskann::cout << "No label found for point " << line_cnt << std::endl; + exit(-1); + } + + line_cnt++; } - _filter_to_medoid_ids[label].swap(medoids); - } -} -template -void InMemFilterStore::load_label_map( - std::basic_istream &map_reader) { - std::string line, token; - LabelT token_as_num; - std::string label_str; - while (std::getline(map_reader, line)) { - std::istringstream iss(line); - getline(iss, token, '\t'); - label_str = token; - getline(iss, token, '\t'); - token_as_num = (LabelT)std::stoul(token); - _label_map[label_str] = token_as_num; - } + // TODO: We need to check if the number of labels and the number of points + // is as expected. Maybe add the check in PQFlashIndex? + // num_points_labels = line_cnt; } template -void InMemFilterStore::parse_universal_label( - const std::string_view &content) { - LabelT label_as_num = (LabelT)std::stoul(std::string(content)); - this->set_universal_label(label_as_num); +void InMemFilterStore::load_labels_to_medoids(std::basic_istream &medoid_stream) +{ + std::string line, token; + + _filter_to_medoid_ids.clear(); + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + std::vector medoids; + LabelT label; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + label = (LabelT)std::stoul(token); + else + medoids.push_back((uint32_t)stoul(token)); + cnt++; + } + _filter_to_medoid_ids[label].swap(medoids); + } } -template -void InMemFilterStore::load_dummy_map( - std::basic_istream &dummy_map_stream) { - std::string line, token; - - while (std::getline(dummy_map_stream, line)) { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) { - if (cnt == 0) - dummy_id = (uint32_t)stoul(token); - else - real_id = (uint32_t)stoul(token); - cnt++; +template void InMemFilterStore::load_label_map(std::basic_istream &map_reader) +{ + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = (LabelT)std::stoul(token); + _label_map[label_str] = token_as_num; } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; +} - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); +template void InMemFilterStore::parse_universal_label(const std::string_view &content) +{ + LabelT label_as_num = (LabelT)std::stoul(std::string(content)); + this->set_universal_label(label_as_num); +} - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } - diskann::cout << "Loaded dummy map" << std::endl; +template void InMemFilterStore::load_dummy_map(std::basic_istream &dummy_map_stream) +{ + std::string line, token; + + while (std::getline(dummy_map_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + dummy_id = (uint32_t)stoul(token); + else + real_id = (uint32_t)stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; + + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector(); + + _real_to_dummy_map[real_id].emplace_back(dummy_id); + } + diskann::cout << "Loaded dummy map" << std::endl; } template -void InMemFilterStore::generate_random_labels( - std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads) { - std::random_device rd; - labels.clear(); - labels.resize(num_labels); - - uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + - _pts_to_label_counts[_num_points - 1]; - std::mt19937 gen(rd()); - if (num_total_labels == 0) { - std::stringstream stream; - stream << "No labels found in data. Not sampling random labels "; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - std::uniform_int_distribution dis(0, num_total_labels - 1); +void InMemFilterStore::generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) +{ + std::random_device rd; + labels.clear(); + labels.resize(num_labels); + + uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; + std::mt19937 gen(rd()); + if (num_total_labels == 0) + { + std::stringstream stream; + stream << "No labels found in data. Not sampling random labels "; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + std::uniform_int_distribution dis(0, num_total_labels - 1); #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < num_labels; i++) { - uint64_t rnd_loc = dis(gen); - labels[i] = (LabelT)_pts_to_labels[rnd_loc]; - } + for (int64_t i = 0; i < num_labels; i++) + { + uint64_t rnd_loc = dis(gen); + labels[i] = (LabelT)_pts_to_labels[rnd_loc]; + } } -template -void InMemFilterStore::reset_stream_for_reading( - std::basic_istream &infile) { - infile.clear(); - infile.seekg(0); +template void InMemFilterStore::reset_stream_for_reading(std::basic_istream &infile) +{ + infile.clear(); + infile.seekg(0); } template -void InMemFilterStore::get_label_file_metadata( - const std::string_view &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels) { - num_pts = 0; - num_total_labels = 0; - - size_t file_size = fileContent.length(); - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) { - next_pos = fileContent.find('\n', cur_pos); - if (next_pos == std::string::npos) { - break; - } - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - next_lbl_pos = fileContent.find(',', lbl_pos); - if (next_lbl_pos == std::string::npos) // the last label - { - next_lbl_pos = next_pos; - } - - num_total_labels++; - - lbl_pos = next_lbl_pos + 1; +void InMemFilterStore::get_label_file_metadata(const std::string_view &fileContent, uint32_t &num_pts, + uint32_t &num_total_labels) +{ + num_pts = 0; + num_total_labels = 0; + + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; + } + + cur_pos = next_pos + 1; + + num_pts++; } - cur_pos = next_pos + 1; - - num_pts++; - } - - diskann::cout << "Labels file metadata: num_points: " << num_pts - << ", #total_labels: " << num_total_labels << std::endl; + diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels + << std::endl; } template -bool InMemFilterStore::parse_stream( - const std::string &filename, - void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)) { - if (file_exists(filename)) { - std::ifstream stream(filename); - if (false == stream.fail()) { - std::invoke(parse_fn, this, stream); - return true; - } else { - std::stringstream ss; - ss << "Could not open file: " << filename << std::endl; - throw diskann::ANNException(ss.str(), -1); +bool InMemFilterStore::parse_stream(const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)) +{ + if (file_exists(filename)) + { + std::ifstream stream(filename); + if (false == stream.fail()) + { + std::invoke(parse_fn, this, stream); + return true; + } + else + { + std::stringstream ss; + ss << "Could not open file: " << filename << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + } + else + { + return false; } - } else { - return false; - } } template -bool InMemFilterStore::load_file_and_parse( - const std::string &filename, - void (InMemFilterStore::*parse_fn)(const std::string_view &content)) { - if (file_exists(filename)) { - size_t file_size = 0; - auto file_content_ptr = get_file_content(filename, file_size); - std::string_view content_as_str(file_content_ptr.get(), file_size); - std::invoke(parse_fn, this, content_as_str); - return true; - } else { - return false; - } +bool InMemFilterStore::load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)) +{ + if (file_exists(filename)) + { + size_t file_size = 0; + auto file_content_ptr = get_file_content(filename, file_size); + std::string_view content_as_str(file_content_ptr.get(), file_size); + std::invoke(parse_fn, this, content_as_str); + return true; + } + else + { + return false; + } } -std::unique_ptr get_file_content(const std::string &filename, - uint64_t &file_size) { - std::ifstream infile(filename, std::ios::binary); - if (infile.fail()) { - throw diskann::ANNException(std::string("Failed to open file ") + filename, - -1); - } - infile.seekg(0, std::ios::end); - file_size = infile.tellg(); - - auto buffer = new char[file_size]; - infile.seekg(0, std::ios::beg); - infile.read(buffer, file_size); - - return std::unique_ptr(buffer); +std::unique_ptr get_file_content(const std::string &filename, uint64_t &file_size) +{ + std::ifstream infile(filename, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + infile.seekg(0, std::ios::end); + file_size = infile.tellg(); + + auto buffer = new char[file_size]; + infile.seekg(0, std::ios::beg); + infile.read(buffer, file_size); + + return std::unique_ptr(buffer); } // Load functions for SEARCH END template class InMemFilterStore; diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp index e99e500ec..c12b2514e 100644 --- a/src/in_mem_graph_store.cpp +++ b/src/in_mem_graph_store.cpp @@ -4,227 +4,239 @@ #include "in_mem_graph_store.h" #include "utils.h" -namespace diskann { -InMemGraphStore::InMemGraphStore(const size_t total_pts, - const size_t reserve_graph_degree) - : AbstractGraphStore(total_pts, reserve_graph_degree) { - this->resize_graph(total_pts); - for (size_t i = 0; i < total_pts; i++) { - _graph[i].reserve(reserve_graph_degree); - } +namespace diskann +{ +InMemGraphStore::InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree) + : AbstractGraphStore(total_pts, reserve_graph_degree) +{ + this->resize_graph(total_pts); + for (size_t i = 0; i < total_pts; i++) + { + _graph[i].reserve(reserve_graph_degree); + } } -std::tuple -InMemGraphStore::load(const std::string &index_path_prefix, - const size_t num_points) { - return load_impl(index_path_prefix, num_points); +std::tuple InMemGraphStore::load(const std::string &index_path_prefix, + const size_t num_points) +{ + return load_impl(index_path_prefix, num_points); } -int InMemGraphStore::store(const std::string &index_path_prefix, - const size_t num_points, - const size_t num_frozen_points, - const uint32_t start) { - return save_graph(index_path_prefix, num_points, num_frozen_points, start); +int InMemGraphStore::store(const std::string &index_path_prefix, const size_t num_points, + const size_t num_frozen_points, const uint32_t start) +{ + return save_graph(index_path_prefix, num_points, num_frozen_points, start); } -const std::vector & -InMemGraphStore::get_neighbours(const location_t i) const { - return _graph.at(i); +const std::vector &InMemGraphStore::get_neighbours(const location_t i) const +{ + return _graph.at(i); } -void InMemGraphStore::add_neighbour(const location_t i, - location_t neighbour_id) { - _graph[i].emplace_back(neighbour_id); - if (_max_observed_degree < _graph[i].size()) { - _max_observed_degree = (uint32_t)(_graph[i].size()); - } +void InMemGraphStore::add_neighbour(const location_t i, location_t neighbour_id) +{ + _graph[i].emplace_back(neighbour_id); + if (_max_observed_degree < _graph[i].size()) + { + _max_observed_degree = (uint32_t)(_graph[i].size()); + } } -void InMemGraphStore::clear_neighbours(const location_t i) { - _graph[i].clear(); +void InMemGraphStore::clear_neighbours(const location_t i) +{ + _graph[i].clear(); }; -void InMemGraphStore::swap_neighbours(const location_t a, location_t b) { - _graph[a].swap(_graph[b]); +void InMemGraphStore::swap_neighbours(const location_t a, location_t b) +{ + _graph[a].swap(_graph[b]); }; -void InMemGraphStore::set_neighbours(const location_t i, - std::vector &neighbours) { - _graph[i].assign(neighbours.begin(), neighbours.end()); - if (_max_observed_degree < neighbours.size()) { - _max_observed_degree = (uint32_t)(neighbours.size()); - } +void InMemGraphStore::set_neighbours(const location_t i, std::vector &neighbours) +{ + _graph[i].assign(neighbours.begin(), neighbours.end()); + if (_max_observed_degree < neighbours.size()) + { + _max_observed_degree = (uint32_t)(neighbours.size()); + } } -size_t InMemGraphStore::resize_graph(const size_t new_size) { - _graph.resize(new_size); - set_total_points(new_size); - return _graph.size(); +size_t InMemGraphStore::resize_graph(const size_t new_size) +{ + _graph.resize(new_size); + set_total_points(new_size); + return _graph.size(); } -void InMemGraphStore::clear_graph() { _graph.clear(); } +void InMemGraphStore::clear_graph() +{ + _graph.clear(); +} #ifdef EXEC_ENV_OLS -std::tuple -InMemGraphStore::load_impl(AlignedFileReader &reader, - size_t expected_num_points) { - size_t expected_file_size; - size_t file_frozen_pts; - uint32_t start; - - auto max_points = get_max_points(); - int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t); - std::unique_ptr header = std::make_unique(header_size); - read_array(reader, header.get(), header_size); - - expected_file_size = *((size_t *)header.get()); - _max_observed_degree = *((uint32_t *)(header.get() + sizeof(size_t))); - start = *((uint32_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t))); - file_frozen_pts = *((size_t *)(header.get() + sizeof(size_t) + - sizeof(uint32_t) + sizeof(uint32_t))); - - diskann::cout << "From graph header, expected_file_size: " - << expected_file_size - << ", _max_observed_degree: " << _max_observed_degree - << ", _start: " << start - << ", file_frozen_pts: " << file_frozen_pts << std::endl; - - diskann::cout << "Loading vamana graph from reader..." << std::flush; - - // If user provides more points than max_points - // resize the _graph to the larger size. - if (get_total_points() < expected_num_points) { - diskann::cout << "resizing graph to " << expected_num_points << std::endl; - this->resize_graph(expected_num_points); - } - - uint32_t nodes_read = 0; - size_t cc = 0; - size_t graph_offset = header_size; - while (nodes_read < expected_num_points) { - uint32_t k; - read_value(reader, k, graph_offset); - graph_offset += sizeof(uint32_t); - std::vector tmp(k); - tmp.reserve(k); - read_array(reader, tmp.data(), k, graph_offset); - graph_offset += k * sizeof(uint32_t); - cc += k; - _graph[nodes_read].swap(tmp); - nodes_read++; - if (nodes_read % 1000000 == 0) { - diskann::cout << "." << std::flush; +std::tuple InMemGraphStore::load_impl(AlignedFileReader &reader, size_t expected_num_points) +{ + size_t expected_file_size; + size_t file_frozen_pts; + uint32_t start; + + auto max_points = get_max_points(); + int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t); + std::unique_ptr header = std::make_unique(header_size); + read_array(reader, header.get(), header_size); + + expected_file_size = *((size_t *)header.get()); + _max_observed_degree = *((uint32_t *)(header.get() + sizeof(size_t))); + start = *((uint32_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t))); + file_frozen_pts = *((size_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t))); + + diskann::cout << "From graph header, expected_file_size: " << expected_file_size + << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start + << ", file_frozen_pts: " << file_frozen_pts << std::endl; + + diskann::cout << "Loading vamana graph from reader..." << std::flush; + + // If user provides more points than max_points + // resize the _graph to the larger size. + if (get_total_points() < expected_num_points) + { + diskann::cout << "resizing graph to " << expected_num_points << std::endl; + this->resize_graph(expected_num_points); } - if (k > _max_range_of_graph) { - _max_range_of_graph = k; + + uint32_t nodes_read = 0; + size_t cc = 0; + size_t graph_offset = header_size; + while (nodes_read < expected_num_points) + { + uint32_t k; + read_value(reader, k, graph_offset); + graph_offset += sizeof(uint32_t); + std::vector tmp(k); + tmp.reserve(k); + read_array(reader, tmp.data(), k, graph_offset); + graph_offset += k * sizeof(uint32_t); + cc += k; + _graph[nodes_read].swap(tmp); + nodes_read++; + if (nodes_read % 1000000 == 0) + { + diskann::cout << "." << std::flush; + } + if (k > _max_range_of_graph) + { + _max_range_of_graph = k; + } } - } - diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc - << " out-edges, _start is set to " << start << std::endl; - return std::make_tuple(nodes_read, start, file_frozen_pts); + diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc << " out-edges, _start is set to " << start + << std::endl; + return std::make_tuple(nodes_read, start, file_frozen_pts); } #endif -std::tuple -InMemGraphStore::load_impl(const std::string &filename, - size_t expected_num_points) { - size_t expected_file_size; - size_t file_frozen_pts; - uint32_t start; - size_t file_offset = 0; // will need this for single file format support - - std::ifstream in; - in.exceptions(std::ios::badbit | std::ios::failbit); - in.open(filename, std::ios::binary); - in.seekg(file_offset, in.beg); - in.read((char *)&expected_file_size, sizeof(size_t)); - in.read((char *)&_max_observed_degree, sizeof(uint32_t)); - in.read((char *)&start, sizeof(uint32_t)); - in.read((char *)&file_frozen_pts, sizeof(size_t)); - size_t vamana_metadata_size = - sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(size_t); - - diskann::cout << "From graph header, expected_file_size: " - << expected_file_size - << ", _max_observed_degree: " << _max_observed_degree - << ", _start: " << start - << ", file_frozen_pts: " << file_frozen_pts << std::endl; - - diskann::cout << "Loading vamana graph " << filename << "..." << std::flush; - - // If user provides more points than max_points - // resize the _graph to the larger size. - if (get_total_points() < expected_num_points) { - diskann::cout << "resizing graph to " << expected_num_points << std::endl; - this->resize_graph(expected_num_points); - } - - size_t bytes_read = vamana_metadata_size; - size_t cc = 0; - uint32_t nodes_read = 0; - while (bytes_read != expected_file_size) { - uint32_t k; - in.read((char *)&k, sizeof(uint32_t)); - - if (k == 0) { - diskann::cerr << "ERROR: Point found with no out-neighbours, point#" - << nodes_read << std::endl; +std::tuple InMemGraphStore::load_impl(const std::string &filename, + size_t expected_num_points) +{ + size_t expected_file_size; + size_t file_frozen_pts; + uint32_t start; + size_t file_offset = 0; // will need this for single file format support + + std::ifstream in; + in.exceptions(std::ios::badbit | std::ios::failbit); + in.open(filename, std::ios::binary); + in.seekg(file_offset, in.beg); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&_max_observed_degree, sizeof(uint32_t)); + in.read((char *)&start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); + size_t vamana_metadata_size = sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(size_t); + + diskann::cout << "From graph header, expected_file_size: " << expected_file_size + << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start + << ", file_frozen_pts: " << file_frozen_pts << std::endl; + + diskann::cout << "Loading vamana graph " << filename << "..." << std::flush; + + // If user provides more points than max_points + // resize the _graph to the larger size. + if (get_total_points() < expected_num_points) + { + diskann::cout << "resizing graph to " << expected_num_points << std::endl; + this->resize_graph(expected_num_points); } - cc += k; - ++nodes_read; - std::vector tmp(k); - tmp.reserve(k); - in.read((char *)tmp.data(), k * sizeof(uint32_t)); - _graph[nodes_read - 1].swap(tmp); - bytes_read += sizeof(uint32_t) * ((size_t)k + 1); - if (nodes_read % 10000000 == 0) - diskann::cout << "." << std::flush; - if (k > _max_range_of_graph) { - _max_range_of_graph = k; + size_t bytes_read = vamana_metadata_size; + size_t cc = 0; + uint32_t nodes_read = 0; + while (bytes_read != expected_file_size) + { + uint32_t k; + in.read((char *)&k, sizeof(uint32_t)); + + if (k == 0) + { + diskann::cerr << "ERROR: Point found with no out-neighbours, point#" << nodes_read << std::endl; + } + + cc += k; + ++nodes_read; + std::vector tmp(k); + tmp.reserve(k); + in.read((char *)tmp.data(), k * sizeof(uint32_t)); + _graph[nodes_read - 1].swap(tmp); + bytes_read += sizeof(uint32_t) * ((size_t)k + 1); + if (nodes_read % 10000000 == 0) + diskann::cout << "." << std::flush; + if (k > _max_range_of_graph) + { + _max_range_of_graph = k; + } } - } - diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc - << " out-edges, _start is set to " << start << std::endl; - return std::make_tuple(nodes_read, start, file_frozen_pts); + diskann::cout << "done. Index has " << nodes_read << " nodes and " << cc << " out-edges, _start is set to " << start + << std::endl; + return std::make_tuple(nodes_read, start, file_frozen_pts); } -int InMemGraphStore::save_graph(const std::string &index_path_prefix, - const size_t num_points, - const size_t num_frozen_points, - const uint32_t start) { - std::ofstream out; - open_file_to_write(out, index_path_prefix); - - size_t file_offset = 0; - out.seekp(file_offset, out.beg); - size_t index_size = 24; - uint32_t max_degree = 0; - out.write((char *)&index_size, sizeof(uint64_t)); - out.write((char *)&_max_observed_degree, sizeof(uint32_t)); - uint32_t ep_u32 = start; - out.write((char *)&ep_u32, sizeof(uint32_t)); - out.write((char *)&num_frozen_points, sizeof(size_t)); - - // Note: num_points = _nd + _num_frozen_points - for (uint32_t i = 0; i < num_points; i++) { - uint32_t GK = (uint32_t)_graph[i].size(); - out.write((char *)&GK, sizeof(uint32_t)); - out.write((char *)_graph[i].data(), GK * sizeof(uint32_t)); - max_degree = - _graph[i].size() > max_degree ? (uint32_t)_graph[i].size() : max_degree; - index_size += (size_t)(sizeof(uint32_t) * (GK + 1)); - } - out.seekp(file_offset, out.beg); - out.write((char *)&index_size, sizeof(uint64_t)); - out.write((char *)&max_degree, sizeof(uint32_t)); - out.close(); - return (int)index_size; +int InMemGraphStore::save_graph(const std::string &index_path_prefix, const size_t num_points, + const size_t num_frozen_points, const uint32_t start) +{ + std::ofstream out; + open_file_to_write(out, index_path_prefix); + + size_t file_offset = 0; + out.seekp(file_offset, out.beg); + size_t index_size = 24; + uint32_t max_degree = 0; + out.write((char *)&index_size, sizeof(uint64_t)); + out.write((char *)&_max_observed_degree, sizeof(uint32_t)); + uint32_t ep_u32 = start; + out.write((char *)&ep_u32, sizeof(uint32_t)); + out.write((char *)&num_frozen_points, sizeof(size_t)); + + // Note: num_points = _nd + _num_frozen_points + for (uint32_t i = 0; i < num_points; i++) + { + uint32_t GK = (uint32_t)_graph[i].size(); + out.write((char *)&GK, sizeof(uint32_t)); + out.write((char *)_graph[i].data(), GK * sizeof(uint32_t)); + max_degree = _graph[i].size() > max_degree ? (uint32_t)_graph[i].size() : max_degree; + index_size += (size_t)(sizeof(uint32_t) * (GK + 1)); + } + out.seekp(file_offset, out.beg); + out.write((char *)&index_size, sizeof(uint64_t)); + out.write((char *)&max_degree, sizeof(uint32_t)); + out.close(); + return (int)index_size; } -size_t InMemGraphStore::get_max_range_of_graph() { return _max_range_of_graph; } +size_t InMemGraphStore::get_max_range_of_graph() +{ + return _max_range_of_graph; +} -uint32_t InMemGraphStore::get_max_observed_degree() { - return _max_observed_degree; +uint32_t InMemGraphStore::get_max_observed_degree() +{ + return _max_observed_degree; } } // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index dd08bdc16..d02488688 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -13,8 +13,7 @@ #include "windows_customizations.h" #include #include -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -26,3149 +25,3307 @@ #define MAX_POINTS_FOR_USING_BITSET 10000000 -namespace diskann { +namespace diskann +{ // Initialize an index with metric m, load the data of type T with filename // (bin), and initialize max_points template -Index::Index( - const IndexConfig &index_config, - std::shared_ptr> data_store, - std::unique_ptr graph_store, - std::shared_ptr> pq_data_store) - : _dist_metric(index_config.metric), _dim(index_config.dimension), - _max_points(index_config.max_points), - _num_frozen_pts(index_config.num_frozen_pts), - _dynamic_index(index_config.dynamic_index), - _enable_tags(index_config.enable_tags), _indexingMaxC(DEFAULT_MAXC), - _query_scratch(nullptr), _pq_dist(index_config.pq_dist_build), - _use_opq(index_config.use_opq), - _filtered_index(index_config.filtered_index), - _num_pq_chunks(index_config.num_pq_chunks), - _delete_set(new tsl::robin_set), - _conc_consolidate(index_config.concurrent_consolidate) { - if (_dynamic_index && !_enable_tags) { - throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - if (_pq_dist) { +Index::Index(const IndexConfig &index_config, std::shared_ptr> data_store, + std::unique_ptr graph_store, + std::shared_ptr> pq_data_store) + : _dist_metric(index_config.metric), _dim(index_config.dimension), _max_points(index_config.max_points), + _num_frozen_pts(index_config.num_frozen_pts), _dynamic_index(index_config.dynamic_index), + _enable_tags(index_config.enable_tags), _indexingMaxC(DEFAULT_MAXC), _query_scratch(nullptr), + _pq_dist(index_config.pq_dist_build), _use_opq(index_config.use_opq), + _filtered_index(index_config.filtered_index), _num_pq_chunks(index_config.num_pq_chunks), + _delete_set(new tsl::robin_set), _conc_consolidate(index_config.concurrent_consolidate) +{ + if (_dynamic_index && !_enable_tags) + { + throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_pq_dist) + { + if (_dynamic_index) + throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based " + "index construction", + -1, __FUNCSIG__, __FILE__, __LINE__); + if (_dist_metric == diskann::Metric::INNER_PRODUCT) + throw ANNException("ERROR: Inner product metrics not yet supported " + "with PQ distance " + "base index", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_dynamic_index && _num_frozen_pts == 0) + { + _num_frozen_pts = 1; + } + // Sanity check. While logically it is correct, max_points = 0 causes + // downstream problems. + if (_max_points == 0) + { + _max_points = 1; + } + const size_t total_internal_points = _max_points + _num_frozen_pts; + + _start = (uint32_t)_max_points; + + _data_store = data_store; + _pq_data_store = pq_data_store; + _graph_store = std::move(graph_store); + + _locks = std::vector(total_internal_points); + if (_enable_tags) + { + _location_to_tag.reserve(total_internal_points); + _tag_to_location.reserve(total_internal_points); + } + if (_dynamic_index) - throw ANNException( - "ERROR: Dynamic Indexing not supported with PQ distance based " - "index construction", - -1, __FUNCSIG__, __FILE__, __LINE__); - if (_dist_metric == diskann::Metric::INNER_PRODUCT) - throw ANNException("ERROR: Inner product metrics not yet supported " - "with PQ distance " - "base index", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (_dynamic_index && _num_frozen_pts == 0) { - _num_frozen_pts = 1; - } - // Sanity check. While logically it is correct, max_points = 0 causes - // downstream problems. - if (_max_points == 0) { - _max_points = 1; - } - const size_t total_internal_points = _max_points + _num_frozen_pts; - - _start = (uint32_t)_max_points; - - _data_store = data_store; - _pq_data_store = pq_data_store; - _graph_store = std::move(graph_store); - - _locks = std::vector(total_internal_points); - if (_enable_tags) { - _location_to_tag.reserve(total_internal_points); - _tag_to_location.reserve(total_internal_points); - } - - if (_dynamic_index) { - this->enable_delete(); // enable delete by default for dynamic index - if (_filtered_index) { - _location_to_labels.resize(total_internal_points); - } - } - - if (index_config.index_write_params != nullptr) { - _indexingQueueSize = index_config.index_write_params->search_list_size; - _indexingRange = index_config.index_write_params->max_degree; - _indexingMaxC = index_config.index_write_params->max_occlusion_size; - _indexingAlpha = index_config.index_write_params->alpha; - _filterIndexingQueueSize = - index_config.index_write_params->filter_list_size; - _indexingThreads = index_config.index_write_params->num_threads; - _saturate_graph = index_config.index_write_params->saturate_graph; - - if (index_config.index_search_params != nullptr) { - uint32_t num_scratch_spaces = - index_config.index_search_params->num_search_threads + - _indexingThreads; - initialize_query_scratch( - num_scratch_spaces, - index_config.index_search_params->initial_search_list_size, - _indexingQueueSize, _indexingRange, _indexingMaxC, - _data_store->get_dims()); - } - } + { + this->enable_delete(); // enable delete by default for dynamic index + if (_filtered_index) + { + _location_to_labels.resize(total_internal_points); + } + } + + if (index_config.index_write_params != nullptr) + { + _indexingQueueSize = index_config.index_write_params->search_list_size; + _indexingRange = index_config.index_write_params->max_degree; + _indexingMaxC = index_config.index_write_params->max_occlusion_size; + _indexingAlpha = index_config.index_write_params->alpha; + _filterIndexingQueueSize = index_config.index_write_params->filter_list_size; + _indexingThreads = index_config.index_write_params->num_threads; + _saturate_graph = index_config.index_write_params->saturate_graph; + + if (index_config.index_search_params != nullptr) + { + uint32_t num_scratch_spaces = index_config.index_search_params->num_search_threads + _indexingThreads; + initialize_query_scratch(num_scratch_spaces, index_config.index_search_params->initial_search_list_size, + _indexingQueueSize, _indexingRange, _indexingMaxC, _data_store->get_dims()); + } + } } template -Index::Index( - Metric m, const size_t dim, const size_t max_points, - const std::shared_ptr index_parameters, - const std::shared_ptr index_search_params, - const size_t num_frozen_pts, const bool dynamic_index, - const bool enable_tags, const bool concurrent_consolidate, - const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq, - const bool filtered_index) - : Index(IndexConfigBuilder() - .with_metric(m) - .with_dimension(dim) - .with_max_points(max_points) - .with_index_write_params(index_parameters) - .with_index_search_params(index_search_params) - .with_num_frozen_pts(num_frozen_pts) - .is_dynamic_index(dynamic_index) - .is_enable_tags(enable_tags) - .is_concurrent_consolidate(concurrent_consolidate) - .is_pq_dist_build(pq_dist_build) - .with_num_pq_chunks(num_pq_chunks) - .is_use_opq(use_opq) - .is_filtered(filtered_index) - .with_data_type(diskann_type_to_name()) - .build(), - IndexFactory::construct_datastore( - DataStoreStrategy::MEMORY, - (max_points == 0 ? (size_t)1 : max_points) + - (dynamic_index && num_frozen_pts == 0 ? (size_t)1 - : num_frozen_pts), - dim, m), - IndexFactory::construct_graphstore( - GraphStoreStrategy::MEMORY, - (max_points == 0 ? (size_t)1 : max_points) + - (dynamic_index && num_frozen_pts == 0 ? (size_t)1 - : num_frozen_pts), - (size_t)((index_parameters == nullptr - ? 0 - : index_parameters->max_degree) * - defaults::GRAPH_SLACK_FACTOR * 1.05))) { - if (_pq_dist) { - _pq_data_store = IndexFactory::construct_pq_datastore( - DataStoreStrategy::MEMORY, max_points + num_frozen_pts, dim, m, - num_pq_chunks, use_opq); - } else { - _pq_data_store = _data_store; - } +Index::Index(Metric m, const size_t dim, const size_t max_points, + const std::shared_ptr index_parameters, + const std::shared_ptr index_search_params, const size_t num_frozen_pts, + const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate, + const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq, + const bool filtered_index) + : Index( + IndexConfigBuilder() + .with_metric(m) + .with_dimension(dim) + .with_max_points(max_points) + .with_index_write_params(index_parameters) + .with_index_search_params(index_search_params) + .with_num_frozen_pts(num_frozen_pts) + .is_dynamic_index(dynamic_index) + .is_enable_tags(enable_tags) + .is_concurrent_consolidate(concurrent_consolidate) + .is_pq_dist_build(pq_dist_build) + .with_num_pq_chunks(num_pq_chunks) + .is_use_opq(use_opq) + .is_filtered(filtered_index) + .with_data_type(diskann_type_to_name()) + .build(), + IndexFactory::construct_datastore(DataStoreStrategy::MEMORY, + (max_points == 0 ? (size_t)1 : max_points) + + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), + dim, m), + IndexFactory::construct_graphstore(GraphStoreStrategy::MEMORY, + (max_points == 0 ? (size_t)1 : max_points) + + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), + (size_t)((index_parameters == nullptr ? 0 : index_parameters->max_degree) * + defaults::GRAPH_SLACK_FACTOR * 1.05))) +{ + if (_pq_dist) + { + _pq_data_store = IndexFactory::construct_pq_datastore(DataStoreStrategy::MEMORY, max_points + num_frozen_pts, + dim, m, num_pq_chunks, use_opq); + } + else + { + _pq_data_store = _data_store; + } } -template -Index::~Index() { - // Ensure that no other activity is happening before dtor() - std::unique_lock ul(_update_lock); - std::unique_lock cl(_consolidate_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); - - for (auto &lock : _locks) { - LockGuard lg(lock); - } - - if (_opt_graph != nullptr) { - delete[] _opt_graph; - } - - if (!_query_scratch.empty()) { - ScratchStoreManager> manager(_query_scratch); - manager.destroy(); - } +template Index::~Index() +{ + // Ensure that no other activity is happening before dtor() + std::unique_lock ul(_update_lock); + std::unique_lock cl(_consolidate_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); + + for (auto &lock : _locks) + { + LockGuard lg(lock); + } + + if (_opt_graph != nullptr) + { + delete[] _opt_graph; + } + + if (!_query_scratch.empty()) + { + ScratchStoreManager> manager(_query_scratch); + manager.destroy(); + } } template -void Index::initialize_query_scratch(uint32_t num_threads, - uint32_t search_l, - uint32_t indexing_l, - uint32_t r, uint32_t maxc, - size_t dim) { - for (uint32_t i = 0; i < num_threads; i++) { - auto scratch = new InMemQueryScratch( - search_l, indexing_l, r, maxc, dim, _data_store->get_aligned_dim(), - _data_store->get_alignment_factor(), _pq_dist); - _query_scratch.push(scratch); - } +void Index::initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, + uint32_t r, uint32_t maxc, size_t dim) +{ + for (uint32_t i = 0; i < num_threads; i++) + { + auto scratch = new InMemQueryScratch(search_l, indexing_l, r, maxc, dim, _data_store->get_aligned_dim(), + _data_store->get_alignment_factor(), _pq_dist); + _query_scratch.push(scratch); + } } -template -size_t Index::save_tags(std::string tags_file) { - if (!_enable_tags) { - diskann::cout << "Not saving tags as they are not enabled." << std::endl; - return 0; - } - - size_t tag_bytes_written; - TagT *tag_data = new TagT[_nd + _num_frozen_pts]; - for (uint32_t i = 0; i < _nd; i++) { - TagT tag; - if (_location_to_tag.try_get(i, tag)) { - tag_data[i] = tag; - } else { - // catering to future when tagT can be any type. - std::memset((char *)&tag_data[i], 0, sizeof(TagT)); - } - } - if (_num_frozen_pts > 0) { - std::memset((char *)&tag_data[_start], 0, sizeof(TagT) * _num_frozen_pts); - } - try { - tag_bytes_written = - save_bin(tags_file, tag_data, _nd + _num_frozen_pts, 1); - } catch (std::system_error &e) { - throw FileException(tags_file, e, __FUNCSIG__, __FILE__, __LINE__); - } - delete[] tag_data; - return tag_bytes_written; +template size_t Index::save_tags(std::string tags_file) +{ + if (!_enable_tags) + { + diskann::cout << "Not saving tags as they are not enabled." << std::endl; + return 0; + } + + size_t tag_bytes_written; + TagT *tag_data = new TagT[_nd + _num_frozen_pts]; + for (uint32_t i = 0; i < _nd; i++) + { + TagT tag; + if (_location_to_tag.try_get(i, tag)) + { + tag_data[i] = tag; + } + else + { + // catering to future when tagT can be any type. + std::memset((char *)&tag_data[i], 0, sizeof(TagT)); + } + } + if (_num_frozen_pts > 0) + { + std::memset((char *)&tag_data[_start], 0, sizeof(TagT) * _num_frozen_pts); + } + try + { + tag_bytes_written = save_bin(tags_file, tag_data, _nd + _num_frozen_pts, 1); + } + catch (std::system_error &e) + { + throw FileException(tags_file, e, __FUNCSIG__, __FILE__, __LINE__); + } + delete[] tag_data; + return tag_bytes_written; } -template -size_t Index::save_data(std::string data_file) { - // Note: at this point, either _nd == _max_points or any frozen points have - // been temporarily moved to _nd, so _nd + _num_frozen_pts is the valid - // location limit. - return _data_store->save(data_file, (location_t)(_nd + _num_frozen_pts)); +template size_t Index::save_data(std::string data_file) +{ + // Note: at this point, either _nd == _max_points or any frozen points have + // been temporarily moved to _nd, so _nd + _num_frozen_pts is the valid + // location limit. + return _data_store->save(data_file, (location_t)(_nd + _num_frozen_pts)); } // save the graph index on a file as an adjacency list. For each point, // first store the number of neighbors, and then the neighbor list (each as // 4 byte uint32_t) -template -size_t Index::save_graph(std::string graph_file) { - return _graph_store->store(graph_file, _nd + _num_frozen_pts, _num_frozen_pts, - _start); +template size_t Index::save_graph(std::string graph_file) +{ + return _graph_store->store(graph_file, _nd + _num_frozen_pts, _num_frozen_pts, _start); } template -size_t Index::save_delete_list(const std::string &filename) { - if (_delete_set->size() == 0) { - return 0; - } - std::unique_ptr delete_list = - std::make_unique(_delete_set->size()); - uint32_t i = 0; - for (auto &del : *_delete_set) { - delete_list[i++] = del; - } - return save_bin(filename, delete_list.get(), _delete_set->size(), - 1); +size_t Index::save_delete_list(const std::string &filename) +{ + if (_delete_set->size() == 0) + { + return 0; + } + std::unique_ptr delete_list = std::make_unique(_delete_set->size()); + uint32_t i = 0; + for (auto &del : *_delete_set) + { + delete_list[i++] = del; + } + return save_bin(filename, delete_list.get(), _delete_set->size(), 1); } template -void Index::save(const char *filename, - bool compact_before_save) { - diskann::Timer timer; - - std::unique_lock ul(_update_lock); - std::unique_lock cl(_consolidate_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); - - if (compact_before_save) { - compact_data(); - compact_frozen_point(); - } else { - if (!_data_compacted) { - throw ANNException( - "Index save for non-compacted index is not yet implemented", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - } - - if (!_save_as_one_file) { - if (_filtered_index) { - if (_label_to_start_id.size() > 0) { - std::ofstream medoid_writer(std::string(filename) + - "_labels_to_medoids.txt"); - if (medoid_writer.fail()) { - throw diskann::ANNException( - std::string("Failed to open file ") + filename, -1); - } - for (auto iter : _label_to_start_id) { - medoid_writer << iter.first << ", " << iter.second << std::endl; - } - medoid_writer.close(); - } - - if (_use_universal_label) { - std::ofstream universal_label_writer(std::string(filename) + - "_universal_label.txt"); - assert(universal_label_writer.is_open()); - universal_label_writer << _universal_label << std::endl; - universal_label_writer.close(); - } - - if (_location_to_labels.size() > 0) { - std::ofstream label_writer(std::string(filename) + "_labels.txt"); - assert(label_writer.is_open()); - for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++) { - for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++) { - label_writer << _location_to_labels[i][j] << ","; - } - if (_location_to_labels[i].size() != 0) - label_writer - << _location_to_labels[i][_location_to_labels[i].size() - 1]; - - label_writer << std::endl; - } - label_writer.close(); - - // write compacted raw_labels if data hence _location_to_labels was also - // compacted - if (compact_before_save && _dynamic_index) { - _label_map = - load_label_map(std::string(filename) + "_labels_map.txt"); - std::unordered_map mapped_to_raw_labels; - // invert label map - for (const auto &[key, value] : _label_map) { - mapped_to_raw_labels.insert({value, key}); - } - - // write updated labels - std::ofstream raw_label_writer(std::string(filename) + - "_raw_labels.txt"); - assert(raw_label_writer.is_open()); - for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++) { - for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++) { - raw_label_writer - << mapped_to_raw_labels[_location_to_labels[i][j]] << ","; +void Index::save(const char *filename, bool compact_before_save) +{ + diskann::Timer timer; + + std::unique_lock ul(_update_lock); + std::unique_lock cl(_consolidate_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); + + if (compact_before_save) + { + compact_data(); + compact_frozen_point(); + } + else + { + if (!_data_compacted) + { + throw ANNException("Index save for non-compacted index is not yet implemented", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + } + + if (!_save_as_one_file) + { + if (_filtered_index) + { + if (_label_to_start_id.size() > 0) + { + std::ofstream medoid_writer(std::string(filename) + "_labels_to_medoids.txt"); + if (medoid_writer.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + for (auto iter : _label_to_start_id) + { + medoid_writer << iter.first << ", " << iter.second << std::endl; + } + medoid_writer.close(); } - if (_location_to_labels[i].size() != 0) - raw_label_writer << mapped_to_raw_labels - [_location_to_labels[i] - [_location_to_labels[i].size() - 1]]; - - raw_label_writer << std::endl; - } - raw_label_writer.close(); - } - } - } - - std::string graph_file = std::string(filename); - std::string tags_file = std::string(filename) + ".tags"; - std::string data_file = std::string(filename) + ".data"; - std::string delete_list_file = std::string(filename) + ".del"; - - // Because the save_* functions use append mode, ensure that - // the files are deleted before save. Ideally, we should check - // the error code for delete_file, but will ignore now because - // delete should succeed if save will succeed. - delete_file(graph_file); - save_graph(graph_file); - delete_file(data_file); - save_data(data_file); - delete_file(tags_file); - save_tags(tags_file); - delete_file(delete_list_file); - save_delete_list(delete_list_file); - } else { - diskann::cout << "Save index in a single file currently not supported. " - "Not saving the index." - << std::endl; - } - // If frozen points were temporarily compacted to _nd, move back to - // _max_points. - reposition_frozen_point_to_end(); + if (_use_universal_label) + { + std::ofstream universal_label_writer(std::string(filename) + "_universal_label.txt"); + assert(universal_label_writer.is_open()); + universal_label_writer << _universal_label << std::endl; + universal_label_writer.close(); + } + + if (_location_to_labels.size() > 0) + { + std::ofstream label_writer(std::string(filename) + "_labels.txt"); + assert(label_writer.is_open()); + for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++) + { + for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++) + { + label_writer << _location_to_labels[i][j] << ","; + } + if (_location_to_labels[i].size() != 0) + label_writer << _location_to_labels[i][_location_to_labels[i].size() - 1]; + + label_writer << std::endl; + } + label_writer.close(); + + // write compacted raw_labels if data hence _location_to_labels was also + // compacted + if (compact_before_save && _dynamic_index) + { + _label_map = load_label_map(std::string(filename) + "_labels_map.txt"); + std::unordered_map mapped_to_raw_labels; + // invert label map + for (const auto &[key, value] : _label_map) + { + mapped_to_raw_labels.insert({value, key}); + } + + // write updated labels + std::ofstream raw_label_writer(std::string(filename) + "_raw_labels.txt"); + assert(raw_label_writer.is_open()); + for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++) + { + for (uint32_t j = 0; j + 1 < _location_to_labels[i].size(); j++) + { + raw_label_writer << mapped_to_raw_labels[_location_to_labels[i][j]] << ","; + } + if (_location_to_labels[i].size() != 0) + raw_label_writer + << mapped_to_raw_labels[_location_to_labels[i][_location_to_labels[i].size() - 1]]; + + raw_label_writer << std::endl; + } + raw_label_writer.close(); + } + } + } + + std::string graph_file = std::string(filename); + std::string tags_file = std::string(filename) + ".tags"; + std::string data_file = std::string(filename) + ".data"; + std::string delete_list_file = std::string(filename) + ".del"; + + // Because the save_* functions use append mode, ensure that + // the files are deleted before save. Ideally, we should check + // the error code for delete_file, but will ignore now because + // delete should succeed if save will succeed. + delete_file(graph_file); + save_graph(graph_file); + delete_file(data_file); + save_data(data_file); + delete_file(tags_file); + save_tags(tags_file); + delete_file(delete_list_file); + save_delete_list(delete_list_file); + } + else + { + diskann::cout << "Save index in a single file currently not supported. " + "Not saving the index." + << std::endl; + } + + // If frozen points were temporarily compacted to _nd, move back to + // _max_points. + reposition_frozen_point_to_end(); - diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 - << "s." << std::endl; + diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl; } #ifdef EXEC_ENV_OLS template -size_t Index::load_tags(AlignedFileReader &reader) { +size_t Index::load_tags(AlignedFileReader &reader) +{ #else template -size_t Index::load_tags(const std::string tag_filename) { - if (_enable_tags && !file_exists(tag_filename)) { - diskann::cerr << "Tag file " << tag_filename << " does not exist!" - << std::endl; - throw diskann::ANNException("Tag file " + tag_filename + " does not exist!", - -1, __FUNCSIG__, __FILE__, __LINE__); - } +size_t Index::load_tags(const std::string tag_filename) +{ + if (_enable_tags && !file_exists(tag_filename)) + { + diskann::cerr << "Tag file " << tag_filename << " does not exist!" << std::endl; + throw diskann::ANNException("Tag file " + tag_filename + " does not exist!", -1, __FUNCSIG__, __FILE__, + __LINE__); + } #endif - if (!_enable_tags) { - diskann::cout << "Tags not loaded as tags not enabled." << std::endl; - return 0; - } + if (!_enable_tags) + { + diskann::cout << "Tags not loaded as tags not enabled." << std::endl; + return 0; + } - size_t file_dim, file_num_points; - TagT *tag_data; + size_t file_dim, file_num_points; + TagT *tag_data; #ifdef EXEC_ENV_OLS - load_bin(reader, tag_data, file_num_points, file_dim); + load_bin(reader, tag_data, file_num_points, file_dim); #else - load_bin(std::string(tag_filename), tag_data, file_num_points, - file_dim); + load_bin(std::string(tag_filename), tag_data, file_num_points, file_dim); #endif - if (file_dim != 1) { - std::stringstream stream; - stream << "ERROR: Found " << file_dim << " dimensions for tags," - << "but tag file must have 1 dimension." << std::endl; - diskann::cerr << stream.str() << std::endl; - delete[] tag_data; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + if (file_dim != 1) + { + std::stringstream stream; + stream << "ERROR: Found " << file_dim << " dimensions for tags," + << "but tag file must have 1 dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; + delete[] tag_data; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - const size_t num_data_points = file_num_points - _num_frozen_pts; - _location_to_tag.reserve(num_data_points); - _tag_to_location.reserve(num_data_points); - for (uint32_t i = 0; i < (uint32_t)num_data_points; i++) { - TagT tag = *(tag_data + i); - if (_delete_set->find(i) == _delete_set->end()) { - _location_to_tag.set(i, tag); - _tag_to_location[tag] = i; + const size_t num_data_points = file_num_points - _num_frozen_pts; + _location_to_tag.reserve(num_data_points); + _tag_to_location.reserve(num_data_points); + for (uint32_t i = 0; i < (uint32_t)num_data_points; i++) + { + TagT tag = *(tag_data + i); + if (_delete_set->find(i) == _delete_set->end()) + { + _location_to_tag.set(i, tag); + _tag_to_location[tag] = i; + } } - } - diskann::cout << "Tags loaded." << std::endl; - delete[] tag_data; - return file_num_points; + diskann::cout << "Tags loaded." << std::endl; + delete[] tag_data; + return file_num_points; } template #ifdef EXEC_ENV_OLS -size_t Index::load_data(AlignedFileReader &reader) { +size_t Index::load_data(AlignedFileReader &reader) +{ #else -size_t Index::load_data(std::string filename) { +size_t Index::load_data(std::string filename) +{ #endif - size_t file_dim, file_num_points; + size_t file_dim, file_num_points; #ifdef EXEC_ENV_OLS - diskann::get_bin_metadata(reader, file_num_points, file_dim); + diskann::get_bin_metadata(reader, file_num_points, file_dim); #else - if (!file_exists(filename)) { - std::stringstream stream; - stream << "ERROR: data file " << filename << " does not exist." - << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - diskann::get_bin_metadata(filename, file_num_points, file_dim); + if (!file_exists(filename)) + { + std::stringstream stream; + stream << "ERROR: data file " << filename << " does not exist." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + diskann::get_bin_metadata(filename, file_num_points, file_dim); #endif - // since we are loading a new dataset, _empty_slots must be cleared - _empty_slots.clear(); + // since we are loading a new dataset, _empty_slots must be cleared + _empty_slots.clear(); - if (file_dim != _dim) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << _dim << " dimension," - << "but file has " << file_dim << " dimension." << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + if (file_dim != _dim) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << _dim << " dimension," + << "but file has " << file_dim << " dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - if (file_num_points > _max_points + _num_frozen_pts) { - // update and tag lock acquired in load() before calling load_data - resize(file_num_points - _num_frozen_pts); - } + if (file_num_points > _max_points + _num_frozen_pts) + { + // update and tag lock acquired in load() before calling load_data + resize(file_num_points - _num_frozen_pts); + } #ifdef EXEC_ENV_OLS - // REFACTOR TODO: Must figure out how to support aligned reader in a clean - // manner. - copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, - _data_store->get_aligned_dim()); + // REFACTOR TODO: Must figure out how to support aligned reader in a clean + // manner. + copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim()); #else - _data_store->load(filename); // offset == 0. + _data_store->load(filename); // offset == 0. #endif - return file_num_points; + return file_num_points; } #ifdef EXEC_ENV_OLS template -size_t Index::load_delete_set(AlignedFileReader &reader) { +size_t Index::load_delete_set(AlignedFileReader &reader) +{ #else template -size_t Index::load_delete_set(const std::string &filename) { +size_t Index::load_delete_set(const std::string &filename) +{ #endif - std::unique_ptr delete_list; - size_t npts, ndim; + std::unique_ptr delete_list; + size_t npts, ndim; #ifdef EXEC_ENV_OLS - diskann::load_bin(reader, delete_list, npts, ndim); + diskann::load_bin(reader, delete_list, npts, ndim); #else - diskann::load_bin(filename, delete_list, npts, ndim); + diskann::load_bin(filename, delete_list, npts, ndim); #endif - assert(ndim == 1); - for (uint32_t i = 0; i < npts; i++) { - _delete_set->insert(delete_list[i]); - } - return npts; + assert(ndim == 1); + for (uint32_t i = 0; i < npts; i++) + { + _delete_set->insert(delete_list[i]); + } + return npts; } // load the index from file and update the max_degree, cur (navigating // node loc), and _final_graph (adjacency list) template #ifdef EXEC_ENV_OLS -void Index::load(AlignedFileReader &reader, - uint32_t num_threads, uint32_t search_l) { +void Index::load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) +{ #else -void Index::load(const char *filename, uint32_t num_threads, - uint32_t search_l) { +void Index::load(const char *filename, uint32_t num_threads, uint32_t search_l) +{ #endif - std::unique_lock ul(_update_lock); - std::unique_lock cl(_consolidate_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); + std::unique_lock ul(_update_lock); + std::unique_lock cl(_consolidate_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); - _has_built = true; + _has_built = true; - size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, - label_num_pts = 0; + size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, label_num_pts = 0; - std::string mem_index_file(filename); - std::string labels_file = mem_index_file + "_labels.txt"; - std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; - std::string labels_map_file = mem_index_file + "_labels_map.txt"; + std::string mem_index_file(filename); + std::string labels_file = mem_index_file + "_labels.txt"; + std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; + std::string labels_map_file = mem_index_file + "_labels_map.txt"; - if (!_save_as_one_file) { - // For DLVS Store, we will not support saving the index in multiple - // files. + if (!_save_as_one_file) + { + // For DLVS Store, we will not support saving the index in multiple + // files. #ifndef EXEC_ENV_OLS - std::string data_file = std::string(filename) + ".data"; - std::string tags_file = std::string(filename) + ".tags"; - std::string delete_set_file = std::string(filename) + ".del"; - std::string graph_file = std::string(filename); - data_file_num_pts = load_data(data_file); - if (file_exists(delete_set_file)) { - load_delete_set(delete_set_file); - } - if (_enable_tags) { - tags_file_num_pts = load_tags(tags_file); - } - graph_num_pts = load_graph(graph_file, data_file_num_pts); + std::string data_file = std::string(filename) + ".data"; + std::string tags_file = std::string(filename) + ".tags"; + std::string delete_set_file = std::string(filename) + ".del"; + std::string graph_file = std::string(filename); + data_file_num_pts = load_data(data_file); + if (file_exists(delete_set_file)) + { + load_delete_set(delete_set_file); + } + if (_enable_tags) + { + tags_file_num_pts = load_tags(tags_file); + } + graph_num_pts = load_graph(graph_file, data_file_num_pts); #endif - } else { - diskann::cout << "Single index file saving/loading support not yet " - "enabled. Not loading the index." + } + else + { + diskann::cout << "Single index file saving/loading support not yet " + "enabled. Not loading the index." + << std::endl; + return; + } + + if (data_file_num_pts != graph_num_pts || (data_file_num_pts != tags_file_num_pts && _enable_tags)) + { + std::stringstream stream; + stream << "ERROR: When loading index, loaded " << data_file_num_pts << " points from datafile, " + << graph_num_pts << " from graph, and " << tags_file_num_pts + << " tags, with num_frozen_pts being set to " << _num_frozen_pts << " in constructor." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (file_exists(labels_file)) + { + _label_map = load_label_map(labels_map_file); + parse_label_file(labels_file, label_num_pts); + assert(label_num_pts == data_file_num_pts - _num_frozen_pts); + if (file_exists(labels_to_medoids)) + { + std::ifstream medoid_stream(labels_to_medoids); + std::string line, token; + uint32_t line_cnt = 0; + + _label_to_start_id.clear(); + + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t medoid = 0; + LabelT label; + while (std::getline(iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = (LabelT)std::stoul(token); + if (cnt == 0) + label = token_as_num; + else + medoid = token_as_num; + cnt++; + } + _label_to_start_id[label] = medoid; + line_cnt++; + } + } + + std::string universal_label_file(filename); + universal_label_file += "_universal_label.txt"; + if (file_exists(universal_label_file)) + { + std::ifstream universal_label_reader(universal_label_file); + universal_label_reader >> _universal_label; + _use_universal_label = true; + universal_label_reader.close(); + } + } + + _nd = data_file_num_pts - _num_frozen_pts; + _empty_slots.clear(); + _empty_slots.reserve(_max_points); + for (auto i = _nd; i < _max_points; i++) + { + _empty_slots.insert((uint32_t)i); + } + + reposition_frozen_point_to_end(); + diskann::cout << "Num frozen points:" << _num_frozen_pts << " _nd: " << _nd << " _start: " << _start + << " size(_location_to_tag): " << _location_to_tag.size() + << " size(_tag_to_location):" << _tag_to_location.size() << " Max points: " << _max_points << std::endl; - return; - } - if (data_file_num_pts != graph_num_pts || - (data_file_num_pts != tags_file_num_pts && _enable_tags)) { - std::stringstream stream; - stream << "ERROR: When loading index, loaded " << data_file_num_pts - << " points from datafile, " << graph_num_pts << " from graph, and " - << tags_file_num_pts << " tags, with num_frozen_pts being set to " - << _num_frozen_pts << " in constructor." << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - if (file_exists(labels_file)) { - _label_map = load_label_map(labels_map_file); - parse_label_file(labels_file, label_num_pts); - assert(label_num_pts == data_file_num_pts - _num_frozen_pts); - if (file_exists(labels_to_medoids)) { - std::ifstream medoid_stream(labels_to_medoids); - std::string line, token; - uint32_t line_cnt = 0; - - _label_to_start_id.clear(); - - while (std::getline(medoid_stream, line)) { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t medoid = 0; - LabelT label; - while (std::getline(iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), - token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), - token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); - if (cnt == 0) - label = token_as_num; - else - medoid = token_as_num; - cnt++; - } - _label_to_start_id[label] = medoid; - line_cnt++; - } - } - - std::string universal_label_file(filename); - universal_label_file += "_universal_label.txt"; - if (file_exists(universal_label_file)) { - std::ifstream universal_label_reader(universal_label_file); - universal_label_reader >> _universal_label; - _use_universal_label = true; - universal_label_reader.close(); - } - } - - _nd = data_file_num_pts - _num_frozen_pts; - _empty_slots.clear(); - _empty_slots.reserve(_max_points); - for (auto i = _nd; i < _max_points; i++) { - _empty_slots.insert((uint32_t)i); - } - - reposition_frozen_point_to_end(); - diskann::cout << "Num frozen points:" << _num_frozen_pts << " _nd: " << _nd - << " _start: " << _start - << " size(_location_to_tag): " << _location_to_tag.size() - << " size(_tag_to_location):" << _tag_to_location.size() - << " Max points: " << _max_points << std::endl; - - // For incremental index, _query_scratch is initialized in the constructor. - // For the bulk index, the params required to initialize _query_scratch - // are known only at load time, hence this check and the call to - // initialize_q_s(). - if (_query_scratch.size() == 0) { - initialize_query_scratch(num_threads, search_l, search_l, - (uint32_t)_graph_store->get_max_range_of_graph(), - _indexingMaxC, _dim); - } + // For incremental index, _query_scratch is initialized in the constructor. + // For the bulk index, the params required to initialize _query_scratch + // are known only at load time, hence this check and the call to + // initialize_q_s(). + if (_query_scratch.size() == 0) + { + initialize_query_scratch(num_threads, search_l, search_l, (uint32_t)_graph_store->get_max_range_of_graph(), + _indexingMaxC, _dim); + } } #ifndef EXEC_ENV_OLS template -size_t Index::get_graph_num_frozen_points( - const std::string &graph_file) { - size_t expected_file_size; - uint32_t max_observed_degree, start; - size_t file_frozen_pts; +size_t Index::get_graph_num_frozen_points(const std::string &graph_file) +{ + size_t expected_file_size; + uint32_t max_observed_degree, start; + size_t file_frozen_pts; - std::ifstream in; - in.exceptions(std::ios::badbit | std::ios::failbit); + std::ifstream in; + in.exceptions(std::ios::badbit | std::ios::failbit); - in.open(graph_file, std::ios::binary); - in.read((char *)&expected_file_size, sizeof(size_t)); - in.read((char *)&max_observed_degree, sizeof(uint32_t)); - in.read((char *)&start, sizeof(uint32_t)); - in.read((char *)&file_frozen_pts, sizeof(size_t)); + in.open(graph_file, std::ios::binary); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&max_observed_degree, sizeof(uint32_t)); + in.read((char *)&start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); - return file_frozen_pts; + return file_frozen_pts; } #endif #ifdef EXEC_ENV_OLS template -size_t Index::load_graph(AlignedFileReader &reader, - size_t expected_num_points) { +size_t Index::load_graph(AlignedFileReader &reader, size_t expected_num_points) +{ #else template -size_t Index::load_graph(std::string filename, - size_t expected_num_points) { +size_t Index::load_graph(std::string filename, size_t expected_num_points) +{ #endif - auto res = _graph_store->load(filename, expected_num_points); - _start = std::get<1>(res); - _num_frozen_pts = std::get<2>(res); - return std::get<0>(res); + auto res = _graph_store->load(filename, expected_num_points); + _start = std::get<1>(res); + _num_frozen_pts = std::get<2>(res); + return std::get<0>(res); } template -int Index::_get_vector_by_tag(TagType &tag, DataType &vec) { - try { - TagT tag_val = std::any_cast(tag); - T *vec_val = std::any_cast(vec); - return this->get_vector_by_tag(tag_val, vec_val); - } catch (const std::bad_any_cast &e) { - throw ANNException( - "Error: bad any cast while performing _get_vector_by_tags() " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error: " + std::string(e.what()), -1); - } +int Index::_get_vector_by_tag(TagType &tag, DataType &vec) +{ + try + { + TagT tag_val = std::any_cast(tag); + T *vec_val = std::any_cast(vec); + return this->get_vector_by_tag(tag_val, vec_val); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _get_vector_by_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } -template -int Index::get_vector_by_tag(TagT &tag, T *vec) { - std::shared_lock lock(_tag_lock); - if (_tag_to_location.find(tag) == _tag_to_location.end()) { - diskann::cout << "Tag " << get_tag_string(tag) << " does not exist" - << std::endl; - return -1; - } +template int Index::get_vector_by_tag(TagT &tag, T *vec) +{ + std::shared_lock lock(_tag_lock); + if (_tag_to_location.find(tag) == _tag_to_location.end()) + { + diskann::cout << "Tag " << get_tag_string(tag) << " does not exist" << std::endl; + return -1; + } - location_t location = _tag_to_location[tag]; - _data_store->get_vector(location, vec); + location_t location = _tag_to_location[tag]; + _data_store->get_vector(location, vec); - return 0; + return 0; } -template -uint32_t Index::calculate_entry_point() { - // REFACTOR TODO: This function does not support multi-threaded calculation of - // medoid. Must revisit if perf is a concern. - return _data_store->calculate_medoid(); +template uint32_t Index::calculate_entry_point() +{ + // REFACTOR TODO: This function does not support multi-threaded calculation of + // medoid. Must revisit if perf is a concern. + return _data_store->calculate_medoid(); } -template -std::vector Index::get_init_ids() { - std::vector init_ids; - init_ids.reserve(1 + _num_frozen_pts); +template std::vector Index::get_init_ids() +{ + std::vector init_ids; + init_ids.reserve(1 + _num_frozen_pts); - init_ids.emplace_back(_start); + init_ids.emplace_back(_start); - for (uint32_t frozen = (uint32_t)_max_points; - frozen < _max_points + _num_frozen_pts; frozen++) { - if (frozen != _start) { - init_ids.emplace_back(frozen); + for (uint32_t frozen = (uint32_t)_max_points; frozen < _max_points + _num_frozen_pts; frozen++) + { + if (frozen != _start) + { + init_ids.emplace_back(frozen); + } } - } - return init_ids; + return init_ids; } // Find common filter between a node's labels and a given set of labels, while // taking into account universal label template -bool Index::detect_common_filters( - uint32_t point_id, bool search_invocation, - const std::vector &incoming_labels) { - auto &curr_node_labels = _location_to_labels[point_id]; - std::vector common_filters; - std::set_intersection(incoming_labels.begin(), incoming_labels.end(), - curr_node_labels.begin(), curr_node_labels.end(), - std::back_inserter(common_filters)); - if (common_filters.size() > 0) { - // This is to reduce the repetitive calls. If common_filters size is > 0 , - // we dont need to check further for universal label - return true; - } - if (_use_universal_label) { - if (!search_invocation) { - if (std::find(incoming_labels.begin(), incoming_labels.end(), - _universal_label) != incoming_labels.end() || - std::find(curr_node_labels.begin(), curr_node_labels.end(), - _universal_label) != curr_node_labels.end()) - common_filters.push_back(_universal_label); - } else { - if (std::find(curr_node_labels.begin(), curr_node_labels.end(), - _universal_label) != curr_node_labels.end()) - common_filters.push_back(_universal_label); - } - } - return (common_filters.size() > 0); +bool Index::detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels) +{ + auto &curr_node_labels = _location_to_labels[point_id]; + std::vector common_filters; + std::set_intersection(incoming_labels.begin(), incoming_labels.end(), curr_node_labels.begin(), + curr_node_labels.end(), std::back_inserter(common_filters)); + if (common_filters.size() > 0) + { + // This is to reduce the repetitive calls. If common_filters size is > 0 , + // we dont need to check further for universal label + return true; + } + if (_use_universal_label) + { + if (!search_invocation) + { + if (std::find(incoming_labels.begin(), incoming_labels.end(), _universal_label) != incoming_labels.end() || + std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.push_back(_universal_label); + } + else + { + if (std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.push_back(_universal_label); + } + } + return (common_filters.size() > 0); } template std::pair Index::iterate_to_fixed_point( - InMemQueryScratch *scratch, const uint32_t Lsize, - const std::vector &init_ids, bool use_filter, - const std::vector &filter_labels, bool search_invocation) { - std::vector &expanded_nodes = scratch->pool(); - NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); - best_L_nodes.reserve(Lsize); - tsl::robin_set &inserted_into_pool_rs = - scratch->inserted_into_pool_rs(); - boost::dynamic_bitset<> &inserted_into_pool_bs = - scratch->inserted_into_pool_bs(); - std::vector &id_scratch = scratch->id_scratch(); - std::vector &dist_scratch = scratch->dist_scratch(); - assert(id_scratch.size() == 0); - - T *aligned_query = scratch->aligned_query(); - - float *pq_dists = nullptr; - - _pq_data_store->preprocess_query(aligned_query, scratch); - - if (expanded_nodes.size() > 0 || id_scratch.size() > 0) { - throw ANNException("ERROR: Clear scratch space before passing.", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - // Decide whether to use bitset or robin set to mark visited nodes - auto total_num_points = _max_points + _num_frozen_pts; - bool fast_iterate = total_num_points <= MAX_POINTS_FOR_USING_BITSET; - - if (fast_iterate) { - if (inserted_into_pool_bs.size() < total_num_points) { - // hopefully using 2X will reduce the number of allocations. - auto resize_size = 2 * total_num_points > MAX_POINTS_FOR_USING_BITSET - ? MAX_POINTS_FOR_USING_BITSET - : 2 * total_num_points; - inserted_into_pool_bs.resize(resize_size); - } - } - - // Lambda to determine if a node has been visited - auto is_not_visited = [this, fast_iterate, &inserted_into_pool_bs, - &inserted_into_pool_rs](const uint32_t id) { - return fast_iterate - ? inserted_into_pool_bs[id] == 0 - : inserted_into_pool_rs.find(id) == inserted_into_pool_rs.end(); - }; - - // Lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, scratch, - pq_dists](const std::vector &ids, - std::vector &dists_out) { - _pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, - scratch); - }; - - // Initialize the candidate pool with starting points - for (auto id : init_ids) { - if (id >= _max_points + _num_frozen_pts) { - diskann::cerr << "Out of range loc found as an edge : " << id - << std::endl; - throw diskann::ANNException(std::string("Wrong loc") + std::to_string(id), - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (use_filter) { - if (!detect_common_filters(id, search_invocation, filter_labels)) - continue; - } - - if (is_not_visited(id)) { - if (fast_iterate) { - inserted_into_pool_bs[id] = 1; - } else { - inserted_into_pool_rs.insert(id); - } - - float distance; - uint32_t ids[] = {id}; - float distances[] = {std::numeric_limits::max()}; - _pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch); - distance = distances[0]; - - Neighbor nn = Neighbor(id, distance); - best_L_nodes.insert(nn); - } - } - - uint32_t hops = 0; - uint32_t cmps = 0; - - while (best_L_nodes.has_unexpanded_node()) { - auto nbr = best_L_nodes.closest_unexpanded(); - auto n = nbr.id; - - // Add node to expanded nodes to create pool for prune later - if (!search_invocation) { - if (!use_filter) { - expanded_nodes.emplace_back(nbr); - } else { // in filter based indexing, the same point might invoke - // multiple iterate_to_fixed_points, so need to be careful - // not to add the same item to pool multiple times. - if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == - expanded_nodes.end()) { - expanded_nodes.emplace_back(nbr); - } - } - } - - // Find which of the nodes in des have not been visited before - id_scratch.clear(); - dist_scratch.clear(); - if (_dynamic_index) { - LockGuard guard(_locks[n]); - for (auto id : _graph_store->get_neighbours(n)) { - assert(id < _max_points + _num_frozen_pts); - - if (use_filter) { - // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. - if (!detect_common_filters(id, search_invocation, filter_labels)) - continue; + InMemQueryScratch *scratch, const uint32_t Lsize, const std::vector &init_ids, bool use_filter, + const std::vector &filter_labels, bool search_invocation) +{ + std::vector &expanded_nodes = scratch->pool(); + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); + best_L_nodes.reserve(Lsize); + tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); + boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); + std::vector &id_scratch = scratch->id_scratch(); + std::vector &dist_scratch = scratch->dist_scratch(); + assert(id_scratch.size() == 0); + + T *aligned_query = scratch->aligned_query(); + + float *pq_dists = nullptr; + + _pq_data_store->preprocess_query(aligned_query, scratch); + + if (expanded_nodes.size() > 0 || id_scratch.size() > 0) + { + throw ANNException("ERROR: Clear scratch space before passing.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // Decide whether to use bitset or robin set to mark visited nodes + auto total_num_points = _max_points + _num_frozen_pts; + bool fast_iterate = total_num_points <= MAX_POINTS_FOR_USING_BITSET; + + if (fast_iterate) + { + if (inserted_into_pool_bs.size() < total_num_points) + { + // hopefully using 2X will reduce the number of allocations. + auto resize_size = + 2 * total_num_points > MAX_POINTS_FOR_USING_BITSET ? MAX_POINTS_FOR_USING_BITSET : 2 * total_num_points; + inserted_into_pool_bs.resize(resize_size); } + } + + // Lambda to determine if a node has been visited + auto is_not_visited = [this, fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](const uint32_t id) { + return fast_iterate ? inserted_into_pool_bs[id] == 0 + : inserted_into_pool_rs.find(id) == inserted_into_pool_rs.end(); + }; + + // Lambda to batch compute query<-> node distances in PQ space + auto compute_dists = [this, scratch, pq_dists](const std::vector &ids, std::vector &dists_out) { + _pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch); + }; - if (is_not_visited(id)) { - id_scratch.push_back(id); + // Initialize the candidate pool with starting points + for (auto id : init_ids) + { + if (id >= _max_points + _num_frozen_pts) + { + diskann::cerr << "Out of range loc found as an edge : " << id << std::endl; + throw diskann::ANNException(std::string("Wrong loc") + std::to_string(id), -1, __FUNCSIG__, __FILE__, + __LINE__); } - } - } else { - _locks[n].lock(); - auto nbrs = _graph_store->get_neighbours(n); - _locks[n].unlock(); - for (auto id : nbrs) { - assert(id < _max_points + _num_frozen_pts); - if (use_filter) { - // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. - if (!detect_common_filters(id, search_invocation, filter_labels)) - continue; + if (use_filter) + { + if (!detect_common_filters(id, search_invocation, filter_labels)) + continue; } - if (is_not_visited(id)) { - id_scratch.push_back(id); + if (is_not_visited(id)) + { + if (fast_iterate) + { + inserted_into_pool_bs[id] = 1; + } + else + { + inserted_into_pool_rs.insert(id); + } + + float distance; + uint32_t ids[] = {id}; + float distances[] = {std::numeric_limits::max()}; + _pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch); + distance = distances[0]; + + Neighbor nn = Neighbor(id, distance); + best_L_nodes.insert(nn); } - } } - // Mark nodes visited - for (auto id : id_scratch) { - if (fast_iterate) { - inserted_into_pool_bs[id] = 1; - } else { - inserted_into_pool_rs.insert(id); - } - } + uint32_t hops = 0; + uint32_t cmps = 0; + + while (best_L_nodes.has_unexpanded_node()) + { + auto nbr = best_L_nodes.closest_unexpanded(); + auto n = nbr.id; + + // Add node to expanded nodes to create pool for prune later + if (!search_invocation) + { + if (!use_filter) + { + expanded_nodes.emplace_back(nbr); + } + else + { // in filter based indexing, the same point might invoke + // multiple iterate_to_fixed_points, so need to be careful + // not to add the same item to pool multiple times. + if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end()) + { + expanded_nodes.emplace_back(nbr); + } + } + } + + // Find which of the nodes in des have not been visited before + id_scratch.clear(); + dist_scratch.clear(); + if (_dynamic_index) + { + LockGuard guard(_locks[n]); + for (auto id : _graph_store->get_neighbours(n)) + { + assert(id < _max_points + _num_frozen_pts); + + if (use_filter) + { + // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. + if (!detect_common_filters(id, search_invocation, filter_labels)) + continue; + } + + if (is_not_visited(id)) + { + id_scratch.push_back(id); + } + } + } + else + { + _locks[n].lock(); + auto nbrs = _graph_store->get_neighbours(n); + _locks[n].unlock(); + for (auto id : nbrs) + { + assert(id < _max_points + _num_frozen_pts); + + if (use_filter) + { + // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. + if (!detect_common_filters(id, search_invocation, filter_labels)) + continue; + } + + if (is_not_visited(id)) + { + id_scratch.push_back(id); + } + } + } + + // Mark nodes visited + for (auto id : id_scratch) + { + if (fast_iterate) + { + inserted_into_pool_bs[id] = 1; + } + else + { + inserted_into_pool_rs.insert(id); + } + } - assert(dist_scratch.capacity() >= id_scratch.size()); - compute_dists(id_scratch, dist_scratch); - cmps += (uint32_t)id_scratch.size(); + assert(dist_scratch.capacity() >= id_scratch.size()); + compute_dists(id_scratch, dist_scratch); + cmps += (uint32_t)id_scratch.size(); - // Insert pairs into the pool of candidates - for (size_t m = 0; m < id_scratch.size(); ++m) { - best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); + // Insert pairs into the pool of candidates + for (size_t m = 0; m < id_scratch.size(); ++m) + { + best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); + } } - } - return std::make_pair(hops, cmps); + return std::make_pair(hops, cmps); } template -void Index::search_for_point_and_prune( - int location, uint32_t Lindex, std::vector &pruned_list, - InMemQueryScratch *scratch, bool use_filter, uint32_t filteredLindex) { - const std::vector init_ids = get_init_ids(); - const std::vector unused_filter_label; - - if (!use_filter) { - _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, - unused_filter_label, false); - } else { - std::shared_lock tl(_tag_lock, std::defer_lock); - if (_dynamic_index) - tl.lock(); - std::vector filter_specific_start_nodes; - for (auto &x : _location_to_labels[location]) - filter_specific_start_nodes.emplace_back(_label_to_start_id[x]); - - if (_dynamic_index) - tl.unlock(); - - _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, - true, _location_to_labels[location], false); +void Index::search_for_point_and_prune(int location, uint32_t Lindex, + std::vector &pruned_list, + InMemQueryScratch *scratch, bool use_filter, + uint32_t filteredLindex) +{ + const std::vector init_ids = get_init_ids(); + const std::vector unused_filter_label; - // combine candidate pools obtained with filter and unfiltered criteria. - std::set best_candidate_pool; - for (auto filtered_neighbor : scratch->pool()) { - best_candidate_pool.insert(filtered_neighbor); + if (!use_filter) + { + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); } + else + { + std::shared_lock tl(_tag_lock, std::defer_lock); + if (_dynamic_index) + tl.lock(); + std::vector filter_specific_start_nodes; + for (auto &x : _location_to_labels[location]) + filter_specific_start_nodes.emplace_back(_label_to_start_id[x]); + + if (_dynamic_index) + tl.unlock(); + + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true, + _location_to_labels[location], false); + + // combine candidate pools obtained with filter and unfiltered criteria. + std::set best_candidate_pool; + for (auto filtered_neighbor : scratch->pool()) + { + best_candidate_pool.insert(filtered_neighbor); + } - // clear scratch for finding unfiltered candidates - scratch->clear(); + // clear scratch for finding unfiltered candidates + scratch->clear(); - _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, - unused_filter_label, false); + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); - for (auto unfiltered_neighbour : scratch->pool()) { - // insert if this neighbour is not already in best_candidate_pool - if (best_candidate_pool.find(unfiltered_neighbour) == - best_candidate_pool.end()) { - best_candidate_pool.insert(unfiltered_neighbour); - } - } + for (auto unfiltered_neighbour : scratch->pool()) + { + // insert if this neighbour is not already in best_candidate_pool + if (best_candidate_pool.find(unfiltered_neighbour) == best_candidate_pool.end()) + { + best_candidate_pool.insert(unfiltered_neighbour); + } + } - scratch->pool().clear(); - std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), - std::back_inserter(scratch->pool())); - } + scratch->pool().clear(); + std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), std::back_inserter(scratch->pool())); + } - auto &pool = scratch->pool(); + auto &pool = scratch->pool(); - for (uint32_t i = 0; i < pool.size(); i++) { - if (pool[i].id == (uint32_t)location) { - pool.erase(pool.begin() + i); - i--; + for (uint32_t i = 0; i < pool.size(); i++) + { + if (pool[i].id == (uint32_t)location) + { + pool.erase(pool.begin() + i); + i--; + } } - } - if (pruned_list.size() > 0) { - throw diskann::ANNException("ERROR: non-empty pruned_list passed", -1, - __FUNCSIG__, __FILE__, __LINE__); - } + if (pruned_list.size() > 0) + { + throw diskann::ANNException("ERROR: non-empty pruned_list passed", -1, __FUNCSIG__, __FILE__, __LINE__); + } - prune_neighbors(location, pool, pruned_list, scratch); + prune_neighbors(location, pool, pruned_list, scratch); - assert(!pruned_list.empty()); - assert(_graph_store->get_total_points() == _max_points + _num_frozen_pts); + assert(!pruned_list.empty()); + assert(_graph_store->get_total_points() == _max_points + _num_frozen_pts); } template -void Index::occlude_list( - const uint32_t location, std::vector &pool, const float alpha, - const uint32_t degree, const uint32_t maxc, std::vector &result, - InMemQueryScratch *scratch, - const tsl::robin_set *const delete_set_ptr) { - if (pool.size() == 0) - return; - - // Truncate pool at maxc and initialize scratch spaces - assert(std::is_sorted(pool.begin(), pool.end())); - assert(result.size() == 0); - if (pool.size() > maxc) - pool.resize(maxc); - std::vector &occlude_factor = scratch->occlude_factor(); - // occlude_list can be called with the same scratch more than once by - // search_for_point_and_add_link through inter_insert. - occlude_factor.clear(); - // Initialize occlude_factor to pool.size() many 0.0f values for correctness - occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); - - float cur_alpha = 1; - while (cur_alpha <= alpha && result.size() < degree) { - // used for MIPS, where we store a value of eps in cur_alpha to - // denote pruned out entries which we can skip in later rounds. - float eps = cur_alpha + 0.01f; - - for (auto iter = pool.begin(); result.size() < degree && iter != pool.end(); - ++iter) { - if (occlude_factor[iter - pool.begin()] > cur_alpha) { - continue; - } - // Set the entry to float::max so that is not considered again - occlude_factor[iter - pool.begin()] = std::numeric_limits::max(); - // Add the entry to the result if its not been deleted, and doesn't - // add a self loop - if (delete_set_ptr == nullptr || - delete_set_ptr->find(iter->id) == delete_set_ptr->end()) { - if (iter->id != location) { - result.push_back(iter->id); - } - } - - // Update occlude factor for points from iter+1 to pool.end() - for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++) { - auto t = iter2 - pool.begin(); - if (occlude_factor[t] > alpha) - continue; - - bool prune_allowed = true; - if (_filtered_index) { - uint32_t a = iter->id; - uint32_t b = iter2->id; - if (_location_to_labels.size() < b || _location_to_labels.size() < a) - continue; - for (auto &x : _location_to_labels[b]) { - if (std::find(_location_to_labels[a].begin(), - _location_to_labels[a].end(), - x) == _location_to_labels[a].end()) { - prune_allowed = false; +void Index::occlude_list(const uint32_t location, std::vector &pool, const float alpha, + const uint32_t degree, const uint32_t maxc, std::vector &result, + InMemQueryScratch *scratch, + const tsl::robin_set *const delete_set_ptr) +{ + if (pool.size() == 0) + return; + + // Truncate pool at maxc and initialize scratch spaces + assert(std::is_sorted(pool.begin(), pool.end())); + assert(result.size() == 0); + if (pool.size() > maxc) + pool.resize(maxc); + std::vector &occlude_factor = scratch->occlude_factor(); + // occlude_list can be called with the same scratch more than once by + // search_for_point_and_add_link through inter_insert. + occlude_factor.clear(); + // Initialize occlude_factor to pool.size() many 0.0f values for correctness + occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); + + float cur_alpha = 1; + while (cur_alpha <= alpha && result.size() < degree) + { + // used for MIPS, where we store a value of eps in cur_alpha to + // denote pruned out entries which we can skip in later rounds. + float eps = cur_alpha + 0.01f; + + for (auto iter = pool.begin(); result.size() < degree && iter != pool.end(); ++iter) + { + if (occlude_factor[iter - pool.begin()] > cur_alpha) + { + continue; + } + // Set the entry to float::max so that is not considered again + occlude_factor[iter - pool.begin()] = std::numeric_limits::max(); + // Add the entry to the result if its not been deleted, and doesn't + // add a self loop + if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end()) + { + if (iter->id != location) + { + result.push_back(iter->id); + } } - if (!prune_allowed) - break; - } - } - if (!prune_allowed) - continue; - float djk = _data_store->get_distance(iter2->id, iter->id); - if (_dist_metric == diskann::Metric::L2 || - _dist_metric == diskann::Metric::COSINE) { - occlude_factor[t] = - (djk == 0) ? std::numeric_limits::max() - : std::max(occlude_factor[t], iter2->distance / djk); - } else if (_dist_metric == diskann::Metric::INNER_PRODUCT) { - // Improvization for flipping max and min dist for MIPS - float x = -iter2->distance; - float y = -djk; - if (y > cur_alpha * x) { - occlude_factor[t] = std::max(occlude_factor[t], eps); - } + // Update occlude factor for points from iter+1 to pool.end() + for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++) + { + auto t = iter2 - pool.begin(); + if (occlude_factor[t] > alpha) + continue; + + bool prune_allowed = true; + if (_filtered_index) + { + uint32_t a = iter->id; + uint32_t b = iter2->id; + if (_location_to_labels.size() < b || _location_to_labels.size() < a) + continue; + for (auto &x : _location_to_labels[b]) + { + if (std::find(_location_to_labels[a].begin(), _location_to_labels[a].end(), x) == + _location_to_labels[a].end()) + { + prune_allowed = false; + } + if (!prune_allowed) + break; + } + } + if (!prune_allowed) + continue; + + float djk = _data_store->get_distance(iter2->id, iter->id); + if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE) + { + occlude_factor[t] = (djk == 0) ? std::numeric_limits::max() + : std::max(occlude_factor[t], iter2->distance / djk); + } + else if (_dist_metric == diskann::Metric::INNER_PRODUCT) + { + // Improvization for flipping max and min dist for MIPS + float x = -iter2->distance; + float y = -djk; + if (y > cur_alpha * x) + { + occlude_factor[t] = std::max(occlude_factor[t], eps); + } + } + } } - } + cur_alpha *= 1.2f; } - cur_alpha *= 1.2f; - } } template -void Index::prune_neighbors(const uint32_t location, - std::vector &pool, - std::vector &pruned_list, - InMemQueryScratch *scratch) { - prune_neighbors(location, pool, _indexingRange, _indexingMaxC, _indexingAlpha, - pruned_list, scratch); +void Index::prune_neighbors(const uint32_t location, std::vector &pool, + std::vector &pruned_list, InMemQueryScratch *scratch) +{ + prune_neighbors(location, pool, _indexingRange, _indexingMaxC, _indexingAlpha, pruned_list, scratch); } template -void Index::prune_neighbors( - const uint32_t location, std::vector &pool, const uint32_t range, - const uint32_t max_candidate_size, const float alpha, - std::vector &pruned_list, InMemQueryScratch *scratch) { - if (pool.size() == 0) { - // if the pool is empty, behave like a noop - pruned_list.clear(); - return; - } +void Index::prune_neighbors(const uint32_t location, std::vector &pool, const uint32_t range, + const uint32_t max_candidate_size, const float alpha, + std::vector &pruned_list, InMemQueryScratch *scratch) +{ + if (pool.size() == 0) + { + // if the pool is empty, behave like a noop + pruned_list.clear(); + return; + } - // If using _pq_build, over-write the PQ distances with actual distances - // REFACTOR PQ: TODO: How to get rid of this!? - if (_pq_dist) { - for (auto &ngh : pool) - ngh.distance = _data_store->get_distance(ngh.id, location); - } + // If using _pq_build, over-write the PQ distances with actual distances + // REFACTOR PQ: TODO: How to get rid of this!? + if (_pq_dist) + { + for (auto &ngh : pool) + ngh.distance = _data_store->get_distance(ngh.id, location); + } - // sort the pool based on distance to query and prune it with occlude_list - std::sort(pool.begin(), pool.end()); - pruned_list.clear(); - pruned_list.reserve(range); + // sort the pool based on distance to query and prune it with occlude_list + std::sort(pool.begin(), pool.end()); + pruned_list.clear(); + pruned_list.reserve(range); - occlude_list(location, pool, alpha, range, max_candidate_size, pruned_list, - scratch); - assert(pruned_list.size() <= range); + occlude_list(location, pool, alpha, range, max_candidate_size, pruned_list, scratch); + assert(pruned_list.size() <= range); - if (_saturate_graph && alpha > 1) { - for (const auto &node : pool) { - if (pruned_list.size() >= range) - break; - if ((std::find(pruned_list.begin(), pruned_list.end(), node.id) == - pruned_list.end()) && - node.id != location) - pruned_list.push_back(node.id); + if (_saturate_graph && alpha > 1) + { + for (const auto &node : pool) + { + if (pruned_list.size() >= range) + break; + if ((std::find(pruned_list.begin(), pruned_list.end(), node.id) == pruned_list.end()) && + node.id != location) + pruned_list.push_back(node.id); + } } - } } template -void Index::inter_insert(uint32_t n, - std::vector &pruned_list, - const uint32_t range, - InMemQueryScratch *scratch) { - const auto &src_pool = pruned_list; - - assert(!src_pool.empty()); - - for (auto des : src_pool) { - // des.loc is the loc of the neighbors of n - assert(des < _max_points + _num_frozen_pts); - // des_pool contains the neighbors of the neighbors of n - std::vector copy_of_neighbors; - bool prune_needed = false; - { - LockGuard guard(_locks[des]); - auto &des_pool = _graph_store->get_neighbours(des); - if (std::find(des_pool.begin(), des_pool.end(), n) == des_pool.end()) { - if (des_pool.size() < - (uint64_t)(defaults::GRAPH_SLACK_FACTOR * range)) { - // des_pool.emplace_back(n); - _graph_store->add_neighbour(des, n); - prune_needed = false; - } else { - copy_of_neighbors.reserve(des_pool.size() + 1); - copy_of_neighbors = des_pool; - copy_of_neighbors.push_back(n); - prune_needed = true; - } - } - } // des lock is released by this point - - if (prune_needed) { - tsl::robin_set dummy_visited(0); - std::vector dummy_pool(0); - - size_t reserveSize = - (size_t)(std::ceil(1.05 * defaults::GRAPH_SLACK_FACTOR * range)); - dummy_visited.reserve(reserveSize); - dummy_pool.reserve(reserveSize); - - for (auto cur_nbr : copy_of_neighbors) { - if (dummy_visited.find(cur_nbr) == dummy_visited.end() && - cur_nbr != des) { - float dist = _data_store->get_distance(des, cur_nbr); - dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); - dummy_visited.insert(cur_nbr); - } - } - std::vector new_out_neighbors; - prune_neighbors(des, dummy_pool, new_out_neighbors, scratch); - { - LockGuard guard(_locks[des]); - - _graph_store->set_neighbours(des, new_out_neighbors); - } - } - } -} +void Index::inter_insert(uint32_t n, std::vector &pruned_list, const uint32_t range, + InMemQueryScratch *scratch) +{ + const auto &src_pool = pruned_list; -template -void Index::inter_insert(uint32_t n, - std::vector &pruned_list, - InMemQueryScratch *scratch) { - inter_insert(n, pruned_list, _indexingRange, scratch); + assert(!src_pool.empty()); + + for (auto des : src_pool) + { + // des.loc is the loc of the neighbors of n + assert(des < _max_points + _num_frozen_pts); + // des_pool contains the neighbors of the neighbors of n + std::vector copy_of_neighbors; + bool prune_needed = false; + { + LockGuard guard(_locks[des]); + auto &des_pool = _graph_store->get_neighbours(des); + if (std::find(des_pool.begin(), des_pool.end(), n) == des_pool.end()) + { + if (des_pool.size() < (uint64_t)(defaults::GRAPH_SLACK_FACTOR * range)) + { + // des_pool.emplace_back(n); + _graph_store->add_neighbour(des, n); + prune_needed = false; + } + else + { + copy_of_neighbors.reserve(des_pool.size() + 1); + copy_of_neighbors = des_pool; + copy_of_neighbors.push_back(n); + prune_needed = true; + } + } + } // des lock is released by this point + + if (prune_needed) + { + tsl::robin_set dummy_visited(0); + std::vector dummy_pool(0); + + size_t reserveSize = (size_t)(std::ceil(1.05 * defaults::GRAPH_SLACK_FACTOR * range)); + dummy_visited.reserve(reserveSize); + dummy_pool.reserve(reserveSize); + + for (auto cur_nbr : copy_of_neighbors) + { + if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != des) + { + float dist = _data_store->get_distance(des, cur_nbr); + dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); + dummy_visited.insert(cur_nbr); + } + } + std::vector new_out_neighbors; + prune_neighbors(des, dummy_pool, new_out_neighbors, scratch); + { + LockGuard guard(_locks[des]); + + _graph_store->set_neighbours(des, new_out_neighbors); + } + } + } } template -void Index::link() { - uint32_t num_threads = _indexingThreads; - if (num_threads != 0) - omp_set_num_threads(num_threads); - - /* visit_order is a vector that is initialized to the entire graph */ - std::vector visit_order; - std::vector pool, tmp; - tsl::robin_set visited; - visit_order.reserve(_nd + _num_frozen_pts); - for (uint32_t i = 0; i < (uint32_t)_nd; i++) { - visit_order.emplace_back(i); - } - - // If there are any frozen points, add them all. - for (uint32_t frozen = (uint32_t)_max_points; - frozen < _max_points + _num_frozen_pts; frozen++) { - visit_order.emplace_back(frozen); - } - - // if there are frozen points, the first such one is set to be the _start - if (_num_frozen_pts > 0) - _start = (uint32_t)_max_points; - else - _start = calculate_entry_point(); - - diskann::Timer link_timer; +void Index::inter_insert(uint32_t n, std::vector &pruned_list, InMemQueryScratch *scratch) +{ + inter_insert(n, pruned_list, _indexingRange, scratch); +} -#pragma omp parallel for schedule(dynamic, 2048) - for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); - node_ctr++) { - auto node = visit_order[node_ctr]; +template void Index::link() +{ + uint32_t num_threads = _indexingThreads; + if (num_threads != 0) + omp_set_num_threads(num_threads); - // Find and add appropriate graph edges - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - std::vector pruned_list; - if (_filtered_index) { - search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, - true, _filterIndexingQueueSize); - } else { - search_for_point_and_prune(node, _indexingQueueSize, pruned_list, - scratch); + /* visit_order is a vector that is initialized to the entire graph */ + std::vector visit_order; + std::vector pool, tmp; + tsl::robin_set visited; + visit_order.reserve(_nd + _num_frozen_pts); + for (uint32_t i = 0; i < (uint32_t)_nd; i++) + { + visit_order.emplace_back(i); } - assert(pruned_list.size() > 0); + // If there are any frozen points, add them all. + for (uint32_t frozen = (uint32_t)_max_points; frozen < _max_points + _num_frozen_pts; frozen++) { - LockGuard guard(_locks[node]); - - _graph_store->set_neighbours(node, pruned_list); - assert(_graph_store->get_neighbours((location_t)node).size() <= - _indexingRange); + visit_order.emplace_back(frozen); } - inter_insert(node, pruned_list, scratch); + // if there are frozen points, the first such one is set to be the _start + if (_num_frozen_pts > 0) + _start = (uint32_t)_max_points; + else + _start = calculate_entry_point(); - if (node_ctr % 100000 == 0) { - diskann::cout << "\r" << (100.0 * node_ctr) / (visit_order.size()) - << "% of index build completed." << std::flush; - } - } + diskann::Timer link_timer; - if (_nd > 0) { - diskann::cout << "Starting final cleanup.." << std::flush; - } #pragma omp parallel for schedule(dynamic, 2048) - for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); - node_ctr++) { - auto node = visit_order[node_ctr]; - if (_graph_store->get_neighbours((location_t)node).size() > - _indexingRange) { - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - - tsl::robin_set dummy_visited(0); - std::vector dummy_pool(0); - std::vector new_out_neighbors; - - for (auto cur_nbr : _graph_store->get_neighbours((location_t)node)) { - if (dummy_visited.find(cur_nbr) == dummy_visited.end() && - cur_nbr != node) { - float dist = _data_store->get_distance(node, cur_nbr); - dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); - dummy_visited.insert(cur_nbr); - } - } - prune_neighbors(node, dummy_pool, new_out_neighbors, scratch); - - _graph_store->clear_neighbours((location_t)node); - _graph_store->set_neighbours((location_t)node, new_out_neighbors); - } - } - if (_nd > 0) { - diskann::cout << "done. Link time: " - << ((double)link_timer.elapsed() / (double)1000000) << "s" - << std::endl; - } -} - -template -void Index::prune_all_neighbors( - const uint32_t max_degree, const uint32_t max_occlusion_size, - const float alpha) { - const uint32_t range = max_degree; - const uint32_t maxc = max_occlusion_size; - - _filtered_index = true; - - diskann::Timer timer; -#pragma omp parallel for - for (int64_t node = 0; node < (int64_t)(_max_points + _num_frozen_pts); - node++) { - if ((size_t)node < _nd || (size_t)node >= _max_points) { - if (_graph_store->get_neighbours((location_t)node).size() > range) { - tsl::robin_set dummy_visited(0); - std::vector dummy_pool(0); - std::vector new_out_neighbors; + for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); node_ctr++) + { + auto node = visit_order[node_ctr]; + // Find and add appropriate graph edges ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); + std::vector pruned_list; + if (_filtered_index) + { + search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); + } + else + { + search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch); + } + assert(pruned_list.size() > 0); - for (auto cur_nbr : _graph_store->get_neighbours((location_t)node)) { - if (dummy_visited.find(cur_nbr) == dummy_visited.end() && - cur_nbr != node) { - float dist = _data_store->get_distance((location_t)node, - (location_t)cur_nbr); - dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); - dummy_visited.insert(cur_nbr); - } - } - - prune_neighbors((uint32_t)node, dummy_pool, range, maxc, alpha, - new_out_neighbors, scratch); - _graph_store->clear_neighbours((location_t)node); - _graph_store->set_neighbours((location_t)node, new_out_neighbors); - } - } - } - - diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms" - << std::endl; - size_t max = 0, min = 1 << 30, total = 0, cnt = 0; - for (size_t i = 0; i < _max_points + _num_frozen_pts; i++) { - if (i < _nd || i >= _max_points) { - const std::vector &pool = - _graph_store->get_neighbours((location_t)i); - max = (std::max)(max, pool.size()); - min = (std::min)(min, pool.size()); - total += pool.size(); - if (pool.size() < 2) - cnt++; - } - } - if (min > max) - min = max; - if (_nd > 0) { - diskann::cout << "Index built with degree: max:" << max - << " avg:" << (float)total / (float)(_nd + _num_frozen_pts) - << " min:" << min << " count(deg<2):" << cnt << std::endl; - } + { + LockGuard guard(_locks[node]); + + _graph_store->set_neighbours(node, pruned_list); + assert(_graph_store->get_neighbours((location_t)node).size() <= _indexingRange); + } + + inter_insert(node, pruned_list, scratch); + + if (node_ctr % 100000 == 0) + { + diskann::cout << "\r" << (100.0 * node_ctr) / (visit_order.size()) << "% of index build completed." + << std::flush; + } + } + + if (_nd > 0) + { + diskann::cout << "Starting final cleanup.." << std::flush; + } +#pragma omp parallel for schedule(dynamic, 2048) + for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); node_ctr++) + { + auto node = visit_order[node_ctr]; + if (_graph_store->get_neighbours((location_t)node).size() > _indexingRange) + { + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + tsl::robin_set dummy_visited(0); + std::vector dummy_pool(0); + std::vector new_out_neighbors; + + for (auto cur_nbr : _graph_store->get_neighbours((location_t)node)) + { + if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) + { + float dist = _data_store->get_distance(node, cur_nbr); + dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); + dummy_visited.insert(cur_nbr); + } + } + prune_neighbors(node, dummy_pool, new_out_neighbors, scratch); + + _graph_store->clear_neighbours((location_t)node); + _graph_store->set_neighbours((location_t)node, new_out_neighbors); + } + } + if (_nd > 0) + { + diskann::cout << "done. Link time: " << ((double)link_timer.elapsed() / (double)1000000) << "s" << std::endl; + } +} + +template +void Index::prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion_size, + const float alpha) +{ + const uint32_t range = max_degree; + const uint32_t maxc = max_occlusion_size; + + _filtered_index = true; + + diskann::Timer timer; +#pragma omp parallel for + for (int64_t node = 0; node < (int64_t)(_max_points + _num_frozen_pts); node++) + { + if ((size_t)node < _nd || (size_t)node >= _max_points) + { + if (_graph_store->get_neighbours((location_t)node).size() > range) + { + tsl::robin_set dummy_visited(0); + std::vector dummy_pool(0); + std::vector new_out_neighbors; + + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + for (auto cur_nbr : _graph_store->get_neighbours((location_t)node)) + { + if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) + { + float dist = _data_store->get_distance((location_t)node, (location_t)cur_nbr); + dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); + dummy_visited.insert(cur_nbr); + } + } + + prune_neighbors((uint32_t)node, dummy_pool, range, maxc, alpha, new_out_neighbors, scratch); + _graph_store->clear_neighbours((location_t)node); + _graph_store->set_neighbours((location_t)node, new_out_neighbors); + } + } + } + + diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms" << std::endl; + size_t max = 0, min = 1 << 30, total = 0, cnt = 0; + for (size_t i = 0; i < _max_points + _num_frozen_pts; i++) + { + if (i < _nd || i >= _max_points) + { + const std::vector &pool = _graph_store->get_neighbours((location_t)i); + max = (std::max)(max, pool.size()); + min = (std::min)(min, pool.size()); + total += pool.size(); + if (pool.size() < 2) + cnt++; + } + } + if (min > max) + min = max; + if (_nd > 0) + { + diskann::cout << "Index built with degree: max:" << max + << " avg:" << (float)total / (float)(_nd + _num_frozen_pts) << " min:" << min + << " count(deg<2):" << cnt << std::endl; + } } // REFACTOR template -void Index::set_start_points(const T *data, - size_t data_count) { - std::unique_lock ul(_update_lock); - std::unique_lock tl(_tag_lock); - if (_nd > 0) - throw ANNException("Can not set starting point for a non-empty index", -1, - __FUNCSIG__, __FILE__, __LINE__); - - if (data_count != _num_frozen_pts * _dim) - throw ANNException("Invalid number of points", -1, __FUNCSIG__, __FILE__, - __LINE__); - - // memcpy(_data + _aligned_dim * _max_points, data, _aligned_dim * - // sizeof(T) * _num_frozen_pts); - for (location_t i = 0; i < _num_frozen_pts; i++) { - _data_store->set_vector((location_t)(i + _max_points), data + i * _dim); - } - _has_built = true; - diskann::cout << "Index start points set: #" << _num_frozen_pts << std::endl; +void Index::set_start_points(const T *data, size_t data_count) +{ + std::unique_lock ul(_update_lock); + std::unique_lock tl(_tag_lock); + if (_nd > 0) + throw ANNException("Can not set starting point for a non-empty index", -1, __FUNCSIG__, __FILE__, __LINE__); + + if (data_count != _num_frozen_pts * _dim) + throw ANNException("Invalid number of points", -1, __FUNCSIG__, __FILE__, __LINE__); + + // memcpy(_data + _aligned_dim * _max_points, data, _aligned_dim * + // sizeof(T) * _num_frozen_pts); + for (location_t i = 0; i < _num_frozen_pts; i++) + { + _data_store->set_vector((location_t)(i + _max_points), data + i * _dim); + } + _has_built = true; + diskann::cout << "Index start points set: #" << _num_frozen_pts << std::endl; } template -void Index::_set_start_points_at_random(DataType radius, - uint32_t random_seed) { - try { - T radius_to_use = std::any_cast(radius); - this->set_start_points_at_random(radius_to_use, random_seed); - } catch (const std::bad_any_cast &e) { - throw ANNException( - "Error: bad any cast while performing _set_start_points_at_random() " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error: " + std::string(e.what()), -1); - } +void Index::_set_start_points_at_random(DataType radius, uint32_t random_seed) +{ + try + { + T radius_to_use = std::any_cast(radius); + this->set_start_points_at_random(radius_to_use, random_seed); + } + catch (const std::bad_any_cast &e) + { + throw ANNException( + "Error: bad any cast while performing _set_start_points_at_random() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } template -void Index::set_start_points_at_random(T radius, - uint32_t random_seed) { - std::mt19937 gen{random_seed}; - std::normal_distribution<> d{0.0, 1.0}; +void Index::set_start_points_at_random(T radius, uint32_t random_seed) +{ + std::mt19937 gen{random_seed}; + std::normal_distribution<> d{0.0, 1.0}; - std::vector points_data; - points_data.reserve(_dim * _num_frozen_pts); - std::vector real_vec(_dim); + std::vector points_data; + points_data.reserve(_dim * _num_frozen_pts); + std::vector real_vec(_dim); - for (size_t frozen_point = 0; frozen_point < _num_frozen_pts; - frozen_point++) { - double norm_sq = 0.0; - for (size_t i = 0; i < _dim; ++i) { - auto r = d(gen); - real_vec[i] = r; - norm_sq += r * r; - } + for (size_t frozen_point = 0; frozen_point < _num_frozen_pts; frozen_point++) + { + double norm_sq = 0.0; + for (size_t i = 0; i < _dim; ++i) + { + auto r = d(gen); + real_vec[i] = r; + norm_sq += r * r; + } - const double norm = std::sqrt(norm_sq); - for (auto iter : real_vec) - points_data.push_back(static_cast(iter * radius / norm)); - } + const double norm = std::sqrt(norm_sq); + for (auto iter : real_vec) + points_data.push_back(static_cast(iter * radius / norm)); + } - set_start_points(points_data.data(), points_data.size()); + set_start_points(points_data.data(), points_data.size()); } template -void Index::build_with_data_populated( - const std::vector &tags) { - diskann::cout << "Starting index build with " << _nd << " points... " - << std::endl; +void Index::build_with_data_populated(const std::vector &tags) +{ + diskann::cout << "Starting index build with " << _nd << " points... " << std::endl; - if (_nd < 1) - throw ANNException("Error: Trying to build an index with 0 points", -1, - __FUNCSIG__, __FILE__, __LINE__); + if (_nd < 1) + throw ANNException("Error: Trying to build an index with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__); - if (_enable_tags && tags.size() != _nd) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << _nd << " points from file," - << "but tags vector is of size " << tags.size() << "." << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - if (_enable_tags) { - for (size_t i = 0; i < tags.size(); ++i) { - _tag_to_location[tags[i]] = (uint32_t)i; - _location_to_tag.set(static_cast(i), tags[i]); - } - } - - uint32_t index_R = _indexingRange; - uint32_t num_threads_index = _indexingThreads; - uint32_t index_L = _indexingQueueSize; - uint32_t maxc = _indexingMaxC; - - if (_query_scratch.size() == 0) { - initialize_query_scratch(5 + num_threads_index, index_L, index_L, index_R, - maxc, _data_store->get_aligned_dim()); - } - - generate_frozen_point(); - link(); - - size_t max = 0, min = SIZE_MAX, total = 0, cnt = 0; - for (size_t i = 0; i < _nd; i++) { - auto &pool = _graph_store->get_neighbours((location_t)i); - max = std::max(max, pool.size()); - min = std::min(min, pool.size()); - total += pool.size(); - if (pool.size() < 2) - cnt++; - } - diskann::cout << "Index built with degree: max:" << max - << " avg:" << (float)total / (float)(_nd + _num_frozen_pts) - << " min:" << min << " count(deg<2):" << cnt << std::endl; - - _has_built = true; + if (_enable_tags && tags.size() != _nd) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << _nd << " points from file," + << "but tags vector is of size " << tags.size() << "." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + if (_enable_tags) + { + for (size_t i = 0; i < tags.size(); ++i) + { + _tag_to_location[tags[i]] = (uint32_t)i; + _location_to_tag.set(static_cast(i), tags[i]); + } + } + + uint32_t index_R = _indexingRange; + uint32_t num_threads_index = _indexingThreads; + uint32_t index_L = _indexingQueueSize; + uint32_t maxc = _indexingMaxC; + + if (_query_scratch.size() == 0) + { + initialize_query_scratch(5 + num_threads_index, index_L, index_L, index_R, maxc, + _data_store->get_aligned_dim()); + } + + generate_frozen_point(); + link(); + + size_t max = 0, min = SIZE_MAX, total = 0, cnt = 0; + for (size_t i = 0; i < _nd; i++) + { + auto &pool = _graph_store->get_neighbours((location_t)i); + max = std::max(max, pool.size()); + min = std::min(min, pool.size()); + total += pool.size(); + if (pool.size() < 2) + cnt++; + } + diskann::cout << "Index built with degree: max:" << max << " avg:" << (float)total / (float)(_nd + _num_frozen_pts) + << " min:" << min << " count(deg<2):" << cnt << std::endl; + + _has_built = true; } template -void Index::_build(const DataType &data, - const size_t num_points_to_load, - TagVector &tags) { - try { - this->build(std::any_cast(data), num_points_to_load, - tags.get>()); - } catch (const std::bad_any_cast &e) { - throw ANNException("Error: bad any cast in while building index. " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error" + std::string(e.what()), -1); - } +void Index::_build(const DataType &data, const size_t num_points_to_load, TagVector &tags) +{ + try + { + this->build(std::any_cast(data), num_points_to_load, tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast in while building index. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error" + std::string(e.what()), -1); + } } template -void Index::build(const T *data, - const size_t num_points_to_load, - const std::vector &tags) { - if (num_points_to_load == 0) { - throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - if (_pq_dist) { - throw ANNException( - "ERROR: DO not use this build interface with PQ distance", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - std::unique_lock ul(_update_lock); - - { - std::unique_lock tl(_tag_lock); - _nd = num_points_to_load; +void Index::build(const T *data, const size_t num_points_to_load, const std::vector &tags) +{ + if (num_points_to_load == 0) + { + throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__); + } + if (_pq_dist) + { + throw ANNException("ERROR: DO not use this build interface with PQ distance", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + std::unique_lock ul(_update_lock); + + { + std::unique_lock tl(_tag_lock); + _nd = num_points_to_load; - _data_store->populate_data(data, (location_t)num_points_to_load); - } + _data_store->populate_data(data, (location_t)num_points_to_load); + } - build_with_data_populated(tags); + build_with_data_populated(tags); } template -void Index::build(const char *filename, - const size_t num_points_to_load, - const std::vector &tags) { - // idealy this should call build_filtered_index based on params passed +void Index::build(const char *filename, const size_t num_points_to_load, const std::vector &tags) +{ + // idealy this should call build_filtered_index based on params passed - std::unique_lock ul(_update_lock); + std::unique_lock ul(_update_lock); - // error checks - if (num_points_to_load == 0) - throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, - __FILE__, __LINE__); + // error checks + if (num_points_to_load == 0) + throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__); - if (!file_exists(filename)) { - std::stringstream stream; - stream << "ERROR: Data file " << filename << " does not exist." - << std::endl; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - size_t file_num_points, file_dim; - if (filename == nullptr) { - throw diskann::ANNException("Can not build with an empty file", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - diskann::get_bin_metadata(filename, file_num_points, file_dim); - if (file_num_points > _max_points) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << num_points_to_load - << " points and file has " << file_num_points << " points, but " - << "index can support only " << _max_points - << " points as specified in constructor." << std::endl; + if (!file_exists(filename)) + { + std::stringstream stream; + stream << "ERROR: Data file " << filename << " does not exist." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + size_t file_num_points, file_dim; + if (filename == nullptr) + { + throw diskann::ANNException("Can not build with an empty file", -1, __FUNCSIG__, __FILE__, __LINE__); + } - if (num_points_to_load > file_num_points) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << num_points_to_load - << " points and file has only " << file_num_points << " points." - << std::endl; + diskann::get_bin_metadata(filename, file_num_points, file_dim); + if (file_num_points > _max_points) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << num_points_to_load << " points and file has " << file_num_points + << " points, but " + << "index can support only " << _max_points << " points as specified in constructor." << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - if (file_dim != _dim) { - std::stringstream stream; - stream << "ERROR: Driver requests loading " << _dim << " dimension," - << "but file has " << file_dim << " dimension." << std::endl; - diskann::cerr << stream.str() << std::endl; + if (num_points_to_load > file_num_points) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << num_points_to_load << " points and file has only " + << file_num_points << " points." << std::endl; + + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (file_dim != _dim) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << _dim << " dimension," + << "but file has " << file_dim << " dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } - // REFACTOR PQ TODO: We can remove this if and add a check in the - // InMemDataStore to not populate_data if it has been called once. - if (_pq_dist) { + // REFACTOR PQ TODO: We can remove this if and add a check in the + // InMemDataStore to not populate_data if it has been called once. + if (_pq_dist) + { #ifdef EXEC_ENV_OLS - std::stringstream ss; - ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS " - "is defined)" - << std::endl; - diskann::cerr << ss.str() << std::endl; - throw ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + std::stringstream ss; + ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS " + "is defined)" + << std::endl; + diskann::cerr << ss.str() << std::endl; + throw ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); #else - // REFACTOR TODO: Both in the previous code and in the current PQDataStore, - // we are writing the PQ files in the same path as the input file. Now we - // may not have write permissions to that folder, but we will always have - // write permissions to the output folder. So we should write the PQ files - // there. The problem is that the Index class gets the output folder prefix - // only at the time of save(), by which time we are too late. So leaving it - // as-is for now. - _pq_data_store->populate_data(filename, 0U); + // REFACTOR TODO: Both in the previous code and in the current PQDataStore, + // we are writing the PQ files in the same path as the input file. Now we + // may not have write permissions to that folder, but we will always have + // write permissions to the output folder. So we should write the PQ files + // there. The problem is that the Index class gets the output folder prefix + // only at the time of save(), by which time we are too late. So leaving it + // as-is for now. + _pq_data_store->populate_data(filename, 0U); #endif - } + } - _data_store->populate_data(filename, 0U); - diskann::cout << "Using only first " << num_points_to_load << " from file.. " - << std::endl; + _data_store->populate_data(filename, 0U); + diskann::cout << "Using only first " << num_points_to_load << " from file.. " << std::endl; - { - std::unique_lock tl(_tag_lock); - _nd = num_points_to_load; - } - build_with_data_populated(tags); + { + std::unique_lock tl(_tag_lock); + _nd = num_points_to_load; + } + build_with_data_populated(tags); } template -void Index::build(const char *filename, - const size_t num_points_to_load, - const char *tag_filename) { - std::vector tags; +void Index::build(const char *filename, const size_t num_points_to_load, const char *tag_filename) +{ + std::vector tags; - if (_enable_tags) { - std::unique_lock tl(_tag_lock); - if (tag_filename == nullptr) { - throw ANNException("Tag filename is null, while _enable_tags is set", -1, - __FUNCSIG__, __FILE__, __LINE__); - } else { - if (file_exists(tag_filename)) { - diskann::cout << "Loading tags from " << tag_filename - << " for vamana index build" << std::endl; - TagT *tag_data = nullptr; - size_t npts, ndim; - diskann::load_bin(tag_filename, tag_data, npts, ndim); - if (npts < num_points_to_load) { - std::stringstream sstream; - sstream << "Loaded " << npts - << " tags, insufficient to populate tags for " - << num_points_to_load << " points to load"; - throw diskann::ANNException(sstream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - for (size_t i = 0; i < num_points_to_load; i++) { - tags.push_back(tag_data[i]); + if (_enable_tags) + { + std::unique_lock tl(_tag_lock); + if (tag_filename == nullptr) + { + throw ANNException("Tag filename is null, while _enable_tags is set", -1, __FUNCSIG__, __FILE__, __LINE__); + } + else + { + if (file_exists(tag_filename)) + { + diskann::cout << "Loading tags from " << tag_filename << " for vamana index build" << std::endl; + TagT *tag_data = nullptr; + size_t npts, ndim; + diskann::load_bin(tag_filename, tag_data, npts, ndim); + if (npts < num_points_to_load) + { + std::stringstream sstream; + sstream << "Loaded " << npts << " tags, insufficient to populate tags for " << num_points_to_load + << " points to load"; + throw diskann::ANNException(sstream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + for (size_t i = 0; i < num_points_to_load; i++) + { + tags.push_back(tag_data[i]); + } + delete[] tag_data; + } + else + { + throw diskann::ANNException(std::string("Tag file") + tag_filename + " does not exist", -1, __FUNCSIG__, + __FILE__, __LINE__); + } } - delete[] tag_data; - } else { - throw diskann::ANNException(std::string("Tag file") + tag_filename + - " does not exist", - -1, __FUNCSIG__, __FILE__, __LINE__); - } } - } - build(filename, num_points_to_load, tags); + build(filename, num_points_to_load, tags); } template -void Index::build(const std::string &data_file, - const size_t num_points_to_load, - IndexFilterParams &filter_params) { - size_t points_to_load = - num_points_to_load == 0 ? _max_points : num_points_to_load; - - auto s = std::chrono::high_resolution_clock::now(); - if (filter_params.label_file == "") { - this->build(data_file.c_str(), points_to_load); - } else { - // TODO: this should ideally happen in save() - std::string labels_file_to_use = - filter_params.save_path_prefix + "_label_formatted.txt"; - std::string mem_labels_int_map_file = - filter_params.save_path_prefix + "_labels_map.txt"; - convert_labels_string_to_int(filter_params.label_file, labels_file_to_use, - mem_labels_int_map_file, - filter_params.universal_label); - if (filter_params.universal_label != "") { - LabelT unv_label_as_num = 0; - this->set_universal_label(unv_label_as_num); - } - this->build_filtered_index(data_file.c_str(), labels_file_to_use, - points_to_load); - } - std::chrono::duration diff = - std::chrono::high_resolution_clock::now() - s; - std::cout << "Indexing time: " << diff.count() << "\n"; +void Index::build(const std::string &data_file, const size_t num_points_to_load, + IndexFilterParams &filter_params) +{ + size_t points_to_load = num_points_to_load == 0 ? _max_points : num_points_to_load; + + auto s = std::chrono::high_resolution_clock::now(); + if (filter_params.label_file == "") + { + this->build(data_file.c_str(), points_to_load); + } + else + { + // TODO: this should ideally happen in save() + std::string labels_file_to_use = filter_params.save_path_prefix + "_label_formatted.txt"; + std::string mem_labels_int_map_file = filter_params.save_path_prefix + "_labels_map.txt"; + convert_labels_string_to_int(filter_params.label_file, labels_file_to_use, mem_labels_int_map_file, + filter_params.universal_label); + if (filter_params.universal_label != "") + { + LabelT unv_label_as_num = 0; + this->set_universal_label(unv_label_as_num); + } + this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load); + } + std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; + std::cout << "Indexing time: " << diff.count() << "\n"; } template -std::unordered_map -Index::load_label_map(const std::string &labels_map_file) { - std::unordered_map string_to_int_mp; - std::ifstream map_reader(labels_map_file); - std::string line, token; - LabelT token_as_num; - std::string label_str; - while (std::getline(map_reader, line)) { - std::istringstream iss(line); - getline(iss, token, '\t'); - label_str = token; - getline(iss, token, '\t'); - token_as_num = (LabelT)std::stoul(token); - string_to_int_mp[label_str] = token_as_num; - } - return string_to_int_mp; +std::unordered_map Index::load_label_map(const std::string &labels_map_file) +{ + std::unordered_map string_to_int_mp; + std::ifstream map_reader(labels_map_file); + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = (LabelT)std::stoul(token); + string_to_int_mp[label_str] = token_as_num; + } + return string_to_int_mp; } template -LabelT -Index::get_converted_label(const std::string &raw_label) { - if (_label_map.find(raw_label) != _label_map.end()) { - return _label_map[raw_label]; - } - if (_use_universal_label) { - return _universal_label; - } - std::stringstream stream; - stream << "Unable to find label in the Label Map"; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); +LabelT Index::get_converted_label(const std::string &raw_label) +{ + if (_label_map.find(raw_label) != _label_map.end()) + { + return _label_map[raw_label]; + } + if (_use_universal_label) + { + return _universal_label; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } template -void Index::parse_label_file(const std::string &label_file, - size_t &num_points) { - // Format of Label txt file: filters with comma separators - - std::ifstream infile(label_file); - if (infile.fail()) { - throw diskann::ANNException( - std::string("Failed to open file ") + label_file, -1); - } - - std::string line, token; - uint32_t line_cnt = 0; - - while (std::getline(infile, line)) { - line_cnt++; - } - _location_to_labels.resize(line_cnt, std::vector()); - - infile.clear(); - infile.seekg(0, std::ios::beg); - line_cnt = 0; - - while (std::getline(infile, line)) { - std::istringstream iss(line); - std::vector lbls(0); - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); - lbls.push_back(token_as_num); - _labels.insert(token_as_num); - } - - std::sort(lbls.begin(), lbls.end()); - _location_to_labels[line_cnt] = lbls; - line_cnt++; - } - num_points = (size_t)line_cnt; - diskann::cout << "Identified " << _labels.size() << " distinct label(s)" - << std::endl; +void Index::parse_label_file(const std::string &label_file, size_t &num_points) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + + while (std::getline(infile, line)) + { + line_cnt++; + } + _location_to_labels.resize(line_cnt, std::vector()); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + std::vector lbls(0); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = (LabelT)std::stoul(token); + lbls.push_back(token_as_num); + _labels.insert(token_as_num); + } + + std::sort(lbls.begin(), lbls.end()); + _location_to_labels[line_cnt] = lbls; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; } template -void Index::_set_universal_label( - const LabelType universal_label) { - this->set_universal_label(std::any_cast(universal_label)); +void Index::_set_universal_label(const LabelType universal_label) +{ + this->set_universal_label(std::any_cast(universal_label)); } template -void Index::set_universal_label(const LabelT &label) { - _use_universal_label = true; - _universal_label = label; +void Index::set_universal_label(const LabelT &label) +{ + _use_universal_label = true; + _universal_label = label; } template -void Index::build_filtered_index( - const char *filename, const std::string &label_file, - const size_t num_points_to_load, const std::vector &tags) { - _filtered_index = true; - _label_to_start_id.clear(); - size_t num_points_labels = 0; - - parse_label_file(label_file, - num_points_labels); // determines medoid for each label and - // identifies the points to label mapping - - std::unordered_map> label_to_points; - - for (uint32_t point_id = 0; point_id < num_points_to_load; point_id++) { - for (auto label : _location_to_labels[point_id]) { - if (label != _universal_label) { - label_to_points[label].emplace_back(point_id); - } else { - for (typename tsl::robin_set::size_type lbl = 0; - lbl < _labels.size(); lbl++) { - auto itr = _labels.begin(); - std::advance(itr, lbl); - auto &x = *itr; - label_to_points[x].emplace_back(point_id); - } - } - } - } - - uint32_t num_cands = 25; - for (auto itr = _labels.begin(); itr != _labels.end(); itr++) { - uint32_t best_medoid_count = std::numeric_limits::max(); - auto &curr_label = *itr; - uint32_t best_medoid; - auto labeled_points = label_to_points[curr_label]; - for (uint32_t cnd = 0; cnd < num_cands; cnd++) { - uint32_t cur_cnd = labeled_points[rand() % labeled_points.size()]; - uint32_t cur_cnt = std::numeric_limits::max(); - if (_medoid_counts.find(cur_cnd) == _medoid_counts.end()) { - _medoid_counts[cur_cnd] = 0; - cur_cnt = 0; - } else { - cur_cnt = _medoid_counts[cur_cnd]; - } - if (cur_cnt < best_medoid_count) { - best_medoid_count = cur_cnt; - best_medoid = cur_cnd; - } - } - _label_to_start_id[curr_label] = best_medoid; - _medoid_counts[best_medoid]++; - } - - this->build(filename, num_points_to_load, tags); +void Index::build_filtered_index(const char *filename, const std::string &label_file, + const size_t num_points_to_load, const std::vector &tags) +{ + _filtered_index = true; + _label_to_start_id.clear(); + size_t num_points_labels = 0; + + parse_label_file(label_file, + num_points_labels); // determines medoid for each label and + // identifies the points to label mapping + + std::unordered_map> label_to_points; + + for (uint32_t point_id = 0; point_id < num_points_to_load; point_id++) + { + for (auto label : _location_to_labels[point_id]) + { + if (label != _universal_label) + { + label_to_points[label].emplace_back(point_id); + } + else + { + for (typename tsl::robin_set::size_type lbl = 0; lbl < _labels.size(); lbl++) + { + auto itr = _labels.begin(); + std::advance(itr, lbl); + auto &x = *itr; + label_to_points[x].emplace_back(point_id); + } + } + } + } + + uint32_t num_cands = 25; + for (auto itr = _labels.begin(); itr != _labels.end(); itr++) + { + uint32_t best_medoid_count = std::numeric_limits::max(); + auto &curr_label = *itr; + uint32_t best_medoid; + auto labeled_points = label_to_points[curr_label]; + for (uint32_t cnd = 0; cnd < num_cands; cnd++) + { + uint32_t cur_cnd = labeled_points[rand() % labeled_points.size()]; + uint32_t cur_cnt = std::numeric_limits::max(); + if (_medoid_counts.find(cur_cnd) == _medoid_counts.end()) + { + _medoid_counts[cur_cnd] = 0; + cur_cnt = 0; + } + else + { + cur_cnt = _medoid_counts[cur_cnd]; + } + if (cur_cnt < best_medoid_count) + { + best_medoid_count = cur_cnt; + best_medoid = cur_cnd; + } + } + _label_to_start_id[curr_label] = best_medoid; + _medoid_counts[best_medoid]++; + } + + this->build(filename, num_points_to_load, tags); } template -std::pair -Index::_search(const DataType &query, const size_t K, - const uint32_t L, std::any &indices, - float *distances) { - try { - auto typed_query = std::any_cast(query); - if (typeid(uint32_t *) == indices.type()) { - auto u32_ptr = std::any_cast(indices); - return this->search(typed_query, K, L, u32_ptr, distances); - } else if (typeid(uint64_t *) == indices.type()) { - auto u64_ptr = std::any_cast(indices); - return this->search(typed_query, K, L, u64_ptr, distances); - } else { - throw ANNException( - "Error: indices type can only be uint64_t or uint32_t.", -1); - } - } catch (const std::bad_any_cast &e) { - throw ANNException( - "Error: bad any cast while searching. " + std::string(e.what()), -1); - } catch (const std::exception &e) { - throw ANNException("Error: " + std::string(e.what()), -1); - } +std::pair Index::_search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances) +{ + try + { + auto typed_query = std::any_cast(query); + if (typeid(uint32_t *) == indices.type()) + { + auto u32_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u32_ptr, distances); + } + else if (typeid(uint64_t *) == indices.type()) + { + auto u64_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u64_ptr, distances); + } + else + { + throw ANNException("Error: indices type can only be uint64_t or uint32_t.", -1); + } + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while searching. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } template template -std::pair -Index::search(const T *query, const size_t K, const uint32_t L, - IdType *indices, float *distances) { - if (K > (uint64_t)L) { - throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - - if (L > scratch->get_L()) { - diskann::cout << "Attempting to expand query scratch_space. Was created " - << "with Lsize: " << scratch->get_L() - << " but search L is: " << L << std::endl; - scratch->resize_for_new_L(L); - diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() - << std::endl; - } +std::pair Index::search(const T *query, const size_t K, const uint32_t L, + IdType *indices, float *distances) +{ + if (K > (uint64_t)L) + { + throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); - const std::vector unused_filter_label; - const std::vector init_ids = get_init_ids(); + if (L > scratch->get_L()) + { + diskann::cout << "Attempting to expand query scratch_space. Was created " + << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl; + scratch->resize_for_new_L(L); + diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl; + } + + const std::vector unused_filter_label; + const std::vector init_ids = get_init_ids(); - std::shared_lock lock(_update_lock); + std::shared_lock lock(_update_lock); - _data_store->preprocess_query(query, scratch); + _data_store->preprocess_query(query, scratch); - auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, - unused_filter_label, true); + auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); - NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); - size_t pos = 0; - for (size_t i = 0; i < best_L_nodes.size(); ++i) { - if (best_L_nodes[i].id < _max_points) { - // safe because Index uses uint32_t ids internally - // and IDType will be uint32_t or uint64_t - indices[pos] = (IdType)best_L_nodes[i].id; - if (distances != nullptr) { + size_t pos = 0; + for (size_t i = 0; i < best_L_nodes.size(); ++i) + { + if (best_L_nodes[i].id < _max_points) + { + // safe because Index uses uint32_t ids internally + // and IDType will be uint32_t or uint64_t + indices[pos] = (IdType)best_L_nodes[i].id; + if (distances != nullptr) + { #ifdef EXEC_ENV_OLS - // DLVS expects negative distances - distances[pos] = best_L_nodes[i].distance; + // DLVS expects negative distances + distances[pos] = best_L_nodes[i].distance; #else - distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT - ? -1 * best_L_nodes[i].distance - : best_L_nodes[i].distance; + distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance + : best_L_nodes[i].distance; #endif - } - pos++; + } + pos++; + } + if (pos == K) + break; + } + if (pos < K) + { + diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; } - if (pos == K) - break; - } - if (pos < K) { - diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K - << " for query" << std::endl; - } - return retval; + return retval; } template -std::pair Index::_search_with_filters( - const DataType &query, const std::string &raw_label, const size_t K, - const uint32_t L, std::any &indices, float *distances) { - auto converted_label = this->get_converted_label(raw_label); - if (typeid(uint64_t *) == indices.type()) { - auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, - K, L, ptr, distances); - } else if (typeid(uint32_t *) == indices.type()) { - auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, - K, L, ptr, distances); - } else { - throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1); - } +std::pair Index::_search_with_filters(const DataType &query, + const std::string &raw_label, const size_t K, + const uint32_t L, std::any &indices, + float *distances) +{ + auto converted_label = this->get_converted_label(raw_label); + if (typeid(uint64_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else if (typeid(uint32_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else + { + throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1); + } } template template -std::pair Index::search_with_filters( - const T *query, const LabelT &filter_label, const size_t K, - const uint32_t L, IdType *indices, float *distances) { - if (K > (uint64_t)L) { - throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - - if (L > scratch->get_L()) { - diskann::cout << "Attempting to expand query scratch_space. Was created " - << "with Lsize: " << scratch->get_L() - << " but search L is: " << L << std::endl; - scratch->resize_for_new_L(L); - diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() - << std::endl; - } - - std::vector filter_vec; - std::vector init_ids = get_init_ids(); - - std::shared_lock lock(_update_lock); - std::shared_lock tl(_tag_lock, std::defer_lock); - if (_dynamic_index) - tl.lock(); - - if (_label_to_start_id.find(filter_label) != _label_to_start_id.end()) { - init_ids.emplace_back(_label_to_start_id[filter_label]); - } else { - diskann::cout << "No filtered medoid found. exitting " - << std::endl; // RKNOTE: If universal label found start there - throw diskann::ANNException("No filtered medoid found. exitting ", -1); - } - if (_dynamic_index) - tl.unlock(); +std::pair Index::search_with_filters(const T *query, const LabelT &filter_label, + const size_t K, const uint32_t L, + IdType *indices, float *distances) +{ + if (K > (uint64_t)L) + { + throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + if (L > scratch->get_L()) + { + diskann::cout << "Attempting to expand query scratch_space. Was created " + << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl; + scratch->resize_for_new_L(L); + diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl; + } + + std::vector filter_vec; + std::vector init_ids = get_init_ids(); + + std::shared_lock lock(_update_lock); + std::shared_lock tl(_tag_lock, std::defer_lock); + if (_dynamic_index) + tl.lock(); + + if (_label_to_start_id.find(filter_label) != _label_to_start_id.end()) + { + init_ids.emplace_back(_label_to_start_id[filter_label]); + } + else + { + diskann::cout << "No filtered medoid found. exitting " + << std::endl; // RKNOTE: If universal label found start there + throw diskann::ANNException("No filtered medoid found. exitting ", -1); + } + if (_dynamic_index) + tl.unlock(); - filter_vec.emplace_back(filter_label); + filter_vec.emplace_back(filter_label); - _data_store->preprocess_query(query, scratch); - auto retval = - iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); + _data_store->preprocess_query(query, scratch); + auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); - auto best_L_nodes = scratch->best_l_nodes(); + auto best_L_nodes = scratch->best_l_nodes(); - size_t pos = 0; - for (size_t i = 0; i < best_L_nodes.size(); ++i) { - if (best_L_nodes[i].id < _max_points) { - indices[pos] = (IdType)best_L_nodes[i].id; + size_t pos = 0; + for (size_t i = 0; i < best_L_nodes.size(); ++i) + { + if (best_L_nodes[i].id < _max_points) + { + indices[pos] = (IdType)best_L_nodes[i].id; - if (distances != nullptr) { + if (distances != nullptr) + { #ifdef EXEC_ENV_OLS - // DLVS expects negative distances - distances[pos] = best_L_nodes[i].distance; + // DLVS expects negative distances + distances[pos] = best_L_nodes[i].distance; #else - distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT - ? -1 * best_L_nodes[i].distance - : best_L_nodes[i].distance; + distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance + : best_L_nodes[i].distance; #endif - } - pos++; + } + pos++; + } + if (pos == K) + break; + } + if (pos < K) + { + diskann::cerr << "Found fewer than K elements for query" << std::endl; } - if (pos == K) - break; - } - if (pos < K) { - diskann::cerr << "Found fewer than K elements for query" << std::endl; - } - return retval; + return retval; } template -size_t Index::_search_with_tags( - const DataType &query, const uint64_t K, const uint32_t L, - const TagType &tags, float *distances, DataVector &res_vectors, - bool use_filters, const std::string filter_label) { - try { - return this->search_with_tags(std::any_cast(query), K, L, - std::any_cast(tags), distances, - res_vectors.get>(), - use_filters, filter_label); - } catch (const std::bad_any_cast &e) { - throw ANNException( - "Error: bad any cast while performing _search_with_tags() " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error: " + std::string(e.what()), -1); - } +size_t Index::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, + const TagType &tags, float *distances, DataVector &res_vectors, + bool use_filters, const std::string filter_label) +{ + try + { + return this->search_with_tags(std::any_cast(query), K, L, std::any_cast(tags), distances, + res_vectors.get>(), use_filters, filter_label); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _search_with_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } template -size_t Index::search_with_tags( - const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label) { - if (K > (uint64_t)L) { - throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - - if (L > scratch->get_L()) { - diskann::cout << "Attempting to expand query scratch_space. Was created " - << "with Lsize: " << scratch->get_L() - << " but search L is: " << L << std::endl; - scratch->resize_for_new_L(L); - diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() - << std::endl; - } +size_t Index::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, + float *distances, std::vector &res_vectors, bool use_filters, + const std::string filter_label) +{ + if (K > (uint64_t)L) + { + throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__); + } + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); - std::shared_lock ul(_update_lock); + if (L > scratch->get_L()) + { + diskann::cout << "Attempting to expand query scratch_space. Was created " + << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl; + scratch->resize_for_new_L(L); + diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl; + } - const std::vector init_ids = get_init_ids(); + std::shared_lock ul(_update_lock); - //_distance->preprocess_query(query, _data_store->get_dims(), - // scratch->aligned_query()); - _data_store->preprocess_query(query, scratch); - if (!use_filters) { - const std::vector unused_filter_label; - iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, - true); - } else { - std::vector filter_vec; - auto converted_label = this->get_converted_label(filter_label); - filter_vec.push_back(converted_label); - iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); - } + const std::vector init_ids = get_init_ids(); + + //_distance->preprocess_query(query, _data_store->get_dims(), + // scratch->aligned_query()); + _data_store->preprocess_query(query, scratch); + if (!use_filters) + { + const std::vector unused_filter_label; + iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + } + else + { + std::vector filter_vec; + auto converted_label = this->get_converted_label(filter_label); + filter_vec.push_back(converted_label); + iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); + } - NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); - assert(best_L_nodes.size() <= L); + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); + assert(best_L_nodes.size() <= L); - std::shared_lock tl(_tag_lock); + std::shared_lock tl(_tag_lock); - size_t pos = 0; - for (size_t i = 0; i < best_L_nodes.size(); ++i) { - auto node = best_L_nodes[i]; + size_t pos = 0; + for (size_t i = 0; i < best_L_nodes.size(); ++i) + { + auto node = best_L_nodes[i]; - TagT tag; - if (_location_to_tag.try_get(node.id, tag)) { - tags[pos] = tag; + TagT tag; + if (_location_to_tag.try_get(node.id, tag)) + { + tags[pos] = tag; - if (res_vectors.size() > 0) { - _data_store->get_vector(node.id, res_vectors[pos]); - } + if (res_vectors.size() > 0) + { + _data_store->get_vector(node.id, res_vectors[pos]); + } - if (distances != nullptr) { + if (distances != nullptr) + { #ifdef EXEC_ENV_OLS - distances[pos] = node.distance; // DLVS expects negative distances + distances[pos] = node.distance; // DLVS expects negative distances #else - distances[pos] = - _dist_metric == INNER_PRODUCT ? -1 * node.distance : node.distance; + distances[pos] = _dist_metric == INNER_PRODUCT ? -1 * node.distance : node.distance; #endif - } - pos++; - // If res_vectors.size() < k, clip at the value. - if (pos == K || pos == res_vectors.size()) - break; + } + pos++; + // If res_vectors.size() < k, clip at the value. + if (pos == K || pos == res_vectors.size()) + break; + } } - } - return pos; + return pos; } -template -size_t Index::get_num_points() { - std::shared_lock tl(_tag_lock); - return _nd; +template size_t Index::get_num_points() +{ + std::shared_lock tl(_tag_lock); + return _nd; } -template -size_t Index::get_max_points() { - std::shared_lock tl(_tag_lock); - return _max_points; +template size_t Index::get_max_points() +{ + std::shared_lock tl(_tag_lock); + return _max_points; } -template -void Index::generate_frozen_point() { - if (_num_frozen_pts == 0) - return; - - if (_num_frozen_pts > 1) { - throw ANNException( - "More than one frozen point not supported in generate_frozen_point", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - if (_nd == 0) { - throw ANNException("ERROR: Can not pick a frozen point since nd=0", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - size_t res = calculate_entry_point(); - - // REFACTOR PQ: Not sure if we should do this for both stores. - if (_pq_dist) { - // copy the PQ data corresponding to the point returned by - // calculate_entry_point - // memcpy(_pq_data + _max_points * _num_pq_chunks, - // _pq_data + res * _num_pq_chunks, - // _num_pq_chunks * DIV_ROUND_UP(NUM_PQ_BITS, 8)); - _pq_data_store->copy_vectors((location_t)res, (location_t)_max_points, 1); - } else { - _data_store->copy_vectors((location_t)res, (location_t)_max_points, 1); - } - _frozen_pts_used++; -} +template void Index::generate_frozen_point() +{ + if (_num_frozen_pts == 0) + return; -template -int Index::enable_delete() { - assert(_enable_tags); + if (_num_frozen_pts > 1) + { + throw ANNException("More than one frozen point not supported in generate_frozen_point", -1, __FUNCSIG__, + __FILE__, __LINE__); + } - if (!_enable_tags) { - diskann::cerr << "Tags must be instantiated for deletions" << std::endl; - return -2; - } + if (_nd == 0) + { + throw ANNException("ERROR: Can not pick a frozen point since nd=0", -1, __FUNCSIG__, __FILE__, __LINE__); + } + size_t res = calculate_entry_point(); - if (this->_deletes_enabled) { - return 0; - } + // REFACTOR PQ: Not sure if we should do this for both stores. + if (_pq_dist) + { + // copy the PQ data corresponding to the point returned by + // calculate_entry_point + // memcpy(_pq_data + _max_points * _num_pq_chunks, + // _pq_data + res * _num_pq_chunks, + // _num_pq_chunks * DIV_ROUND_UP(NUM_PQ_BITS, 8)); + _pq_data_store->copy_vectors((location_t)res, (location_t)_max_points, 1); + } + else + { + _data_store->copy_vectors((location_t)res, (location_t)_max_points, 1); + } + _frozen_pts_used++; +} + +template int Index::enable_delete() +{ + assert(_enable_tags); + + if (!_enable_tags) + { + diskann::cerr << "Tags must be instantiated for deletions" << std::endl; + return -2; + } - std::unique_lock ul(_update_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); + if (this->_deletes_enabled) + { + return 0; + } + + std::unique_lock ul(_update_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); - if (_data_compacted) { - for (uint32_t slot = (uint32_t)_nd; slot < _max_points; ++slot) { - _empty_slots.insert(slot); + if (_data_compacted) + { + for (uint32_t slot = (uint32_t)_nd; slot < _max_points; ++slot) + { + _empty_slots.insert(slot); + } } - } - this->_deletes_enabled = true; - return 0; + this->_deletes_enabled = true; + return 0; } template -inline void Index::process_delete( - const tsl::robin_set &old_delete_set, size_t loc, - const uint32_t range, const uint32_t maxc, const float alpha, - InMemQueryScratch *scratch) { - tsl::robin_set &expanded_nodes_set = scratch->expanded_nodes_set(); - std::vector &expanded_nghrs_vec = scratch->expanded_nodes_vec(); - - // If this condition were not true, deadlock could result - assert(old_delete_set.find((uint32_t)loc) == old_delete_set.end()); - - std::vector adj_list; - { - // Acquire and release lock[loc] before acquiring locks for neighbors - std::unique_lock adj_list_lock; - if (_conc_consolidate) - adj_list_lock = std::unique_lock(_locks[loc]); - adj_list = _graph_store->get_neighbours((location_t)loc); - } - - bool modify = false; - for (auto ngh : adj_list) { - if (old_delete_set.find(ngh) == old_delete_set.end()) { - expanded_nodes_set.insert(ngh); - } else { - modify = true; - - std::unique_lock ngh_lock; - if (_conc_consolidate) - ngh_lock = std::unique_lock(_locks[ngh]); - for (auto j : _graph_store->get_neighbours((location_t)ngh)) - if (j != loc && old_delete_set.find(j) == old_delete_set.end()) - expanded_nodes_set.insert(j); - } - } - - if (modify) { - if (expanded_nodes_set.size() <= range) { - std::unique_lock adj_list_lock(_locks[loc]); - _graph_store->clear_neighbours((location_t)loc); - for (auto &ngh : expanded_nodes_set) - _graph_store->add_neighbour((location_t)loc, ngh); - } else { - // Create a pool of Neighbor candidates from the expanded_nodes_set - expanded_nghrs_vec.reserve(expanded_nodes_set.size()); - for (auto &ngh : expanded_nodes_set) { - expanded_nghrs_vec.emplace_back( - ngh, _data_store->get_distance((location_t)loc, (location_t)ngh)); - } - std::sort(expanded_nghrs_vec.begin(), expanded_nghrs_vec.end()); - std::vector &occlude_list_output = - scratch->occlude_list_output(); - occlude_list((uint32_t)loc, expanded_nghrs_vec, alpha, range, maxc, - occlude_list_output, scratch, &old_delete_set); - std::unique_lock adj_list_lock(_locks[loc]); - _graph_store->set_neighbours((location_t)loc, occlude_list_output); - } - } +inline void Index::process_delete(const tsl::robin_set &old_delete_set, size_t loc, + const uint32_t range, const uint32_t maxc, const float alpha, + InMemQueryScratch *scratch) +{ + tsl::robin_set &expanded_nodes_set = scratch->expanded_nodes_set(); + std::vector &expanded_nghrs_vec = scratch->expanded_nodes_vec(); + + // If this condition were not true, deadlock could result + assert(old_delete_set.find((uint32_t)loc) == old_delete_set.end()); + + std::vector adj_list; + { + // Acquire and release lock[loc] before acquiring locks for neighbors + std::unique_lock adj_list_lock; + if (_conc_consolidate) + adj_list_lock = std::unique_lock(_locks[loc]); + adj_list = _graph_store->get_neighbours((location_t)loc); + } + + bool modify = false; + for (auto ngh : adj_list) + { + if (old_delete_set.find(ngh) == old_delete_set.end()) + { + expanded_nodes_set.insert(ngh); + } + else + { + modify = true; + + std::unique_lock ngh_lock; + if (_conc_consolidate) + ngh_lock = std::unique_lock(_locks[ngh]); + for (auto j : _graph_store->get_neighbours((location_t)ngh)) + if (j != loc && old_delete_set.find(j) == old_delete_set.end()) + expanded_nodes_set.insert(j); + } + } + + if (modify) + { + if (expanded_nodes_set.size() <= range) + { + std::unique_lock adj_list_lock(_locks[loc]); + _graph_store->clear_neighbours((location_t)loc); + for (auto &ngh : expanded_nodes_set) + _graph_store->add_neighbour((location_t)loc, ngh); + } + else + { + // Create a pool of Neighbor candidates from the expanded_nodes_set + expanded_nghrs_vec.reserve(expanded_nodes_set.size()); + for (auto &ngh : expanded_nodes_set) + { + expanded_nghrs_vec.emplace_back(ngh, _data_store->get_distance((location_t)loc, (location_t)ngh)); + } + std::sort(expanded_nghrs_vec.begin(), expanded_nghrs_vec.end()); + std::vector &occlude_list_output = scratch->occlude_list_output(); + occlude_list((uint32_t)loc, expanded_nghrs_vec, alpha, range, maxc, occlude_list_output, scratch, + &old_delete_set); + std::unique_lock adj_list_lock(_locks[loc]); + _graph_store->set_neighbours((location_t)loc, occlude_list_output); + } + } } // Returns number of live points left after consolidation template -consolidation_report Index::consolidate_deletes( - const IndexWriteParameters ¶ms) { - if (!_enable_tags) - throw diskann::ANNException("Point tag array not instantiated", -1, - __FUNCSIG__, __FILE__, __LINE__); +consolidation_report Index::consolidate_deletes(const IndexWriteParameters ¶ms) +{ + if (!_enable_tags) + throw diskann::ANNException("Point tag array not instantiated", -1, __FUNCSIG__, __FILE__, __LINE__); - { - std::shared_lock ul(_update_lock); - std::shared_lock tl(_tag_lock); - std::shared_lock dl(_delete_lock); - if (_empty_slots.size() + _nd != _max_points) { - std::string err = "#empty slots + nd != max points"; - diskann::cerr << err << std::endl; - throw ANNException(err, -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (_location_to_tag.size() + _delete_set->size() != _nd) { - diskann::cerr << "Error: _location_to_tag.size (" - << _location_to_tag.size() << ") + _delete_set->size (" - << _delete_set->size() << ") != _nd(" << _nd << ") "; - return consolidation_report( - diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR, - 0, 0, 0, 0, 0, 0, 0); - } - - if (_location_to_tag.size() != _tag_to_location.size()) { - throw diskann::ANNException( - "_location_to_tag and _tag_to_location not of same size", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - } - - std::unique_lock update_lock(_update_lock, - std::defer_lock); - if (!_conc_consolidate) - update_lock.lock(); - - std::unique_lock cl(_consolidate_lock, - std::defer_lock); - if (!cl.try_lock()) { - diskann::cerr - << "Consildate delete function failed to acquire consolidate lock" - << std::endl; - return consolidation_report( - diskann::consolidation_report::status_code::LOCK_FAIL, 0, 0, 0, 0, 0, 0, - 0); - } - - diskann::cout << "Starting consolidate_deletes... "; - - std::unique_ptr> old_delete_set( - new tsl::robin_set); - { - std::unique_lock dl(_delete_lock); - std::swap(_delete_set, old_delete_set); - } - - if (old_delete_set->find(_start) != old_delete_set->end()) { - throw diskann::ANNException("ERROR: start node has been deleted", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - const uint32_t range = params.max_degree; - const uint32_t maxc = params.max_occlusion_size; - const float alpha = params.alpha; - const uint32_t num_threads = - params.num_threads == 0 ? omp_get_num_procs() : params.num_threads; - - uint32_t num_calls_to_process_delete = 0; - diskann::Timer timer; + { + std::shared_lock ul(_update_lock); + std::shared_lock tl(_tag_lock); + std::shared_lock dl(_delete_lock); + if (_empty_slots.size() + _nd != _max_points) + { + std::string err = "#empty slots + nd != max points"; + diskann::cerr << err << std::endl; + throw ANNException(err, -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_location_to_tag.size() + _delete_set->size() != _nd) + { + diskann::cerr << "Error: _location_to_tag.size (" << _location_to_tag.size() << ") + _delete_set->size (" + << _delete_set->size() << ") != _nd(" << _nd << ") "; + return consolidation_report(diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR, 0, 0, 0, + 0, 0, 0, 0); + } + + if (_location_to_tag.size() != _tag_to_location.size()) + { + throw diskann::ANNException("_location_to_tag and _tag_to_location not of same size", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + } + + std::unique_lock update_lock(_update_lock, std::defer_lock); + if (!_conc_consolidate) + update_lock.lock(); + + std::unique_lock cl(_consolidate_lock, std::defer_lock); + if (!cl.try_lock()) + { + diskann::cerr << "Consildate delete function failed to acquire consolidate lock" << std::endl; + return consolidation_report(diskann::consolidation_report::status_code::LOCK_FAIL, 0, 0, 0, 0, 0, 0, 0); + } + + diskann::cout << "Starting consolidate_deletes... "; + + std::unique_ptr> old_delete_set(new tsl::robin_set); + { + std::unique_lock dl(_delete_lock); + std::swap(_delete_set, old_delete_set); + } + + if (old_delete_set->find(_start) != old_delete_set->end()) + { + throw diskann::ANNException("ERROR: start node has been deleted", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + const uint32_t range = params.max_degree; + const uint32_t maxc = params.max_occlusion_size; + const float alpha = params.alpha; + const uint32_t num_threads = params.num_threads == 0 ? omp_get_num_procs() : params.num_threads; + + uint32_t num_calls_to_process_delete = 0; + diskann::Timer timer; #pragma omp parallel for num_threads(num_threads) schedule(dynamic, 8192) reduction(+ : num_calls_to_process_delete) - for (int64_t loc = 0; loc < (int64_t)_max_points; loc++) { - if (old_delete_set->find((uint32_t)loc) == old_delete_set->end() && - !_empty_slots.is_in_set((uint32_t)loc)) { - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - process_delete(*old_delete_set, loc, range, maxc, alpha, scratch); - num_calls_to_process_delete += 1; - } - } - for (int64_t loc = _max_points; - loc < (int64_t)(_max_points + _num_frozen_pts); loc++) { - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - process_delete(*old_delete_set, loc, range, maxc, alpha, scratch); - num_calls_to_process_delete += 1; - } + for (int64_t loc = 0; loc < (int64_t)_max_points; loc++) + { + if (old_delete_set->find((uint32_t)loc) == old_delete_set->end() && !_empty_slots.is_in_set((uint32_t)loc)) + { + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + process_delete(*old_delete_set, loc, range, maxc, alpha, scratch); + num_calls_to_process_delete += 1; + } + } + for (int64_t loc = _max_points; loc < (int64_t)(_max_points + _num_frozen_pts); loc++) + { + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + process_delete(*old_delete_set, loc, range, maxc, alpha, scratch); + num_calls_to_process_delete += 1; + } - std::unique_lock tl(_tag_lock); - size_t ret_nd = release_locations(*old_delete_set); - size_t max_points = _max_points; - size_t empty_slots_size = _empty_slots.size(); + std::unique_lock tl(_tag_lock); + size_t ret_nd = release_locations(*old_delete_set); + size_t max_points = _max_points; + size_t empty_slots_size = _empty_slots.size(); - std::shared_lock dl(_delete_lock); - size_t delete_set_size = _delete_set->size(); - size_t old_delete_set_size = old_delete_set->size(); + std::shared_lock dl(_delete_lock); + size_t delete_set_size = _delete_set->size(); + size_t old_delete_set_size = old_delete_set->size(); - if (!_conc_consolidate) { - update_lock.unlock(); - } + if (!_conc_consolidate) + { + update_lock.unlock(); + } - double duration = timer.elapsed() / 1000000.0; - diskann::cout << " done in " << duration << " seconds." << std::endl; - return consolidation_report( - diskann::consolidation_report::status_code::SUCCESS, ret_nd, max_points, - empty_slots_size, old_delete_set_size, delete_set_size, - num_calls_to_process_delete, duration); + double duration = timer.elapsed() / 1000000.0; + diskann::cout << " done in " << duration << " seconds." << std::endl; + return consolidation_report(diskann::consolidation_report::status_code::SUCCESS, ret_nd, max_points, + empty_slots_size, old_delete_set_size, delete_set_size, num_calls_to_process_delete, + duration); } -template -void Index::compact_frozen_point() { - if (_nd < _max_points && _num_frozen_pts > 0) { - reposition_points((uint32_t)_max_points, (uint32_t)_nd, - (uint32_t)_num_frozen_pts); - _start = (uint32_t)_nd; - - if (_filtered_index && _dynamic_index) { - // update medoid id's as frozen points are treated as medoid - for (auto &[label, medoid_id] : _label_to_start_id) { - /* if (label == _universal_label) - continue;*/ - _label_to_start_id[label] = - (uint32_t)_nd + (medoid_id - (uint32_t)_max_points); - } +template void Index::compact_frozen_point() +{ + if (_nd < _max_points && _num_frozen_pts > 0) + { + reposition_points((uint32_t)_max_points, (uint32_t)_nd, (uint32_t)_num_frozen_pts); + _start = (uint32_t)_nd; + + if (_filtered_index && _dynamic_index) + { + // update medoid id's as frozen points are treated as medoid + for (auto &[label, medoid_id] : _label_to_start_id) + { + /* if (label == _universal_label) + continue;*/ + _label_to_start_id[label] = (uint32_t)_nd + (medoid_id - (uint32_t)_max_points); + } + } } - } } // Should be called after acquiring _update_lock -template -void Index::compact_data() { - if (!_dynamic_index) - throw ANNException("Can not compact a non-dynamic index", -1, __FUNCSIG__, - __FILE__, __LINE__); - - if (_data_compacted) { - diskann::cerr - << "Warning! Calling compact_data() when _data_compacted is true!" - << std::endl; - return; - } - - if (_delete_set->size() > 0) { - throw ANNException( - "Can not compact data when index has non-empty _delete_set of " - "size: " + - std::to_string(_delete_set->size()), - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - diskann::Timer timer; - - std::vector new_location = - std::vector(_max_points + _num_frozen_pts, UINT32_MAX); - - uint32_t new_counter = 0; - std::set empty_locations; - for (uint32_t old_location = 0; old_location < _max_points; old_location++) { - if (_location_to_tag.contains(old_location)) { - new_location[old_location] = new_counter; - new_counter++; - } else { - empty_locations.insert(old_location); - } - } - for (uint32_t old_location = (uint32_t)_max_points; - old_location < _max_points + _num_frozen_pts; old_location++) { - new_location[old_location] = old_location; - } - - // If start node is removed, throw an exception - if (_start < _max_points && !_location_to_tag.contains(_start)) { - throw diskann::ANNException("ERROR: Start node deleted.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - - size_t num_dangling = 0; - for (uint32_t old = 0; old < _max_points + _num_frozen_pts; ++old) { - // compact _final_graph - std::vector new_adj_list; - - if ((new_location[old] < _max_points) // If point continues to exist - || (old >= _max_points && old < _max_points + _num_frozen_pts)) { - new_adj_list.reserve( - _graph_store->get_neighbours((location_t)old).size()); - for (auto ngh_iter : _graph_store->get_neighbours((location_t)old)) { - if (empty_locations.find(ngh_iter) != empty_locations.end()) { - ++num_dangling; - diskann::cerr << "Error in compact_data(). _final_graph[" << old - << "] has neighbor " << ngh_iter - << " which is a location not associated with any tag." - << std::endl; - } else { - new_adj_list.push_back(new_location[ngh_iter]); - } - } - //_graph_store->get_neighbours((location_t)old).swap(new_adj_list); - _graph_store->set_neighbours((location_t)old, new_adj_list); - - // Move the data and adj list to the correct position - if (new_location[old] != old) { - assert(new_location[old] < old); - _graph_store->swap_neighbours(new_location[old], (location_t)old); - - if (_filtered_index) { - _location_to_labels[new_location[old]].swap(_location_to_labels[old]); - } - - _data_store->copy_vectors(old, new_location[old], 1); - } - } else { - _graph_store->clear_neighbours((location_t)old); - } - } - diskann::cerr << "#dangling references after data compaction: " - << num_dangling << std::endl; - - _tag_to_location.clear(); - for (auto pos = _location_to_tag.find_first(); pos.is_valid(); - pos = _location_to_tag.find_next(pos)) { - const auto tag = _location_to_tag.get(pos); - _tag_to_location[tag] = new_location[pos._key]; - } - _location_to_tag.clear(); - for (const auto &iter : _tag_to_location) { - _location_to_tag.set(iter.second, iter.first); - } - // remove all cleared up old - for (size_t old = _nd; old < _max_points; ++old) { - _graph_store->clear_neighbours((location_t)old); - } - if (_filtered_index) { - for (size_t old = _nd; old < _max_points; old++) { - _location_to_labels[old].clear(); - } - } - - _empty_slots.clear(); - // mark all slots after _nd as empty - for (auto i = _nd; i < _max_points; i++) { - _empty_slots.insert((uint32_t)i); - } - _data_compacted = true; - diskann::cout << "Time taken for compact_data: " << timer.elapsed() / 1000000. - << "s." << std::endl; +template void Index::compact_data() +{ + if (!_dynamic_index) + throw ANNException("Can not compact a non-dynamic index", -1, __FUNCSIG__, __FILE__, __LINE__); + + if (_data_compacted) + { + diskann::cerr << "Warning! Calling compact_data() when _data_compacted is true!" << std::endl; + return; + } + + if (_delete_set->size() > 0) + { + throw ANNException("Can not compact data when index has non-empty _delete_set of " + "size: " + + std::to_string(_delete_set->size()), + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + diskann::Timer timer; + + std::vector new_location = std::vector(_max_points + _num_frozen_pts, UINT32_MAX); + + uint32_t new_counter = 0; + std::set empty_locations; + for (uint32_t old_location = 0; old_location < _max_points; old_location++) + { + if (_location_to_tag.contains(old_location)) + { + new_location[old_location] = new_counter; + new_counter++; + } + else + { + empty_locations.insert(old_location); + } + } + for (uint32_t old_location = (uint32_t)_max_points; old_location < _max_points + _num_frozen_pts; old_location++) + { + new_location[old_location] = old_location; + } + + // If start node is removed, throw an exception + if (_start < _max_points && !_location_to_tag.contains(_start)) + { + throw diskann::ANNException("ERROR: Start node deleted.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + size_t num_dangling = 0; + for (uint32_t old = 0; old < _max_points + _num_frozen_pts; ++old) + { + // compact _final_graph + std::vector new_adj_list; + + if ((new_location[old] < _max_points) // If point continues to exist + || (old >= _max_points && old < _max_points + _num_frozen_pts)) + { + new_adj_list.reserve(_graph_store->get_neighbours((location_t)old).size()); + for (auto ngh_iter : _graph_store->get_neighbours((location_t)old)) + { + if (empty_locations.find(ngh_iter) != empty_locations.end()) + { + ++num_dangling; + diskann::cerr << "Error in compact_data(). _final_graph[" << old << "] has neighbor " << ngh_iter + << " which is a location not associated with any tag." << std::endl; + } + else + { + new_adj_list.push_back(new_location[ngh_iter]); + } + } + //_graph_store->get_neighbours((location_t)old).swap(new_adj_list); + _graph_store->set_neighbours((location_t)old, new_adj_list); + + // Move the data and adj list to the correct position + if (new_location[old] != old) + { + assert(new_location[old] < old); + _graph_store->swap_neighbours(new_location[old], (location_t)old); + + if (_filtered_index) + { + _location_to_labels[new_location[old]].swap(_location_to_labels[old]); + } + + _data_store->copy_vectors(old, new_location[old], 1); + } + } + else + { + _graph_store->clear_neighbours((location_t)old); + } + } + diskann::cerr << "#dangling references after data compaction: " << num_dangling << std::endl; + + _tag_to_location.clear(); + for (auto pos = _location_to_tag.find_first(); pos.is_valid(); pos = _location_to_tag.find_next(pos)) + { + const auto tag = _location_to_tag.get(pos); + _tag_to_location[tag] = new_location[pos._key]; + } + _location_to_tag.clear(); + for (const auto &iter : _tag_to_location) + { + _location_to_tag.set(iter.second, iter.first); + } + // remove all cleared up old + for (size_t old = _nd; old < _max_points; ++old) + { + _graph_store->clear_neighbours((location_t)old); + } + if (_filtered_index) + { + for (size_t old = _nd; old < _max_points; old++) + { + _location_to_labels[old].clear(); + } + } + + _empty_slots.clear(); + // mark all slots after _nd as empty + for (auto i = _nd; i < _max_points; i++) + { + _empty_slots.insert((uint32_t)i); + } + _data_compacted = true; + diskann::cout << "Time taken for compact_data: " << timer.elapsed() / 1000000. << "s." << std::endl; } // // Caller must hold unique _tag_lock and _delete_lock before calling this // -template -int Index::reserve_location() { - if (_nd >= _max_points) { - return -1; - } - uint32_t location; - if (_data_compacted && _empty_slots.is_empty()) { - // This code path is encountered when enable_delete hasn't been - // called yet, so no points have been deleted and _empty_slots - // hasn't been filled in. In that case, just keep assigning - // consecutive locations. - location = (uint32_t)_nd; - } else { - assert(_empty_slots.size() != 0); - assert(_empty_slots.size() + _nd == _max_points); - - location = _empty_slots.pop_any(); - _delete_set->erase(location); - } - ++_nd; - return location; -} - -template -size_t Index::release_location(int location) { - if (_empty_slots.is_in_set(location)) - throw ANNException( - "Trying to release location, but location already in empty slots", -1, - __FUNCSIG__, __FILE__, __LINE__); - _empty_slots.insert(location); +template int Index::reserve_location() +{ + if (_nd >= _max_points) + { + return -1; + } + uint32_t location; + if (_data_compacted && _empty_slots.is_empty()) + { + // This code path is encountered when enable_delete hasn't been + // called yet, so no points have been deleted and _empty_slots + // hasn't been filled in. In that case, just keep assigning + // consecutive locations. + location = (uint32_t)_nd; + } + else + { + assert(_empty_slots.size() != 0); + assert(_empty_slots.size() + _nd == _max_points); - _nd--; - return _nd; + location = _empty_slots.pop_any(); + _delete_set->erase(location); + } + ++_nd; + return location; } -template -size_t Index::release_locations( - const tsl::robin_set &locations) { - for (auto location : locations) { +template size_t Index::release_location(int location) +{ if (_empty_slots.is_in_set(location)) - throw ANNException("Trying to release location, but location " - "already in empty slots", - -1, __FUNCSIG__, __FILE__, __LINE__); + throw ANNException("Trying to release location, but location already in empty slots", -1, __FUNCSIG__, __FILE__, + __LINE__); _empty_slots.insert(location); _nd--; - } + return _nd; +} + +template +size_t Index::release_locations(const tsl::robin_set &locations) +{ + for (auto location : locations) + { + if (_empty_slots.is_in_set(location)) + throw ANNException("Trying to release location, but location " + "already in empty slots", + -1, __FUNCSIG__, __FILE__, __LINE__); + _empty_slots.insert(location); - if (_empty_slots.size() + _nd != _max_points) - throw ANNException("#empty slots + nd != max points", -1, __FUNCSIG__, - __FILE__, __LINE__); + _nd--; + } + + if (_empty_slots.size() + _nd != _max_points) + throw ANNException("#empty slots + nd != max points", -1, __FUNCSIG__, __FILE__, __LINE__); - return _nd; + return _nd; } template -void Index::reposition_points(uint32_t old_location_start, - uint32_t new_location_start, - uint32_t num_locations) { - if (num_locations == 0 || old_location_start == new_location_start) { - return; - } - - // Update pointers to the moved nodes. Note: the computation is correct even - // when new_location_start < old_location_start given the C++ uint32_t - // integer arithmetic rules. - const uint32_t location_delta = new_location_start - old_location_start; - - std::vector updated_neighbours_location; - for (uint32_t i = 0; i < _max_points + _num_frozen_pts; i++) { - auto &i_neighbours = _graph_store->get_neighbours((location_t)i); - std::vector i_neighbours_copy(i_neighbours.begin(), - i_neighbours.end()); - for (auto &loc : i_neighbours_copy) { - if (loc >= old_location_start && loc < old_location_start + num_locations) - loc += location_delta; - } - _graph_store->set_neighbours(i, i_neighbours_copy); - } - - // The [start, end) interval which will contain obsolete points to be - // cleared. - uint32_t mem_clear_loc_start = old_location_start; - uint32_t mem_clear_loc_end_limit = old_location_start + num_locations; - - // Move the adjacency lists. Make sure that overlapping ranges are handled - // correctly. - if (new_location_start < old_location_start) { - // New location before the old location: copy the entries in order - // to avoid modifying locations that are yet to be copied. - for (uint32_t loc_offset = 0; loc_offset < num_locations; loc_offset++) { - assert(_graph_store->get_neighbours(new_location_start + loc_offset) - .empty()); - _graph_store->swap_neighbours(new_location_start + loc_offset, - old_location_start + loc_offset); - if (_dynamic_index && _filtered_index) { - _location_to_labels[new_location_start + loc_offset].swap( - _location_to_labels[old_location_start + loc_offset]); - } - } - // If ranges are overlapping, make sure not to clear the newly copied - // data. - if (mem_clear_loc_start < new_location_start + num_locations) { - // Clear only after the end of the new range. - mem_clear_loc_start = new_location_start + num_locations; - } - } else { - // Old location after the new location: copy from the end of the range - // to avoid modifying locations that are yet to be copied. - for (uint32_t loc_offset = num_locations; loc_offset > 0; loc_offset--) { - assert(_graph_store->get_neighbours(new_location_start + loc_offset - 1u) - .empty()); - _graph_store->swap_neighbours(new_location_start + loc_offset - 1u, - old_location_start + loc_offset - 1u); - if (_dynamic_index && _filtered_index) { - _location_to_labels[new_location_start + loc_offset - 1u].swap( - _location_to_labels[old_location_start + loc_offset - 1u]); - } - } - - // If ranges are overlapping, make sure not to clear the newly copied - // data. - if (mem_clear_loc_end_limit > new_location_start) { - // Clear only up to the beginning of the new range. - mem_clear_loc_end_limit = new_location_start; - } - } - _data_store->move_vectors(old_location_start, new_location_start, - num_locations); +void Index::reposition_points(uint32_t old_location_start, uint32_t new_location_start, + uint32_t num_locations) +{ + if (num_locations == 0 || old_location_start == new_location_start) + { + return; + } + + // Update pointers to the moved nodes. Note: the computation is correct even + // when new_location_start < old_location_start given the C++ uint32_t + // integer arithmetic rules. + const uint32_t location_delta = new_location_start - old_location_start; + + std::vector updated_neighbours_location; + for (uint32_t i = 0; i < _max_points + _num_frozen_pts; i++) + { + auto &i_neighbours = _graph_store->get_neighbours((location_t)i); + std::vector i_neighbours_copy(i_neighbours.begin(), i_neighbours.end()); + for (auto &loc : i_neighbours_copy) + { + if (loc >= old_location_start && loc < old_location_start + num_locations) + loc += location_delta; + } + _graph_store->set_neighbours(i, i_neighbours_copy); + } + + // The [start, end) interval which will contain obsolete points to be + // cleared. + uint32_t mem_clear_loc_start = old_location_start; + uint32_t mem_clear_loc_end_limit = old_location_start + num_locations; + + // Move the adjacency lists. Make sure that overlapping ranges are handled + // correctly. + if (new_location_start < old_location_start) + { + // New location before the old location: copy the entries in order + // to avoid modifying locations that are yet to be copied. + for (uint32_t loc_offset = 0; loc_offset < num_locations; loc_offset++) + { + assert(_graph_store->get_neighbours(new_location_start + loc_offset).empty()); + _graph_store->swap_neighbours(new_location_start + loc_offset, old_location_start + loc_offset); + if (_dynamic_index && _filtered_index) + { + _location_to_labels[new_location_start + loc_offset].swap( + _location_to_labels[old_location_start + loc_offset]); + } + } + // If ranges are overlapping, make sure not to clear the newly copied + // data. + if (mem_clear_loc_start < new_location_start + num_locations) + { + // Clear only after the end of the new range. + mem_clear_loc_start = new_location_start + num_locations; + } + } + else + { + // Old location after the new location: copy from the end of the range + // to avoid modifying locations that are yet to be copied. + for (uint32_t loc_offset = num_locations; loc_offset > 0; loc_offset--) + { + assert(_graph_store->get_neighbours(new_location_start + loc_offset - 1u).empty()); + _graph_store->swap_neighbours(new_location_start + loc_offset - 1u, old_location_start + loc_offset - 1u); + if (_dynamic_index && _filtered_index) + { + _location_to_labels[new_location_start + loc_offset - 1u].swap( + _location_to_labels[old_location_start + loc_offset - 1u]); + } + } + + // If ranges are overlapping, make sure not to clear the newly copied + // data. + if (mem_clear_loc_end_limit > new_location_start) + { + // Clear only up to the beginning of the new range. + mem_clear_loc_end_limit = new_location_start; + } + } + _data_store->move_vectors(old_location_start, new_location_start, num_locations); } -template -void Index::reposition_frozen_point_to_end() { - if (_num_frozen_pts == 0) - return; +template void Index::reposition_frozen_point_to_end() +{ + if (_num_frozen_pts == 0) + return; - if (_nd == _max_points) { - diskann::cout - << "Not repositioning frozen point as it is already at the end." - << std::endl; - return; - } + if (_nd == _max_points) + { + diskann::cout << "Not repositioning frozen point as it is already at the end." << std::endl; + return; + } - reposition_points((uint32_t)_nd, (uint32_t)_max_points, - (uint32_t)_num_frozen_pts); - _start = (uint32_t)_max_points; + reposition_points((uint32_t)_nd, (uint32_t)_max_points, (uint32_t)_num_frozen_pts); + _start = (uint32_t)_max_points; - // update medoid id's as frozen points are treated as medoid - if (_filtered_index && _dynamic_index) { - for (auto &[label, medoid_id] : _label_to_start_id) { - /*if (label == _universal_label) - continue;*/ - _label_to_start_id[label] = - (uint32_t)_max_points + (medoid_id - (uint32_t)_nd); + // update medoid id's as frozen points are treated as medoid + if (_filtered_index && _dynamic_index) + { + for (auto &[label, medoid_id] : _label_to_start_id) + { + /*if (label == _universal_label) + continue;*/ + _label_to_start_id[label] = (uint32_t)_max_points + (medoid_id - (uint32_t)_nd); + } } - } } -template -void Index::resize(size_t new_max_points) { - const size_t new_internal_points = new_max_points + _num_frozen_pts; - auto start = std::chrono::high_resolution_clock::now(); - assert(_empty_slots.size() == - 0); // should not resize if there are empty slots. - - _data_store->resize((location_t)new_internal_points); - _graph_store->resize_graph(new_internal_points); - _locks = std::vector(new_internal_points); - - if (_num_frozen_pts != 0) { - reposition_points((uint32_t)_max_points, (uint32_t)new_max_points, - (uint32_t)_num_frozen_pts); - _start = (uint32_t)new_max_points; - } - - _max_points = new_max_points; - _empty_slots.reserve(_max_points); - for (auto i = _nd; i < _max_points; i++) { - _empty_slots.insert((uint32_t)i); - } - - auto stop = std::chrono::high_resolution_clock::now(); - diskann::cout << "Resizing took: " - << std::chrono::duration(stop - start).count() << "s" - << std::endl; +template void Index::resize(size_t new_max_points) +{ + const size_t new_internal_points = new_max_points + _num_frozen_pts; + auto start = std::chrono::high_resolution_clock::now(); + assert(_empty_slots.size() == 0); // should not resize if there are empty slots. + + _data_store->resize((location_t)new_internal_points); + _graph_store->resize_graph(new_internal_points); + _locks = std::vector(new_internal_points); + + if (_num_frozen_pts != 0) + { + reposition_points((uint32_t)_max_points, (uint32_t)new_max_points, (uint32_t)_num_frozen_pts); + _start = (uint32_t)new_max_points; + } + + _max_points = new_max_points; + _empty_slots.reserve(_max_points); + for (auto i = _nd; i < _max_points; i++) + { + _empty_slots.insert((uint32_t)i); + } + + auto stop = std::chrono::high_resolution_clock::now(); + diskann::cout << "Resizing took: " << std::chrono::duration(stop - start).count() << "s" << std::endl; } template -int Index::_insert_point(const DataType &point, - const TagType tag) { - try { - return this->insert_point(std::any_cast(point), - std::any_cast(tag)); - } catch (const std::bad_any_cast &anycast_e) { - throw new ANNException("Error:Trying to insert invalid data type" + - std::string(anycast_e.what()), - -1); - } catch (const std::exception &e) { - throw new ANNException("Error:" + std::string(e.what()), -1); - } +int Index::_insert_point(const DataType &point, const TagType tag) +{ + try + { + return this->insert_point(std::any_cast(point), std::any_cast(tag)); + } + catch (const std::bad_any_cast &anycast_e) + { + throw new ANNException("Error:Trying to insert invalid data type" + std::string(anycast_e.what()), -1); + } + catch (const std::exception &e) + { + throw new ANNException("Error:" + std::string(e.what()), -1); + } } template -int Index::_insert_point(const DataType &point, - const TagType tag, - Labelvector &labels) { - try { - return this->insert_point(std::any_cast(point), - std::any_cast(tag), - labels.get>()); - } catch (const std::bad_any_cast &anycast_e) { - throw new ANNException("Error:Trying to insert invalid data type" + - std::string(anycast_e.what()), - -1); - } catch (const std::exception &e) { - throw new ANNException("Error:" + std::string(e.what()), -1); - } +int Index::_insert_point(const DataType &point, const TagType tag, Labelvector &labels) +{ + try + { + return this->insert_point(std::any_cast(point), std::any_cast(tag), + labels.get>()); + } + catch (const std::bad_any_cast &anycast_e) + { + throw new ANNException("Error:Trying to insert invalid data type" + std::string(anycast_e.what()), -1); + } + catch (const std::exception &e) + { + throw new ANNException("Error:" + std::string(e.what()), -1); + } } template -int Index::insert_point(const T *point, const TagT tag) { - std::vector no_labels{0}; - return insert_point(point, tag, no_labels); +int Index::insert_point(const T *point, const TagT tag) +{ + std::vector no_labels{0}; + return insert_point(point, tag, no_labels); } template -int Index::insert_point(const T *point, const TagT tag, - const std::vector &labels) { - - assert(_has_built); - if (tag == 0) { - throw diskann::ANNException("Do not insert point with tag 0. That is " - "reserved for points hidden " - "from the user.", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - std::shared_lock shared_ul(_update_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); - - auto location = reserve_location(); - if (_filtered_index) { - if (labels.empty()) { - release_location(location); - std::cerr << "Error: Can't insert point with tag " + get_tag_string(tag) + - " . there are no labels for the point." - << std::endl; - return -1; - } - - _location_to_labels[location] = labels; - - for (LabelT label : labels) { - if (_labels.find(label) == _labels.end()) { - if (_frozen_pts_used >= _num_frozen_pts) { - throw ANNException("Error: For dynamic filtered index, the number of " - "frozen points should be atleast equal " - "to number of unique labels.", - -1); - } - - auto fz_location = - (int)(_max_points) + _frozen_pts_used; // as first _fz_point - _labels.insert(label); - _label_to_start_id[label] = (uint32_t)fz_location; - _location_to_labels[fz_location] = {label}; - _data_store->set_vector((location_t)fz_location, point); - _frozen_pts_used++; - } - } - } - - if (location == -1) { +int Index::insert_point(const T *point, const TagT tag, const std::vector &labels) +{ + + assert(_has_built); + if (tag == 0) + { + throw diskann::ANNException("Do not insert point with tag 0. That is " + "reserved for points hidden " + "from the user.", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + std::shared_lock shared_ul(_update_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); + + auto location = reserve_location(); + if (_filtered_index) + { + if (labels.empty()) + { + release_location(location); + std::cerr << "Error: Can't insert point with tag " + get_tag_string(tag) + + " . there are no labels for the point." + << std::endl; + return -1; + } + + _location_to_labels[location] = labels; + + for (LabelT label : labels) + { + if (_labels.find(label) == _labels.end()) + { + if (_frozen_pts_used >= _num_frozen_pts) + { + throw ANNException("Error: For dynamic filtered index, the number of " + "frozen points should be atleast equal " + "to number of unique labels.", + -1); + } + + auto fz_location = (int)(_max_points) + _frozen_pts_used; // as first _fz_point + _labels.insert(label); + _label_to_start_id[label] = (uint32_t)fz_location; + _location_to_labels[fz_location] = {label}; + _data_store->set_vector((location_t)fz_location, point); + _frozen_pts_used++; + } + } + } + + if (location == -1) + { #if EXPAND_IF_FULL + dl.unlock(); + tl.unlock(); + shared_ul.unlock(); + + { + std::unique_lock ul(_update_lock); + tl.lock(); + dl.lock(); + + if (_nd >= _max_points) + { + auto new_max_points = (size_t)(_max_points * INDEX_GROWTH_FACTOR); + resize(new_max_points); + } + + dl.unlock(); + tl.unlock(); + ul.unlock(); + } + + shared_ul.lock(); + tl.lock(); + dl.lock(); + + location = reserve_location(); + if (location == -1) + { + throw diskann::ANNException("Cannot reserve location even after " + "expanding graph. Terminating.", + -1, __FUNCSIG__, __FILE__, __LINE__); + } +#else + return -1; +#endif + } // cant insert as active pts >= max_pts dl.unlock(); - tl.unlock(); - shared_ul.unlock(); + // Insert tag and mapping to location + if (_enable_tags) { - std::unique_lock ul(_update_lock); - tl.lock(); - dl.lock(); + // if tags are enabled and tag is already inserted. so we can't reuse that + // tag. + if (_tag_to_location.find(tag) != _tag_to_location.end()) + { + release_location(location); + return -1; + } + + _tag_to_location[tag] = location; + _location_to_tag.set(location, tag); + } + tl.unlock(); - if (_nd >= _max_points) { - auto new_max_points = (size_t)(_max_points * INDEX_GROWTH_FACTOR); - resize(new_max_points); - } + _data_store->set_vector(location, point); // update datastore - dl.unlock(); - tl.unlock(); - ul.unlock(); + // Find and add appropriate graph edges + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + std::vector pruned_list; // it is the set best candidates to connect to this point + if (_filtered_index) + { + // when filtered the best_candidates will share the same label ( + // label_present > distance) + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); } + else + { + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch); + } + assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e + // frozen point acting as medoid) - shared_ul.lock(); - tl.lock(); - dl.lock(); + { + std::shared_lock tlock(_tag_lock, std::defer_lock); + if (_conc_consolidate) + tlock.lock(); + + LockGuard guard(_locks[location]); + _graph_store->clear_neighbours(location); + + std::vector neighbor_links; + for (auto link : pruned_list) + { + if (_conc_consolidate) + if (!_location_to_tag.contains(link)) + continue; + neighbor_links.emplace_back(link); + } + _graph_store->set_neighbours(location, neighbor_links); + assert(_graph_store->get_neighbours(location).size() <= _indexingRange); - location = reserve_location(); - if (location == -1) { - throw diskann::ANNException("Cannot reserve location even after " - "expanding graph. Terminating.", - -1, __FUNCSIG__, __FILE__, __LINE__); + if (_conc_consolidate) + tlock.unlock(); } -#else - return -1; -#endif - } // cant insert as active pts >= max_pts - dl.unlock(); - - // Insert tag and mapping to location - if (_enable_tags) { - // if tags are enabled and tag is already inserted. so we can't reuse that - // tag. - if (_tag_to_location.find(tag) != _tag_to_location.end()) { - release_location(location); - return -1; - } - - _tag_to_location[tag] = location; - _location_to_tag.set(location, tag); - } - tl.unlock(); - - _data_store->set_vector(location, point); // update datastore - - // Find and add appropriate graph edges - ScratchStoreManager> manager(_query_scratch); - auto scratch = manager.scratch_space(); - std::vector - pruned_list; // it is the set best candidates to connect to this point - if (_filtered_index) { - // when filtered the best_candidates will share the same label ( - // label_present > distance) - search_for_point_and_prune(location, _indexingQueueSize, pruned_list, - scratch, true, _filterIndexingQueueSize); - } else { - search_for_point_and_prune(location, _indexingQueueSize, pruned_list, - scratch); - } - assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e - // frozen point acting as medoid) - - { - std::shared_lock tlock(_tag_lock, std::defer_lock); - if (_conc_consolidate) - tlock.lock(); - - LockGuard guard(_locks[location]); - _graph_store->clear_neighbours(location); - - std::vector neighbor_links; - for (auto link : pruned_list) { - if (_conc_consolidate) - if (!_location_to_tag.contains(link)) - continue; - neighbor_links.emplace_back(link); - } - _graph_store->set_neighbours(location, neighbor_links); - assert(_graph_store->get_neighbours(location).size() <= _indexingRange); - - if (_conc_consolidate) - tlock.unlock(); - } - - inter_insert(location, pruned_list, scratch); - - return 0; + + inter_insert(location, pruned_list, scratch); + + return 0; } -template -int Index::_lazy_delete(const TagType &tag) { - try { - return lazy_delete(std::any_cast(tag)); - } catch (const std::bad_any_cast &e) { - throw ANNException(std::string("Error: ") + e.what(), -1); - } +template int Index::_lazy_delete(const TagType &tag) +{ + try + { + return lazy_delete(std::any_cast(tag)); + } + catch (const std::bad_any_cast &e) + { + throw ANNException(std::string("Error: ") + e.what(), -1); + } } template -void Index::_lazy_delete(TagVector &tags, - TagVector &failed_tags) { - try { - this->lazy_delete(tags.get>(), - failed_tags.get>()); - } catch (const std::bad_any_cast &e) { - throw ANNException("Error: bad any cast while performing _lazy_delete() " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error: " + std::string(e.what()), -1); - } +void Index::_lazy_delete(TagVector &tags, TagVector &failed_tags) +{ + try + { + this->lazy_delete(tags.get>(), failed_tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _lazy_delete() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } -template -int Index::lazy_delete(const TagT &tag) { - std::shared_lock ul(_update_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); - _data_compacted = false; - - if (_tag_to_location.find(tag) == _tag_to_location.end()) { - diskann::cerr << "Delete tag not found " << get_tag_string(tag) - << std::endl; - return -1; - } - assert(_tag_to_location[tag] < _max_points); +template int Index::lazy_delete(const TagT &tag) +{ + std::shared_lock ul(_update_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); + _data_compacted = false; - const auto location = _tag_to_location[tag]; - _delete_set->insert(location); - _location_to_tag.erase(location); - _tag_to_location.erase(tag); - return 0; + if (_tag_to_location.find(tag) == _tag_to_location.end()) + { + diskann::cerr << "Delete tag not found " << get_tag_string(tag) << std::endl; + return -1; + } + assert(_tag_to_location[tag] < _max_points); + + const auto location = _tag_to_location[tag]; + _delete_set->insert(location); + _location_to_tag.erase(location); + _tag_to_location.erase(tag); + return 0; } template -void Index::lazy_delete(const std::vector &tags, - std::vector &failed_tags) { - if (failed_tags.size() > 0) { - throw ANNException("failed_tags should be passed as an empty list", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - std::shared_lock ul(_update_lock); - std::unique_lock tl(_tag_lock); - std::unique_lock dl(_delete_lock); - _data_compacted = false; - - for (auto tag : tags) { - if (_tag_to_location.find(tag) == _tag_to_location.end()) { - failed_tags.push_back(tag); - } else { - const auto location = _tag_to_location[tag]; - _delete_set->insert(location); - _location_to_tag.erase(location); - _tag_to_location.erase(tag); - } - } +void Index::lazy_delete(const std::vector &tags, std::vector &failed_tags) +{ + if (failed_tags.size() > 0) + { + throw ANNException("failed_tags should be passed as an empty list", -1, __FUNCSIG__, __FILE__, __LINE__); + } + std::shared_lock ul(_update_lock); + std::unique_lock tl(_tag_lock); + std::unique_lock dl(_delete_lock); + _data_compacted = false; + + for (auto tag : tags) + { + if (_tag_to_location.find(tag) == _tag_to_location.end()) + { + failed_tags.push_back(tag); + } + else + { + const auto location = _tag_to_location[tag]; + _delete_set->insert(location); + _location_to_tag.erase(location); + _tag_to_location.erase(tag); + } + } } -template -bool Index::is_index_saved() { - return _is_saved; +template bool Index::is_index_saved() +{ + return _is_saved; } template -void Index::_get_active_tags(TagRobinSet &active_tags) { - try { - this->get_active_tags(active_tags.get>()); - } catch (const std::bad_any_cast &e) { - throw ANNException( - "Error: bad_any cast while performing _get_active_tags() " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error :" + std::string(e.what()), -1); - } +void Index::_get_active_tags(TagRobinSet &active_tags) +{ + try + { + this->get_active_tags(active_tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad_any cast while performing _get_active_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error :" + std::string(e.what()), -1); + } } template -void Index::get_active_tags( - tsl::robin_set &active_tags) { - active_tags.clear(); - std::shared_lock tl(_tag_lock); - for (auto iter : _tag_to_location) { - active_tags.insert(iter.first); - } +void Index::get_active_tags(tsl::robin_set &active_tags) +{ + active_tags.clear(); + std::shared_lock tl(_tag_lock); + for (auto iter : _tag_to_location) + { + active_tags.insert(iter.first); + } } -template -void Index::print_status() { - std::shared_lock ul(_update_lock); - std::shared_lock cl(_consolidate_lock); - std::shared_lock tl(_tag_lock); - std::shared_lock dl(_delete_lock); - - diskann::cout << "------------------- Index object: " << (uint64_t)this - << " -------------------" << std::endl; - diskann::cout << "Number of points: " << _nd << std::endl; - diskann::cout << "Graph size: " << _graph_store->get_total_points() - << std::endl; - diskann::cout << "Location to tag size: " << _location_to_tag.size() - << std::endl; - diskann::cout << "Tag to location size: " << _tag_to_location.size() - << std::endl; - diskann::cout << "Number of empty slots: " << _empty_slots.size() - << std::endl; - diskann::cout << std::boolalpha << "Data compacted: " << this->_data_compacted - << std::endl; - diskann::cout << "---------------------------------------------------------" - "------------" - << std::endl; +template void Index::print_status() +{ + std::shared_lock ul(_update_lock); + std::shared_lock cl(_consolidate_lock); + std::shared_lock tl(_tag_lock); + std::shared_lock dl(_delete_lock); + + diskann::cout << "------------------- Index object: " << (uint64_t)this << " -------------------" << std::endl; + diskann::cout << "Number of points: " << _nd << std::endl; + diskann::cout << "Graph size: " << _graph_store->get_total_points() << std::endl; + diskann::cout << "Location to tag size: " << _location_to_tag.size() << std::endl; + diskann::cout << "Tag to location size: " << _tag_to_location.size() << std::endl; + diskann::cout << "Number of empty slots: " << _empty_slots.size() << std::endl; + diskann::cout << std::boolalpha << "Data compacted: " << this->_data_compacted << std::endl; + diskann::cout << "---------------------------------------------------------" + "------------" + << std::endl; } -template -void Index::count_nodes_at_bfs_levels() { - std::unique_lock ul(_update_lock); +template void Index::count_nodes_at_bfs_levels() +{ + std::unique_lock ul(_update_lock); - boost::dynamic_bitset<> visited(_max_points + _num_frozen_pts); + boost::dynamic_bitset<> visited(_max_points + _num_frozen_pts); - size_t MAX_BFS_LEVELS = 32; - auto bfs_sets = new tsl::robin_set[MAX_BFS_LEVELS]; + size_t MAX_BFS_LEVELS = 32; + auto bfs_sets = new tsl::robin_set[MAX_BFS_LEVELS]; - bfs_sets[0].insert(_start); - visited.set(_start); + bfs_sets[0].insert(_start); + visited.set(_start); - for (uint32_t i = (uint32_t)_max_points; i < _max_points + _num_frozen_pts; - ++i) { - if (i != _start) { - bfs_sets[0].insert(i); - visited.set(i); + for (uint32_t i = (uint32_t)_max_points; i < _max_points + _num_frozen_pts; ++i) + { + if (i != _start) + { + bfs_sets[0].insert(i); + visited.set(i); + } } - } - for (size_t l = 0; l < MAX_BFS_LEVELS - 1; ++l) { - diskann::cout << "Number of nodes at BFS level " << l << " is " - << bfs_sets[l].size() << std::endl; - if (bfs_sets[l].size() == 0) - break; - for (auto node : bfs_sets[l]) { - for (auto nghbr : _graph_store->get_neighbours((location_t)node)) { - if (!visited.test(nghbr)) { - visited.set(nghbr); - bfs_sets[l + 1].insert(nghbr); + for (size_t l = 0; l < MAX_BFS_LEVELS - 1; ++l) + { + diskann::cout << "Number of nodes at BFS level " << l << " is " << bfs_sets[l].size() << std::endl; + if (bfs_sets[l].size() == 0) + break; + for (auto node : bfs_sets[l]) + { + for (auto nghbr : _graph_store->get_neighbours((location_t)node)) + { + if (!visited.test(nghbr)) + { + visited.set(nghbr); + bfs_sets[l + 1].insert(nghbr); + } + } } - } } - } - delete[] bfs_sets; + delete[] bfs_sets; } // REFACTOR: This should be an OptimizedDataStore class -template -void Index::optimize_index_layout() { // use after build or load - if (_dynamic_index) { - throw diskann::ANNException( - "Optimize_index_layout not implemented for dyanmic indices", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - float *cur_vec = new float[_data_store->get_aligned_dim()]; - std::memset(cur_vec, 0, _data_store->get_aligned_dim() * sizeof(float)); - _data_len = (_data_store->get_aligned_dim() + 1) * sizeof(float); - _neighbor_len = - (_graph_store->get_max_observed_degree() + 1) * sizeof(uint32_t); - _node_size = _data_len + _neighbor_len; - _opt_graph = new char[_node_size * _nd]; - auto dist_fast = (DistanceFastL2 *)(_data_store->get_dist_fn()); - for (uint32_t i = 0; i < _nd; i++) { - char *cur_node_offset = _opt_graph + i * _node_size; - _data_store->get_vector(i, (T *)cur_vec); - float cur_norm = - dist_fast->norm((T *)cur_vec, (uint32_t)_data_store->get_aligned_dim()); - std::memcpy(cur_node_offset, &cur_norm, sizeof(float)); - std::memcpy(cur_node_offset + sizeof(float), cur_vec, - _data_len - sizeof(float)); - - cur_node_offset += _data_len; - uint32_t k = (uint32_t)_graph_store->get_neighbours(i).size(); - std::memcpy(cur_node_offset, &k, sizeof(uint32_t)); - std::memcpy(cur_node_offset + sizeof(uint32_t), - _graph_store->get_neighbours(i).data(), k * sizeof(uint32_t)); - // std::vector().swap(_graph_store->get_neighbours(i)); - _graph_store->clear_neighbours(i); - } - _graph_store->clear_graph(); - _graph_store->resize_graph(0); - delete[] cur_vec; +template void Index::optimize_index_layout() +{ // use after build or load + if (_dynamic_index) + { + throw diskann::ANNException("Optimize_index_layout not implemented for dyanmic indices", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + + float *cur_vec = new float[_data_store->get_aligned_dim()]; + std::memset(cur_vec, 0, _data_store->get_aligned_dim() * sizeof(float)); + _data_len = (_data_store->get_aligned_dim() + 1) * sizeof(float); + _neighbor_len = (_graph_store->get_max_observed_degree() + 1) * sizeof(uint32_t); + _node_size = _data_len + _neighbor_len; + _opt_graph = new char[_node_size * _nd]; + auto dist_fast = (DistanceFastL2 *)(_data_store->get_dist_fn()); + for (uint32_t i = 0; i < _nd; i++) + { + char *cur_node_offset = _opt_graph + i * _node_size; + _data_store->get_vector(i, (T *)cur_vec); + float cur_norm = dist_fast->norm((T *)cur_vec, (uint32_t)_data_store->get_aligned_dim()); + std::memcpy(cur_node_offset, &cur_norm, sizeof(float)); + std::memcpy(cur_node_offset + sizeof(float), cur_vec, _data_len - sizeof(float)); + + cur_node_offset += _data_len; + uint32_t k = (uint32_t)_graph_store->get_neighbours(i).size(); + std::memcpy(cur_node_offset, &k, sizeof(uint32_t)); + std::memcpy(cur_node_offset + sizeof(uint32_t), _graph_store->get_neighbours(i).data(), k * sizeof(uint32_t)); + // std::vector().swap(_graph_store->get_neighbours(i)); + _graph_store->clear_neighbours(i); + } + _graph_store->clear_graph(); + _graph_store->resize_graph(0); + delete[] cur_vec; } template -void Index::_search_with_optimized_layout( - const DataType &query, size_t K, size_t L, uint32_t *indices) { - try { - return this->search_with_optimized_layout(std::any_cast(query), - K, L, indices); - } catch (const std::bad_any_cast &e) { - throw ANNException("Error: bad any cast while performing " - "_search_with_optimized_layout() " + - std::string(e.what()), - -1); - } catch (const std::exception &e) { - throw ANNException("Error: " + std::string(e.what()), -1); - } +void Index::_search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) +{ + try + { + return this->search_with_optimized_layout(std::any_cast(query), K, L, indices); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing " + "_search_with_optimized_layout() " + + std::string(e.what()), + -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } template -void Index::search_with_optimized_layout(const T *query, - size_t K, size_t L, - uint32_t *indices) { - DistanceFastL2 *dist_fast = - (DistanceFastL2 *)(_data_store->get_dist_fn()); - - NeighborPriorityQueue retset(L); - std::vector init_ids(L); - - boost::dynamic_bitset<> flags{_nd, 0}; - uint32_t tmp_l = 0; - uint32_t *neighbors = - (uint32_t *)(_opt_graph + _node_size * _start + _data_len); - uint32_t MaxM_ep = *neighbors; - neighbors++; - - for (; tmp_l < L && tmp_l < MaxM_ep; tmp_l++) { - init_ids[tmp_l] = neighbors[tmp_l]; - flags[init_ids[tmp_l]] = true; - } - - while (tmp_l < L) { - uint32_t id = rand() % _nd; - if (flags[id]) - continue; - flags[id] = true; - init_ids[tmp_l] = id; - tmp_l++; - } - - for (uint32_t i = 0; i < init_ids.size(); i++) { - uint32_t id = init_ids[i]; - if (id >= _nd) - continue; - _mm_prefetch(_opt_graph + _node_size * id, _MM_HINT_T0); - } - L = 0; - for (uint32_t i = 0; i < init_ids.size(); i++) { - uint32_t id = init_ids[i]; - if (id >= _nd) - continue; - T *x = (T *)(_opt_graph + _node_size * id); - float norm_x = *x; - x++; - float dist = dist_fast->compare(x, query, norm_x, - (uint32_t)_data_store->get_aligned_dim()); - retset.insert(Neighbor(id, dist)); - flags[id] = true; - L++; - } - - while (retset.has_unexpanded_node()) { - auto nbr = retset.closest_unexpanded(); - auto n = nbr.id; - _mm_prefetch(_opt_graph + _node_size * n + _data_len, _MM_HINT_T0); - neighbors = (uint32_t *)(_opt_graph + _node_size * n + _data_len); - uint32_t MaxM = *neighbors; +void Index::search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices) +{ + DistanceFastL2 *dist_fast = (DistanceFastL2 *)(_data_store->get_dist_fn()); + + NeighborPriorityQueue retset(L); + std::vector init_ids(L); + + boost::dynamic_bitset<> flags{_nd, 0}; + uint32_t tmp_l = 0; + uint32_t *neighbors = (uint32_t *)(_opt_graph + _node_size * _start + _data_len); + uint32_t MaxM_ep = *neighbors; neighbors++; - for (uint32_t m = 0; m < MaxM; ++m) - _mm_prefetch(_opt_graph + _node_size * neighbors[m], _MM_HINT_T0); - for (uint32_t m = 0; m < MaxM; ++m) { - uint32_t id = neighbors[m]; - if (flags[id]) - continue; - flags[id] = 1; - T *data = (T *)(_opt_graph + _node_size * id); - float norm = *data; - data++; - float dist = dist_fast->compare(query, data, norm, - (uint32_t)_data_store->get_aligned_dim()); - Neighbor nn(id, dist); - retset.insert(nn); - } - } - - for (size_t i = 0; i < K; i++) { - indices[i] = retset[i].id; - } + + for (; tmp_l < L && tmp_l < MaxM_ep; tmp_l++) + { + init_ids[tmp_l] = neighbors[tmp_l]; + flags[init_ids[tmp_l]] = true; + } + + while (tmp_l < L) + { + uint32_t id = rand() % _nd; + if (flags[id]) + continue; + flags[id] = true; + init_ids[tmp_l] = id; + tmp_l++; + } + + for (uint32_t i = 0; i < init_ids.size(); i++) + { + uint32_t id = init_ids[i]; + if (id >= _nd) + continue; + _mm_prefetch(_opt_graph + _node_size * id, _MM_HINT_T0); + } + L = 0; + for (uint32_t i = 0; i < init_ids.size(); i++) + { + uint32_t id = init_ids[i]; + if (id >= _nd) + continue; + T *x = (T *)(_opt_graph + _node_size * id); + float norm_x = *x; + x++; + float dist = dist_fast->compare(x, query, norm_x, (uint32_t)_data_store->get_aligned_dim()); + retset.insert(Neighbor(id, dist)); + flags[id] = true; + L++; + } + + while (retset.has_unexpanded_node()) + { + auto nbr = retset.closest_unexpanded(); + auto n = nbr.id; + _mm_prefetch(_opt_graph + _node_size * n + _data_len, _MM_HINT_T0); + neighbors = (uint32_t *)(_opt_graph + _node_size * n + _data_len); + uint32_t MaxM = *neighbors; + neighbors++; + for (uint32_t m = 0; m < MaxM; ++m) + _mm_prefetch(_opt_graph + _node_size * neighbors[m], _MM_HINT_T0); + for (uint32_t m = 0; m < MaxM; ++m) + { + uint32_t id = neighbors[m]; + if (flags[id]) + continue; + flags[id] = 1; + T *data = (T *)(_opt_graph + _node_size * id); + float norm = *data; + data++; + float dist = dist_fast->compare(query, data, norm, (uint32_t)_data_store->get_aligned_dim()); + Neighbor nn(id, dist); + retset.insert(nn); + } + } + + for (size_t i = 0; i < K; i++) + { + indices[i] = retset[i].id; + } } /* Internals of the library */ -template -const float Index::INDEX_GROWTH_FACTOR = 1.5f; +template const float Index::INDEX_GROWTH_FACTOR = 1.5f; // EXPORTS template DISKANN_DLLEXPORT class Index; @@ -3203,252 +3360,132 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); - -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); // TagT==uint32_t -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint32_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); - -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); + +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const float *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const uint8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint64_t *indices, - float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search(const int8_t *query, - const size_t K, - const uint32_t L, - uint32_t *indices, - float *distances); - -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); // TagT==uint32_t -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const float *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const uint8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair -Index::search_with_filters( - const int8_t *query, const uint16_t &filter_label, const size_t K, - const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); } // namespace diskann diff --git a/src/index_factory.cpp b/src/index_factory.cpp index b24de48b1..08b89da1d 100644 --- a/src/index_factory.cpp +++ b/src/index_factory.cpp @@ -1,196 +1,212 @@ #include "index_factory.h" #include "pq_l2_distance.h" -namespace diskann { +namespace diskann +{ -IndexFactory::IndexFactory(const IndexConfig &config) - : _config(std::make_unique(config)) { - check_config(); +IndexFactory::IndexFactory(const IndexConfig &config) : _config(std::make_unique(config)) +{ + check_config(); } -std::unique_ptr IndexFactory::create_instance() { - return create_instance(_config->data_type, _config->tag_type, - _config->label_type); +std::unique_ptr IndexFactory::create_instance() +{ + return create_instance(_config->data_type, _config->tag_type, _config->label_type); } -void IndexFactory::check_config() { - if (_config->dynamic_index && !_config->enable_tags) { - throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - - if (_config->pq_dist_build) { - if (_config->dynamic_index) - throw ANNException( - "ERROR: Dynamic Indexing not supported with PQ distance based " - "index construction", - -1, __FUNCSIG__, __FILE__, __LINE__); - if (_config->metric == diskann::Metric::INNER_PRODUCT) - throw ANNException("ERROR: Inner product metrics not yet supported " - "with PQ distance " - "base index", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (_config->data_type != "float" && _config->data_type != "uint8" && - _config->data_type != "int8") { - throw ANNException( - "ERROR: invalid data type : + " + _config->data_type + - " is not supported. please select from [float, int8, uint8]", - -1); - } - - if (_config->tag_type != "int32" && _config->tag_type != "uint32" && - _config->tag_type != "int64" && _config->tag_type != "uint64") { - throw ANNException("ERROR: invalid data type : + " + _config->tag_type + - " is not supported. please select from [int32, " - "uint32, int64, uint64]", - -1); - } +void IndexFactory::check_config() +{ + if (_config->dynamic_index && !_config->enable_tags) + { + throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_config->pq_dist_build) + { + if (_config->dynamic_index) + throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based " + "index construction", + -1, __FUNCSIG__, __FILE__, __LINE__); + if (_config->metric == diskann::Metric::INNER_PRODUCT) + throw ANNException("ERROR: Inner product metrics not yet supported " + "with PQ distance " + "base index", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_config->data_type != "float" && _config->data_type != "uint8" && _config->data_type != "int8") + { + throw ANNException("ERROR: invalid data type : + " + _config->data_type + + " is not supported. please select from [float, int8, uint8]", + -1); + } + + if (_config->tag_type != "int32" && _config->tag_type != "uint32" && _config->tag_type != "int64" && + _config->tag_type != "uint64") + { + throw ANNException("ERROR: invalid data type : + " + _config->tag_type + + " is not supported. please select from [int32, " + "uint32, int64, uint64]", + -1); + } } -template -Distance *IndexFactory::construct_inmem_distance_fn(Metric metric) { - if (metric == diskann::Metric::COSINE && std::is_same::value) { - return (Distance *)new AVXNormalizedCosineDistanceFloat(); - } else { - return (Distance *)get_distance_function(metric); - } +template Distance *IndexFactory::construct_inmem_distance_fn(Metric metric) +{ + if (metric == diskann::Metric::COSINE && std::is_same::value) + { + return (Distance *)new AVXNormalizedCosineDistanceFloat(); + } + else + { + return (Distance *)get_distance_function(metric); + } } template -std::shared_ptr> -IndexFactory::construct_datastore(DataStoreStrategy strategy, - size_t total_internal_points, - size_t dimension, Metric metric) { - std::unique_ptr> distance; - switch (strategy) { - case DataStoreStrategy::MEMORY: - distance.reset(construct_inmem_distance_fn(metric)); - return std::make_shared>( - (location_t)total_internal_points, dimension, std::move(distance)); - default: - break; - } - return nullptr; +std::shared_ptr> IndexFactory::construct_datastore(DataStoreStrategy strategy, + size_t total_internal_points, size_t dimension, + Metric metric) +{ + std::unique_ptr> distance; + switch (strategy) + { + case DataStoreStrategy::MEMORY: + distance.reset(construct_inmem_distance_fn(metric)); + return std::make_shared>((location_t)total_internal_points, dimension, + std::move(distance)); + default: + break; + } + return nullptr; } -std::unique_ptr -IndexFactory::construct_graphstore(const GraphStoreStrategy strategy, - const size_t size, - const size_t reserve_graph_degree) { - switch (strategy) { - case GraphStoreStrategy::MEMORY: - return std::make_unique(size, reserve_graph_degree); - default: - throw ANNException("Error : Current GraphStoreStratagy is not supported.", - -1); - } +std::unique_ptr IndexFactory::construct_graphstore(const GraphStoreStrategy strategy, + const size_t size, + const size_t reserve_graph_degree) +{ + switch (strategy) + { + case GraphStoreStrategy::MEMORY: + return std::make_unique(size, reserve_graph_degree); + default: + throw ANNException("Error : Current GraphStoreStratagy is not supported.", -1); + } } template -std::shared_ptr> IndexFactory::construct_pq_datastore( - DataStoreStrategy strategy, size_t num_points, size_t dimension, Metric m, - size_t num_pq_chunks, bool use_opq) { - std::unique_ptr> distance_fn; - std::unique_ptr> quantized_distance_fn; - - quantized_distance_fn = std::move( - std::make_unique>((uint32_t)num_pq_chunks, use_opq)); - switch (strategy) { - case DataStoreStrategy::MEMORY: - distance_fn.reset(construct_inmem_distance_fn(m)); - return std::make_shared>( - dimension, (location_t)(num_points), num_pq_chunks, - std::move(distance_fn), std::move(quantized_distance_fn)); - default: - // REFACTOR TODO: We do support diskPQ - so we may need to add a new class - // for SSDPQDataStore! - break; - } - return nullptr; +std::shared_ptr> IndexFactory::construct_pq_datastore(DataStoreStrategy strategy, size_t num_points, + size_t dimension, Metric m, size_t num_pq_chunks, + bool use_opq) +{ + std::unique_ptr> distance_fn; + std::unique_ptr> quantized_distance_fn; + + quantized_distance_fn = std::move(std::make_unique>((uint32_t)num_pq_chunks, use_opq)); + switch (strategy) + { + case DataStoreStrategy::MEMORY: + distance_fn.reset(construct_inmem_distance_fn(m)); + return std::make_shared>(dimension, (location_t)(num_points), num_pq_chunks, + std::move(distance_fn), std::move(quantized_distance_fn)); + default: + // REFACTOR TODO: We do support diskPQ - so we may need to add a new class + // for SSDPQDataStore! + break; + } + return nullptr; } template -std::unique_ptr IndexFactory::create_instance() { - size_t num_points = _config->max_points + _config->num_frozen_pts; - size_t dim = _config->dimension; - // auto graph_store = construct_graphstore(_config->graph_strategy, - // num_points); - auto data_store = construct_datastore( - _config->data_strategy, num_points, dim, _config->metric); - std::shared_ptr> pq_data_store = nullptr; - - if (_config->data_strategy == DataStoreStrategy::MEMORY && - _config->pq_dist_build) { - pq_data_store = construct_pq_datastore( - _config->data_strategy, num_points + _config->num_frozen_pts, dim, - _config->metric, _config->num_pq_chunks, _config->use_opq); - } else { - pq_data_store = data_store; - } - size_t max_reserve_degree = - (size_t)(defaults::GRAPH_SLACK_FACTOR * 1.05 * - (_config->index_write_params == nullptr - ? 0 - : _config->index_write_params->max_degree)); - std::unique_ptr graph_store = construct_graphstore( - _config->graph_strategy, num_points + _config->num_frozen_pts, - max_reserve_degree); - - // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK - // and must construct in-mem and on-disk PQDataStore if strategy == ONDISK and - // diskPQ is required. - return std::make_unique>( - *_config, data_store, std::move(graph_store), pq_data_store); +std::unique_ptr IndexFactory::create_instance() +{ + size_t num_points = _config->max_points + _config->num_frozen_pts; + size_t dim = _config->dimension; + // auto graph_store = construct_graphstore(_config->graph_strategy, + // num_points); + auto data_store = construct_datastore(_config->data_strategy, num_points, dim, _config->metric); + std::shared_ptr> pq_data_store = nullptr; + + if (_config->data_strategy == DataStoreStrategy::MEMORY && _config->pq_dist_build) + { + pq_data_store = + construct_pq_datastore(_config->data_strategy, num_points + _config->num_frozen_pts, dim, + _config->metric, _config->num_pq_chunks, _config->use_opq); + } + else + { + pq_data_store = data_store; + } + size_t max_reserve_degree = + (size_t)(defaults::GRAPH_SLACK_FACTOR * 1.05 * + (_config->index_write_params == nullptr ? 0 : _config->index_write_params->max_degree)); + std::unique_ptr graph_store = + construct_graphstore(_config->graph_strategy, num_points + _config->num_frozen_pts, max_reserve_degree); + + // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK + // and must construct in-mem and on-disk PQDataStore if strategy == ONDISK and + // diskPQ is required. + return std::make_unique>(*_config, data_store, + std::move(graph_store), pq_data_store); } -std::unique_ptr -IndexFactory::create_instance(const std::string &data_type, - const std::string &tag_type, - const std::string &label_type) { - if (data_type == std::string("float")) { - return create_instance(tag_type, label_type); - } else if (data_type == std::string("uint8")) { - return create_instance(tag_type, label_type); - } else if (data_type == std::string("int8")) { - return create_instance(tag_type, label_type); - } else - throw ANNException( - "Error: unsupported data_type please choose from [float/int8/uint8]", - -1); +std::unique_ptr IndexFactory::create_instance(const std::string &data_type, const std::string &tag_type, + const std::string &label_type) +{ + if (data_type == std::string("float")) + { + return create_instance(tag_type, label_type); + } + else if (data_type == std::string("uint8")) + { + return create_instance(tag_type, label_type); + } + else if (data_type == std::string("int8")) + { + return create_instance(tag_type, label_type); + } + else + throw ANNException("Error: unsupported data_type please choose from [float/int8/uint8]", -1); } template -std::unique_ptr -IndexFactory::create_instance(const std::string &tag_type, - const std::string &label_type) { - if (tag_type == std::string("int32")) { - return create_instance(label_type); - } else if (tag_type == std::string("uint32")) { - return create_instance(label_type); - } else if (tag_type == std::string("int64")) { - return create_instance(label_type); - } else if (tag_type == std::string("uint64")) { - return create_instance(label_type); - } else - throw ANNException("Error: unsupported tag_type please choose from " - "[int32/uint32/int64/uint64]", - -1); +std::unique_ptr IndexFactory::create_instance(const std::string &tag_type, const std::string &label_type) +{ + if (tag_type == std::string("int32")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("uint32")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("int64")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("uint64")) + { + return create_instance(label_type); + } + else + throw ANNException("Error: unsupported tag_type please choose from " + "[int32/uint32/int64/uint64]", + -1); } template -std::unique_ptr -IndexFactory::create_instance(const std::string &label_type) { - if (label_type == std::string("uint16") || - label_type == std::string("ushort")) { - return create_instance(); - } else if (label_type == std::string("uint32") || - label_type == std::string("uint")) { - return create_instance(); - } else - throw ANNException( - "Error: unsupported label_type please choose from [uint/ushort]", -1); +std::unique_ptr IndexFactory::create_instance(const std::string &label_type) +{ + if (label_type == std::string("uint16") || label_type == std::string("ushort")) + { + return create_instance(); + } + else if (label_type == std::string("uint32") || label_type == std::string("uint")) + { + return create_instance(); + } + else + throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1); } // template DISKANN_DLLEXPORT std::shared_ptr> diff --git a/src/linux_aligned_file_reader.cpp b/src/linux_aligned_file_reader.cpp index 317b20d03..94e14dc08 100644 --- a/src/linux_aligned_file_reader.cpp +++ b/src/linux_aligned_file_reader.cpp @@ -10,192 +10,221 @@ #include #define MAX_EVENTS 1024 -namespace { +namespace +{ typedef struct io_event io_event_t; typedef struct iocb iocb_t; -void execute_io(io_context_t ctx, int fd, std::vector &read_reqs, - uint64_t n_retries = 0) { +void execute_io(io_context_t ctx, int fd, std::vector &read_reqs, uint64_t n_retries = 0) +{ #ifdef DEBUG - for (auto &req : read_reqs) { - assert(IS_ALIGNED(req.len, 512)); - // std::cout << "request:"<= req.len); - } + for (auto &req : read_reqs) + { + assert(IS_ALIGNED(req.len, 512)); + // std::cout << "request:"<= req.len); + } #endif - // break-up requests into chunks of size MAX_EVENTS each - uint64_t n_iters = ROUND_UP(read_reqs.size(), MAX_EVENTS) / MAX_EVENTS; - for (uint64_t iter = 0; iter < n_iters; iter++) { - uint64_t n_ops = std::min((uint64_t)read_reqs.size() - (iter * MAX_EVENTS), - (uint64_t)MAX_EVENTS); - std::vector cbs(n_ops, nullptr); - std::vector evts(n_ops); - std::vector cb(n_ops); - for (uint64_t j = 0; j < n_ops; j++) { - io_prep_pread(cb.data() + j, fd, read_reqs[j + iter * MAX_EVENTS].buf, - read_reqs[j + iter * MAX_EVENTS].len, - read_reqs[j + iter * MAX_EVENTS].offset); - } + // break-up requests into chunks of size MAX_EVENTS each + uint64_t n_iters = ROUND_UP(read_reqs.size(), MAX_EVENTS) / MAX_EVENTS; + for (uint64_t iter = 0; iter < n_iters; iter++) + { + uint64_t n_ops = std::min((uint64_t)read_reqs.size() - (iter * MAX_EVENTS), (uint64_t)MAX_EVENTS); + std::vector cbs(n_ops, nullptr); + std::vector evts(n_ops); + std::vector cb(n_ops); + for (uint64_t j = 0; j < n_ops; j++) + { + io_prep_pread(cb.data() + j, fd, read_reqs[j + iter * MAX_EVENTS].buf, read_reqs[j + iter * MAX_EVENTS].len, + read_reqs[j + iter * MAX_EVENTS].offset); + } - // initialize `cbs` using `cb` array - // + // initialize `cbs` using `cb` array + // - for (uint64_t i = 0; i < n_ops; i++) { - cbs[i] = cb.data() + i; - } + for (uint64_t i = 0; i < n_ops; i++) + { + cbs[i] = cb.data() + i; + } - uint64_t n_tries = 0; - while (n_tries <= n_retries) { - // issue reads - int64_t ret = io_submit(ctx, (int64_t)n_ops, cbs.data()); - // if requests didn't get accepted - if (ret != (int64_t)n_ops) { - std::cerr << "io_submit() failed; returned " << ret - << ", expected=" << n_ops << ", ernno=" << errno << "=" - << ::strerror(-ret) << ", try #" << n_tries + 1; - std::cout << "ctx: " << ctx << "\n"; - exit(-1); - } else { - // wait on io_getevents - ret = io_getevents(ctx, (int64_t)n_ops, (int64_t)n_ops, evts.data(), - nullptr); - // if requests didn't complete - if (ret != (int64_t)n_ops) { - std::cerr << "io_getevents() failed; returned " << ret - << ", expected=" << n_ops << ", ernno=" << errno << "=" - << ::strerror(-ret) << ", try #" << n_tries + 1; - exit(-1); - } else { - break; + uint64_t n_tries = 0; + while (n_tries <= n_retries) + { + // issue reads + int64_t ret = io_submit(ctx, (int64_t)n_ops, cbs.data()); + // if requests didn't get accepted + if (ret != (int64_t)n_ops) + { + std::cerr << "io_submit() failed; returned " << ret << ", expected=" << n_ops << ", ernno=" << errno + << "=" << ::strerror(-ret) << ", try #" << n_tries + 1; + std::cout << "ctx: " << ctx << "\n"; + exit(-1); + } + else + { + // wait on io_getevents + ret = io_getevents(ctx, (int64_t)n_ops, (int64_t)n_ops, evts.data(), nullptr); + // if requests didn't complete + if (ret != (int64_t)n_ops) + { + std::cerr << "io_getevents() failed; returned " << ret << ", expected=" << n_ops + << ", ernno=" << errno << "=" << ::strerror(-ret) << ", try #" << n_tries + 1; + exit(-1); + } + else + { + break; + } + } } - } - } - // disabled since req.buf could be an offset into another buf - /* - for (auto &req : read_reqs) { - // corruption check - assert(malloc_usable_size(req.buf) >= req.len); + // disabled since req.buf could be an offset into another buf + /* + for (auto &req : read_reqs) { + // corruption check + assert(malloc_usable_size(req.buf) >= req.len); + } + */ } - */ - } } } // namespace -LinuxAlignedFileReader::LinuxAlignedFileReader() { this->file_desc = -1; } - -LinuxAlignedFileReader::~LinuxAlignedFileReader() { - int64_t ret; - // check to make sure file_desc is closed - ret = ::fcntl(this->file_desc, F_GETFD); - if (ret == -1) { - if (errno != EBADF) { - std::cerr << "close() not called" << std::endl; - // close file desc - ret = ::close(this->file_desc); - // error checks - if (ret == -1) { - std::cerr << "close() failed; returned " << ret << ", errno=" << errno - << ":" << ::strerror(errno) << std::endl; - } - } - } +LinuxAlignedFileReader::LinuxAlignedFileReader() +{ + this->file_desc = -1; } -io_context_t &LinuxAlignedFileReader::get_ctx() { - std::unique_lock lk(ctx_mut); - // perform checks only in DEBUG mode - if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) { - std::cerr << "bad thread access; returning -1 as io_context_t" << std::endl; - return this->bad_ctx; - } else { - return ctx_map[std::this_thread::get_id()]; - } +LinuxAlignedFileReader::~LinuxAlignedFileReader() +{ + int64_t ret; + // check to make sure file_desc is closed + ret = ::fcntl(this->file_desc, F_GETFD); + if (ret == -1) + { + if (errno != EBADF) + { + std::cerr << "close() not called" << std::endl; + // close file desc + ret = ::close(this->file_desc); + // error checks + if (ret == -1) + { + std::cerr << "close() failed; returned " << ret << ", errno=" << errno << ":" << ::strerror(errno) + << std::endl; + } + } + } } -void LinuxAlignedFileReader::register_thread() { - auto my_id = std::this_thread::get_id(); - std::unique_lock lk(ctx_mut); - if (ctx_map.find(my_id) != ctx_map.end()) { - std::cerr << "multiple calls to register_thread from the same thread" - << std::endl; - return; - } - io_context_t ctx = 0; - int ret = io_setup(MAX_EVENTS, &ctx); - if (ret != 0) { - lk.unlock(); - if (ret == -EAGAIN) { - std::cerr << "io_setup() failed with EAGAIN: Consider increasing " - "/proc/sys/fs/aio-max-nr" - << std::endl; - } else { - std::cerr << "io_setup() failed; returned " << ret << ": " - << ::strerror(-ret) << std::endl; +io_context_t &LinuxAlignedFileReader::get_ctx() +{ + std::unique_lock lk(ctx_mut); + // perform checks only in DEBUG mode + if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) + { + std::cerr << "bad thread access; returning -1 as io_context_t" << std::endl; + return this->bad_ctx; + } + else + { + return ctx_map[std::this_thread::get_id()]; } - } else { - diskann::cout << "allocating ctx: " << ctx << " to thread-id:" << my_id - << std::endl; - ctx_map[my_id] = ctx; - } - lk.unlock(); } -void LinuxAlignedFileReader::deregister_thread() { - auto my_id = std::this_thread::get_id(); - std::unique_lock lk(ctx_mut); - assert(ctx_map.find(my_id) != ctx_map.end()); - - lk.unlock(); - io_context_t ctx = this->get_ctx(); - io_destroy(ctx); - // assert(ret == 0); - lk.lock(); - ctx_map.erase(my_id); - std::cerr << "returned ctx from thread-id:" << my_id << std::endl; - lk.unlock(); +void LinuxAlignedFileReader::register_thread() +{ + auto my_id = std::this_thread::get_id(); + std::unique_lock lk(ctx_mut); + if (ctx_map.find(my_id) != ctx_map.end()) + { + std::cerr << "multiple calls to register_thread from the same thread" << std::endl; + return; + } + io_context_t ctx = 0; + int ret = io_setup(MAX_EVENTS, &ctx); + if (ret != 0) + { + lk.unlock(); + if (ret == -EAGAIN) + { + std::cerr << "io_setup() failed with EAGAIN: Consider increasing " + "/proc/sys/fs/aio-max-nr" + << std::endl; + } + else + { + std::cerr << "io_setup() failed; returned " << ret << ": " << ::strerror(-ret) << std::endl; + } + } + else + { + diskann::cout << "allocating ctx: " << ctx << " to thread-id:" << my_id << std::endl; + ctx_map[my_id] = ctx; + } + lk.unlock(); } -void LinuxAlignedFileReader::deregister_all_threads() { - std::unique_lock lk(ctx_mut); - for (auto x = ctx_map.begin(); x != ctx_map.end(); x++) { - io_context_t ctx = x.value(); +void LinuxAlignedFileReader::deregister_thread() +{ + auto my_id = std::this_thread::get_id(); + std::unique_lock lk(ctx_mut); + assert(ctx_map.find(my_id) != ctx_map.end()); + + lk.unlock(); + io_context_t ctx = this->get_ctx(); io_destroy(ctx); // assert(ret == 0); - // lk.lock(); - // ctx_map.erase(my_id); - // std::cerr << "returned ctx from thread-id:" << my_id << std::endl; - } - ctx_map.clear(); - // lk.unlock(); + lk.lock(); + ctx_map.erase(my_id); + std::cerr << "returned ctx from thread-id:" << my_id << std::endl; + lk.unlock(); +} + +void LinuxAlignedFileReader::deregister_all_threads() +{ + std::unique_lock lk(ctx_mut); + for (auto x = ctx_map.begin(); x != ctx_map.end(); x++) + { + io_context_t ctx = x.value(); + io_destroy(ctx); + // assert(ret == 0); + // lk.lock(); + // ctx_map.erase(my_id); + // std::cerr << "returned ctx from thread-id:" << my_id << std::endl; + } + ctx_map.clear(); + // lk.unlock(); } -void LinuxAlignedFileReader::open(const std::string &fname) { - int flags = O_DIRECT | O_RDONLY | O_LARGEFILE; - this->file_desc = ::open(fname.c_str(), flags); - // error checks - assert(this->file_desc != -1); - std::cerr << "Opened file : " << fname << std::endl; +void LinuxAlignedFileReader::open(const std::string &fname) +{ + int flags = O_DIRECT | O_RDONLY | O_LARGEFILE; + this->file_desc = ::open(fname.c_str(), flags); + // error checks + assert(this->file_desc != -1); + std::cerr << "Opened file : " << fname << std::endl; } -void LinuxAlignedFileReader::close() { - // int64_t ret; +void LinuxAlignedFileReader::close() +{ + // int64_t ret; - // check to make sure file_desc is closed - ::fcntl(this->file_desc, F_GETFD); - // assert(ret != -1); + // check to make sure file_desc is closed + ::fcntl(this->file_desc, F_GETFD); + // assert(ret != -1); - ::close(this->file_desc); - // assert(ret != -1); + ::close(this->file_desc); + // assert(ret != -1); } -void LinuxAlignedFileReader::read(std::vector &read_reqs, - io_context_t &ctx, bool async) { - if (async == true) { - diskann::cout << "Async currently not supported in linux." << std::endl; - } - assert(this->file_desc != -1); - execute_io(ctx, this->file_desc, read_reqs); +void LinuxAlignedFileReader::read(std::vector &read_reqs, io_context_t &ctx, bool async) +{ + if (async == true) + { + diskann::cout << "Async currently not supported in linux." << std::endl; + } + assert(this->file_desc != -1); + execute_io(ctx, this->file_desc, read_reqs); } diff --git a/src/logger.cpp b/src/logger.cpp index 84fb5b6e2..052f54877 100644 --- a/src/logger.cpp +++ b/src/logger.cpp @@ -7,7 +7,8 @@ #include "logger_impl.h" #include "windows_customizations.h" -namespace diskann { +namespace diskann +{ #ifdef ENABLE_CUSTOM_LOGGER DISKANN_DLLEXPORT ANNStreamBuf coutBuff(stdout); @@ -17,67 +18,76 @@ DISKANN_DLLEXPORT std::basic_ostream cout(&coutBuff); DISKANN_DLLEXPORT std::basic_ostream cerr(&cerrBuff); std::function g_logger; -void SetCustomLogger(std::function logger) { - g_logger = logger; - diskann::cout << "Set Custom Logger" << std::endl; +void SetCustomLogger(std::function logger) +{ + g_logger = logger; + diskann::cout << "Set Custom Logger" << std::endl; } -ANNStreamBuf::ANNStreamBuf(FILE *fp) { - if (fp == nullptr) { - throw diskann::ANNException( - "File pointer passed to ANNStreamBuf() cannot be null", -1); - } - if (fp != stdout && fp != stderr) { - throw diskann::ANNException( - "The custom logger only supports stdout and stderr.", -1); - } - _fp = fp; - _logLevel = (_fp == stdout) ? LogLevel::LL_Info : LogLevel::LL_Error; - _buf = new char[BUFFER_SIZE + 1]; // See comment in the header +ANNStreamBuf::ANNStreamBuf(FILE *fp) +{ + if (fp == nullptr) + { + throw diskann::ANNException("File pointer passed to ANNStreamBuf() cannot be null", -1); + } + if (fp != stdout && fp != stderr) + { + throw diskann::ANNException("The custom logger only supports stdout and stderr.", -1); + } + _fp = fp; + _logLevel = (_fp == stdout) ? LogLevel::LL_Info : LogLevel::LL_Error; + _buf = new char[BUFFER_SIZE + 1]; // See comment in the header - std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char)); - setp(_buf, _buf + BUFFER_SIZE - 1); + std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char)); + setp(_buf, _buf + BUFFER_SIZE - 1); } -ANNStreamBuf::~ANNStreamBuf() { - sync(); - _fp = nullptr; // we'll not close because we can't. - delete[] _buf; +ANNStreamBuf::~ANNStreamBuf() +{ + sync(); + _fp = nullptr; // we'll not close because we can't. + delete[] _buf; } -int ANNStreamBuf::overflow(int c) { - std::lock_guard lock(_mutex); - if (c != EOF) { - *pptr() = (char)c; - pbump(1); - } - flush(); - return c; +int ANNStreamBuf::overflow(int c) +{ + std::lock_guard lock(_mutex); + if (c != EOF) + { + *pptr() = (char)c; + pbump(1); + } + flush(); + return c; } -int ANNStreamBuf::sync() { - std::lock_guard lock(_mutex); - flush(); - return 0; +int ANNStreamBuf::sync() +{ + std::lock_guard lock(_mutex); + flush(); + return 0; } -int ANNStreamBuf::underflow() { - throw diskann::ANNException( - "Attempt to read on streambuf meant only for writing.", -1); +int ANNStreamBuf::underflow() +{ + throw diskann::ANNException("Attempt to read on streambuf meant only for writing.", -1); } -int ANNStreamBuf::flush() { - const int num = (int)(pptr() - pbase()); - logImpl(pbase(), num); - pbump(-num); - return num; +int ANNStreamBuf::flush() +{ + const int num = (int)(pptr() - pbase()); + logImpl(pbase(), num); + pbump(-num); + return num; } -void ANNStreamBuf::logImpl(char *str, int num) { - str[num] = '\0'; // Safe. See the c'tor. - // Invoke the OLS custom logging function. - if (g_logger) { - g_logger(_logLevel, str); - } +void ANNStreamBuf::logImpl(char *str, int num) +{ + str[num] = '\0'; // Safe. See the c'tor. + // Invoke the OLS custom logging function. + if (g_logger) + { + g_logger(_logLevel, str); + } } #else using std::cerr; diff --git a/src/math_utils.cpp b/src/math_utils.cpp index 4d161a581..5ce66fb2e 100644 --- a/src/math_utils.cpp +++ b/src/math_utils.cpp @@ -8,42 +8,47 @@ #include #include -namespace math_utils { - -float calc_distance(float *vec_1, float *vec_2, size_t dim) { - float dist = 0; - for (size_t j = 0; j < dim; j++) { - dist += (vec_1[j] - vec_2[j]) * (vec_1[j] - vec_2[j]); - } - return dist; +namespace math_utils +{ + +float calc_distance(float *vec_1, float *vec_2, size_t dim) +{ + float dist = 0; + for (size_t j = 0; j < dim; j++) + { + dist += (vec_1[j] - vec_2[j]) * (vec_1[j] - vec_2[j]); + } + return dist; } // compute l2-squared norms of data stored in row major num_points * dim, // needs // to be pre-allocated -void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, - const size_t dim) { +void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, const size_t dim) +{ #pragma omp parallel for schedule(static, 8192) - for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++) { - vecs_l2sq[n_iter] = cblas_snrm2((MKL_INT)dim, (data + (n_iter * dim)), 1); - vecs_l2sq[n_iter] *= vecs_l2sq[n_iter]; - } + for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++) + { + vecs_l2sq[n_iter] = cblas_snrm2((MKL_INT)dim, (data + (n_iter * dim)), 1); + vecs_l2sq[n_iter] *= vecs_l2sq[n_iter]; + } } -void rotate_data_randomly(float *data, size_t num_points, size_t dim, - float *rot_mat, float *&new_mat, bool transpose_rot) { - CBLAS_TRANSPOSE transpose = CblasNoTrans; - if (transpose_rot) { - diskann::cout << "Transposing rotation matrix.." << std::flush; - transpose = CblasTrans; - } - diskann::cout << "done Rotating data with random matrix.." << std::flush; +void rotate_data_randomly(float *data, size_t num_points, size_t dim, float *rot_mat, float *&new_mat, + bool transpose_rot) +{ + CBLAS_TRANSPOSE transpose = CblasNoTrans; + if (transpose_rot) + { + diskann::cout << "Transposing rotation matrix.." << std::flush; + transpose = CblasTrans; + } + diskann::cout << "done Rotating data with random matrix.." << std::flush; - cblas_sgemm(CblasRowMajor, CblasNoTrans, transpose, (MKL_INT)num_points, - (MKL_INT)dim, (MKL_INT)dim, 1.0, data, (MKL_INT)dim, rot_mat, - (MKL_INT)dim, 0, new_mat, (MKL_INT)dim); + cblas_sgemm(CblasRowMajor, CblasNoTrans, transpose, (MKL_INT)num_points, (MKL_INT)dim, (MKL_INT)dim, 1.0, data, + (MKL_INT)dim, rot_mat, (MKL_INT)dim, 0, new_mat, (MKL_INT)dim); - diskann::cout << "done." << std::endl; + diskann::cout << "done." << std::endl; } // calculate k closest centers to data of num_points * dim (row major) @@ -56,70 +61,77 @@ void rotate_data_randomly(float *data, size_t num_points, size_t dim, // Default value of k is 1 // Ideally used only by compute_closest_centers -void compute_closest_centers_in_block( - const float *const data, const size_t num_points, const size_t dim, - const float *const centers, const size_t num_centers, - const float *const docs_l2sq, const float *const centers_l2sq, - uint32_t *center_index, float *const dist_matrix, size_t k) { - if (k > num_centers) { - diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers - << ")" << std::endl; - return; - } - - float *ones_a = new float[num_centers]; - float *ones_b = new float[num_points]; - - for (size_t i = 0; i < num_centers; i++) { - ones_a[i] = 1.0; - } - for (size_t i = 0; i < num_points; i++) { - ones_b[i] = 1.0; - } - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, - (MKL_INT)num_centers, (MKL_INT)1, 1.0f, docs_l2sq, (MKL_INT)1, - ones_a, (MKL_INT)1, 0.0f, dist_matrix, (MKL_INT)num_centers); - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, - (MKL_INT)num_centers, (MKL_INT)1, 1.0f, ones_b, (MKL_INT)1, - centers_l2sq, (MKL_INT)1, 1.0f, dist_matrix, - (MKL_INT)num_centers); - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, - (MKL_INT)num_centers, (MKL_INT)dim, -2.0f, data, (MKL_INT)dim, - centers, (MKL_INT)dim, 1.0f, dist_matrix, (MKL_INT)num_centers); - - if (k == 1) { +void compute_closest_centers_in_block(const float *const data, const size_t num_points, const size_t dim, + const float *const centers, const size_t num_centers, + const float *const docs_l2sq, const float *const centers_l2sq, + uint32_t *center_index, float *const dist_matrix, size_t k) +{ + if (k > num_centers) + { + diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers << ")" << std::endl; + return; + } + + float *ones_a = new float[num_centers]; + float *ones_b = new float[num_points]; + + for (size_t i = 0; i < num_centers; i++) + { + ones_a[i] = 1.0; + } + for (size_t i = 0; i < num_points; i++) + { + ones_b[i] = 1.0; + } + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, (MKL_INT)num_centers, (MKL_INT)1, 1.0f, + docs_l2sq, (MKL_INT)1, ones_a, (MKL_INT)1, 0.0f, dist_matrix, (MKL_INT)num_centers); + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, (MKL_INT)num_centers, (MKL_INT)1, 1.0f, + ones_b, (MKL_INT)1, centers_l2sq, (MKL_INT)1, 1.0f, dist_matrix, (MKL_INT)num_centers); + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT)num_points, (MKL_INT)num_centers, (MKL_INT)dim, -2.0f, + data, (MKL_INT)dim, centers, (MKL_INT)dim, 1.0f, dist_matrix, (MKL_INT)num_centers); + + if (k == 1) + { #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (int64_t)num_points; i++) { - float min = std::numeric_limits::max(); - float *current = dist_matrix + (i * num_centers); - for (size_t j = 0; j < num_centers; j++) { - if (current[j] < min) { - center_index[i] = (uint32_t)j; - min = current[j]; + for (int64_t i = 0; i < (int64_t)num_points; i++) + { + float min = std::numeric_limits::max(); + float *current = dist_matrix + (i * num_centers); + for (size_t j = 0; j < num_centers; j++) + { + if (current[j] < min) + { + center_index[i] = (uint32_t)j; + min = current[j]; + } + } } - } } - } else { + else + { #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (int64_t)num_points; i++) { - std::priority_queue top_k_queue; - float *current = dist_matrix + (i * num_centers); - for (size_t j = 0; j < num_centers; j++) { - PivotContainer this_piv(j, current[j]); - top_k_queue.push(this_piv); - } - for (size_t j = 0; j < k; j++) { - PivotContainer this_piv = top_k_queue.top(); - center_index[i * k + j] = (uint32_t)this_piv.piv_id; - top_k_queue.pop(); - } + for (int64_t i = 0; i < (int64_t)num_points; i++) + { + std::priority_queue top_k_queue; + float *current = dist_matrix + (i * num_centers); + for (size_t j = 0; j < num_centers; j++) + { + PivotContainer this_piv(j, current[j]); + top_k_queue.push(this_piv); + } + for (size_t j = 0; j < k; j++) + { + PivotContainer this_piv = top_k_queue.top(); + center_index[i * k + j] = (uint32_t)this_piv.piv_id; + top_k_queue.pop(); + } + } } - } - delete[] ones_a; - delete[] ones_b; + delete[] ones_a; + delete[] ones_b; } // Given data in num_points * new_dim row major @@ -132,95 +144,92 @@ void compute_closest_centers_in_block( // indices is an empty vector. Additionally, if pts_norms_squared is not null, // then it will assume that point norms are pre-computed and use those values -void compute_closest_centers(float *data, size_t num_points, size_t dim, - float *pivot_data, size_t num_centers, size_t k, - uint32_t *closest_centers_ivf, - std::vector *inverted_index, - float *pts_norms_squared) { - if (k > num_centers) { - diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers - << ")" << std::endl; - return; - } - - bool is_norm_given_for_pts = (pts_norms_squared != NULL); - - float *pivs_norms_squared = new float[num_centers]; - if (!is_norm_given_for_pts) - pts_norms_squared = new float[num_points]; - - size_t PAR_BLOCK_SIZE = num_points; - size_t N_BLOCKS = (num_points % PAR_BLOCK_SIZE) == 0 - ? (num_points / PAR_BLOCK_SIZE) - : (num_points / PAR_BLOCK_SIZE) + 1; - - if (!is_norm_given_for_pts) - math_utils::compute_vecs_l2sq(pts_norms_squared, data, num_points, dim); - math_utils::compute_vecs_l2sq(pivs_norms_squared, pivot_data, num_centers, - dim); - uint32_t *closest_centers = new uint32_t[PAR_BLOCK_SIZE * k]; - float *distance_matrix = new float[num_centers * PAR_BLOCK_SIZE]; - - for (size_t cur_blk = 0; cur_blk < N_BLOCKS; cur_blk++) { - float *data_cur_blk = data + cur_blk * PAR_BLOCK_SIZE * dim; - size_t num_pts_blk = - std::min(PAR_BLOCK_SIZE, num_points - cur_blk * PAR_BLOCK_SIZE); - float *pts_norms_blk = pts_norms_squared + cur_blk * PAR_BLOCK_SIZE; - - math_utils::compute_closest_centers_in_block( - data_cur_blk, num_pts_blk, dim, pivot_data, num_centers, pts_norms_blk, - pivs_norms_squared, closest_centers, distance_matrix, k); +void compute_closest_centers(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers, + size_t k, uint32_t *closest_centers_ivf, std::vector *inverted_index, + float *pts_norms_squared) +{ + if (k > num_centers) + { + diskann::cout << "ERROR: k (" << k << ") > num_center(" << num_centers << ")" << std::endl; + return; + } + + bool is_norm_given_for_pts = (pts_norms_squared != NULL); + + float *pivs_norms_squared = new float[num_centers]; + if (!is_norm_given_for_pts) + pts_norms_squared = new float[num_points]; + + size_t PAR_BLOCK_SIZE = num_points; + size_t N_BLOCKS = + (num_points % PAR_BLOCK_SIZE) == 0 ? (num_points / PAR_BLOCK_SIZE) : (num_points / PAR_BLOCK_SIZE) + 1; + + if (!is_norm_given_for_pts) + math_utils::compute_vecs_l2sq(pts_norms_squared, data, num_points, dim); + math_utils::compute_vecs_l2sq(pivs_norms_squared, pivot_data, num_centers, dim); + uint32_t *closest_centers = new uint32_t[PAR_BLOCK_SIZE * k]; + float *distance_matrix = new float[num_centers * PAR_BLOCK_SIZE]; + + for (size_t cur_blk = 0; cur_blk < N_BLOCKS; cur_blk++) + { + float *data_cur_blk = data + cur_blk * PAR_BLOCK_SIZE * dim; + size_t num_pts_blk = std::min(PAR_BLOCK_SIZE, num_points - cur_blk * PAR_BLOCK_SIZE); + float *pts_norms_blk = pts_norms_squared + cur_blk * PAR_BLOCK_SIZE; + + math_utils::compute_closest_centers_in_block(data_cur_blk, num_pts_blk, dim, pivot_data, num_centers, + pts_norms_blk, pivs_norms_squared, closest_centers, + distance_matrix, k); #pragma omp parallel for schedule(static, 1) - for (int64_t j = cur_blk * PAR_BLOCK_SIZE; - j < std::min((int64_t)num_points, - (int64_t)((cur_blk + 1) * PAR_BLOCK_SIZE)); - j++) { - for (size_t l = 0; l < k; l++) { - size_t this_center_id = - closest_centers[(j - cur_blk * PAR_BLOCK_SIZE) * k + l]; - closest_centers_ivf[j * k + l] = (uint32_t)this_center_id; - if (inverted_index != NULL) { + for (int64_t j = cur_blk * PAR_BLOCK_SIZE; + j < std::min((int64_t)num_points, (int64_t)((cur_blk + 1) * PAR_BLOCK_SIZE)); j++) + { + for (size_t l = 0; l < k; l++) + { + size_t this_center_id = closest_centers[(j - cur_blk * PAR_BLOCK_SIZE) * k + l]; + closest_centers_ivf[j * k + l] = (uint32_t)this_center_id; + if (inverted_index != NULL) + { #pragma omp critical - inverted_index[this_center_id].push_back(j); + inverted_index[this_center_id].push_back(j); + } + } } - } } - } - delete[] closest_centers; - delete[] distance_matrix; - delete[] pivs_norms_squared; - if (!is_norm_given_for_pts) - delete[] pts_norms_squared; + delete[] closest_centers; + delete[] distance_matrix; + delete[] pivs_norms_squared; + if (!is_norm_given_for_pts) + delete[] pts_norms_squared; } // if to_subtract is 1, will subtract nearest center from each row. Else will // add. Output will be in data_load iself. // Nearest centers need to be provided in closst_centers. -void process_residuals(float *data_load, size_t num_points, size_t dim, - float *cur_pivot_data, size_t num_centers, - uint32_t *closest_centers, bool to_subtract) { - diskann::cout << "Processing residuals of " << num_points << " points in " - << dim << " dimensions using " << num_centers << " centers " - << std::endl; +void process_residuals(float *data_load, size_t num_points, size_t dim, float *cur_pivot_data, size_t num_centers, + uint32_t *closest_centers, bool to_subtract) +{ + diskann::cout << "Processing residuals of " << num_points << " points in " << dim << " dimensions using " + << num_centers << " centers " << std::endl; #pragma omp parallel for schedule(static, 8192) - for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++) { - for (size_t d_iter = 0; d_iter < dim; d_iter++) { - if (to_subtract == 1) - data_load[n_iter * dim + d_iter] = - data_load[n_iter * dim + d_iter] - - cur_pivot_data[closest_centers[n_iter] * dim + d_iter]; - else - data_load[n_iter * dim + d_iter] = - data_load[n_iter * dim + d_iter] + - cur_pivot_data[closest_centers[n_iter] * dim + d_iter]; + for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++) + { + for (size_t d_iter = 0; d_iter < dim; d_iter++) + { + if (to_subtract == 1) + data_load[n_iter * dim + d_iter] = + data_load[n_iter * dim + d_iter] - cur_pivot_data[closest_centers[n_iter] * dim + d_iter]; + else + data_load[n_iter * dim + d_iter] = + data_load[n_iter * dim + d_iter] + cur_pivot_data[closest_centers[n_iter] * dim + d_iter]; + } } - } } } // namespace math_utils -namespace kmeans { +namespace kmeans +{ // run Lloyds one iteration // Given data in row major num_points * dim, and centers in row major @@ -230,67 +239,67 @@ namespace kmeans { // closest_centers == NULL, will allocate memory and return. Similarly, if // closest_docs == NULL, will allocate memory and return. -float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, - size_t num_centers, float *docs_l2sq, - std::vector *closest_docs, - uint32_t *&closest_center) { - bool compute_residual = true; - // Timer timer; +float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, size_t num_centers, float *docs_l2sq, + std::vector *closest_docs, uint32_t *&closest_center) +{ + bool compute_residual = true; + // Timer timer; - if (closest_center == NULL) - closest_center = new uint32_t[num_points]; - if (closest_docs == NULL) - closest_docs = new std::vector[num_centers]; - else - for (size_t c = 0; c < num_centers; ++c) - closest_docs[c].clear(); + if (closest_center == NULL) + closest_center = new uint32_t[num_points]; + if (closest_docs == NULL) + closest_docs = new std::vector[num_centers]; + else + for (size_t c = 0; c < num_centers; ++c) + closest_docs[c].clear(); - math_utils::compute_closest_centers(data, num_points, dim, centers, - num_centers, 1, closest_center, - closest_docs, docs_l2sq); + math_utils::compute_closest_centers(data, num_points, dim, centers, num_centers, 1, closest_center, closest_docs, + docs_l2sq); - memset(centers, 0, sizeof(float) * (size_t)num_centers * (size_t)dim); + memset(centers, 0, sizeof(float) * (size_t)num_centers * (size_t)dim); #pragma omp parallel for schedule(static, 1) - for (int64_t c = 0; c < (int64_t)num_centers; ++c) { - float *center = centers + (size_t)c * (size_t)dim; - double *cluster_sum = new double[dim]; - for (size_t i = 0; i < dim; i++) - cluster_sum[i] = 0.0; - for (size_t i = 0; i < closest_docs[c].size(); i++) { - float *current = data + ((closest_docs[c][i]) * dim); - for (size_t j = 0; j < dim; j++) { - cluster_sum[j] += (double)current[j]; - } - } - if (closest_docs[c].size() > 0) { - for (size_t i = 0; i < dim; i++) - center[i] = (float)(cluster_sum[i] / ((double)closest_docs[c].size())); + for (int64_t c = 0; c < (int64_t)num_centers; ++c) + { + float *center = centers + (size_t)c * (size_t)dim; + double *cluster_sum = new double[dim]; + for (size_t i = 0; i < dim; i++) + cluster_sum[i] = 0.0; + for (size_t i = 0; i < closest_docs[c].size(); i++) + { + float *current = data + ((closest_docs[c][i]) * dim); + for (size_t j = 0; j < dim; j++) + { + cluster_sum[j] += (double)current[j]; + } + } + if (closest_docs[c].size() > 0) + { + for (size_t i = 0; i < dim; i++) + center[i] = (float)(cluster_sum[i] / ((double)closest_docs[c].size())); + } + delete[] cluster_sum; } - delete[] cluster_sum; - } - float residual = 0.0; - if (compute_residual) { - size_t BUF_PAD = 32; - size_t CHUNK_SIZE = 2 * 8192; - size_t nchunks = - num_points / CHUNK_SIZE + (num_points % CHUNK_SIZE == 0 ? 0 : 1); - std::vector residuals(nchunks * BUF_PAD, 0.0); + float residual = 0.0; + if (compute_residual) + { + size_t BUF_PAD = 32; + size_t CHUNK_SIZE = 2 * 8192; + size_t nchunks = num_points / CHUNK_SIZE + (num_points % CHUNK_SIZE == 0 ? 0 : 1); + std::vector residuals(nchunks * BUF_PAD, 0.0); #pragma omp parallel for schedule(static, 32) - for (int64_t chunk = 0; chunk < (int64_t)nchunks; ++chunk) - for (size_t d = chunk * CHUNK_SIZE; - d < num_points && d < (chunk + 1) * CHUNK_SIZE; ++d) - residuals[chunk * BUF_PAD] += math_utils::calc_distance( - data + (d * dim), centers + (size_t)closest_center[d] * (size_t)dim, - dim); - - for (size_t chunk = 0; chunk < nchunks; ++chunk) - residual += residuals[chunk * BUF_PAD]; - } - - return residual; + for (int64_t chunk = 0; chunk < (int64_t)nchunks; ++chunk) + for (size_t d = chunk * CHUNK_SIZE; d < num_points && d < (chunk + 1) * CHUNK_SIZE; ++d) + residuals[chunk * BUF_PAD] += + math_utils::calc_distance(data + (d * dim), centers + (size_t)closest_center[d] * (size_t)dim, dim); + + for (size_t chunk = 0; chunk < nchunks; ++chunk) + residual += residuals[chunk * BUF_PAD]; + } + + return residual; } // Run Lloyds until max_reps or stopping criterion @@ -300,145 +309,150 @@ float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, // vector [num_centers], and closest_center = new size_t[num_points] // Final centers are output in centers as row major num_centers * dim // -float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, - const size_t num_centers, const size_t max_reps, - std::vector *closest_docs, uint32_t *closest_center) { - float residual = std::numeric_limits::max(); - bool ret_closest_docs = true; - bool ret_closest_center = true; - if (closest_docs == NULL) { - closest_docs = new std::vector[num_centers]; - ret_closest_docs = false; - } - if (closest_center == NULL) { - closest_center = new uint32_t[num_points]; - ret_closest_center = false; - } - - float *docs_l2sq = new float[num_points]; - math_utils::compute_vecs_l2sq(docs_l2sq, data, num_points, dim); - - float old_residual; - // Timer timer; - for (size_t i = 0; i < max_reps; ++i) { - old_residual = residual; - - residual = lloyds_iter(data, num_points, dim, centers, num_centers, - docs_l2sq, closest_docs, closest_center); - - if (((i != 0) && ((old_residual - residual) / residual) < 0.00001) || - (residual < std::numeric_limits::epsilon())) { - diskann::cout << "Residuals unchanged: " << old_residual << " becomes " - << residual << ". Early termination." << std::endl; - break; +float run_lloyds(float *data, size_t num_points, size_t dim, float *centers, const size_t num_centers, + const size_t max_reps, std::vector *closest_docs, uint32_t *closest_center) +{ + float residual = std::numeric_limits::max(); + bool ret_closest_docs = true; + bool ret_closest_center = true; + if (closest_docs == NULL) + { + closest_docs = new std::vector[num_centers]; + ret_closest_docs = false; + } + if (closest_center == NULL) + { + closest_center = new uint32_t[num_points]; + ret_closest_center = false; + } + + float *docs_l2sq = new float[num_points]; + math_utils::compute_vecs_l2sq(docs_l2sq, data, num_points, dim); + + float old_residual; + // Timer timer; + for (size_t i = 0; i < max_reps; ++i) + { + old_residual = residual; + + residual = lloyds_iter(data, num_points, dim, centers, num_centers, docs_l2sq, closest_docs, closest_center); + + if (((i != 0) && ((old_residual - residual) / residual) < 0.00001) || + (residual < std::numeric_limits::epsilon())) + { + diskann::cout << "Residuals unchanged: " << old_residual << " becomes " << residual + << ". Early termination." << std::endl; + break; + } } - } - delete[] docs_l2sq; - if (!ret_closest_docs) - delete[] closest_docs; - if (!ret_closest_center) - delete[] closest_center; - return residual; + delete[] docs_l2sq; + if (!ret_closest_docs) + delete[] closest_docs; + if (!ret_closest_center) + delete[] closest_center; + return residual; } // assumes memory allocated for pivot_data as new // float[num_centers*dim] // and select randomly num_centers points as pivots -void selecting_pivots(float *data, size_t num_points, size_t dim, - float *pivot_data, size_t num_centers) { - // pivot_data = new float[num_centers * dim]; - - std::vector picked; - std::random_device rd; - auto x = rd(); - std::mt19937 generator(x); - std::uniform_int_distribution distribution(0, num_points - 1); - - size_t tmp_pivot; - for (size_t j = 0; j < num_centers; j++) { - tmp_pivot = distribution(generator); - if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end()) - continue; - picked.push_back(tmp_pivot); - std::memcpy(pivot_data + j * dim, data + tmp_pivot * dim, - dim * sizeof(float)); - } +void selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers) +{ + // pivot_data = new float[num_centers * dim]; + + std::vector picked; + std::random_device rd; + auto x = rd(); + std::mt19937 generator(x); + std::uniform_int_distribution distribution(0, num_points - 1); + + size_t tmp_pivot; + for (size_t j = 0; j < num_centers; j++) + { + tmp_pivot = distribution(generator); + if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end()) + continue; + picked.push_back(tmp_pivot); + std::memcpy(pivot_data + j * dim, data + tmp_pivot * dim, dim * sizeof(float)); + } } -void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, - float *pivot_data, size_t num_centers) { - if (num_points > 1 << 23) { - diskann::cout << "ERROR: n_pts " << num_points - << " currently not supported for k-means++, maximum is " - "8388608. Falling back to random pivot " - "selection." - << std::endl; - selecting_pivots(data, num_points, dim, pivot_data, num_centers); - return; - } - - std::vector picked; - std::random_device rd; - auto x = rd(); - std::mt19937 generator(x); - std::uniform_real_distribution<> distribution(0, 1); - std::uniform_int_distribution int_dist(0, num_points - 1); - size_t init_id = int_dist(generator); - size_t num_picked = 1; - - picked.push_back(init_id); - std::memcpy(pivot_data, data + init_id * dim, dim * sizeof(float)); - - float *dist = new float[num_points]; +void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float *pivot_data, size_t num_centers) +{ + if (num_points > 1 << 23) + { + diskann::cout << "ERROR: n_pts " << num_points + << " currently not supported for k-means++, maximum is " + "8388608. Falling back to random pivot " + "selection." + << std::endl; + selecting_pivots(data, num_points, dim, pivot_data, num_centers); + return; + } -#pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (int64_t)num_points; i++) { - dist[i] = - math_utils::calc_distance(data + i * dim, data + init_id * dim, dim); - } + std::vector picked; + std::random_device rd; + auto x = rd(); + std::mt19937 generator(x); + std::uniform_real_distribution<> distribution(0, 1); + std::uniform_int_distribution int_dist(0, num_points - 1); + size_t init_id = int_dist(generator); + size_t num_picked = 1; - double dart_val; - size_t tmp_pivot; - bool sum_flag = false; + picked.push_back(init_id); + std::memcpy(pivot_data, data + init_id * dim, dim * sizeof(float)); - while (num_picked < num_centers) { - dart_val = distribution(generator); + float *dist = new float[num_points]; - double sum = 0; - for (size_t i = 0; i < num_points; i++) { - sum = sum + dist[i]; +#pragma omp parallel for schedule(static, 8192) + for (int64_t i = 0; i < (int64_t)num_points; i++) + { + dist[i] = math_utils::calc_distance(data + i * dim, data + init_id * dim, dim); } - if (sum == 0) - sum_flag = true; - dart_val *= sum; + double dart_val; + size_t tmp_pivot; + bool sum_flag = false; - double prefix_sum = 0; - for (size_t i = 0; i < (num_points); i++) { - tmp_pivot = i; - if (dart_val >= prefix_sum && dart_val < prefix_sum + dist[i]) { - break; - } + while (num_picked < num_centers) + { + dart_val = distribution(generator); - prefix_sum += dist[i]; - } + double sum = 0; + for (size_t i = 0; i < num_points; i++) + { + sum = sum + dist[i]; + } + if (sum == 0) + sum_flag = true; - if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end() && - (sum_flag == false)) - continue; - picked.push_back(tmp_pivot); - std::memcpy(pivot_data + num_picked * dim, data + tmp_pivot * dim, - dim * sizeof(float)); + dart_val *= sum; + + double prefix_sum = 0; + for (size_t i = 0; i < (num_points); i++) + { + tmp_pivot = i; + if (dart_val >= prefix_sum && dart_val < prefix_sum + dist[i]) + { + break; + } + + prefix_sum += dist[i]; + } + + if (std::find(picked.begin(), picked.end(), tmp_pivot) != picked.end() && (sum_flag == false)) + continue; + picked.push_back(tmp_pivot); + std::memcpy(pivot_data + num_picked * dim, data + tmp_pivot * dim, dim * sizeof(float)); #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (int64_t)num_points; i++) { - dist[i] = - (std::min)(dist[i], math_utils::calc_distance( - data + i * dim, data + tmp_pivot * dim, dim)); + for (int64_t i = 0; i < (int64_t)num_points; i++) + { + dist[i] = (std::min)(dist[i], math_utils::calc_distance(data + i * dim, data + tmp_pivot * dim, dim)); + } + num_picked++; } - num_picked++; - } - delete[] dist; + delete[] dist; } } // namespace kmeans diff --git a/src/memory_mapper.cpp b/src/memory_mapper.cpp index 5f07f165e..819df7fec 100644 --- a/src/memory_mapper.cpp +++ b/src/memory_mapper.cpp @@ -8,86 +8,100 @@ using namespace diskann; -MemoryMapper::MemoryMapper(const std::string &filename) - : MemoryMapper(filename.c_str()) {} +MemoryMapper::MemoryMapper(const std::string &filename) : MemoryMapper(filename.c_str()) +{ +} -MemoryMapper::MemoryMapper(const char *filename) { +MemoryMapper::MemoryMapper(const char *filename) +{ #ifndef _WINDOWS - _fd = open(filename, O_RDONLY); - if (_fd <= 0) { - std::cerr << "Inner vertices file not found" << std::endl; - return; - } - struct stat sb; - if (fstat(_fd, &sb) != 0) { - std::cerr << "Inner vertices file not dound. " << std::endl; - return; - } - _fileSize = sb.st_size; - diskann::cout << "File Size: " << _fileSize << std::endl; - _buf = (char *)mmap(NULL, _fileSize, PROT_READ, MAP_PRIVATE, _fd, 0); + _fd = open(filename, O_RDONLY); + if (_fd <= 0) + { + std::cerr << "Inner vertices file not found" << std::endl; + return; + } + struct stat sb; + if (fstat(_fd, &sb) != 0) + { + std::cerr << "Inner vertices file not dound. " << std::endl; + return; + } + _fileSize = sb.st_size; + diskann::cout << "File Size: " << _fileSize << std::endl; + _buf = (char *)mmap(NULL, _fileSize, PROT_READ, MAP_PRIVATE, _fd, 0); #else - _bareFile = CreateFileA(filename, GENERIC_READ | GENERIC_EXECUTE, 0, NULL, - OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); - if (_bareFile == nullptr) { - std::ostringstream message; - message << "CreateFileA(" << filename << ") failed with error " - << GetLastError() << std::endl; - std::cerr << message.str(); - throw std::exception(message.str().c_str()); - } + _bareFile = + CreateFileA(filename, GENERIC_READ | GENERIC_EXECUTE, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); + if (_bareFile == nullptr) + { + std::ostringstream message; + message << "CreateFileA(" << filename << ") failed with error " << GetLastError() << std::endl; + std::cerr << message.str(); + throw std::exception(message.str().c_str()); + } - _fd = CreateFileMapping(_bareFile, NULL, PAGE_EXECUTE_READ, 0, 0, NULL); - if (_fd == nullptr) { - std::ostringstream message; - message << "CreateFileMapping(" << filename << ") failed with error " - << GetLastError() << std::endl; - std::cerr << message.str() << std::endl; - throw std::exception(message.str().c_str()); - } + _fd = CreateFileMapping(_bareFile, NULL, PAGE_EXECUTE_READ, 0, 0, NULL); + if (_fd == nullptr) + { + std::ostringstream message; + message << "CreateFileMapping(" << filename << ") failed with error " << GetLastError() << std::endl; + std::cerr << message.str() << std::endl; + throw std::exception(message.str().c_str()); + } - _buf = (char *)MapViewOfFile(_fd, FILE_MAP_READ, 0, 0, 0); - if (_buf == nullptr) { - std::ostringstream message; - message << "MapViewOfFile(" << filename - << ") failed with error: " << GetLastError() << std::endl; - std::cerr << message.str() << std::endl; - throw std::exception(message.str().c_str()); - } + _buf = (char *)MapViewOfFile(_fd, FILE_MAP_READ, 0, 0, 0); + if (_buf == nullptr) + { + std::ostringstream message; + message << "MapViewOfFile(" << filename << ") failed with error: " << GetLastError() << std::endl; + std::cerr << message.str() << std::endl; + throw std::exception(message.str().c_str()); + } - LARGE_INTEGER fSize; - if (TRUE == GetFileSizeEx(_bareFile, &fSize)) { - _fileSize = fSize.QuadPart; // take the 64-bit value - diskann::cout << "File Size: " << _fileSize << std::endl; - } else { - std::cerr << "Failed to get size of file " << filename << std::endl; - } + LARGE_INTEGER fSize; + if (TRUE == GetFileSizeEx(_bareFile, &fSize)) + { + _fileSize = fSize.QuadPart; // take the 64-bit value + diskann::cout << "File Size: " << _fileSize << std::endl; + } + else + { + std::cerr << "Failed to get size of file " << filename << std::endl; + } #endif } -char *MemoryMapper::getBuf() { return _buf; } +char *MemoryMapper::getBuf() +{ + return _buf; +} -size_t MemoryMapper::getFileSize() { return _fileSize; } +size_t MemoryMapper::getFileSize() +{ + return _fileSize; +} -MemoryMapper::~MemoryMapper() { +MemoryMapper::~MemoryMapper() +{ #ifndef _WINDOWS - if (munmap(_buf, _fileSize) != 0) - std::cerr << "ERROR unmapping. CHECK!" << std::endl; - close(_fd); + if (munmap(_buf, _fileSize) != 0) + std::cerr << "ERROR unmapping. CHECK!" << std::endl; + close(_fd); #else - if (FALSE == UnmapViewOfFile(_buf)) { - std::cerr << "Unmap view of file failed. Error: " << GetLastError() - << std::endl; - } + if (FALSE == UnmapViewOfFile(_buf)) + { + std::cerr << "Unmap view of file failed. Error: " << GetLastError() << std::endl; + } - if (FALSE == CloseHandle(_fd)) { - std::cerr << "Failed to close memory mapped file. Error: " << GetLastError() - << std::endl; - } + if (FALSE == CloseHandle(_fd)) + { + std::cerr << "Failed to close memory mapped file. Error: " << GetLastError() << std::endl; + } - if (FALSE == CloseHandle(_bareFile)) { - std::cerr << "Failed to close file: " << _fileName - << " Error: " << GetLastError() << std::endl; - } + if (FALSE == CloseHandle(_bareFile)) + { + std::cerr << "Failed to close file: " << _fileName << " Error: " << GetLastError() << std::endl; + } #endif } diff --git a/src/natural_number_map.cpp b/src/natural_number_map.cpp index c08825b32..a996dcf75 100644 --- a/src/natural_number_map.cpp +++ b/src/natural_number_map.cpp @@ -7,98 +7,104 @@ #include "natural_number_map.h" #include "tag_uint128.h" -namespace diskann { +namespace diskann +{ static constexpr auto invalid_position = boost::dynamic_bitset<>::npos; template natural_number_map::natural_number_map() - : _size(0), _values_bitset(std::make_unique>()) {} + : _size(0), _values_bitset(std::make_unique>()) +{ +} -template -void natural_number_map::reserve(size_t count) { - _values_vector.reserve(count); - _values_bitset->reserve(count); +template void natural_number_map::reserve(size_t count) +{ + _values_vector.reserve(count); + _values_bitset->reserve(count); } -template -size_t natural_number_map::size() const { - return _size; +template size_t natural_number_map::size() const +{ + return _size; } -template -void natural_number_map::set(Key key, Value value) { - if (key >= _values_bitset->size()) { - _values_bitset->resize(static_cast(key) + 1); - _values_vector.resize(_values_bitset->size()); - } - - _values_vector[key] = value; - const bool was_present = _values_bitset->test_set(key, true); - - if (!was_present) { - ++_size; - } +template void natural_number_map::set(Key key, Value value) +{ + if (key >= _values_bitset->size()) + { + _values_bitset->resize(static_cast(key) + 1); + _values_vector.resize(_values_bitset->size()); + } + + _values_vector[key] = value; + const bool was_present = _values_bitset->test_set(key, true); + + if (!was_present) + { + ++_size; + } } -template -void natural_number_map::erase(Key key) { - if (key < _values_bitset->size()) { - const bool was_present = _values_bitset->test_set(key, false); +template void natural_number_map::erase(Key key) +{ + if (key < _values_bitset->size()) + { + const bool was_present = _values_bitset->test_set(key, false); - if (was_present) { - --_size; + if (was_present) + { + --_size; + } } - } } -template -bool natural_number_map::contains(Key key) const { - return key < _values_bitset->size() && _values_bitset->test(key); +template bool natural_number_map::contains(Key key) const +{ + return key < _values_bitset->size() && _values_bitset->test(key); } -template -bool natural_number_map::try_get(Key key, Value &value) const { - if (!contains(key)) { - return false; - } +template bool natural_number_map::try_get(Key key, Value &value) const +{ + if (!contains(key)) + { + return false; + } - value = _values_vector[key]; - return true; + value = _values_vector[key]; + return true; } template -typename natural_number_map::position -natural_number_map::find_first() const { - return position{_size > 0 ? _values_bitset->find_first() : invalid_position, - 0}; +typename natural_number_map::position natural_number_map::find_first() const +{ + return position{_size > 0 ? _values_bitset->find_first() : invalid_position, 0}; } template -typename natural_number_map::position -natural_number_map::find_next( - const position &after_position) const { - return position{after_position._keys_already_enumerated < _size - ? _values_bitset->find_next(after_position._key) - : invalid_position, - after_position._keys_already_enumerated + 1}; +typename natural_number_map::position natural_number_map::find_next( + const position &after_position) const +{ + return position{after_position._keys_already_enumerated < _size ? _values_bitset->find_next(after_position._key) + : invalid_position, + after_position._keys_already_enumerated + 1}; } -template -bool natural_number_map::position::is_valid() const { - return _key != invalid_position; +template bool natural_number_map::position::is_valid() const +{ + return _key != invalid_position; } -template -Value natural_number_map::get(const position &pos) const { - assert(pos.is_valid()); - return _values_vector[pos._key]; +template Value natural_number_map::get(const position &pos) const +{ + assert(pos.is_valid()); + return _values_vector[pos._key]; } -template -void natural_number_map::clear() { - _size = 0; - _values_vector.clear(); - _values_bitset->clear(); +template void natural_number_map::clear() +{ + _size = 0; + _values_vector.clear(); + _values_bitset->clear(); } // Instantiate used templates. diff --git a/src/natural_number_set.cpp b/src/natural_number_set.cpp index 97ee70a0b..b36cb5298 100644 --- a/src/natural_number_set.cpp +++ b/src/natural_number_set.cpp @@ -6,54 +6,63 @@ #include "ann_exception.h" #include "natural_number_set.h" -namespace diskann { +namespace diskann +{ template -natural_number_set::natural_number_set() - : _values_bitset(std::make_unique>()) {} +natural_number_set::natural_number_set() : _values_bitset(std::make_unique>()) +{ +} -template bool natural_number_set::is_empty() const { - return _values_vector.empty(); +template bool natural_number_set::is_empty() const +{ + return _values_vector.empty(); } -template void natural_number_set::reserve(size_t count) { - _values_vector.reserve(count); - _values_bitset->reserve(count); +template void natural_number_set::reserve(size_t count) +{ + _values_vector.reserve(count); + _values_bitset->reserve(count); } -template void natural_number_set::insert(T id) { - _values_vector.emplace_back(id); +template void natural_number_set::insert(T id) +{ + _values_vector.emplace_back(id); - if (id >= _values_bitset->size()) - _values_bitset->resize(static_cast(id) + 1); + if (id >= _values_bitset->size()) + _values_bitset->resize(static_cast(id) + 1); - _values_bitset->set(id, true); + _values_bitset->set(id, true); } -template T natural_number_set::pop_any() { - if (_values_vector.empty()) { - throw diskann::ANNException("No values available", -1, __FUNCSIG__, - __FILE__, __LINE__); - } +template T natural_number_set::pop_any() +{ + if (_values_vector.empty()) + { + throw diskann::ANNException("No values available", -1, __FUNCSIG__, __FILE__, __LINE__); + } - const T id = _values_vector.back(); - _values_vector.pop_back(); + const T id = _values_vector.back(); + _values_vector.pop_back(); - _values_bitset->set(id, false); + _values_bitset->set(id, false); - return id; + return id; } -template void natural_number_set::clear() { - _values_vector.clear(); - _values_bitset->clear(); +template void natural_number_set::clear() +{ + _values_vector.clear(); + _values_bitset->clear(); } -template size_t natural_number_set::size() const { - return _values_vector.size(); +template size_t natural_number_set::size() const +{ + return _values_vector.size(); } -template bool natural_number_set::is_in_set(T id) const { - return _values_bitset->test(id); +template bool natural_number_set::is_in_set(T id) const +{ + return _values_bitset->test(id); } // Instantiate used templates. diff --git a/src/partition.cpp b/src/partition.cpp index 021ced6e7..1428eb801 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -11,8 +11,7 @@ #include "tsl/robin_set.h" #include -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -32,59 +31,56 @@ // #define SAVE_INFLATED_PQ true template -void gen_random_slice(const std::string base_file, - const std::string output_prefix, double sampling_rate) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream base_reader(base_file.c_str(), read_blk_size); - std::ofstream sample_writer(std::string(output_prefix + "_data.bin").c_str(), - std::ios::binary); - std::ofstream sample_id_writer( - std::string(output_prefix + "_ids.bin").c_str(), std::ios::binary); - - std::random_device - rd; // Will be used to obtain a seed for the random number engine - auto x = rd(); - std::mt19937 generator( - x); // Standard mersenne_twister_engine seeded with rd() - std::uniform_real_distribution distribution(0, 1); - - size_t npts, nd; - uint32_t npts_u32, nd_u32; - uint32_t num_sampled_pts_u32 = 0; - uint32_t one_const = 1; - - base_reader.read((char *)&npts_u32, sizeof(uint32_t)); - base_reader.read((char *)&nd_u32, sizeof(uint32_t)); - diskann::cout << "Loading base " << base_file << ". #points: " << npts_u32 - << ". #dim: " << nd_u32 << "." << std::endl; - sample_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); - sample_writer.write((char *)&nd_u32, sizeof(uint32_t)); - sample_id_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); - sample_id_writer.write((char *)&one_const, sizeof(uint32_t)); - - npts = npts_u32; - nd = nd_u32; - std::unique_ptr cur_row = std::make_unique(nd); - - for (size_t i = 0; i < npts; i++) { - base_reader.read((char *)cur_row.get(), sizeof(T) * nd); - float sample = distribution(generator); - if (sample < sampling_rate) { - sample_writer.write((char *)cur_row.get(), sizeof(T) * nd); - uint32_t cur_i_u32 = (uint32_t)i; - sample_id_writer.write((char *)&cur_i_u32, sizeof(uint32_t)); - num_sampled_pts_u32++; +void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream base_reader(base_file.c_str(), read_blk_size); + std::ofstream sample_writer(std::string(output_prefix + "_data.bin").c_str(), std::ios::binary); + std::ofstream sample_id_writer(std::string(output_prefix + "_ids.bin").c_str(), std::ios::binary); + + std::random_device rd; // Will be used to obtain a seed for the random number engine + auto x = rd(); + std::mt19937 generator(x); // Standard mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distribution(0, 1); + + size_t npts, nd; + uint32_t npts_u32, nd_u32; + uint32_t num_sampled_pts_u32 = 0; + uint32_t one_const = 1; + + base_reader.read((char *)&npts_u32, sizeof(uint32_t)); + base_reader.read((char *)&nd_u32, sizeof(uint32_t)); + diskann::cout << "Loading base " << base_file << ". #points: " << npts_u32 << ". #dim: " << nd_u32 << "." + << std::endl; + sample_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); + sample_writer.write((char *)&nd_u32, sizeof(uint32_t)); + sample_id_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); + sample_id_writer.write((char *)&one_const, sizeof(uint32_t)); + + npts = npts_u32; + nd = nd_u32; + std::unique_ptr cur_row = std::make_unique(nd); + + for (size_t i = 0; i < npts; i++) + { + base_reader.read((char *)cur_row.get(), sizeof(T) * nd); + float sample = distribution(generator); + if (sample < sampling_rate) + { + sample_writer.write((char *)cur_row.get(), sizeof(T) * nd); + uint32_t cur_i_u32 = (uint32_t)i; + sample_id_writer.write((char *)&cur_i_u32, sizeof(uint32_t)); + num_sampled_pts_u32++; + } } - } - sample_writer.seekp(0, std::ios::beg); - sample_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); - sample_id_writer.seekp(0, std::ios::beg); - sample_id_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); - sample_writer.close(); - sample_id_writer.close(); - diskann::cout << "Wrote " << num_sampled_pts_u32 - << " points to sample file: " << output_prefix + "_data.bin" - << std::endl; + sample_writer.seekp(0, std::ios::beg); + sample_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); + sample_id_writer.seekp(0, std::ios::beg); + sample_id_writer.write((char *)&num_sampled_pts_u32, sizeof(uint32_t)); + sample_writer.close(); + sample_id_writer.close(); + diskann::cout << "Wrote " << num_sampled_pts_u32 << " points to sample file: " << output_prefix + "_data.bin" + << std::endl; } // streams data from the file, and samples each vector with probability p_val @@ -96,597 +92,566 @@ void gen_random_slice(const std::string base_file, ************************************/ template -void gen_random_slice(const std::string data_file, double p_val, - float *&sampled_data, size_t &slice_size, size_t &ndims) { - size_t npts; - uint32_t npts32, ndims32; - std::vector> sampled_vectors; - - // amount to read in one shot - size_t read_blk_size = 64 * 1024 * 1024; - // create cached reader + writer - cached_ifstream base_reader(data_file.c_str(), read_blk_size); - - // metadata: npts, ndims - base_reader.read((char *)&npts32, sizeof(uint32_t)); - base_reader.read((char *)&ndims32, sizeof(uint32_t)); - npts = npts32; - ndims = ndims32; - - std::unique_ptr cur_vector_T = std::make_unique(ndims); - p_val = p_val < 1 ? p_val : 1; - - std::random_device rd; // Will be used to obtain a seed for the random number - size_t x = rd(); - std::mt19937 generator((uint32_t)x); - std::uniform_real_distribution distribution(0, 1); - - for (size_t i = 0; i < npts; i++) { - base_reader.read((char *)cur_vector_T.get(), ndims * sizeof(T)); - float rnd_val = distribution(generator); - if (rnd_val < p_val) { - std::vector cur_vector_float; - for (size_t d = 0; d < ndims; d++) - cur_vector_float.push_back(cur_vector_T[d]); - sampled_vectors.push_back(cur_vector_float); +void gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, size_t &slice_size, + size_t &ndims) +{ + size_t npts; + uint32_t npts32, ndims32; + std::vector> sampled_vectors; + + // amount to read in one shot + size_t read_blk_size = 64 * 1024 * 1024; + // create cached reader + writer + cached_ifstream base_reader(data_file.c_str(), read_blk_size); + + // metadata: npts, ndims + base_reader.read((char *)&npts32, sizeof(uint32_t)); + base_reader.read((char *)&ndims32, sizeof(uint32_t)); + npts = npts32; + ndims = ndims32; + + std::unique_ptr cur_vector_T = std::make_unique(ndims); + p_val = p_val < 1 ? p_val : 1; + + std::random_device rd; // Will be used to obtain a seed for the random number + size_t x = rd(); + std::mt19937 generator((uint32_t)x); + std::uniform_real_distribution distribution(0, 1); + + for (size_t i = 0; i < npts; i++) + { + base_reader.read((char *)cur_vector_T.get(), ndims * sizeof(T)); + float rnd_val = distribution(generator); + if (rnd_val < p_val) + { + std::vector cur_vector_float; + for (size_t d = 0; d < ndims; d++) + cur_vector_float.push_back(cur_vector_T[d]); + sampled_vectors.push_back(cur_vector_float); + } } - } - slice_size = sampled_vectors.size(); - sampled_data = new float[slice_size * ndims]; - for (size_t i = 0; i < slice_size; i++) { - for (size_t j = 0; j < ndims; j++) { - sampled_data[i * ndims + j] = sampled_vectors[i][j]; + slice_size = sampled_vectors.size(); + sampled_data = new float[slice_size * ndims]; + for (size_t i = 0; i < slice_size; i++) + { + for (size_t j = 0; j < ndims; j++) + { + sampled_data[i * ndims + j] = sampled_vectors[i][j]; + } } - } } // same as above, but samples from the matrix inputdata instead of a file of // npts*ndims to return sampled_data of size slice_size*ndims. template -void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, - double p_val, float *&sampled_data, size_t &slice_size) { - std::vector> sampled_vectors; - const T *cur_vector_T; - - p_val = p_val < 1 ? p_val : 1; - - std::random_device - rd; // Will be used to obtain a seed for the random number engine - size_t x = rd(); - std::mt19937 generator( - (uint32_t)x); // Standard mersenne_twister_engine seeded with rd() - std::uniform_real_distribution distribution(0, 1); - - for (size_t i = 0; i < npts; i++) { - cur_vector_T = inputdata + ndims * i; - float rnd_val = distribution(generator); - if (rnd_val < p_val) { - std::vector cur_vector_float; - for (size_t d = 0; d < ndims; d++) - cur_vector_float.push_back(cur_vector_T[d]); - sampled_vectors.push_back(cur_vector_float); +void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data, + size_t &slice_size) +{ + std::vector> sampled_vectors; + const T *cur_vector_T; + + p_val = p_val < 1 ? p_val : 1; + + std::random_device rd; // Will be used to obtain a seed for the random number engine + size_t x = rd(); + std::mt19937 generator((uint32_t)x); // Standard mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distribution(0, 1); + + for (size_t i = 0; i < npts; i++) + { + cur_vector_T = inputdata + ndims * i; + float rnd_val = distribution(generator); + if (rnd_val < p_val) + { + std::vector cur_vector_float; + for (size_t d = 0; d < ndims; d++) + cur_vector_float.push_back(cur_vector_T[d]); + sampled_vectors.push_back(cur_vector_float); + } } - } - slice_size = sampled_vectors.size(); - sampled_data = new float[slice_size * ndims]; - for (size_t i = 0; i < slice_size; i++) { - for (size_t j = 0; j < ndims; j++) { - sampled_data[i * ndims + j] = sampled_vectors[i][j]; + slice_size = sampled_vectors.size(); + sampled_data = new float[slice_size * ndims]; + for (size_t i = 0; i < slice_size; i++) + { + for (size_t j = 0; j < ndims; j++) + { + sampled_data[i * ndims + j] = sampled_vectors[i][j]; + } } - } } -int estimate_cluster_sizes(float *test_data_float, size_t num_test, - float *pivots, const size_t num_centers, - const size_t test_dim, const size_t k_base, - std::vector &cluster_sizes) { - cluster_sizes.clear(); +int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivots, const size_t num_centers, + const size_t test_dim, const size_t k_base, std::vector &cluster_sizes) +{ + cluster_sizes.clear(); - size_t *shard_counts = new size_t[num_centers]; + size_t *shard_counts = new size_t[num_centers]; - for (size_t i = 0; i < num_centers; i++) { - shard_counts[i] = 0; - } + for (size_t i = 0; i < num_centers; i++) + { + shard_counts[i] = 0; + } - size_t block_size = num_test <= BLOCK_SIZE ? num_test : BLOCK_SIZE; - uint32_t *block_closest_centers = new uint32_t[block_size * k_base]; - float *block_data_float; + size_t block_size = num_test <= BLOCK_SIZE ? num_test : BLOCK_SIZE; + uint32_t *block_closest_centers = new uint32_t[block_size * k_base]; + float *block_data_float; - size_t num_blocks = DIV_ROUND_UP(num_test, block_size); + size_t num_blocks = DIV_ROUND_UP(num_test, block_size); - for (size_t block = 0; block < num_blocks; block++) { - size_t start_id = block * block_size; - size_t end_id = (std::min)((block + 1) * block_size, num_test); - size_t cur_blk_size = end_id - start_id; + for (size_t block = 0; block < num_blocks; block++) + { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_test); + size_t cur_blk_size = end_id - start_id; - block_data_float = test_data_float + start_id * test_dim; + block_data_float = test_data_float + start_id * test_dim; - math_utils::compute_closest_centers(block_data_float, cur_blk_size, - test_dim, pivots, num_centers, k_base, - block_closest_centers); + math_utils::compute_closest_centers(block_data_float, cur_blk_size, test_dim, pivots, num_centers, k_base, + block_closest_centers); - for (size_t p = 0; p < cur_blk_size; p++) { - for (size_t p1 = 0; p1 < k_base; p1++) { - size_t shard_id = block_closest_centers[p * k_base + p1]; - shard_counts[shard_id]++; - } + for (size_t p = 0; p < cur_blk_size; p++) + { + for (size_t p1 = 0; p1 < k_base; p1++) + { + size_t shard_id = block_closest_centers[p * k_base + p1]; + shard_counts[shard_id]++; + } + } } - } - - diskann::cout << "Estimated cluster sizes: "; - for (size_t i = 0; i < num_centers; i++) { - uint32_t cur_shard_count = (uint32_t)shard_counts[i]; - cluster_sizes.push_back((size_t)cur_shard_count); - diskann::cout << cur_shard_count << " "; - } - diskann::cout << std::endl; - delete[] shard_counts; - delete[] block_closest_centers; - return 0; -} -template -int shard_data_into_clusters(const std::string data_file, float *pivots, - const size_t num_centers, const size_t dim, - const size_t k_base, std::string prefix_path) { - size_t read_blk_size = 64 * 1024 * 1024; - // uint64_t write_blk_size = 64 * 1024 * 1024; - // create cached reader + writer - cached_ifstream base_reader(data_file, read_blk_size); - uint32_t npts32; - uint32_t basedim32; - base_reader.read((char *)&npts32, sizeof(uint32_t)); - base_reader.read((char *)&basedim32, sizeof(uint32_t)); - size_t num_points = npts32; - if (basedim32 != dim) { - diskann::cout << "Error. dimensions dont match for train set and base set" - << std::endl; - return -1; - } - - std::unique_ptr shard_counts = - std::make_unique(num_centers); - std::vector shard_data_writer(num_centers); - std::vector shard_idmap_writer(num_centers); - uint32_t dummy_size = 0; - uint32_t const_one = 1; - - for (size_t i = 0; i < num_centers; i++) { - std::string data_filename = - prefix_path + "_subshard-" + std::to_string(i) + ".bin"; - std::string idmap_filename = - prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin"; - shard_data_writer[i] = - std::ofstream(data_filename.c_str(), std::ios::binary); - shard_idmap_writer[i] = - std::ofstream(idmap_filename.c_str(), std::ios::binary); - shard_data_writer[i].write((char *)&dummy_size, sizeof(uint32_t)); - shard_data_writer[i].write((char *)&basedim32, sizeof(uint32_t)); - shard_idmap_writer[i].write((char *)&dummy_size, sizeof(uint32_t)); - shard_idmap_writer[i].write((char *)&const_one, sizeof(uint32_t)); - shard_counts[i] = 0; - } - - size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; - std::unique_ptr block_closest_centers = - std::make_unique(block_size * k_base); - std::unique_ptr block_data_T = std::make_unique(block_size * dim); - std::unique_ptr block_data_float = - std::make_unique(block_size * dim); - - size_t num_blocks = DIV_ROUND_UP(num_points, block_size); - - for (size_t block = 0; block < num_blocks; block++) { - size_t start_id = block * block_size; - size_t end_id = (std::min)((block + 1) * block_size, num_points); - size_t cur_blk_size = end_id - start_id; - - base_reader.read((char *)block_data_T.get(), - sizeof(T) * (cur_blk_size * dim)); - diskann::convert_types(block_data_T.get(), block_data_float.get(), - cur_blk_size, dim); - - math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, - dim, pivots, num_centers, k_base, - block_closest_centers.get()); - - for (size_t p = 0; p < cur_blk_size; p++) { - for (size_t p1 = 0; p1 < k_base; p1++) { - size_t shard_id = block_closest_centers[p * k_base + p1]; - uint32_t original_point_map_id = (uint32_t)(start_id + p); - shard_data_writer[shard_id].write( - (char *)(block_data_T.get() + p * dim), sizeof(T) * dim); - shard_idmap_writer[shard_id].write((char *)&original_point_map_id, - sizeof(uint32_t)); - shard_counts[shard_id]++; - } + diskann::cout << "Estimated cluster sizes: "; + for (size_t i = 0; i < num_centers; i++) + { + uint32_t cur_shard_count = (uint32_t)shard_counts[i]; + cluster_sizes.push_back((size_t)cur_shard_count); + diskann::cout << cur_shard_count << " "; } - } - - size_t total_count = 0; - diskann::cout << "Actual shard sizes: " << std::flush; - for (size_t i = 0; i < num_centers; i++) { - uint32_t cur_shard_count = (uint32_t)shard_counts[i]; - total_count += cur_shard_count; - diskann::cout << cur_shard_count << " "; - shard_data_writer[i].seekp(0); - shard_data_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t)); - shard_data_writer[i].close(); - shard_idmap_writer[i].seekp(0); - shard_idmap_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t)); - shard_idmap_writer[i].close(); - } - - diskann::cout << "\n Partitioned " << num_points - << " with replication factor " << k_base << " to get " - << total_count << " points across " << num_centers << " shards " - << std::endl; - return 0; + diskann::cout << std::endl; + delete[] shard_counts; + delete[] block_closest_centers; + return 0; } -// useful for partitioning large dataset. we first generate only the IDS for -// each shard, and retrieve the actual vectors on demand. template -int shard_data_into_clusters_only_ids(const std::string data_file, - float *pivots, const size_t num_centers, - const size_t dim, const size_t k_base, - std::string prefix_path) { - size_t read_blk_size = 64 * 1024 * 1024; - // uint64_t write_blk_size = 64 * 1024 * 1024; - // create cached reader + writer - cached_ifstream base_reader(data_file, read_blk_size); - uint32_t npts32; - uint32_t basedim32; - base_reader.read((char *)&npts32, sizeof(uint32_t)); - base_reader.read((char *)&basedim32, sizeof(uint32_t)); - size_t num_points = npts32; - if (basedim32 != dim) { - diskann::cout << "Error. dimensions dont match for train set and base set" - << std::endl; - return -1; - } - - std::unique_ptr shard_counts = - std::make_unique(num_centers); - - std::vector shard_idmap_writer(num_centers); - uint32_t dummy_size = 0; - uint32_t const_one = 1; - - for (size_t i = 0; i < num_centers; i++) { - std::string idmap_filename = - prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin"; - shard_idmap_writer[i] = - std::ofstream(idmap_filename.c_str(), std::ios::binary); - shard_idmap_writer[i].write((char *)&dummy_size, sizeof(uint32_t)); - shard_idmap_writer[i].write((char *)&const_one, sizeof(uint32_t)); - shard_counts[i] = 0; - } - - size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; - std::unique_ptr block_closest_centers = - std::make_unique(block_size * k_base); - std::unique_ptr block_data_T = std::make_unique(block_size * dim); - std::unique_ptr block_data_float = - std::make_unique(block_size * dim); - - size_t num_blocks = DIV_ROUND_UP(num_points, block_size); - - for (size_t block = 0; block < num_blocks; block++) { - size_t start_id = block * block_size; - size_t end_id = (std::min)((block + 1) * block_size, num_points); - size_t cur_blk_size = end_id - start_id; - - base_reader.read((char *)block_data_T.get(), - sizeof(T) * (cur_blk_size * dim)); - diskann::convert_types(block_data_T.get(), block_data_float.get(), - cur_blk_size, dim); - - math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, - dim, pivots, num_centers, k_base, - block_closest_centers.get()); - - for (size_t p = 0; p < cur_blk_size; p++) { - for (size_t p1 = 0; p1 < k_base; p1++) { - size_t shard_id = block_closest_centers[p * k_base + p1]; - uint32_t original_point_map_id = (uint32_t)(start_id + p); - shard_idmap_writer[shard_id].write((char *)&original_point_map_id, - sizeof(uint32_t)); - shard_counts[shard_id]++; - } +int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim, + const size_t k_base, std::string prefix_path) +{ + size_t read_blk_size = 64 * 1024 * 1024; + // uint64_t write_blk_size = 64 * 1024 * 1024; + // create cached reader + writer + cached_ifstream base_reader(data_file, read_blk_size); + uint32_t npts32; + uint32_t basedim32; + base_reader.read((char *)&npts32, sizeof(uint32_t)); + base_reader.read((char *)&basedim32, sizeof(uint32_t)); + size_t num_points = npts32; + if (basedim32 != dim) + { + diskann::cout << "Error. dimensions dont match for train set and base set" << std::endl; + return -1; } - } - - size_t total_count = 0; - diskann::cout << "Actual shard sizes: " << std::flush; - for (size_t i = 0; i < num_centers; i++) { - uint32_t cur_shard_count = (uint32_t)shard_counts[i]; - total_count += cur_shard_count; - diskann::cout << cur_shard_count << " "; - shard_idmap_writer[i].seekp(0); - shard_idmap_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t)); - shard_idmap_writer[i].close(); - } - - diskann::cout << "\n Partitioned " << num_points - << " with replication factor " << k_base << " to get " - << total_count << " points across " << num_centers << " shards " - << std::endl; - return 0; -} -template -int retrieve_shard_data_from_ids(const std::string data_file, - std::string idmap_filename, - std::string data_filename) { - size_t read_blk_size = 64 * 1024 * 1024; - // uint64_t write_blk_size = 64 * 1024 * 1024; - // create cached reader + writer - cached_ifstream base_reader(data_file, read_blk_size); - uint32_t npts32; - uint32_t basedim32; - base_reader.read((char *)&npts32, sizeof(uint32_t)); - base_reader.read((char *)&basedim32, sizeof(uint32_t)); - size_t num_points = npts32; - size_t dim = basedim32; - - uint32_t dummy_size = 0; - - std::ofstream shard_data_writer(data_filename.c_str(), std::ios::binary); - shard_data_writer.write((char *)&dummy_size, sizeof(uint32_t)); - shard_data_writer.write((char *)&basedim32, sizeof(uint32_t)); - - uint32_t *shard_ids; - uint64_t shard_size, tmp; - diskann::load_bin(idmap_filename, shard_ids, shard_size, tmp); - - uint32_t cur_pos = 0; - uint32_t num_written = 0; - std::cout << "Shard has " << shard_size << " points" << std::endl; - - size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; - std::unique_ptr block_data_T = std::make_unique(block_size * dim); - - size_t num_blocks = DIV_ROUND_UP(num_points, block_size); - - for (size_t block = 0; block < num_blocks; block++) { - size_t start_id = block * block_size; - size_t end_id = (std::min)((block + 1) * block_size, num_points); - size_t cur_blk_size = end_id - start_id; - - base_reader.read((char *)block_data_T.get(), - sizeof(T) * (cur_blk_size * dim)); - - for (size_t p = 0; p < cur_blk_size; p++) { - uint32_t original_point_map_id = (uint32_t)(start_id + p); - if (cur_pos == shard_size) - break; - if (original_point_map_id == shard_ids[cur_pos]) { - cur_pos++; - shard_data_writer.write((char *)(block_data_T.get() + p * dim), - sizeof(T) * dim); - num_written++; - } + std::unique_ptr shard_counts = std::make_unique(num_centers); + std::vector shard_data_writer(num_centers); + std::vector shard_idmap_writer(num_centers); + uint32_t dummy_size = 0; + uint32_t const_one = 1; + + for (size_t i = 0; i < num_centers; i++) + { + std::string data_filename = prefix_path + "_subshard-" + std::to_string(i) + ".bin"; + std::string idmap_filename = prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin"; + shard_data_writer[i] = std::ofstream(data_filename.c_str(), std::ios::binary); + shard_idmap_writer[i] = std::ofstream(idmap_filename.c_str(), std::ios::binary); + shard_data_writer[i].write((char *)&dummy_size, sizeof(uint32_t)); + shard_data_writer[i].write((char *)&basedim32, sizeof(uint32_t)); + shard_idmap_writer[i].write((char *)&dummy_size, sizeof(uint32_t)); + shard_idmap_writer[i].write((char *)&const_one, sizeof(uint32_t)); + shard_counts[i] = 0; } - if (cur_pos == shard_size) - break; - } - - diskann::cout << "Written file with " << num_written << " points" - << std::endl; - - shard_data_writer.seekp(0); - shard_data_writer.write((char *)&num_written, sizeof(uint32_t)); - shard_data_writer.close(); - delete[] shard_ids; - return 0; -} -// partitions a large base file into many shards using k-means hueristic -// on a random sample generated using sampling_rate probability. After this, it -// assignes each base point to the closest k_base nearest centers and creates -// the shards. -// The total number of points across all shards will be k_base * num_points. - -template -int partition(const std::string data_file, const float sampling_rate, - size_t num_parts, size_t max_k_means_reps, - const std::string prefix_path, size_t k_base) { - size_t train_dim; - size_t num_train; - float *train_data_float; - - gen_random_slice(data_file, sampling_rate, train_data_float, num_train, - train_dim); - - float *pivot_data; + size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + std::unique_ptr block_closest_centers = std::make_unique(block_size * k_base); + std::unique_ptr block_data_T = std::make_unique(block_size * dim); + std::unique_ptr block_data_float = std::make_unique(block_size * dim); + + size_t num_blocks = DIV_ROUND_UP(num_points, block_size); + + for (size_t block = 0; block < num_blocks; block++) + { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_points); + size_t cur_blk_size = end_id - start_id; + + base_reader.read((char *)block_data_T.get(), sizeof(T) * (cur_blk_size * dim)); + diskann::convert_types(block_data_T.get(), block_data_float.get(), cur_blk_size, dim); + + math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, dim, pivots, num_centers, k_base, + block_closest_centers.get()); + + for (size_t p = 0; p < cur_blk_size; p++) + { + for (size_t p1 = 0; p1 < k_base; p1++) + { + size_t shard_id = block_closest_centers[p * k_base + p1]; + uint32_t original_point_map_id = (uint32_t)(start_id + p); + shard_data_writer[shard_id].write((char *)(block_data_T.get() + p * dim), sizeof(T) * dim); + shard_idmap_writer[shard_id].write((char *)&original_point_map_id, sizeof(uint32_t)); + shard_counts[shard_id]++; + } + } + } - std::string cur_file = std::string(prefix_path); - std::string output_file; + size_t total_count = 0; + diskann::cout << "Actual shard sizes: " << std::flush; + for (size_t i = 0; i < num_centers; i++) + { + uint32_t cur_shard_count = (uint32_t)shard_counts[i]; + total_count += cur_shard_count; + diskann::cout << cur_shard_count << " "; + shard_data_writer[i].seekp(0); + shard_data_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t)); + shard_data_writer[i].close(); + shard_idmap_writer[i].seekp(0); + shard_idmap_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t)); + shard_idmap_writer[i].close(); + } - // kmeans_partitioning on training data + diskann::cout << "\n Partitioned " << num_points << " with replication factor " << k_base << " to get " + << total_count << " points across " << num_centers << " shards " << std::endl; + return 0; +} - // cur_file = cur_file + "_kmeans_partitioning-" + - // std::to_string(num_parts); - output_file = cur_file + "_centroids.bin"; +// useful for partitioning large dataset. we first generate only the IDS for +// each shard, and retrieve the actual vectors on demand. +template +int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots, const size_t num_centers, + const size_t dim, const size_t k_base, std::string prefix_path) +{ + size_t read_blk_size = 64 * 1024 * 1024; + // uint64_t write_blk_size = 64 * 1024 * 1024; + // create cached reader + writer + cached_ifstream base_reader(data_file, read_blk_size); + uint32_t npts32; + uint32_t basedim32; + base_reader.read((char *)&npts32, sizeof(uint32_t)); + base_reader.read((char *)&basedim32, sizeof(uint32_t)); + size_t num_points = npts32; + if (basedim32 != dim) + { + diskann::cout << "Error. dimensions dont match for train set and base set" << std::endl; + return -1; + } - pivot_data = new float[num_parts * train_dim]; + std::unique_ptr shard_counts = std::make_unique(num_centers); - // Process Global k-means for kmeans_partitioning Step - diskann::cout << "Processing global k-means (kmeans_partitioning Step)" - << std::endl; - kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim, - pivot_data, num_parts); + std::vector shard_idmap_writer(num_centers); + uint32_t dummy_size = 0; + uint32_t const_one = 1; - kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data, - num_parts, max_k_means_reps, NULL, NULL); + for (size_t i = 0; i < num_centers; i++) + { + std::string idmap_filename = prefix_path + "_subshard-" + std::to_string(i) + "_ids_uint32.bin"; + shard_idmap_writer[i] = std::ofstream(idmap_filename.c_str(), std::ios::binary); + shard_idmap_writer[i].write((char *)&dummy_size, sizeof(uint32_t)); + shard_idmap_writer[i].write((char *)&const_one, sizeof(uint32_t)); + shard_counts[i] = 0; + } - diskann::cout << "Saving global k-center pivots" << std::endl; - diskann::save_bin(output_file.c_str(), pivot_data, (size_t)num_parts, - train_dim); + size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + std::unique_ptr block_closest_centers = std::make_unique(block_size * k_base); + std::unique_ptr block_data_T = std::make_unique(block_size * dim); + std::unique_ptr block_data_float = std::make_unique(block_size * dim); + + size_t num_blocks = DIV_ROUND_UP(num_points, block_size); + + for (size_t block = 0; block < num_blocks; block++) + { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_points); + size_t cur_blk_size = end_id - start_id; + + base_reader.read((char *)block_data_T.get(), sizeof(T) * (cur_blk_size * dim)); + diskann::convert_types(block_data_T.get(), block_data_float.get(), cur_blk_size, dim); + + math_utils::compute_closest_centers(block_data_float.get(), cur_blk_size, dim, pivots, num_centers, k_base, + block_closest_centers.get()); + + for (size_t p = 0; p < cur_blk_size; p++) + { + for (size_t p1 = 0; p1 < k_base; p1++) + { + size_t shard_id = block_closest_centers[p * k_base + p1]; + uint32_t original_point_map_id = (uint32_t)(start_id + p); + shard_idmap_writer[shard_id].write((char *)&original_point_map_id, sizeof(uint32_t)); + shard_counts[shard_id]++; + } + } + } - // now pivots are ready. need to stream base points and assign them to - // closest clusters. + size_t total_count = 0; + diskann::cout << "Actual shard sizes: " << std::flush; + for (size_t i = 0; i < num_centers; i++) + { + uint32_t cur_shard_count = (uint32_t)shard_counts[i]; + total_count += cur_shard_count; + diskann::cout << cur_shard_count << " "; + shard_idmap_writer[i].seekp(0); + shard_idmap_writer[i].write((char *)&cur_shard_count, sizeof(uint32_t)); + shard_idmap_writer[i].close(); + } - shard_data_into_clusters(data_file, pivot_data, num_parts, train_dim, - k_base, prefix_path); - delete[] pivot_data; - delete[] train_data_float; - return 0; + diskann::cout << "\n Partitioned " << num_points << " with replication factor " << k_base << " to get " + << total_count << " points across " << num_centers << " shards " << std::endl; + return 0; } template -int partition_with_ram_budget(const std::string data_file, - const double sampling_rate, double ram_budget, - size_t graph_degree, - const std::string prefix_path, size_t k_base) { - size_t train_dim; - size_t num_train; - float *train_data_float; - size_t max_k_means_reps = 10; +int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, std::string data_filename) +{ + size_t read_blk_size = 64 * 1024 * 1024; + // uint64_t write_blk_size = 64 * 1024 * 1024; + // create cached reader + writer + cached_ifstream base_reader(data_file, read_blk_size); + uint32_t npts32; + uint32_t basedim32; + base_reader.read((char *)&npts32, sizeof(uint32_t)); + base_reader.read((char *)&basedim32, sizeof(uint32_t)); + size_t num_points = npts32; + size_t dim = basedim32; + + uint32_t dummy_size = 0; + + std::ofstream shard_data_writer(data_filename.c_str(), std::ios::binary); + shard_data_writer.write((char *)&dummy_size, sizeof(uint32_t)); + shard_data_writer.write((char *)&basedim32, sizeof(uint32_t)); + + uint32_t *shard_ids; + uint64_t shard_size, tmp; + diskann::load_bin(idmap_filename, shard_ids, shard_size, tmp); + + uint32_t cur_pos = 0; + uint32_t num_written = 0; + std::cout << "Shard has " << shard_size << " points" << std::endl; + + size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + std::unique_ptr block_data_T = std::make_unique(block_size * dim); + + size_t num_blocks = DIV_ROUND_UP(num_points, block_size); + + for (size_t block = 0; block < num_blocks; block++) + { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_points); + size_t cur_blk_size = end_id - start_id; + + base_reader.read((char *)block_data_T.get(), sizeof(T) * (cur_blk_size * dim)); + + for (size_t p = 0; p < cur_blk_size; p++) + { + uint32_t original_point_map_id = (uint32_t)(start_id + p); + if (cur_pos == shard_size) + break; + if (original_point_map_id == shard_ids[cur_pos]) + { + cur_pos++; + shard_data_writer.write((char *)(block_data_T.get() + p * dim), sizeof(T) * dim); + num_written++; + } + } + if (cur_pos == shard_size) + break; + } - int num_parts = 3; - bool fit_in_ram = false; + diskann::cout << "Written file with " << num_written << " points" << std::endl; - gen_random_slice(data_file, sampling_rate, train_data_float, num_train, - train_dim); + shard_data_writer.seekp(0); + shard_data_writer.write((char *)&num_written, sizeof(uint32_t)); + shard_data_writer.close(); + delete[] shard_ids; + return 0; +} - size_t test_dim; - size_t num_test; - float *test_data_float; - gen_random_slice(data_file, sampling_rate, test_data_float, num_test, - test_dim); +// partitions a large base file into many shards using k-means hueristic +// on a random sample generated using sampling_rate probability. After this, it +// assignes each base point to the closest k_base nearest centers and creates +// the shards. +// The total number of points across all shards will be k_base * num_points. - float *pivot_data = nullptr; +template +int partition(const std::string data_file, const float sampling_rate, size_t num_parts, size_t max_k_means_reps, + const std::string prefix_path, size_t k_base) +{ + size_t train_dim; + size_t num_train; + float *train_data_float; - std::string cur_file = std::string(prefix_path); - std::string output_file; + gen_random_slice(data_file, sampling_rate, train_data_float, num_train, train_dim); - // kmeans_partitioning on training data + float *pivot_data; - // cur_file = cur_file + "_kmeans_partitioning-" + - // std::to_string(num_parts); - output_file = cur_file + "_centroids.bin"; + std::string cur_file = std::string(prefix_path); + std::string output_file; - while (!fit_in_ram) { - fit_in_ram = true; + // kmeans_partitioning on training data - double max_ram_usage = 0; - if (pivot_data != nullptr) - delete[] pivot_data; + // cur_file = cur_file + "_kmeans_partitioning-" + + // std::to_string(num_parts); + output_file = cur_file + "_centroids.bin"; pivot_data = new float[num_parts * train_dim]; + // Process Global k-means for kmeans_partitioning Step - diskann::cout << "Processing global k-means (kmeans_partitioning Step)" - << std::endl; - kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim, - pivot_data, num_parts); + diskann::cout << "Processing global k-means (kmeans_partitioning Step)" << std::endl; + kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim, pivot_data, num_parts); + + kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data, num_parts, max_k_means_reps, NULL, NULL); - kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data, - num_parts, max_k_means_reps, NULL, NULL); + diskann::cout << "Saving global k-center pivots" << std::endl; + diskann::save_bin(output_file.c_str(), pivot_data, (size_t)num_parts, train_dim); // now pivots are ready. need to stream base points and assign them to // closest clusters. - std::vector cluster_sizes; - estimate_cluster_sizes(test_data_float, num_test, pivot_data, num_parts, - train_dim, k_base, cluster_sizes); - - for (auto &p : cluster_sizes) { - // to account for the fact that p is the size of the shard over the - // testing sample. - p = (uint64_t)(p / sampling_rate); - double cur_shard_ram_estimate = diskann::estimate_ram_usage( - p, (uint32_t)train_dim, sizeof(T), (uint32_t)graph_degree); + shard_data_into_clusters(data_file, pivot_data, num_parts, train_dim, k_base, prefix_path); + delete[] pivot_data; + delete[] train_data_float; + return 0; +} - if (cur_shard_ram_estimate > max_ram_usage) - max_ram_usage = cur_shard_ram_estimate; - } - diskann::cout << "With " << num_parts << " parts, max estimated RAM usage: " - << max_ram_usage / (1024 * 1024 * 1024) - << "GB, budget given is " << ram_budget << std::endl; - if (max_ram_usage > 1024 * 1024 * 1024 * ram_budget) { - fit_in_ram = false; - num_parts += 2; +template +int partition_with_ram_budget(const std::string data_file, const double sampling_rate, double ram_budget, + size_t graph_degree, const std::string prefix_path, size_t k_base) +{ + size_t train_dim; + size_t num_train; + float *train_data_float; + size_t max_k_means_reps = 10; + + int num_parts = 3; + bool fit_in_ram = false; + + gen_random_slice(data_file, sampling_rate, train_data_float, num_train, train_dim); + + size_t test_dim; + size_t num_test; + float *test_data_float; + gen_random_slice(data_file, sampling_rate, test_data_float, num_test, test_dim); + + float *pivot_data = nullptr; + + std::string cur_file = std::string(prefix_path); + std::string output_file; + + // kmeans_partitioning on training data + + // cur_file = cur_file + "_kmeans_partitioning-" + + // std::to_string(num_parts); + output_file = cur_file + "_centroids.bin"; + + while (!fit_in_ram) + { + fit_in_ram = true; + + double max_ram_usage = 0; + if (pivot_data != nullptr) + delete[] pivot_data; + + pivot_data = new float[num_parts * train_dim]; + // Process Global k-means for kmeans_partitioning Step + diskann::cout << "Processing global k-means (kmeans_partitioning Step)" << std::endl; + kmeans::kmeanspp_selecting_pivots(train_data_float, num_train, train_dim, pivot_data, num_parts); + + kmeans::run_lloyds(train_data_float, num_train, train_dim, pivot_data, num_parts, max_k_means_reps, NULL, NULL); + + // now pivots are ready. need to stream base points and assign them to + // closest clusters. + + std::vector cluster_sizes; + estimate_cluster_sizes(test_data_float, num_test, pivot_data, num_parts, train_dim, k_base, cluster_sizes); + + for (auto &p : cluster_sizes) + { + // to account for the fact that p is the size of the shard over the + // testing sample. + p = (uint64_t)(p / sampling_rate); + double cur_shard_ram_estimate = + diskann::estimate_ram_usage(p, (uint32_t)train_dim, sizeof(T), (uint32_t)graph_degree); + + if (cur_shard_ram_estimate > max_ram_usage) + max_ram_usage = cur_shard_ram_estimate; + } + diskann::cout << "With " << num_parts + << " parts, max estimated RAM usage: " << max_ram_usage / (1024 * 1024 * 1024) + << "GB, budget given is " << ram_budget << std::endl; + if (max_ram_usage > 1024 * 1024 * 1024 * ram_budget) + { + fit_in_ram = false; + num_parts += 2; + } } - } - - diskann::cout << "Saving global k-center pivots" << std::endl; - diskann::save_bin(output_file.c_str(), pivot_data, (size_t)num_parts, - train_dim); - - shard_data_into_clusters_only_ids(data_file, pivot_data, num_parts, - train_dim, k_base, prefix_path); - delete[] pivot_data; - delete[] train_data_float; - delete[] test_data_float; - return num_parts; + + diskann::cout << "Saving global k-center pivots" << std::endl; + diskann::save_bin(output_file.c_str(), pivot_data, (size_t)num_parts, train_dim); + + shard_data_into_clusters_only_ids(data_file, pivot_data, num_parts, train_dim, k_base, prefix_path); + delete[] pivot_data; + delete[] train_data_float; + delete[] test_data_float; + return num_parts; } // Instantations of supported templates -template void DISKANN_DLLEXPORT -gen_random_slice(const std::string base_file, - const std::string output_prefix, double sampling_rate); -template void DISKANN_DLLEXPORT gen_random_slice( - const std::string base_file, const std::string output_prefix, - double sampling_rate); -template void DISKANN_DLLEXPORT -gen_random_slice(const std::string base_file, - const std::string output_prefix, double sampling_rate); - -template void DISKANN_DLLEXPORT -gen_random_slice(const float *inputdata, size_t npts, size_t ndims, - double p_val, float *&sampled_data, size_t &slice_size); -template void DISKANN_DLLEXPORT gen_random_slice( - const uint8_t *inputdata, size_t npts, size_t ndims, double p_val, - float *&sampled_data, size_t &slice_size); -template void DISKANN_DLLEXPORT gen_random_slice( - const int8_t *inputdata, size_t npts, size_t ndims, double p_val, - float *&sampled_data, size_t &slice_size); - -template void DISKANN_DLLEXPORT gen_random_slice( - const std::string data_file, double p_val, float *&sampled_data, - size_t &slice_size, size_t &ndims); -template void DISKANN_DLLEXPORT gen_random_slice( - const std::string data_file, double p_val, float *&sampled_data, - size_t &slice_size, size_t &ndims); -template void DISKANN_DLLEXPORT gen_random_slice( - const std::string data_file, double p_val, float *&sampled_data, - size_t &slice_size, size_t &ndims); - -template DISKANN_DLLEXPORT int -partition(const std::string data_file, const float sampling_rate, - size_t num_centers, size_t max_k_means_reps, - const std::string prefix_path, size_t k_base); -template DISKANN_DLLEXPORT int -partition(const std::string data_file, const float sampling_rate, - size_t num_centers, size_t max_k_means_reps, - const std::string prefix_path, size_t k_base); -template DISKANN_DLLEXPORT int -partition(const std::string data_file, const float sampling_rate, - size_t num_centers, size_t max_k_means_reps, - const std::string prefix_path, size_t k_base); - -template DISKANN_DLLEXPORT int partition_with_ram_budget( - const std::string data_file, const double sampling_rate, double ram_budget, - size_t graph_degree, const std::string prefix_path, size_t k_base); -template DISKANN_DLLEXPORT int partition_with_ram_budget( - const std::string data_file, const double sampling_rate, double ram_budget, - size_t graph_degree, const std::string prefix_path, size_t k_base); -template DISKANN_DLLEXPORT int partition_with_ram_budget( - const std::string data_file, const double sampling_rate, double ram_budget, - size_t graph_degree, const std::string prefix_path, size_t k_base); - -template DISKANN_DLLEXPORT int -retrieve_shard_data_from_ids(const std::string data_file, - std::string idmap_filename, - std::string data_filename); -template DISKANN_DLLEXPORT int -retrieve_shard_data_from_ids(const std::string data_file, - std::string idmap_filename, - std::string data_filename); -template DISKANN_DLLEXPORT int -retrieve_shard_data_from_ids(const std::string data_file, - std::string idmap_filename, - std::string data_filename); \ No newline at end of file +template void DISKANN_DLLEXPORT gen_random_slice(const std::string base_file, const std::string output_prefix, + double sampling_rate); +template void DISKANN_DLLEXPORT gen_random_slice(const std::string base_file, const std::string output_prefix, + double sampling_rate); +template void DISKANN_DLLEXPORT gen_random_slice(const std::string base_file, const std::string output_prefix, + double sampling_rate); + +template void DISKANN_DLLEXPORT gen_random_slice(const float *inputdata, size_t npts, size_t ndims, double p_val, + float *&sampled_data, size_t &slice_size); +template void DISKANN_DLLEXPORT gen_random_slice(const uint8_t *inputdata, size_t npts, size_t ndims, + double p_val, float *&sampled_data, size_t &slice_size); +template void DISKANN_DLLEXPORT gen_random_slice(const int8_t *inputdata, size_t npts, size_t ndims, + double p_val, float *&sampled_data, size_t &slice_size); + +template void DISKANN_DLLEXPORT gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, + size_t &slice_size, size_t &ndims); +template void DISKANN_DLLEXPORT gen_random_slice(const std::string data_file, double p_val, + float *&sampled_data, size_t &slice_size, size_t &ndims); +template void DISKANN_DLLEXPORT gen_random_slice(const std::string data_file, double p_val, + float *&sampled_data, size_t &slice_size, size_t &ndims); + +template DISKANN_DLLEXPORT int partition(const std::string data_file, const float sampling_rate, + size_t num_centers, size_t max_k_means_reps, + const std::string prefix_path, size_t k_base); +template DISKANN_DLLEXPORT int partition(const std::string data_file, const float sampling_rate, + size_t num_centers, size_t max_k_means_reps, + const std::string prefix_path, size_t k_base); +template DISKANN_DLLEXPORT int partition(const std::string data_file, const float sampling_rate, + size_t num_centers, size_t max_k_means_reps, + const std::string prefix_path, size_t k_base); + +template DISKANN_DLLEXPORT int partition_with_ram_budget(const std::string data_file, + const double sampling_rate, double ram_budget, + size_t graph_degree, const std::string prefix_path, + size_t k_base); +template DISKANN_DLLEXPORT int partition_with_ram_budget(const std::string data_file, + const double sampling_rate, double ram_budget, + size_t graph_degree, const std::string prefix_path, + size_t k_base); +template DISKANN_DLLEXPORT int partition_with_ram_budget(const std::string data_file, const double sampling_rate, + double ram_budget, size_t graph_degree, + const std::string prefix_path, size_t k_base); + +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std::string data_file, + std::string idmap_filename, + std::string data_filename); +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std::string data_file, + std::string idmap_filename, + std::string data_filename); +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std::string data_file, + std::string idmap_filename, + std::string data_filename); \ No newline at end of file diff --git a/src/pq.cpp b/src/pq.cpp index d11901844..d1cc8e861 100644 --- a/src/pq.cpp +++ b/src/pq.cpp @@ -2,8 +2,7 @@ // Licensed under the MIT license. #include "mkl.h" -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif #include "math_utils.h" @@ -14,313 +13,335 @@ // block size for reading/processing large files and matrices in blocks #define BLOCK_SIZE 5000000 -namespace diskann { -FixedChunkPQTable::FixedChunkPQTable() {} +namespace diskann +{ +FixedChunkPQTable::FixedChunkPQTable() +{ +} -FixedChunkPQTable::~FixedChunkPQTable() { +FixedChunkPQTable::~FixedChunkPQTable() +{ #ifndef EXEC_ENV_OLS - if (tables != nullptr) - delete[] tables; - if (tables_tr != nullptr) - delete[] tables_tr; - if (chunk_offsets != nullptr) - delete[] chunk_offsets; - if (centroid != nullptr) - delete[] centroid; - if (rotmat_tr != nullptr) - delete[] rotmat_tr; + if (tables != nullptr) + delete[] tables; + if (tables_tr != nullptr) + delete[] tables_tr; + if (chunk_offsets != nullptr) + delete[] chunk_offsets; + if (centroid != nullptr) + delete[] centroid; + if (rotmat_tr != nullptr) + delete[] rotmat_tr; #endif } #ifdef EXEC_ENV_OLS -void FixedChunkPQTable::load_pq_centroid_bin(MemoryMappedFiles &files, - const char *pq_table_file, - size_t num_chunks) { +void FixedChunkPQTable::load_pq_centroid_bin(MemoryMappedFiles &files, const char *pq_table_file, size_t num_chunks) +{ #else -void FixedChunkPQTable::load_pq_centroid_bin(const char *pq_table_file, - size_t num_chunks) { +void FixedChunkPQTable::load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks) +{ #endif - uint64_t nr, nc; - std::string rotmat_file = std::string(pq_table_file) + "_rotation_matrix.bin"; + uint64_t nr, nc; + std::string rotmat_file = std::string(pq_table_file) + "_rotation_matrix.bin"; #ifdef EXEC_ENV_OLS - size_t *file_offset_data; // since load_bin only sets the pointer, no need - // to delete. - diskann::load_bin(files, pq_table_file, file_offset_data, nr, nc); + size_t *file_offset_data; // since load_bin only sets the pointer, no need + // to delete. + diskann::load_bin(files, pq_table_file, file_offset_data, nr, nc); #else - std::unique_ptr file_offset_data; - diskann::load_bin(pq_table_file, file_offset_data, nr, nc); + std::unique_ptr file_offset_data; + diskann::load_bin(pq_table_file, file_offset_data, nr, nc); #endif - bool use_old_filetype = false; - - if (nr != 4 && nr != 5) { - diskann::cout << "Error reading pq_pivots file " << pq_table_file - << ". Offsets dont contain correct metadata, # offsets = " - << nr << ", but expecting " << 4 << " or " << 5; - throw diskann::ANNException("Error reading pq_pivots file at offsets data.", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - if (nr == 4) { - diskann::cout << "Offsets: " << file_offset_data[0] << " " - << file_offset_data[1] << " " << file_offset_data[2] << " " - << file_offset_data[3] << std::endl; - } else if (nr == 5) { - use_old_filetype = true; - diskann::cout << "Offsets: " << file_offset_data[0] << " " - << file_offset_data[1] << " " << file_offset_data[2] << " " - << file_offset_data[3] << file_offset_data[4] << std::endl; - } else { - throw diskann::ANNException("Wrong number of offsets in pq_pivots", -1, - __FUNCSIG__, __FILE__, __LINE__); - } + bool use_old_filetype = false; + + if (nr != 4 && nr != 5) + { + diskann::cout << "Error reading pq_pivots file " << pq_table_file + << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting " << 4 + << " or " << 5; + throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + if (nr == 4) + { + diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2] + << " " << file_offset_data[3] << std::endl; + } + else if (nr == 5) + { + use_old_filetype = true; + diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2] + << " " << file_offset_data[3] << file_offset_data[4] << std::endl; + } + else + { + throw diskann::ANNException("Wrong number of offsets in pq_pivots", -1, __FUNCSIG__, __FILE__, __LINE__); + } #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_table_file, tables, nr, nc, - file_offset_data[0]); + diskann::load_bin(files, pq_table_file, tables, nr, nc, file_offset_data[0]); #else - diskann::load_bin(pq_table_file, tables, nr, nc, file_offset_data[0]); + diskann::load_bin(pq_table_file, tables, nr, nc, file_offset_data[0]); #endif - if ((nr != NUM_PQ_CENTROIDS)) { - diskann::cout << "Error reading pq_pivots file " << pq_table_file - << ". file_num_centers = " << nr << " but expecting " - << NUM_PQ_CENTROIDS << " centers"; - throw diskann::ANNException("Error reading pq_pivots file at pivots data.", - -1, __FUNCSIG__, __FILE__, __LINE__); - } + if ((nr != NUM_PQ_CENTROIDS)) + { + diskann::cout << "Error reading pq_pivots file " << pq_table_file << ". file_num_centers = " << nr + << " but expecting " << NUM_PQ_CENTROIDS << " centers"; + throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - this->ndims = nc; + this->ndims = nc; #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_table_file, centroid, nr, nc, - file_offset_data[1]); + diskann::load_bin(files, pq_table_file, centroid, nr, nc, file_offset_data[1]); #else - diskann::load_bin(pq_table_file, centroid, nr, nc, - file_offset_data[1]); + diskann::load_bin(pq_table_file, centroid, nr, nc, file_offset_data[1]); #endif - if ((nr != this->ndims) || (nc != 1)) { - diskann::cerr << "Error reading centroids from pq_pivots file " - << pq_table_file << ". file_dim = " << nr - << ", file_cols = " << nc << " but expecting " << this->ndims - << " entries in 1 dimension."; - throw diskann::ANNException( - "Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - - int chunk_offsets_index = 2; - if (use_old_filetype) { - chunk_offsets_index = 3; - } + if ((nr != this->ndims) || (nc != 1)) + { + diskann::cerr << "Error reading centroids from pq_pivots file " << pq_table_file << ". file_dim = " << nr + << ", file_cols = " << nc << " but expecting " << this->ndims << " entries in 1 dimension."; + throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + + int chunk_offsets_index = 2; + if (use_old_filetype) + { + chunk_offsets_index = 3; + } #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_table_file, chunk_offsets, nr, nc, - file_offset_data[chunk_offsets_index]); + diskann::load_bin(files, pq_table_file, chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]); #else - diskann::load_bin(pq_table_file, chunk_offsets, nr, nc, - file_offset_data[chunk_offsets_index]); + diskann::load_bin(pq_table_file, chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]); #endif - if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0)) { - diskann::cerr << "Error loading chunk offsets file. numc: " << nc - << " (should be 1). numr: " << nr << " (should be " - << num_chunks + 1 << " or 0 if we need to infer)" - << std::endl; - throw diskann::ANNException("Error loading chunk offsets file", -1, - __FUNCSIG__, __FILE__, __LINE__); - } + if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0)) + { + diskann::cerr << "Error loading chunk offsets file. numc: " << nc << " (should be 1). numr: " << nr + << " (should be " << num_chunks + 1 << " or 0 if we need to infer)" << std::endl; + throw diskann::ANNException("Error loading chunk offsets file", -1, __FUNCSIG__, __FILE__, __LINE__); + } - this->n_chunks = nr - 1; - diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS - << ", #dims: " << this->ndims << ", #chunks: " << this->n_chunks - << std::endl; + this->n_chunks = nr - 1; + diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS << ", #dims: " << this->ndims + << ", #chunks: " << this->n_chunks << std::endl; #ifdef EXEC_ENV_OLS - if (files.fileExists(rotmat_file)) { - diskann::load_bin(files, rotmat_file, (float *&)rotmat_tr, nr, nc); + if (files.fileExists(rotmat_file)) + { + diskann::load_bin(files, rotmat_file, (float *&)rotmat_tr, nr, nc); #else - if (file_exists(rotmat_file)) { - diskann::load_bin(rotmat_file, rotmat_tr, nr, nc); + if (file_exists(rotmat_file)) + { + diskann::load_bin(rotmat_file, rotmat_tr, nr, nc); #endif - if (nr != this->ndims || nc != this->ndims) { - diskann::cerr << "Error loading rotation matrix file" << std::endl; - throw diskann::ANNException("Error loading rotation matrix file", -1, - __FUNCSIG__, __FILE__, __LINE__); + if (nr != this->ndims || nc != this->ndims) + { + diskann::cerr << "Error loading rotation matrix file" << std::endl; + throw diskann::ANNException("Error loading rotation matrix file", -1, __FUNCSIG__, __FILE__, __LINE__); + } + use_rotation = true; } - use_rotation = true; - } - - // alloc and compute transpose - tables_tr = new float[256 * this->ndims]; - for (size_t i = 0; i < 256; i++) { - for (size_t j = 0; j < this->ndims; j++) { - tables_tr[j * 256 + i] = tables[i * this->ndims + j]; + + // alloc and compute transpose + tables_tr = new float[256 * this->ndims]; + for (size_t i = 0; i < 256; i++) + { + for (size_t j = 0; j < this->ndims; j++) + { + tables_tr[j * 256 + i] = tables[i * this->ndims + j]; + } } - } } -uint32_t FixedChunkPQTable::get_num_chunks() { - return static_cast(n_chunks); +uint32_t FixedChunkPQTable::get_num_chunks() +{ + return static_cast(n_chunks); } -void FixedChunkPQTable::preprocess_query(float *query_vec) { - for (uint32_t d = 0; d < ndims; d++) { - query_vec[d] -= centroid[d]; - } - std::vector tmp(ndims, 0); - if (use_rotation) { - for (uint32_t d = 0; d < ndims; d++) { - for (uint32_t d1 = 0; d1 < ndims; d1++) { - tmp[d] += query_vec[d1] * rotmat_tr[d1 * ndims + d]; - } +void FixedChunkPQTable::preprocess_query(float *query_vec) +{ + for (uint32_t d = 0; d < ndims; d++) + { + query_vec[d] -= centroid[d]; + } + std::vector tmp(ndims, 0); + if (use_rotation) + { + for (uint32_t d = 0; d < ndims; d++) + { + for (uint32_t d1 = 0; d1 < ndims; d1++) + { + tmp[d] += query_vec[d1] * rotmat_tr[d1 * ndims + d]; + } + } + std::memcpy(query_vec, tmp.data(), ndims * sizeof(float)); } - std::memcpy(query_vec, tmp.data(), ndims * sizeof(float)); - } } // assumes pre-processed query -void FixedChunkPQTable::populate_chunk_distances(const float *query_vec, - float *dist_vec) { - memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); - // chunk wise distance computation - for (size_t chunk = 0; chunk < n_chunks; chunk++) { - // sum (q-c)^2 for the dimensions associated with this chunk - float *chunk_dists = dist_vec + (256 * chunk); - for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = tables_tr + (256 * j); - for (size_t idx = 0; idx < 256; idx++) { - double diff = centers_dim_vec[idx] - (query_vec[j]); - chunk_dists[idx] += (float)(diff * diff); - } +void FixedChunkPQTable::populate_chunk_distances(const float *query_vec, float *dist_vec) +{ + memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); + // chunk wise distance computation + for (size_t chunk = 0; chunk < n_chunks; chunk++) + { + // sum (q-c)^2 for the dimensions associated with this chunk + float *chunk_dists = dist_vec + (256 * chunk); + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = tables_tr + (256 * j); + for (size_t idx = 0; idx < 256; idx++) + { + double diff = centers_dim_vec[idx] - (query_vec[j]); + chunk_dists[idx] += (float)(diff * diff); + } + } } - } } -float FixedChunkPQTable::l2_distance(const float *query_vec, - uint8_t *base_vec) { - float res = 0; - for (size_t chunk = 0; chunk < n_chunks; chunk++) { - for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = tables_tr + (256 * j); - float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]); - res += diff * diff; +float FixedChunkPQTable::l2_distance(const float *query_vec, uint8_t *base_vec) +{ + float res = 0; + for (size_t chunk = 0; chunk < n_chunks; chunk++) + { + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = tables_tr + (256 * j); + float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]); + res += diff * diff; + } } - } - return res; + return res; } -float FixedChunkPQTable::inner_product(const float *query_vec, - uint8_t *base_vec) { - float res = 0; - for (size_t chunk = 0; chunk < n_chunks; chunk++) { - for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = tables_tr + (256 * j); - float diff = centers_dim_vec[base_vec[chunk]] * - query_vec[j]; // assumes centroid is 0 to - // prevent translation errors - res += diff; +float FixedChunkPQTable::inner_product(const float *query_vec, uint8_t *base_vec) +{ + float res = 0; + for (size_t chunk = 0; chunk < n_chunks; chunk++) + { + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = tables_tr + (256 * j); + float diff = centers_dim_vec[base_vec[chunk]] * query_vec[j]; // assumes centroid is 0 to + // prevent translation errors + res += diff; + } } - } - return -res; // returns negative value to simulate distances (max -> min - // conversion) + return -res; // returns negative value to simulate distances (max -> min + // conversion) } // assumes no rotation is involved -void FixedChunkPQTable::inflate_vector(uint8_t *base_vec, float *out_vec) { - for (size_t chunk = 0; chunk < n_chunks; chunk++) { - for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = tables_tr + (256 * j); - out_vec[j] = centers_dim_vec[base_vec[chunk]] + centroid[j]; +void FixedChunkPQTable::inflate_vector(uint8_t *base_vec, float *out_vec) +{ + for (size_t chunk = 0; chunk < n_chunks; chunk++) + { + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = tables_tr + (256 * j); + out_vec[j] = centers_dim_vec[base_vec[chunk]] + centroid[j]; + } } - } } -void FixedChunkPQTable::populate_chunk_inner_products(const float *query_vec, - float *dist_vec) { - memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); - // chunk wise distance computation - for (size_t chunk = 0; chunk < n_chunks; chunk++) { - // sum (q-c)^2 for the dimensions associated with this chunk - float *chunk_dists = dist_vec + (256 * chunk); - for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = tables_tr + (256 * j); - for (size_t idx = 0; idx < 256; idx++) { - double prod = - centers_dim_vec[idx] * query_vec[j]; // assumes that we are not - // shifting the vectors to - // mean zero, i.e., centroid - // array should be all zeros - chunk_dists[idx] -= - (float)prod; // returning negative to keep the search code - // clean (max inner product vs min distance) - } +void FixedChunkPQTable::populate_chunk_inner_products(const float *query_vec, float *dist_vec) +{ + memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); + // chunk wise distance computation + for (size_t chunk = 0; chunk < n_chunks; chunk++) + { + // sum (q-c)^2 for the dimensions associated with this chunk + float *chunk_dists = dist_vec + (256 * chunk); + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = tables_tr + (256 * j); + for (size_t idx = 0; idx < 256; idx++) + { + double prod = centers_dim_vec[idx] * query_vec[j]; // assumes that we are not + // shifting the vectors to + // mean zero, i.e., centroid + // array should be all zeros + chunk_dists[idx] -= (float)prod; // returning negative to keep the search code + // clean (max inner product vs min distance) + } + } } - } } -void aggregate_coords(const std::vector &ids, - const uint8_t *all_coords, const size_t ndims, - uint8_t *out) { - for (size_t i = 0; i < ids.size(); i++) { - memcpy(out + i * ndims, all_coords + ids[i] * ndims, - ndims * sizeof(uint8_t)); - } +void aggregate_coords(const std::vector &ids, const uint8_t *all_coords, const size_t ndims, uint8_t *out) +{ + for (size_t i = 0; i < ids.size(); i++) + { + memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(uint8_t)); + } } -void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, - const size_t pq_nchunks, const float *pq_dists, - std::vector &dists_out) { - //_mm_prefetch((char*) dists_out, _MM_HINT_T0); - _mm_prefetch((char *)pq_ids, _MM_HINT_T0); - _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0); - _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0); - dists_out.clear(); - dists_out.resize(n_pts, 0); - for (size_t chunk = 0; chunk < pq_nchunks; chunk++) { - const float *chunk_dists = pq_dists + 256 * chunk; - if (chunk < pq_nchunks - 1) { - _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0); - } - for (size_t idx = 0; idx < n_pts; idx++) { - uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; - dists_out[idx] += chunk_dists[pq_centerid]; +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, + std::vector &dists_out) +{ + //_mm_prefetch((char*) dists_out, _MM_HINT_T0); + _mm_prefetch((char *)pq_ids, _MM_HINT_T0); + _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0); + _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0); + dists_out.clear(); + dists_out.resize(n_pts, 0); + for (size_t chunk = 0; chunk < pq_nchunks; chunk++) + { + const float *chunk_dists = pq_dists + 256 * chunk; + if (chunk < pq_nchunks - 1) + { + _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0); + } + for (size_t idx = 0; idx < n_pts; idx++) + { + uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; + dists_out[idx] += chunk_dists[pq_centerid]; + } } - } } // Need to replace calls to these functions with calls to vector& based // functions above -void aggregate_coords(const uint32_t *ids, const size_t n_ids, - const uint8_t *all_coords, const size_t ndims, - uint8_t *out) { - for (size_t i = 0; i < n_ids; i++) { - memcpy(out + i * ndims, all_coords + ids[i] * ndims, - ndims * sizeof(uint8_t)); - } +void aggregate_coords(const uint32_t *ids, const size_t n_ids, const uint8_t *all_coords, const size_t ndims, + uint8_t *out) +{ + for (size_t i = 0; i < n_ids; i++) + { + memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(uint8_t)); + } } -void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, - const size_t pq_nchunks, const float *pq_dists, - float *dists_out) { - _mm_prefetch((char *)dists_out, _MM_HINT_T0); - _mm_prefetch((char *)pq_ids, _MM_HINT_T0); - _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0); - _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0); - memset(dists_out, 0, n_pts * sizeof(float)); - for (size_t chunk = 0; chunk < pq_nchunks; chunk++) { - const float *chunk_dists = pq_dists + 256 * chunk; - if (chunk < pq_nchunks - 1) { - _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0); - } - for (size_t idx = 0; idx < n_pts; idx++) { - uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; - dists_out[idx] += chunk_dists[pq_centerid]; +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, + float *dists_out) +{ + _mm_prefetch((char *)dists_out, _MM_HINT_T0); + _mm_prefetch((char *)pq_ids, _MM_HINT_T0); + _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0); + _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0); + memset(dists_out, 0, n_pts * sizeof(float)); + for (size_t chunk = 0; chunk < pq_nchunks; chunk++) + { + const float *chunk_dists = pq_dists + 256 * chunk; + if (chunk < pq_nchunks - 1) + { + _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0); + } + for (size_t idx = 0; idx < n_pts; idx++) + { + uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; + dists_out[idx] += chunk_dists[pq_centerid]; + } } - } } // generate_pq_pivots_simplified is a simplified version of generate_pq_pivots. @@ -336,50 +357,51 @@ void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, // The compiler pragma for multi-threading support is removed from this // implementation for the purpose of integration into systems that strictly // control resource allocation. -int generate_pq_pivots_simplified(const float *train_data, size_t num_train, - size_t dim, size_t num_pq_chunks, - std::vector &pivot_data_vector) { - if (num_pq_chunks > dim || dim % num_pq_chunks != 0) { - return -1; - } - - const size_t num_centers = 256; - const size_t cur_chunk_size = dim / num_pq_chunks; - const uint32_t KMEANS_ITERS_FOR_PQ = 15; - - pivot_data_vector.resize(num_centers * dim); - std::vector cur_pivot_data_vector(num_centers * cur_chunk_size); - std::vector cur_data_vector(num_train * cur_chunk_size); - std::vector closest_center_vector(num_train); - - float *pivot_data = &pivot_data_vector[0]; - float *cur_pivot_data = &cur_pivot_data_vector[0]; - float *cur_data = &cur_data_vector[0]; - uint32_t *closest_center = &closest_center_vector[0]; - - for (size_t i = 0; i < num_pq_chunks; i++) { - size_t chunk_offset = cur_chunk_size * i; - - for (int32_t j = 0; j < num_train; j++) { - std::memcpy(cur_data + j * cur_chunk_size, - train_data + j * dim + chunk_offset, - cur_chunk_size * sizeof(float)); +int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, size_t num_pq_chunks, + std::vector &pivot_data_vector) +{ + if (num_pq_chunks > dim || dim % num_pq_chunks != 0) + { + return -1; } - kmeans::kmeanspp_selecting_pivots(cur_data, num_train, cur_chunk_size, - cur_pivot_data, num_centers); + const size_t num_centers = 256; + const size_t cur_chunk_size = dim / num_pq_chunks; + const uint32_t KMEANS_ITERS_FOR_PQ = 15; + + pivot_data_vector.resize(num_centers * dim); + std::vector cur_pivot_data_vector(num_centers * cur_chunk_size); + std::vector cur_data_vector(num_train * cur_chunk_size); + std::vector closest_center_vector(num_train); + + float *pivot_data = &pivot_data_vector[0]; + float *cur_pivot_data = &cur_pivot_data_vector[0]; + float *cur_data = &cur_data_vector[0]; + uint32_t *closest_center = &closest_center_vector[0]; - kmeans::run_lloyds(cur_data, num_train, cur_chunk_size, cur_pivot_data, - num_centers, KMEANS_ITERS_FOR_PQ, NULL, closest_center); + for (size_t i = 0; i < num_pq_chunks; i++) + { + size_t chunk_offset = cur_chunk_size * i; - for (uint64_t j = 0; j < num_centers; j++) { - std::memcpy(pivot_data + j * dim + chunk_offset, - cur_pivot_data + j * cur_chunk_size, - cur_chunk_size * sizeof(float)); + for (int32_t j = 0; j < num_train; j++) + { + std::memcpy(cur_data + j * cur_chunk_size, train_data + j * dim + chunk_offset, + cur_chunk_size * sizeof(float)); + } + + kmeans::kmeanspp_selecting_pivots(cur_data, num_train, cur_chunk_size, cur_pivot_data, num_centers); + + kmeans::run_lloyds(cur_data, num_train, cur_chunk_size, cur_pivot_data, num_centers, KMEANS_ITERS_FOR_PQ, NULL, + closest_center); + + for (uint64_t j = 0; j < num_centers; j++) + { + std::memcpy(pivot_data + j * dim + chunk_offset, cur_pivot_data + j * cur_chunk_size, + cur_chunk_size * sizeof(float)); + } } - } - return 0; + return 0; } // given training data in train_data of dimensions num_train * dim, generate @@ -387,379 +409,367 @@ int generate_pq_pivots_simplified(const float *train_data, size_t num_train, // num_pq_chunks (if it divides dimension, else rounded) chunks, and runs // k-means in each chunk to compute the PQ pivots and stores in bin format in // file pq_pivots_path as a s num_centers*dim floating point binary file -int generate_pq_pivots(const float *const passed_train_data, size_t num_train, - uint32_t dim, uint32_t num_centers, - uint32_t num_pq_chunks, uint32_t max_k_means_reps, - std::string pq_pivots_path, bool make_zero_mean) { - if (num_pq_chunks > dim) { - diskann::cout << " Error: number of chunks more than dimension" - << std::endl; - return -1; - } - - std::unique_ptr train_data = - std::make_unique(num_train * dim); - std::memcpy(train_data.get(), passed_train_data, - num_train * dim * sizeof(float)); - - std::unique_ptr full_pivot_data; - - if (file_exists(pq_pivots_path)) { - size_t file_dim, file_num_centers; - diskann::load_bin(pq_pivots_path, full_pivot_data, file_num_centers, - file_dim, METADATA_SIZE); - if (file_dim == dim && file_num_centers == num_centers) { - diskann::cout << "PQ pivot file exists. Not generating again" - << std::endl; - return -1; +int generate_pq_pivots(const float *const passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers, + uint32_t num_pq_chunks, uint32_t max_k_means_reps, std::string pq_pivots_path, + bool make_zero_mean) +{ + if (num_pq_chunks > dim) + { + diskann::cout << " Error: number of chunks more than dimension" << std::endl; + return -1; } - } - - // Calculate centroid and center the training data - std::unique_ptr centroid = std::make_unique(dim); - for (uint64_t d = 0; d < dim; d++) { - centroid[d] = 0; - } - if (make_zero_mean) { // If we use L2 distance, there is an option to - // translate all vectors to make them centered and - // then compute PQ. This needs to be set to false - // when using PQ for MIPS as such translations dont - // preserve inner products. - for (uint64_t d = 0; d < dim; d++) { - for (uint64_t p = 0; p < num_train; p++) { - centroid[d] += train_data[p * dim + d]; - } - centroid[d] /= num_train; + + std::unique_ptr train_data = std::make_unique(num_train * dim); + std::memcpy(train_data.get(), passed_train_data, num_train * dim * sizeof(float)); + + std::unique_ptr full_pivot_data; + + if (file_exists(pq_pivots_path)) + { + size_t file_dim, file_num_centers; + diskann::load_bin(pq_pivots_path, full_pivot_data, file_num_centers, file_dim, METADATA_SIZE); + if (file_dim == dim && file_num_centers == num_centers) + { + diskann::cout << "PQ pivot file exists. Not generating again" << std::endl; + return -1; + } } - for (uint64_t d = 0; d < dim; d++) { - for (uint64_t p = 0; p < num_train; p++) { - train_data[p * dim + d] -= centroid[d]; - } + // Calculate centroid and center the training data + std::unique_ptr centroid = std::make_unique(dim); + for (uint64_t d = 0; d < dim; d++) + { + centroid[d] = 0; } - } - - std::vector chunk_offsets; - - size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks); - size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks); - size_t max_num_high = dim - (low_val * num_pq_chunks); - size_t cur_num_high = 0; - size_t cur_bin_threshold = high_val; - - std::vector> bin_to_dims(num_pq_chunks); - tsl::robin_map dim_to_bin; - std::vector bin_loads(num_pq_chunks, 0); - - // Process dimensions not inserted by previous loop - for (uint32_t d = 0; d < dim; d++) { - if (dim_to_bin.find(d) != dim_to_bin.end()) - continue; - auto cur_best = num_pq_chunks + 1; - float cur_best_load = std::numeric_limits::max(); - for (uint32_t b = 0; b < num_pq_chunks; b++) { - if (bin_loads[b] < cur_best_load && - bin_to_dims[b].size() < cur_bin_threshold) { - cur_best = b; - cur_best_load = bin_loads[b]; - } + if (make_zero_mean) + { // If we use L2 distance, there is an option to + // translate all vectors to make them centered and + // then compute PQ. This needs to be set to false + // when using PQ for MIPS as such translations dont + // preserve inner products. + for (uint64_t d = 0; d < dim; d++) + { + for (uint64_t p = 0; p < num_train; p++) + { + centroid[d] += train_data[p * dim + d]; + } + centroid[d] /= num_train; + } + + for (uint64_t d = 0; d < dim; d++) + { + for (uint64_t p = 0; p < num_train; p++) + { + train_data[p * dim + d] -= centroid[d]; + } + } } - bin_to_dims[cur_best].push_back(d); - if (bin_to_dims[cur_best].size() == high_val) { - cur_num_high++; - if (cur_num_high == max_num_high) - cur_bin_threshold = low_val; + + std::vector chunk_offsets; + + size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks); + size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks); + size_t max_num_high = dim - (low_val * num_pq_chunks); + size_t cur_num_high = 0; + size_t cur_bin_threshold = high_val; + + std::vector> bin_to_dims(num_pq_chunks); + tsl::robin_map dim_to_bin; + std::vector bin_loads(num_pq_chunks, 0); + + // Process dimensions not inserted by previous loop + for (uint32_t d = 0; d < dim; d++) + { + if (dim_to_bin.find(d) != dim_to_bin.end()) + continue; + auto cur_best = num_pq_chunks + 1; + float cur_best_load = std::numeric_limits::max(); + for (uint32_t b = 0; b < num_pq_chunks; b++) + { + if (bin_loads[b] < cur_best_load && bin_to_dims[b].size() < cur_bin_threshold) + { + cur_best = b; + cur_best_load = bin_loads[b]; + } + } + bin_to_dims[cur_best].push_back(d); + if (bin_to_dims[cur_best].size() == high_val) + { + cur_num_high++; + if (cur_num_high == max_num_high) + cur_bin_threshold = low_val; + } } - } - chunk_offsets.clear(); - chunk_offsets.push_back(0); + chunk_offsets.clear(); + chunk_offsets.push_back(0); - for (uint32_t b = 0; b < num_pq_chunks; b++) { - if (b > 0) - chunk_offsets.push_back(chunk_offsets[b - 1] + - (uint32_t)bin_to_dims[b - 1].size()); - } - chunk_offsets.push_back(dim); + for (uint32_t b = 0; b < num_pq_chunks; b++) + { + if (b > 0) + chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size()); + } + chunk_offsets.push_back(dim); - full_pivot_data.reset(new float[num_centers * dim]); + full_pivot_data.reset(new float[num_centers * dim]); - for (size_t i = 0; i < num_pq_chunks; i++) { - size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i]; + for (size_t i = 0; i < num_pq_chunks; i++) + { + size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i]; - if (cur_chunk_size == 0) - continue; - std::unique_ptr cur_pivot_data = - std::make_unique(num_centers * cur_chunk_size); - std::unique_ptr cur_data = - std::make_unique(num_train * cur_chunk_size); - std::unique_ptr closest_center = - std::make_unique(num_train); + if (cur_chunk_size == 0) + continue; + std::unique_ptr cur_pivot_data = std::make_unique(num_centers * cur_chunk_size); + std::unique_ptr cur_data = std::make_unique(num_train * cur_chunk_size); + std::unique_ptr closest_center = std::make_unique(num_train); - diskann::cout << "Processing chunk " << i << " with dimensions [" - << chunk_offsets[i] << ", " << chunk_offsets[i + 1] << ")" - << std::endl; + diskann::cout << "Processing chunk " << i << " with dimensions [" << chunk_offsets[i] << ", " + << chunk_offsets[i + 1] << ")" << std::endl; #pragma omp parallel for schedule(static, 65536) - for (int64_t j = 0; j < (int64_t)num_train; j++) { - std::memcpy(cur_data.get() + j * cur_chunk_size, - train_data.get() + j * dim + chunk_offsets[i], - cur_chunk_size * sizeof(float)); - } + for (int64_t j = 0; j < (int64_t)num_train; j++) + { + std::memcpy(cur_data.get() + j * cur_chunk_size, train_data.get() + j * dim + chunk_offsets[i], + cur_chunk_size * sizeof(float)); + } - kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, - cur_pivot_data.get(), num_centers); + kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers); - kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, - cur_pivot_data.get(), num_centers, max_k_means_reps, - NULL, closest_center.get()); + kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers, + max_k_means_reps, NULL, closest_center.get()); - for (uint64_t j = 0; j < num_centers; j++) { - std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i], - cur_pivot_data.get() + j * cur_chunk_size, - cur_chunk_size * sizeof(float)); + for (uint64_t j = 0; j < num_centers; j++) + { + std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i], cur_pivot_data.get() + j * cur_chunk_size, + cur_chunk_size * sizeof(float)); + } } - } - - std::vector cumul_bytes(4, 0); - cumul_bytes[0] = METADATA_SIZE; - cumul_bytes[1] = - cumul_bytes[0] + - diskann::save_bin(pq_pivots_path.c_str(), full_pivot_data.get(), - (size_t)num_centers, dim, cumul_bytes[0]); - cumul_bytes[2] = cumul_bytes[1] + diskann::save_bin( - pq_pivots_path.c_str(), centroid.get(), - (size_t)dim, 1, cumul_bytes[1]); - cumul_bytes[3] = - cumul_bytes[2] + - diskann::save_bin(pq_pivots_path.c_str(), chunk_offsets.data(), - chunk_offsets.size(), 1, cumul_bytes[2]); - diskann::save_bin(pq_pivots_path.c_str(), cumul_bytes.data(), - cumul_bytes.size(), 1, 0); - - diskann::cout << "Saved pq pivot data to " << pq_pivots_path << " of size " - << cumul_bytes[cumul_bytes.size() - 1] << "B." << std::endl; - - return 0; + + std::vector cumul_bytes(4, 0); + cumul_bytes[0] = METADATA_SIZE; + cumul_bytes[1] = cumul_bytes[0] + diskann::save_bin(pq_pivots_path.c_str(), full_pivot_data.get(), + (size_t)num_centers, dim, cumul_bytes[0]); + cumul_bytes[2] = cumul_bytes[1] + + diskann::save_bin(pq_pivots_path.c_str(), centroid.get(), (size_t)dim, 1, cumul_bytes[1]); + cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin(pq_pivots_path.c_str(), chunk_offsets.data(), + chunk_offsets.size(), 1, cumul_bytes[2]); + diskann::save_bin(pq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0); + + diskann::cout << "Saved pq pivot data to " << pq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1] + << "B." << std::endl; + + return 0; } -int generate_opq_pivots(const float *passed_train_data, size_t num_train, - uint32_t dim, uint32_t num_centers, - uint32_t num_pq_chunks, std::string opq_pivots_path, - bool make_zero_mean) { - if (num_pq_chunks > dim) { - diskann::cout << " Error: number of chunks more than dimension" - << std::endl; - return -1; - } - - std::unique_ptr train_data = - std::make_unique(num_train * dim); - std::memcpy(train_data.get(), passed_train_data, - num_train * dim * sizeof(float)); - - std::unique_ptr rotated_train_data = - std::make_unique(num_train * dim); - std::unique_ptr rotated_and_quantized_train_data = - std::make_unique(num_train * dim); - - std::unique_ptr full_pivot_data; - - // rotation matrix for OPQ - std::unique_ptr rotmat_tr; - - // matrices for SVD - std::unique_ptr Umat = std::make_unique(dim * dim); - std::unique_ptr Vmat_T = std::make_unique(dim * dim); - std::unique_ptr singular_values = std::make_unique(dim); - std::unique_ptr correlation_matrix = - std::make_unique(dim * dim); - - // Calculate centroid and center the training data - std::unique_ptr centroid = std::make_unique(dim); - for (uint64_t d = 0; d < dim; d++) { - centroid[d] = 0; - } - if (make_zero_mean) { // If we use L2 distance, there is an option to - // translate all vectors to make them centered and - // then compute PQ. This needs to be set to false - // when using PQ for MIPS as such translations dont - // preserve inner products. - for (uint64_t d = 0; d < dim; d++) { - for (uint64_t p = 0; p < num_train; p++) { - centroid[d] += train_data[p * dim + d]; - } - centroid[d] /= num_train; +int generate_opq_pivots(const float *passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers, + uint32_t num_pq_chunks, std::string opq_pivots_path, bool make_zero_mean) +{ + if (num_pq_chunks > dim) + { + diskann::cout << " Error: number of chunks more than dimension" << std::endl; + return -1; + } + + std::unique_ptr train_data = std::make_unique(num_train * dim); + std::memcpy(train_data.get(), passed_train_data, num_train * dim * sizeof(float)); + + std::unique_ptr rotated_train_data = std::make_unique(num_train * dim); + std::unique_ptr rotated_and_quantized_train_data = std::make_unique(num_train * dim); + + std::unique_ptr full_pivot_data; + + // rotation matrix for OPQ + std::unique_ptr rotmat_tr; + + // matrices for SVD + std::unique_ptr Umat = std::make_unique(dim * dim); + std::unique_ptr Vmat_T = std::make_unique(dim * dim); + std::unique_ptr singular_values = std::make_unique(dim); + std::unique_ptr correlation_matrix = std::make_unique(dim * dim); + + // Calculate centroid and center the training data + std::unique_ptr centroid = std::make_unique(dim); + for (uint64_t d = 0; d < dim; d++) + { + centroid[d] = 0; } - for (uint64_t d = 0; d < dim; d++) { - for (uint64_t p = 0; p < num_train; p++) { - train_data[p * dim + d] -= centroid[d]; - } + if (make_zero_mean) + { // If we use L2 distance, there is an option to + // translate all vectors to make them centered and + // then compute PQ. This needs to be set to false + // when using PQ for MIPS as such translations dont + // preserve inner products. + for (uint64_t d = 0; d < dim; d++) + { + for (uint64_t p = 0; p < num_train; p++) + { + centroid[d] += train_data[p * dim + d]; + } + centroid[d] /= num_train; + } + for (uint64_t d = 0; d < dim; d++) + { + for (uint64_t p = 0; p < num_train; p++) + { + train_data[p * dim + d] -= centroid[d]; + } + } } - } - - std::vector chunk_offsets; - - size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks); - size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks); - size_t max_num_high = dim - (low_val * num_pq_chunks); - size_t cur_num_high = 0; - size_t cur_bin_threshold = high_val; - - std::vector> bin_to_dims(num_pq_chunks); - tsl::robin_map dim_to_bin; - std::vector bin_loads(num_pq_chunks, 0); - - // Process dimensions not inserted by previous loop - for (uint32_t d = 0; d < dim; d++) { - if (dim_to_bin.find(d) != dim_to_bin.end()) - continue; - auto cur_best = num_pq_chunks + 1; - float cur_best_load = std::numeric_limits::max(); - for (uint32_t b = 0; b < num_pq_chunks; b++) { - if (bin_loads[b] < cur_best_load && - bin_to_dims[b].size() < cur_bin_threshold) { - cur_best = b; - cur_best_load = bin_loads[b]; - } + + std::vector chunk_offsets; + + size_t low_val = (size_t)std::floor((double)dim / (double)num_pq_chunks); + size_t high_val = (size_t)std::ceil((double)dim / (double)num_pq_chunks); + size_t max_num_high = dim - (low_val * num_pq_chunks); + size_t cur_num_high = 0; + size_t cur_bin_threshold = high_val; + + std::vector> bin_to_dims(num_pq_chunks); + tsl::robin_map dim_to_bin; + std::vector bin_loads(num_pq_chunks, 0); + + // Process dimensions not inserted by previous loop + for (uint32_t d = 0; d < dim; d++) + { + if (dim_to_bin.find(d) != dim_to_bin.end()) + continue; + auto cur_best = num_pq_chunks + 1; + float cur_best_load = std::numeric_limits::max(); + for (uint32_t b = 0; b < num_pq_chunks; b++) + { + if (bin_loads[b] < cur_best_load && bin_to_dims[b].size() < cur_bin_threshold) + { + cur_best = b; + cur_best_load = bin_loads[b]; + } + } + bin_to_dims[cur_best].push_back(d); + if (bin_to_dims[cur_best].size() == high_val) + { + cur_num_high++; + if (cur_num_high == max_num_high) + cur_bin_threshold = low_val; + } } - bin_to_dims[cur_best].push_back(d); - if (bin_to_dims[cur_best].size() == high_val) { - cur_num_high++; - if (cur_num_high == max_num_high) - cur_bin_threshold = low_val; + + chunk_offsets.clear(); + chunk_offsets.push_back(0); + + for (uint32_t b = 0; b < num_pq_chunks; b++) + { + if (b > 0) + chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size()); } - } - - chunk_offsets.clear(); - chunk_offsets.push_back(0); - - for (uint32_t b = 0; b < num_pq_chunks; b++) { - if (b > 0) - chunk_offsets.push_back(chunk_offsets[b - 1] + - (uint32_t)bin_to_dims[b - 1].size()); - } - chunk_offsets.push_back(dim); - - full_pivot_data.reset(new float[num_centers * dim]); - rotmat_tr.reset(new float[dim * dim]); - - std::memset(rotmat_tr.get(), 0, dim * dim * sizeof(float)); - for (uint32_t d1 = 0; d1 < dim; d1++) - *(rotmat_tr.get() + d1 * dim + d1) = 1; - - for (uint32_t rnd = 0; rnd < MAX_OPQ_ITERS; rnd++) { - // rotate the training data using the current rotation matrix - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)num_train, - (MKL_INT)dim, (MKL_INT)dim, 1.0f, train_data.get(), - (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f, - rotated_train_data.get(), (MKL_INT)dim); - - // compute the PQ pivots on the rotated space - for (size_t i = 0; i < num_pq_chunks; i++) { - size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i]; - - if (cur_chunk_size == 0) - continue; - std::unique_ptr cur_pivot_data = - std::make_unique(num_centers * cur_chunk_size); - std::unique_ptr cur_data = - std::make_unique(num_train * cur_chunk_size); - std::unique_ptr closest_center = - std::make_unique(num_train); - - diskann::cout << "Processing chunk " << i << " with dimensions [" - << chunk_offsets[i] << ", " << chunk_offsets[i + 1] << ")" - << std::endl; + chunk_offsets.push_back(dim); + + full_pivot_data.reset(new float[num_centers * dim]); + rotmat_tr.reset(new float[dim * dim]); + + std::memset(rotmat_tr.get(), 0, dim * dim * sizeof(float)); + for (uint32_t d1 = 0; d1 < dim; d1++) + *(rotmat_tr.get() + d1 * dim + d1) = 1; + + for (uint32_t rnd = 0; rnd < MAX_OPQ_ITERS; rnd++) + { + // rotate the training data using the current rotation matrix + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)num_train, (MKL_INT)dim, (MKL_INT)dim, 1.0f, + train_data.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f, rotated_train_data.get(), + (MKL_INT)dim); + + // compute the PQ pivots on the rotated space + for (size_t i = 0; i < num_pq_chunks; i++) + { + size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i]; + + if (cur_chunk_size == 0) + continue; + std::unique_ptr cur_pivot_data = std::make_unique(num_centers * cur_chunk_size); + std::unique_ptr cur_data = std::make_unique(num_train * cur_chunk_size); + std::unique_ptr closest_center = std::make_unique(num_train); + + diskann::cout << "Processing chunk " << i << " with dimensions [" << chunk_offsets[i] << ", " + << chunk_offsets[i + 1] << ")" << std::endl; #pragma omp parallel for schedule(static, 65536) - for (int64_t j = 0; j < (int64_t)num_train; j++) { - std::memcpy(cur_data.get() + j * cur_chunk_size, - rotated_train_data.get() + j * dim + chunk_offsets[i], - cur_chunk_size * sizeof(float)); - } - - if (rnd == 0) { - kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, - cur_chunk_size, cur_pivot_data.get(), - num_centers); - } else { - for (uint64_t j = 0; j < num_centers; j++) { - std::memcpy(cur_pivot_data.get() + j * cur_chunk_size, - full_pivot_data.get() + j * dim + chunk_offsets[i], - cur_chunk_size * sizeof(float)); + for (int64_t j = 0; j < (int64_t)num_train; j++) + { + std::memcpy(cur_data.get() + j * cur_chunk_size, rotated_train_data.get() + j * dim + chunk_offsets[i], + cur_chunk_size * sizeof(float)); + } + + if (rnd == 0) + { + kmeans::kmeanspp_selecting_pivots(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), + num_centers); + } + else + { + for (uint64_t j = 0; j < num_centers; j++) + { + std::memcpy(cur_pivot_data.get() + j * cur_chunk_size, + full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float)); + } + } + + uint32_t num_lloyds_iters = 8; + kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers, + num_lloyds_iters, NULL, closest_center.get()); + + for (uint64_t j = 0; j < num_centers; j++) + { + std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i], + cur_pivot_data.get() + j * cur_chunk_size, cur_chunk_size * sizeof(float)); + } + + for (size_t j = 0; j < num_train; j++) + { + std::memcpy(rotated_and_quantized_train_data.get() + j * dim + chunk_offsets[i], + cur_pivot_data.get() + (size_t)closest_center[j] * cur_chunk_size, + cur_chunk_size * sizeof(float)); + } + } + + // compute the correlation matrix between the original data and the + // quantized data to compute the new rotation + cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, (MKL_INT)dim, (MKL_INT)dim, (MKL_INT)num_train, 1.0f, + train_data.get(), (MKL_INT)dim, rotated_and_quantized_train_data.get(), (MKL_INT)dim, 0.0f, + correlation_matrix.get(), (MKL_INT)dim); + + // compute the SVD of the correlation matrix to help determine the new + // rotation matrix + uint32_t errcode = (uint32_t)LAPACKE_sgesdd(LAPACK_ROW_MAJOR, 'A', (MKL_INT)dim, (MKL_INT)dim, + correlation_matrix.get(), (MKL_INT)dim, singular_values.get(), + Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim); + + if (errcode > 0) + { + std::cout << "SVD failed to converge." << std::endl; + exit(-1); } - } - - uint32_t num_lloyds_iters = 8; - kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, - cur_pivot_data.get(), num_centers, num_lloyds_iters, - NULL, closest_center.get()); - - for (uint64_t j = 0; j < num_centers; j++) { - std::memcpy(full_pivot_data.get() + j * dim + chunk_offsets[i], - cur_pivot_data.get() + j * cur_chunk_size, - cur_chunk_size * sizeof(float)); - } - - for (size_t j = 0; j < num_train; j++) { - std::memcpy( - rotated_and_quantized_train_data.get() + j * dim + chunk_offsets[i], - cur_pivot_data.get() + (size_t)closest_center[j] * cur_chunk_size, - cur_chunk_size * sizeof(float)); - } - } - // compute the correlation matrix between the original data and the - // quantized data to compute the new rotation - cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, (MKL_INT)dim, - (MKL_INT)dim, (MKL_INT)num_train, 1.0f, train_data.get(), - (MKL_INT)dim, rotated_and_quantized_train_data.get(), - (MKL_INT)dim, 0.0f, correlation_matrix.get(), (MKL_INT)dim); - - // compute the SVD of the correlation matrix to help determine the new - // rotation matrix - uint32_t errcode = (uint32_t)LAPACKE_sgesdd( - LAPACK_ROW_MAJOR, 'A', (MKL_INT)dim, (MKL_INT)dim, - correlation_matrix.get(), (MKL_INT)dim, singular_values.get(), - Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim); - - if (errcode > 0) { - std::cout << "SVD failed to converge." << std::endl; - exit(-1); + // compute the new rotation matrix from the singular vectors as R^T = U + // V^T + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)dim, (MKL_INT)dim, (MKL_INT)dim, 1.0f, + Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim, 0.0f, rotmat_tr.get(), (MKL_INT)dim); } - // compute the new rotation matrix from the singular vectors as R^T = U - // V^T - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)dim, - (MKL_INT)dim, (MKL_INT)dim, 1.0f, Umat.get(), (MKL_INT)dim, - Vmat_T.get(), (MKL_INT)dim, 0.0f, rotmat_tr.get(), - (MKL_INT)dim); - } - - std::vector cumul_bytes(4, 0); - cumul_bytes[0] = METADATA_SIZE; - cumul_bytes[1] = - cumul_bytes[0] + - diskann::save_bin(opq_pivots_path.c_str(), full_pivot_data.get(), - (size_t)num_centers, dim, cumul_bytes[0]); - cumul_bytes[2] = cumul_bytes[1] + diskann::save_bin( - opq_pivots_path.c_str(), centroid.get(), - (size_t)dim, 1, cumul_bytes[1]); - cumul_bytes[3] = - cumul_bytes[2] + - diskann::save_bin(opq_pivots_path.c_str(), chunk_offsets.data(), - chunk_offsets.size(), 1, cumul_bytes[2]); - diskann::save_bin(opq_pivots_path.c_str(), cumul_bytes.data(), - cumul_bytes.size(), 1, 0); - - diskann::cout << "Saved opq pivot data to " << opq_pivots_path << " of size " - << cumul_bytes[cumul_bytes.size() - 1] << "B." << std::endl; - - std::string rotmat_path = opq_pivots_path + "_rotation_matrix.bin"; - diskann::save_bin(rotmat_path.c_str(), rotmat_tr.get(), dim, dim); - - return 0; + std::vector cumul_bytes(4, 0); + cumul_bytes[0] = METADATA_SIZE; + cumul_bytes[1] = cumul_bytes[0] + diskann::save_bin(opq_pivots_path.c_str(), full_pivot_data.get(), + (size_t)num_centers, dim, cumul_bytes[0]); + cumul_bytes[2] = cumul_bytes[1] + + diskann::save_bin(opq_pivots_path.c_str(), centroid.get(), (size_t)dim, 1, cumul_bytes[1]); + cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin(opq_pivots_path.c_str(), chunk_offsets.data(), + chunk_offsets.size(), 1, cumul_bytes[2]); + diskann::save_bin(opq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0); + + diskann::cout << "Saved opq pivot data to " << opq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1] + << "B." << std::endl; + + std::string rotmat_path = opq_pivots_path + "_rotation_matrix.bin"; + diskann::save_bin(rotmat_path.c_str(), rotmat_tr.get(), dim, dim); + + return 0; } // generate_pq_data_from_pivots_simplified is a simplified version of @@ -775,59 +785,61 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, // The compiler pragma for multi-threading support is removed from this // implementation for the purpose of integration into systems that strictly // control resource allocation. -int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, - const float *pivot_data, - const size_t pivots_num, - const size_t dim, - const size_t num_pq_chunks, - std::vector &pq) { - if (num_pq_chunks == 0 || num_pq_chunks > dim || dim % num_pq_chunks != 0) { - return -1; - } - - const size_t num_centers = 256; - const size_t chunk_size = dim / num_pq_chunks; - - if (pivots_num != num_centers * dim) { - return -1; - } - - pq.resize(num * num_pq_chunks); - - std::vector cur_pivot_vector(num_centers * chunk_size); - std::vector cur_data_vector(num * chunk_size); - std::vector closest_center_vector(num); - - float *cur_pivot_data = &cur_pivot_vector[0]; - float *cur_data = &cur_data_vector[0]; - uint32_t *closest_center = &closest_center_vector[0]; - - for (size_t i = 0; i < num_pq_chunks; i++) { - const size_t chunk_offset = chunk_size * i; - - for (int j = 0; j < num_centers; j++) { - std::memcpy(cur_pivot_data + j * chunk_size, - pivot_data + j * dim + chunk_offset, - chunk_size * sizeof(float)); +int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, const float *pivot_data, + const size_t pivots_num, const size_t dim, const size_t num_pq_chunks, + std::vector &pq) +{ + if (num_pq_chunks == 0 || num_pq_chunks > dim || dim % num_pq_chunks != 0) + { + return -1; } - for (int j = 0; j < num; j++) { - for (size_t k = 0; k < chunk_size; k++) { - cur_data[j * chunk_size + k] = data[j * dim + chunk_offset + k]; - } + const size_t num_centers = 256; + const size_t chunk_size = dim / num_pq_chunks; + + if (pivots_num != num_centers * dim) + { + return -1; } - math_utils::compute_closest_centers(cur_data, num, chunk_size, - cur_pivot_data, num_centers, 1, - closest_center); + pq.resize(num * num_pq_chunks); + + std::vector cur_pivot_vector(num_centers * chunk_size); + std::vector cur_data_vector(num * chunk_size); + std::vector closest_center_vector(num); - for (int j = 0; j < num; j++) { - assert(closest_center[j] < num_centers); - pq[j * num_pq_chunks + i] = closest_center[j]; + float *cur_pivot_data = &cur_pivot_vector[0]; + float *cur_data = &cur_data_vector[0]; + uint32_t *closest_center = &closest_center_vector[0]; + + for (size_t i = 0; i < num_pq_chunks; i++) + { + const size_t chunk_offset = chunk_size * i; + + for (int j = 0; j < num_centers; j++) + { + std::memcpy(cur_pivot_data + j * chunk_size, pivot_data + j * dim + chunk_offset, + chunk_size * sizeof(float)); + } + + for (int j = 0; j < num; j++) + { + for (size_t k = 0; k < chunk_size; k++) + { + cur_data[j * chunk_size + k] = data[j * dim + chunk_offset + k]; + } + } + + math_utils::compute_closest_centers(cur_data, num, chunk_size, cur_pivot_data, num_centers, 1, closest_center); + + for (int j = 0; j < num; j++) + { + assert(closest_center[j] < num_centers); + pq[j * num_pq_chunks + i] = closest_center[j]; + } } - } - return 0; + return 0; } // streams the base file (data_file), and computes the closest centers in each @@ -836,361 +848,344 @@ int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, // If the numbber of centers is < 256, it stores as byte vector, else as // 4-byte vector in binary format. template -int generate_pq_data_from_pivots(const std::string &data_file, - uint32_t num_centers, uint32_t num_pq_chunks, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, - bool use_opq) { - size_t read_blk_size = 64 * 1024 * 1024; - cached_ifstream base_reader(data_file, read_blk_size); - uint32_t npts32; - uint32_t basedim32; - base_reader.read((char *)&npts32, sizeof(uint32_t)); - base_reader.read((char *)&basedim32, sizeof(uint32_t)); - size_t num_points = npts32; - size_t dim = basedim32; - - std::unique_ptr full_pivot_data; - std::unique_ptr rotmat_tr; - std::unique_ptr centroid; - std::unique_ptr chunk_offsets; - - std::string inflated_pq_file = pq_compressed_vectors_path + "_inflated.bin"; - - if (!file_exists(pq_pivots_path)) { - std::cout << "ERROR: PQ k-means pivot file not found" << std::endl; - throw diskann::ANNException("PQ k-means pivot file not found", -1); - } else { - size_t nr, nc; - std::unique_ptr file_offset_data; +int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks, + const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, + bool use_opq) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream base_reader(data_file, read_blk_size); + uint32_t npts32; + uint32_t basedim32; + base_reader.read((char *)&npts32, sizeof(uint32_t)); + base_reader.read((char *)&basedim32, sizeof(uint32_t)); + size_t num_points = npts32; + size_t dim = basedim32; + + std::unique_ptr full_pivot_data; + std::unique_ptr rotmat_tr; + std::unique_ptr centroid; + std::unique_ptr chunk_offsets; + + std::string inflated_pq_file = pq_compressed_vectors_path + "_inflated.bin"; + + if (!file_exists(pq_pivots_path)) + { + std::cout << "ERROR: PQ k-means pivot file not found" << std::endl; + throw diskann::ANNException("PQ k-means pivot file not found", -1); + } + else + { + size_t nr, nc; + std::unique_ptr file_offset_data; + + diskann::load_bin(pq_pivots_path.c_str(), file_offset_data, nr, nc, 0); + + if (nr != 4) + { + diskann::cout << "Error reading pq_pivots file " << pq_pivots_path + << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting 4."; + throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - diskann::load_bin(pq_pivots_path.c_str(), file_offset_data, nr, nc, - 0); + diskann::load_bin(pq_pivots_path.c_str(), full_pivot_data, nr, nc, file_offset_data[0]); - if (nr != 4) { - diskann::cout << "Error reading pq_pivots file " << pq_pivots_path - << ". Offsets dont contain correct metadata, # offsets = " - << nr << ", but expecting 4."; - throw diskann::ANNException( - "Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } + if ((nr != num_centers) || (nc != dim)) + { + diskann::cout << "Error reading pq_pivots file " << pq_pivots_path << ". file_num_centers = " << nr + << ", file_dim = " << nc << " but expecting " << num_centers << " centers in " << dim + << " dimensions."; + throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - diskann::load_bin(pq_pivots_path.c_str(), full_pivot_data, nr, nc, - file_offset_data[0]); - - if ((nr != num_centers) || (nc != dim)) { - diskann::cout << "Error reading pq_pivots file " << pq_pivots_path - << ". file_num_centers = " << nr << ", file_dim = " << nc - << " but expecting " << num_centers << " centers in " << dim - << " dimensions."; - throw diskann::ANNException( - "Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } + diskann::load_bin(pq_pivots_path.c_str(), centroid, nr, nc, file_offset_data[1]); - diskann::load_bin(pq_pivots_path.c_str(), centroid, nr, nc, - file_offset_data[1]); + if ((nr != dim) || (nc != 1)) + { + diskann::cout << "Error reading pq_pivots file " << pq_pivots_path << ". file_dim = " << nr + << ", file_cols = " << nc << " but expecting " << dim << " entries in 1 dimension."; + throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - if ((nr != dim) || (nc != 1)) { - diskann::cout << "Error reading pq_pivots file " << pq_pivots_path - << ". file_dim = " << nr << ", file_cols = " << nc - << " but expecting " << dim << " entries in 1 dimension."; - throw diskann::ANNException( - "Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } + diskann::load_bin(pq_pivots_path.c_str(), chunk_offsets, nr, nc, file_offset_data[2]); - diskann::load_bin(pq_pivots_path.c_str(), chunk_offsets, nr, nc, - file_offset_data[2]); - - if (nr != (uint64_t)num_pq_chunks + 1 || nc != 1) { - diskann::cout - << "Error reading pq_pivots file at chunk offsets; file has nr=" << nr - << ",nc=" << nc << ", expecting nr=" << num_pq_chunks + 1 << ", nc=1." - << std::endl; - throw diskann::ANNException( - "Error reading pq_pivots file at chunk offsets.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } + if (nr != (uint64_t)num_pq_chunks + 1 || nc != 1) + { + diskann::cout << "Error reading pq_pivots file at chunk offsets; file has nr=" << nr << ",nc=" << nc + << ", expecting nr=" << num_pq_chunks + 1 << ", nc=1." << std::endl; + throw diskann::ANNException("Error reading pq_pivots file at chunk offsets.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - if (use_opq) { - std::string rotmat_path = pq_pivots_path + "_rotation_matrix.bin"; - diskann::load_bin(rotmat_path.c_str(), rotmat_tr, nr, nc); - if (nr != (uint64_t)dim || nc != dim) { - diskann::cout << "Error reading rotation matrix file." << std::endl; - throw diskann::ANNException("Error reading rotation matrix file.", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - } + if (use_opq) + { + std::string rotmat_path = pq_pivots_path + "_rotation_matrix.bin"; + diskann::load_bin(rotmat_path.c_str(), rotmat_tr, nr, nc); + if (nr != (uint64_t)dim || nc != dim) + { + diskann::cout << "Error reading rotation matrix file." << std::endl; + throw diskann::ANNException("Error reading rotation matrix file.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + } - diskann::cout << "Loaded PQ pivot information" << std::endl; - } + diskann::cout << "Loaded PQ pivot information" << std::endl; + } - std::ofstream compressed_file_writer(pq_compressed_vectors_path, - std::ios::binary); - uint32_t num_pq_chunks_u32 = num_pq_chunks; + std::ofstream compressed_file_writer(pq_compressed_vectors_path, std::ios::binary); + uint32_t num_pq_chunks_u32 = num_pq_chunks; - compressed_file_writer.write((char *)&num_points, sizeof(uint32_t)); - compressed_file_writer.write((char *)&num_pq_chunks_u32, sizeof(uint32_t)); + compressed_file_writer.write((char *)&num_points, sizeof(uint32_t)); + compressed_file_writer.write((char *)&num_pq_chunks_u32, sizeof(uint32_t)); - size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; + size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; #ifdef SAVE_INFLATED_PQ - std::ofstream inflated_file_writer(inflated_pq_file, std::ios::binary); - inflated_file_writer.write((char *)&num_points, sizeof(uint32_t)); - inflated_file_writer.write((char *)&basedim32, sizeof(uint32_t)); + std::ofstream inflated_file_writer(inflated_pq_file, std::ios::binary); + inflated_file_writer.write((char *)&num_points, sizeof(uint32_t)); + inflated_file_writer.write((char *)&basedim32, sizeof(uint32_t)); - std::unique_ptr block_inflated_base = - std::make_unique(block_size * dim); - std::memset(block_inflated_base.get(), 0, block_size * dim * sizeof(float)); + std::unique_ptr block_inflated_base = std::make_unique(block_size * dim); + std::memset(block_inflated_base.get(), 0, block_size * dim * sizeof(float)); #endif - std::unique_ptr block_compressed_base = - std::make_unique(block_size * (size_t)num_pq_chunks); - std::memset(block_compressed_base.get(), 0, - block_size * (size_t)num_pq_chunks * sizeof(uint32_t)); + std::unique_ptr block_compressed_base = + std::make_unique(block_size * (size_t)num_pq_chunks); + std::memset(block_compressed_base.get(), 0, block_size * (size_t)num_pq_chunks * sizeof(uint32_t)); - std::unique_ptr block_data_T = std::make_unique(block_size * dim); - std::unique_ptr block_data_float = - std::make_unique(block_size * dim); - std::unique_ptr block_data_tmp = - std::make_unique(block_size * dim); + std::unique_ptr block_data_T = std::make_unique(block_size * dim); + std::unique_ptr block_data_float = std::make_unique(block_size * dim); + std::unique_ptr block_data_tmp = std::make_unique(block_size * dim); - size_t num_blocks = DIV_ROUND_UP(num_points, block_size); + size_t num_blocks = DIV_ROUND_UP(num_points, block_size); - for (size_t block = 0; block < num_blocks; block++) { - size_t start_id = block * block_size; - size_t end_id = (std::min)((block + 1) * block_size, num_points); - size_t cur_blk_size = end_id - start_id; + for (size_t block = 0; block < num_blocks; block++) + { + size_t start_id = block * block_size; + size_t end_id = (std::min)((block + 1) * block_size, num_points); + size_t cur_blk_size = end_id - start_id; - base_reader.read((char *)(block_data_T.get()), - sizeof(T) * (cur_blk_size * dim)); - diskann::convert_types(block_data_T.get(), block_data_tmp.get(), - cur_blk_size, dim); + base_reader.read((char *)(block_data_T.get()), sizeof(T) * (cur_blk_size * dim)); + diskann::convert_types(block_data_T.get(), block_data_tmp.get(), cur_blk_size, dim); - diskann::cout << "Processing points [" << start_id << ", " << end_id - << ").." << std::flush; + diskann::cout << "Processing points [" << start_id << ", " << end_id << ").." << std::flush; - for (size_t p = 0; p < cur_blk_size; p++) { - for (uint64_t d = 0; d < dim; d++) { - block_data_tmp[p * dim + d] -= centroid[d]; - } - } + for (size_t p = 0; p < cur_blk_size; p++) + { + for (uint64_t d = 0; d < dim; d++) + { + block_data_tmp[p * dim + d] -= centroid[d]; + } + } - for (size_t p = 0; p < cur_blk_size; p++) { - for (uint64_t d = 0; d < dim; d++) { - block_data_float[p * dim + d] = block_data_tmp[p * dim + d]; - } - } + for (size_t p = 0; p < cur_blk_size; p++) + { + for (uint64_t d = 0; d < dim; d++) + { + block_data_float[p * dim + d] = block_data_tmp[p * dim + d]; + } + } - if (use_opq) { - // rotate the current block with the trained rotation matrix before - // PQ - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - (MKL_INT)cur_blk_size, (MKL_INT)dim, (MKL_INT)dim, 1.0f, - block_data_float.get(), (MKL_INT)dim, rotmat_tr.get(), - (MKL_INT)dim, 0.0f, block_data_tmp.get(), (MKL_INT)dim); - std::memcpy(block_data_float.get(), block_data_tmp.get(), - cur_blk_size * dim * sizeof(float)); - } + if (use_opq) + { + // rotate the current block with the trained rotation matrix before + // PQ + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)cur_blk_size, (MKL_INT)dim, (MKL_INT)dim, + 1.0f, block_data_float.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f, + block_data_tmp.get(), (MKL_INT)dim); + std::memcpy(block_data_float.get(), block_data_tmp.get(), cur_blk_size * dim * sizeof(float)); + } - for (size_t i = 0; i < num_pq_chunks; i++) { - size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i]; - if (cur_chunk_size == 0) - continue; + for (size_t i = 0; i < num_pq_chunks; i++) + { + size_t cur_chunk_size = chunk_offsets[i + 1] - chunk_offsets[i]; + if (cur_chunk_size == 0) + continue; - std::unique_ptr cur_pivot_data = - std::make_unique(num_centers * cur_chunk_size); - std::unique_ptr cur_data = - std::make_unique(cur_blk_size * cur_chunk_size); - std::unique_ptr closest_center = - std::make_unique(cur_blk_size); + std::unique_ptr cur_pivot_data = std::make_unique(num_centers * cur_chunk_size); + std::unique_ptr cur_data = std::make_unique(cur_blk_size * cur_chunk_size); + std::unique_ptr closest_center = std::make_unique(cur_blk_size); #pragma omp parallel for schedule(static, 8192) - for (int64_t j = 0; j < (int64_t)cur_blk_size; j++) { - for (size_t k = 0; k < cur_chunk_size; k++) - cur_data[j * cur_chunk_size + k] = - block_data_float[j * dim + chunk_offsets[i] + k]; - } + for (int64_t j = 0; j < (int64_t)cur_blk_size; j++) + { + for (size_t k = 0; k < cur_chunk_size; k++) + cur_data[j * cur_chunk_size + k] = block_data_float[j * dim + chunk_offsets[i] + k]; + } #pragma omp parallel for schedule(static, 1) - for (int64_t j = 0; j < (int64_t)num_centers; j++) { - std::memcpy(cur_pivot_data.get() + j * cur_chunk_size, - full_pivot_data.get() + j * dim + chunk_offsets[i], - cur_chunk_size * sizeof(float)); - } + for (int64_t j = 0; j < (int64_t)num_centers; j++) + { + std::memcpy(cur_pivot_data.get() + j * cur_chunk_size, + full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float)); + } - math_utils::compute_closest_centers(cur_data.get(), cur_blk_size, - cur_chunk_size, cur_pivot_data.get(), - num_centers, 1, closest_center.get()); + math_utils::compute_closest_centers(cur_data.get(), cur_blk_size, cur_chunk_size, cur_pivot_data.get(), + num_centers, 1, closest_center.get()); #pragma omp parallel for schedule(static, 8192) - for (int64_t j = 0; j < (int64_t)cur_blk_size; j++) { - block_compressed_base[j * num_pq_chunks + i] = closest_center[j]; + for (int64_t j = 0; j < (int64_t)cur_blk_size; j++) + { + block_compressed_base[j * num_pq_chunks + i] = closest_center[j]; #ifdef SAVE_INFLATED_PQ - for (size_t k = 0; k < cur_chunk_size; k++) - block_inflated_base[j * dim + chunk_offsets[i] + k] = - cur_pivot_data[closest_center[j] * cur_chunk_size + k] + - centroid[chunk_offsets[i] + k]; + for (size_t k = 0; k < cur_chunk_size; k++) + block_inflated_base[j * dim + chunk_offsets[i] + k] = + cur_pivot_data[closest_center[j] * cur_chunk_size + k] + centroid[chunk_offsets[i] + k]; #endif - } - } + } + } - if (num_centers > 256) { - compressed_file_writer.write((char *)(block_compressed_base.get()), - cur_blk_size * num_pq_chunks * - sizeof(uint32_t)); - } else { - std::unique_ptr pVec = - std::make_unique(cur_blk_size * num_pq_chunks); - diskann::convert_types( - block_compressed_base.get(), pVec.get(), cur_blk_size, num_pq_chunks); - compressed_file_writer.write( - (char *)(pVec.get()), cur_blk_size * num_pq_chunks * sizeof(uint8_t)); - } + if (num_centers > 256) + { + compressed_file_writer.write((char *)(block_compressed_base.get()), + cur_blk_size * num_pq_chunks * sizeof(uint32_t)); + } + else + { + std::unique_ptr pVec = std::make_unique(cur_blk_size * num_pq_chunks); + diskann::convert_types(block_compressed_base.get(), pVec.get(), cur_blk_size, + num_pq_chunks); + compressed_file_writer.write((char *)(pVec.get()), cur_blk_size * num_pq_chunks * sizeof(uint8_t)); + } #ifdef SAVE_INFLATED_PQ - inflated_file_writer.write((char *)(block_inflated_base.get()), - cur_blk_size * dim * sizeof(float)); + inflated_file_writer.write((char *)(block_inflated_base.get()), cur_blk_size * dim * sizeof(float)); #endif - diskann::cout << ".done." << std::endl; - } + diskann::cout << ".done." << std::endl; + } // Gopal. Splitting diskann_dll into separate DLLs for search and build. // This code should only be available in the "build" DLL. -#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && \ - defined(DISKANN_BUILD) - MallocExtension::instance()->ReleaseFreeMemory(); +#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) + MallocExtension::instance()->ReleaseFreeMemory(); #endif - compressed_file_writer.close(); + compressed_file_writer.close(); #ifdef SAVE_INFLATED_PQ - inflated_file_writer.close(); + inflated_file_writer.close(); #endif - return 0; + return 0; } template -void generate_disk_quantized_data( - const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, - const std::string &disk_pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims) { - size_t train_size, train_dim; - float *train_data; - - // instantiates train_data with random sample updates train_size - gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, - train_dim); - diskann::cout << "Training data with " << train_size << " samples loaded." - << std::endl; - - if (disk_pq_dims > train_dim) - disk_pq_dims = train_dim; - - std::cout << "Compressing base for disk-PQ into " << disk_pq_dims - << " chunks " << std::endl; - generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, 256, - (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ, - disk_pq_pivots_path, false); - if (compareMetric == diskann::Metric::INNER_PRODUCT) - generate_pq_data_from_pivots( - data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, - disk_pq_compressed_vectors_path); - else - generate_pq_data_from_pivots(data_file_to_use, 256, - (uint32_t)disk_pq_dims, disk_pq_pivots_path, - disk_pq_compressed_vectors_path); - - delete[] train_data; -} +void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, + const double p_val, size_t &disk_pq_dims) +{ + size_t train_size, train_dim; + float *train_data; -template -void generate_quantized_data(const std::string &data_file_to_use, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq, - const std::string &codebook_prefix) { - size_t train_size, train_dim; - float *train_data; - if (!file_exists(codebook_prefix)) { // instantiates train_data with random sample updates train_size - gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, - train_dim); - diskann::cout << "Training data with " << train_size << " samples loaded." - << std::endl; + gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, train_dim); + diskann::cout << "Training data with " << train_size << " samples loaded." << std::endl; + + if (disk_pq_dims > train_dim) + disk_pq_dims = train_dim; - bool make_zero_mean = true; + std::cout << "Compressing base for disk-PQ into " << disk_pq_dims << " chunks " << std::endl; + generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ, + disk_pq_pivots_path, false); if (compareMetric == diskann::Metric::INNER_PRODUCT) - make_zero_mean = false; - if (use_opq) // we also do not center the data for OPQ - make_zero_mean = false; - - if (!use_opq) { - generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, - NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, - NUM_KMEANS_REPS_PQ, pq_pivots_path, make_zero_mean); - } else { - generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, - NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, - pq_pivots_path, make_zero_mean); - } + generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, + disk_pq_compressed_vectors_path); + else + generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, + disk_pq_compressed_vectors_path); + delete[] train_data; - } else { - diskann::cout << "Skip Training with predefined pivots in: " - << pq_pivots_path << std::endl; - } - generate_pq_data_from_pivots(data_file_to_use, NUM_PQ_CENTROIDS, - (uint32_t)num_pq_chunks, pq_pivots_path, - pq_compressed_vectors_path, use_opq); +} + +template +void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, diskann::Metric compareMetric, + const double p_val, const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix) +{ + size_t train_size, train_dim; + float *train_data; + if (!file_exists(codebook_prefix)) + { + // instantiates train_data with random sample updates train_size + gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, train_dim); + diskann::cout << "Training data with " << train_size << " samples loaded." << std::endl; + + bool make_zero_mean = true; + if (compareMetric == diskann::Metric::INNER_PRODUCT) + make_zero_mean = false; + if (use_opq) // we also do not center the data for OPQ + make_zero_mean = false; + + if (!use_opq) + { + generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, + NUM_KMEANS_REPS_PQ, pq_pivots_path, make_zero_mean); + } + else + { + generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, + pq_pivots_path, make_zero_mean); + } + delete[] train_data; + } + else + { + diskann::cout << "Skip Training with predefined pivots in: " << pq_pivots_path << std::endl; + } + generate_pq_data_from_pivots(data_file_to_use, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, pq_pivots_path, + pq_compressed_vectors_path, use_opq); } // Instantations of supported templates -template DISKANN_DLLEXPORT int generate_pq_data_from_pivots( - const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, bool use_opq); -template DISKANN_DLLEXPORT int generate_pq_data_from_pivots( - const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, bool use_opq); -template DISKANN_DLLEXPORT int generate_pq_data_from_pivots( - const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks, - const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, bool use_opq); - -template DISKANN_DLLEXPORT void generate_disk_quantized_data( - const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, - const std::string &disk_pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, + uint32_t num_pq_chunks, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, + bool use_opq); +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, + uint32_t num_pq_chunks, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, + bool use_opq); +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, + uint32_t num_pq_chunks, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, + bool use_opq); + +template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string &data_file_to_use, + const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, + diskann::Metric compareMetric, const double p_val, + size_t &disk_pq_dims); template DISKANN_DLLEXPORT void generate_disk_quantized_data( const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, - const std::string &disk_pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); - -template DISKANN_DLLEXPORT void generate_disk_quantized_data( - const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, - const std::string &disk_pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); - -template DISKANN_DLLEXPORT void generate_quantized_data( - const std::string &data_file_to_use, const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq, - const std::string &codebook_prefix); - -template DISKANN_DLLEXPORT void generate_quantized_data( - const std::string &data_file_to_use, const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq, - const std::string &codebook_prefix); - -template DISKANN_DLLEXPORT void generate_quantized_data( - const std::string &data_file_to_use, const std::string &pq_pivots_path, - const std::string &pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq, - const std::string &codebook_prefix); + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, + size_t &disk_pq_dims); + +template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string &data_file_to_use, + const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, + diskann::Metric compareMetric, const double p_val, + size_t &disk_pq_dims); + +template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, + diskann::Metric compareMetric, const double p_val, + const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); + +template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, + diskann::Metric compareMetric, const double p_val, + const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); + +template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, + diskann::Metric compareMetric, const double p_val, + const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); } // namespace diskann diff --git a/src/pq_data_store.cpp b/src/pq_data_store.cpp index 167215fe8..16e94d40e 100644 --- a/src/pq_data_store.cpp +++ b/src/pq_data_store.cpp @@ -6,260 +6,251 @@ #include "pq_scratch.h" #include "utils.h" -namespace diskann { +namespace diskann +{ // REFACTOR TODO: Assuming that num_pq_chunks is known already. Must verify if // this is true. template -PQDataStore::PQDataStore( - size_t dim, location_t num_points, size_t num_pq_chunks, - std::unique_ptr> distance_fn, - std::unique_ptr> pq_distance_fn) - : AbstractDataStore(num_points, dim), _quantized_data(nullptr), - _num_chunks(num_pq_chunks), _distance_metric(distance_fn->get_metric()) { - if (num_pq_chunks > dim) { - throw diskann::ANNException("ERROR: num_pq_chunks > dim", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - _distance_fn = std::move(distance_fn); - _pq_distance_fn = std::move(pq_distance_fn); +PQDataStore::PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, + std::unique_ptr> distance_fn, + std::unique_ptr> pq_distance_fn) + : AbstractDataStore(num_points, dim), _quantized_data(nullptr), _num_chunks(num_pq_chunks), + _distance_metric(distance_fn->get_metric()) +{ + if (num_pq_chunks > dim) + { + throw diskann::ANNException("ERROR: num_pq_chunks > dim", -1, __FUNCSIG__, __FILE__, __LINE__); + } + _distance_fn = std::move(distance_fn); + _pq_distance_fn = std::move(pq_distance_fn); } -template PQDataStore::~PQDataStore() { - if (_quantized_data != nullptr) { - aligned_free(_quantized_data); - _quantized_data = nullptr; - } +template PQDataStore::~PQDataStore() +{ + if (_quantized_data != nullptr) + { + aligned_free(_quantized_data); + _quantized_data = nullptr; + } } -template -location_t PQDataStore::load(const std::string &filename) { - return load_impl(filename); +template location_t PQDataStore::load(const std::string &filename) +{ + return load_impl(filename); } -template -size_t PQDataStore::save(const std::string &filename, - const location_t num_points) { - return diskann::save_bin(filename, _quantized_data, this->capacity(), - _num_chunks, 0); +template size_t PQDataStore::save(const std::string &filename, const location_t num_points) +{ + return diskann::save_bin(filename, _quantized_data, this->capacity(), _num_chunks, 0); } -template size_t PQDataStore::get_aligned_dim() const { - return this->get_dims(); +template size_t PQDataStore::get_aligned_dim() const +{ + return this->get_dims(); } // Populate quantized data from regular data. -template -void PQDataStore::populate_data(const data_t *vectors, - const location_t num_pts) { - throw std::logic_error("Not implemented yet"); +template void PQDataStore::populate_data(const data_t *vectors, const location_t num_pts) +{ + throw std::logic_error("Not implemented yet"); } -template -void PQDataStore::populate_data(const std::string &filename, - const size_t offset) { - if (_quantized_data != nullptr) { - aligned_free(_quantized_data); - } - - uint64_t file_num_points = 0, file_dim = 0; - get_bin_metadata(filename, file_num_points, file_dim, offset); - this->_capacity = (location_t)file_num_points; - this->_dim = file_dim; - - double p_val = std::min( - 1.0, ((double)MAX_PQ_TRAINING_SET_SIZE / (double)file_num_points)); - - auto pivots_file = _pq_distance_fn->get_pivot_data_filename(filename); - auto compressed_file = - _pq_distance_fn->get_quantized_vectors_filename(filename); - - generate_quantized_data(filename, pivots_file, compressed_file, - _distance_metric, p_val, _num_chunks, - _pq_distance_fn->is_opq()); - - // REFACTOR TODO: Not sure of the alignment. Just copying from index.cpp - alloc_aligned(((void **)&_quantized_data), - file_num_points * _num_chunks * sizeof(uint8_t), 1); - copy_aligned_data_from_file(compressed_file.c_str(), _quantized_data, - file_num_points, _num_chunks, - _num_chunks); +template void PQDataStore::populate_data(const std::string &filename, const size_t offset) +{ + if (_quantized_data != nullptr) + { + aligned_free(_quantized_data); + } + + uint64_t file_num_points = 0, file_dim = 0; + get_bin_metadata(filename, file_num_points, file_dim, offset); + this->_capacity = (location_t)file_num_points; + this->_dim = file_dim; + + double p_val = std::min(1.0, ((double)MAX_PQ_TRAINING_SET_SIZE / (double)file_num_points)); + + auto pivots_file = _pq_distance_fn->get_pivot_data_filename(filename); + auto compressed_file = _pq_distance_fn->get_quantized_vectors_filename(filename); + + generate_quantized_data(filename, pivots_file, compressed_file, _distance_metric, p_val, _num_chunks, + _pq_distance_fn->is_opq()); + + // REFACTOR TODO: Not sure of the alignment. Just copying from index.cpp + alloc_aligned(((void **)&_quantized_data), file_num_points * _num_chunks * sizeof(uint8_t), 1); + copy_aligned_data_from_file(compressed_file.c_str(), _quantized_data, file_num_points, _num_chunks, + _num_chunks); #ifdef EXEC_ENV_OLS - throw ANNException("load_pq_centroid_bin should not be called when " - "EXEC_ENV_OLS is defined.", - -1, __FUNCSIG__, __FILE__, __LINE__); + throw ANNException("load_pq_centroid_bin should not be called when " + "EXEC_ENV_OLS is defined.", + -1, __FUNCSIG__, __FILE__, __LINE__); #else - _pq_distance_fn->load_pivot_data(pivots_file.c_str(), _num_chunks); + _pq_distance_fn->load_pivot_data(pivots_file.c_str(), _num_chunks); #endif } template -void PQDataStore::extract_data_to_bin(const std::string &filename, - const location_t num_pts) { - throw std::logic_error("Not implemented yet"); +void PQDataStore::extract_data_to_bin(const std::string &filename, const location_t num_pts) +{ + throw std::logic_error("Not implemented yet"); } -template -void PQDataStore::get_vector(const location_t i, data_t *target) const { - // REFACTOR TODO: Should we inflate the compressed vector here? - if (i < this->capacity()) { - throw std::logic_error("Not implemented yet."); - } else { - std::stringstream ss; - ss << "Requested vector " << i << " but only " << this->capacity() - << " vectors are present"; - throw diskann::ANNException(ss.str(), -1); - } +template void PQDataStore::get_vector(const location_t i, data_t *target) const +{ + // REFACTOR TODO: Should we inflate the compressed vector here? + if (i < this->capacity()) + { + throw std::logic_error("Not implemented yet."); + } + else + { + std::stringstream ss; + ss << "Requested vector " << i << " but only " << this->capacity() << " vectors are present"; + throw diskann::ANNException(ss.str(), -1); + } } -template -void PQDataStore::set_vector(const location_t i, - const data_t *const vector) { - // REFACTOR TODO: Should we accept a normal vector and compress here? - // memcpy (_data + i * _num_chunks, vector, _num_chunks * sizeof(data_t)); - throw std::logic_error("Not implemented yet"); +template void PQDataStore::set_vector(const location_t i, const data_t *const vector) +{ + // REFACTOR TODO: Should we accept a normal vector and compress here? + // memcpy (_data + i * _num_chunks, vector, _num_chunks * sizeof(data_t)); + throw std::logic_error("Not implemented yet"); } -template -void PQDataStore::prefetch_vector(const location_t loc) { - const uint8_t *ptr = - _quantized_data + ((size_t)loc) * _num_chunks * sizeof(data_t); - diskann::prefetch_vector((const char *)ptr, _num_chunks * sizeof(data_t)); +template void PQDataStore::prefetch_vector(const location_t loc) +{ + const uint8_t *ptr = _quantized_data + ((size_t)loc) * _num_chunks * sizeof(data_t); + diskann::prefetch_vector((const char *)ptr, _num_chunks * sizeof(data_t)); } template -void PQDataStore::move_vectors(const location_t old_location_start, - const location_t new_location_start, - const location_t num_points) { - // REFACTOR TODO: Moving vectors is only for in-mem fresh. - throw std::logic_error("Not implemented yet"); +void PQDataStore::move_vectors(const location_t old_location_start, const location_t new_location_start, + const location_t num_points) +{ + // REFACTOR TODO: Moving vectors is only for in-mem fresh. + throw std::logic_error("Not implemented yet"); } template -void PQDataStore::copy_vectors(const location_t from_loc, - const location_t to_loc, - const location_t num_points) { - // REFACTOR TODO: Is the number of bytes correct? - memcpy(_quantized_data + to_loc * _num_chunks, - _quantized_data + from_loc * _num_chunks, _num_chunks * num_points); +void PQDataStore::copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) +{ + // REFACTOR TODO: Is the number of bytes correct? + memcpy(_quantized_data + to_loc * _num_chunks, _quantized_data + from_loc * _num_chunks, _num_chunks * num_points); } // REFACTOR TODO: Currently, we take aligned_query as parameter, but this // function should also do the alignment. template -void PQDataStore::preprocess_query( - const data_t *aligned_query, AbstractScratch *scratch) const { - if (scratch == nullptr) { - throw diskann::ANNException("Scratch space is null", -1); - } +void PQDataStore::preprocess_query(const data_t *aligned_query, AbstractScratch *scratch) const +{ + if (scratch == nullptr) + { + throw diskann::ANNException("Scratch space is null", -1); + } - PQScratch *pq_scratch = scratch->pq_scratch(); + PQScratch *pq_scratch = scratch->pq_scratch(); - if (pq_scratch == nullptr) { - throw diskann::ANNException( - "PQScratch space has not been set in the scratch object.", -1); - } + if (pq_scratch == nullptr) + { + throw diskann::ANNException("PQScratch space has not been set in the scratch object.", -1); + } - _pq_distance_fn->preprocess_query(aligned_query, (location_t)this->get_dims(), - *pq_scratch); + _pq_distance_fn->preprocess_query(aligned_query, (location_t)this->get_dims(), *pq_scratch); } -template -float PQDataStore::get_distance(const data_t *query, - const location_t loc) const { - throw std::logic_error("Not implemented yet"); +template float PQDataStore::get_distance(const data_t *query, const location_t loc) const +{ + throw std::logic_error("Not implemented yet"); } -template -float PQDataStore::get_distance(const location_t loc1, - const location_t loc2) const { - throw std::logic_error("Not implemented yet"); +template float PQDataStore::get_distance(const location_t loc1, const location_t loc2) const +{ + throw std::logic_error("Not implemented yet"); } template -void PQDataStore::get_distance( - const data_t *preprocessed_query, const location_t *locations, - const uint32_t location_count, float *distances, - AbstractScratch *scratch_space) const { - if (scratch_space == nullptr) { - throw diskann::ANNException("Scratch space is null", -1); - } - PQScratch *pq_scratch = scratch_space->pq_scratch(); - if (pq_scratch == nullptr) { - throw diskann::ANNException("PQScratch not set in scratch space.", -1); - } - diskann::aggregate_coords(locations, location_count, _quantized_data, - this->_num_chunks, - pq_scratch->aligned_pq_coord_scratch); - _pq_distance_fn->preprocessed_distance(*pq_scratch, location_count, - distances); +void PQDataStore::get_distance(const data_t *preprocessed_query, const location_t *locations, + const uint32_t location_count, float *distances, + AbstractScratch *scratch_space) const +{ + if (scratch_space == nullptr) + { + throw diskann::ANNException("Scratch space is null", -1); + } + PQScratch *pq_scratch = scratch_space->pq_scratch(); + if (pq_scratch == nullptr) + { + throw diskann::ANNException("PQScratch not set in scratch space.", -1); + } + diskann::aggregate_coords(locations, location_count, _quantized_data, this->_num_chunks, + pq_scratch->aligned_pq_coord_scratch); + _pq_distance_fn->preprocessed_distance(*pq_scratch, location_count, distances); } template -void PQDataStore::get_distance( - const data_t *preprocessed_query, const std::vector &ids, - std::vector &distances, - AbstractScratch *scratch_space) const { - if (scratch_space == nullptr) { - throw diskann::ANNException("Scratch space is null", -1); - } - PQScratch *pq_scratch = scratch_space->pq_scratch(); - if (pq_scratch == nullptr) { - throw diskann::ANNException("PQScratch not set in scratch space.", -1); - } - diskann::aggregate_coords(ids, _quantized_data, this->_num_chunks, - pq_scratch->aligned_pq_coord_scratch); - _pq_distance_fn->preprocessed_distance(*pq_scratch, (location_t)ids.size(), - distances); +void PQDataStore::get_distance(const data_t *preprocessed_query, const std::vector &ids, + std::vector &distances, AbstractScratch *scratch_space) const +{ + if (scratch_space == nullptr) + { + throw diskann::ANNException("Scratch space is null", -1); + } + PQScratch *pq_scratch = scratch_space->pq_scratch(); + if (pq_scratch == nullptr) + { + throw diskann::ANNException("PQScratch not set in scratch space.", -1); + } + diskann::aggregate_coords(ids, _quantized_data, this->_num_chunks, pq_scratch->aligned_pq_coord_scratch); + _pq_distance_fn->preprocessed_distance(*pq_scratch, (location_t)ids.size(), distances); } -template -location_t PQDataStore::calculate_medoid() const { - // REFACTOR TODO: Must calculate this just like we do with data store. - size_t r = (size_t)rand() * (size_t)RAND_MAX + (size_t)rand(); - return (uint32_t)(r % (size_t)this->capacity()); +template location_t PQDataStore::calculate_medoid() const +{ + // REFACTOR TODO: Must calculate this just like we do with data store. + size_t r = (size_t)rand() * (size_t)RAND_MAX + (size_t)rand(); + return (uint32_t)(r % (size_t)this->capacity()); } -template -size_t PQDataStore::get_alignment_factor() const { - return 1; +template size_t PQDataStore::get_alignment_factor() const +{ + return 1; } -template -Distance *PQDataStore::get_dist_fn() const { - return _distance_fn.get(); +template Distance *PQDataStore::get_dist_fn() const +{ + return _distance_fn.get(); } -template -location_t PQDataStore::load_impl(const std::string &file_prefix) { - if (_quantized_data != nullptr) { - aligned_free(_quantized_data); - } - auto quantized_vectors_file = - _pq_distance_fn->get_quantized_vectors_filename(file_prefix); - - size_t num_points; - load_aligned_bin(quantized_vectors_file, _quantized_data, num_points, - _num_chunks, _num_chunks); - this->_capacity = (location_t)num_points; - - auto pivots_file = _pq_distance_fn->get_pivot_data_filename(file_prefix); - _pq_distance_fn->load_pivot_data(pivots_file, _num_chunks); - - return this->_capacity; +template location_t PQDataStore::load_impl(const std::string &file_prefix) +{ + if (_quantized_data != nullptr) + { + aligned_free(_quantized_data); + } + auto quantized_vectors_file = _pq_distance_fn->get_quantized_vectors_filename(file_prefix); + + size_t num_points; + load_aligned_bin(quantized_vectors_file, _quantized_data, num_points, _num_chunks, _num_chunks); + this->_capacity = (location_t)num_points; + + auto pivots_file = _pq_distance_fn->get_pivot_data_filename(file_prefix); + _pq_distance_fn->load_pivot_data(pivots_file, _num_chunks); + + return this->_capacity; } -template -location_t PQDataStore::expand(const location_t new_size) { - throw std::logic_error("Not implemented yet"); +template location_t PQDataStore::expand(const location_t new_size) +{ + throw std::logic_error("Not implemented yet"); } -template -location_t PQDataStore::shrink(const location_t new_size) { - throw std::logic_error("Not implemented yet"); +template location_t PQDataStore::shrink(const location_t new_size) +{ + throw std::logic_error("Not implemented yet"); } #ifdef EXEC_ENV_OLS -template -location_t PQDataStore::load_impl(AlignedFileReader &reader) {} +template location_t PQDataStore::load_impl(AlignedFileReader &reader) +{ +} #endif template DISKANN_DLLEXPORT class PQDataStore; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index e5a460601..80c1c7460 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -21,1388 +21,1437 @@ #define READ_UNSIGNED(stream, val) stream.read((char *)&val, sizeof(unsigned)) // sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_NO(id) \ - (((uint64_t)(id)) / _nvecs_per_sector + _reorder_data_start_sector) +#define VECTOR_SECTOR_NO(id) (((uint64_t)(id)) / _nvecs_per_sector + _reorder_data_start_sector) // sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_OFFSET(id) \ - ((((uint64_t)(id)) % _nvecs_per_sector) * _data_dim * sizeof(float)) +#define VECTOR_SECTOR_OFFSET(id) ((((uint64_t)(id)) % _nvecs_per_sector) * _data_dim * sizeof(float)) -namespace diskann { +namespace diskann +{ template -PQFlashIndex::PQFlashIndex( - std::shared_ptr &fileReader, diskann::Metric m) - : reader(fileReader), metric(m), _thread_data(nullptr) { - diskann::Metric metric_to_invoke = m; - if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { - if (std::is_floating_point::value) { - diskann::cout << "Since data is floating point, we assume that it has " - "been appropriately pre-processed " - "(normalization for cosine, and convert-to-l2 by " - "adding extra dimension for MIPS). So we " - "shall invoke an l2 distance function." - << std::endl; - metric_to_invoke = diskann::Metric::L2; - } else { - diskann::cerr << "WARNING: Cannot normalize integral data types." - << " This may result in erroneous results or poor recall." - << " Consider using L2 distance with integral data types." - << std::endl; +PQFlashIndex::PQFlashIndex(std::shared_ptr &fileReader, diskann::Metric m) + : reader(fileReader), metric(m), _thread_data(nullptr) +{ + diskann::Metric metric_to_invoke = m; + if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) + { + if (std::is_floating_point::value) + { + diskann::cout << "Since data is floating point, we assume that it has " + "been appropriately pre-processed " + "(normalization for cosine, and convert-to-l2 by " + "adding extra dimension for MIPS). So we " + "shall invoke an l2 distance function." + << std::endl; + metric_to_invoke = diskann::Metric::L2; + } + else + { + diskann::cerr << "WARNING: Cannot normalize integral data types." + << " This may result in erroneous results or poor recall." + << " Consider using L2 distance with integral data types." << std::endl; + } } - } - this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); - this->_dist_cmp_float.reset( - diskann::get_distance_function(metric_to_invoke)); - this->_filter_store = std::make_unique>(); + this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); + this->_dist_cmp_float.reset(diskann::get_distance_function(metric_to_invoke)); + this->_filter_store = std::make_unique>(); } -template -PQFlashIndex::~PQFlashIndex() { +template PQFlashIndex::~PQFlashIndex() +{ #ifndef EXEC_ENV_OLS - if (data != nullptr) { - delete[] data; - } + if (data != nullptr) + { + delete[] data; + } #endif - if (_centroid_data != nullptr) - aligned_free(_centroid_data); - // delete backing bufs for nhood and coord cache - if (_nhood_cache_buf != nullptr) { - delete[] _nhood_cache_buf; - diskann::aligned_free(_coord_cache_buf); - } - - if (_medoids != nullptr) { - delete[] _medoids; - _medoids = nullptr; - } - - if (_load_flag) { - diskann::cout << "Clearing scratch" << std::endl; - ScratchStoreManager> manager(this->_thread_data); - manager.destroy(); - this->reader->deregister_all_threads(); - reader->close(); - } + if (_centroid_data != nullptr) + aligned_free(_centroid_data); + // delete backing bufs for nhood and coord cache + if (_nhood_cache_buf != nullptr) + { + delete[] _nhood_cache_buf; + diskann::aligned_free(_coord_cache_buf); + } + + if (_medoids != nullptr) + { + delete[] _medoids; + _medoids = nullptr; + } + + if (_load_flag) + { + diskann::cout << "Clearing scratch" << std::endl; + ScratchStoreManager> manager(this->_thread_data); + manager.destroy(); + this->reader->deregister_all_threads(); + reader->close(); + } } -template -inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) { - return 1 + - (_nnodes_per_sector > 0 - ? node_id / _nnodes_per_sector - : node_id * DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN)); +template inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) +{ + return 1 + (_nnodes_per_sector > 0 ? node_id / _nnodes_per_sector + : node_id * DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN)); } template -inline char *PQFlashIndex::offset_to_node(char *sector_buf, - uint64_t node_id) { - return sector_buf + (_nnodes_per_sector == 0 - ? 0 - : (node_id % _nnodes_per_sector) * _max_node_len); +inline char *PQFlashIndex::offset_to_node(char *sector_buf, uint64_t node_id) +{ + return sector_buf + (_nnodes_per_sector == 0 ? 0 : (node_id % _nnodes_per_sector) * _max_node_len); } -template -inline uint32_t *PQFlashIndex::offset_to_node_nhood(char *node_buf) { - return (unsigned *)(node_buf + _disk_bytes_per_point); +template inline uint32_t *PQFlashIndex::offset_to_node_nhood(char *node_buf) +{ + return (unsigned *)(node_buf + _disk_bytes_per_point); } -template -inline T *PQFlashIndex::offset_to_node_coords(char *node_buf) { - return (T *)(node_buf); +template inline T *PQFlashIndex::offset_to_node_coords(char *node_buf) +{ + return (T *)(node_buf); } template -void PQFlashIndex::setup_thread_data(uint64_t nthreads, - uint64_t visited_reserve) { - diskann::cout << "Setting up thread-specific contexts for nthreads: " - << nthreads << std::endl; +void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve) +{ + diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl; // omp parallel for to generate unique thread IDs #pragma omp parallel for num_threads((int)nthreads) - for (int64_t thread = 0; thread < (int64_t)nthreads; thread++) { -#pragma omp critical + for (int64_t thread = 0; thread < (int64_t)nthreads; thread++) { - SSDThreadData *data = - new SSDThreadData(this->_aligned_dim, visited_reserve); - this->reader->register_thread(); - data->ctx = this->reader->get_ctx(); - this->_thread_data.push(data); +#pragma omp critical + { + SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); + this->reader->register_thread(); + data->ctx = this->reader->get_ctx(); + this->_thread_data.push(data); + } } - } - _load_flag = true; + _load_flag = true; } template -std::vector PQFlashIndex::read_nodes( - const std::vector &node_ids, std::vector &coord_buffers, - std::vector> &nbr_buffers) { - std::vector read_reqs; - std::vector retval(node_ids.size(), true); - - char *buf = nullptr; - auto num_sectors = _nnodes_per_sector > 0 - ? 1 - : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); - alloc_aligned((void **)&buf, - node_ids.size() * num_sectors * defaults::SECTOR_LEN, - defaults::SECTOR_LEN); - - // create read requests - for (size_t i = 0; i < node_ids.size(); ++i) { - auto node_id = node_ids[i]; - - AlignedRead read; - read.len = num_sectors * defaults::SECTOR_LEN; - read.buf = buf + i * num_sectors * defaults::SECTOR_LEN; - read.offset = get_node_sector(node_id) * defaults::SECTOR_LEN; - read_reqs.push_back(read); - } - - // borrow thread data and issue reads - ScratchStoreManager> manager(this->_thread_data); - auto this_thread_data = manager.scratch_space(); - IOContext &ctx = this_thread_data->ctx; - reader->read(read_reqs, ctx); - - // copy reads into buffers - for (uint32_t i = 0; i < read_reqs.size(); i++) { -#if defined(_WINDOWS) && \ - defined(USE_BING_INFRA) // this block is to handle failed reads in - // production settings - if ((*ctx.m_pRequestsStatus)[i] != IOContext::READ_SUCCESS) { - retval[i] = false; - continue; +std::vector PQFlashIndex::read_nodes(const std::vector &node_ids, + std::vector &coord_buffers, + std::vector> &nbr_buffers) +{ + std::vector read_reqs; + std::vector retval(node_ids.size(), true); + + char *buf = nullptr; + auto num_sectors = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); + alloc_aligned((void **)&buf, node_ids.size() * num_sectors * defaults::SECTOR_LEN, defaults::SECTOR_LEN); + + // create read requests + for (size_t i = 0; i < node_ids.size(); ++i) + { + auto node_id = node_ids[i]; + + AlignedRead read; + read.len = num_sectors * defaults::SECTOR_LEN; + read.buf = buf + i * num_sectors * defaults::SECTOR_LEN; + read.offset = get_node_sector(node_id) * defaults::SECTOR_LEN; + read_reqs.push_back(read); } + + // borrow thread data and issue reads + ScratchStoreManager> manager(this->_thread_data); + auto this_thread_data = manager.scratch_space(); + IOContext &ctx = this_thread_data->ctx; + reader->read(read_reqs, ctx); + + // copy reads into buffers + for (uint32_t i = 0; i < read_reqs.size(); i++) + { +#if defined(_WINDOWS) && defined(USE_BING_INFRA) // this block is to handle failed reads in + // production settings + if ((*ctx.m_pRequestsStatus)[i] != IOContext::READ_SUCCESS) + { + retval[i] = false; + continue; + } #endif - char *node_buf = offset_to_node((char *)read_reqs[i].buf, node_ids[i]); + char *node_buf = offset_to_node((char *)read_reqs[i].buf, node_ids[i]); - if (coord_buffers[i] != nullptr) { - T *node_coords = offset_to_node_coords(node_buf); - memcpy(coord_buffers[i], node_coords, _disk_bytes_per_point); - } + if (coord_buffers[i] != nullptr) + { + T *node_coords = offset_to_node_coords(node_buf); + memcpy(coord_buffers[i], node_coords, _disk_bytes_per_point); + } - if (nbr_buffers[i].second != nullptr) { - uint32_t *node_nhood = offset_to_node_nhood(node_buf); - auto num_nbrs = *node_nhood; - nbr_buffers[i].first = num_nbrs; - memcpy(nbr_buffers[i].second, node_nhood + 1, - num_nbrs * sizeof(uint32_t)); + if (nbr_buffers[i].second != nullptr) + { + uint32_t *node_nhood = offset_to_node_nhood(node_buf); + auto num_nbrs = *node_nhood; + nbr_buffers[i].first = num_nbrs; + memcpy(nbr_buffers[i].second, node_nhood + 1, num_nbrs * sizeof(uint32_t)); + } } - } - aligned_free(buf); + aligned_free(buf); - return retval; + return retval; } -template -void PQFlashIndex::load_cache_list( - std::vector &node_list) { - diskann::cout << "Loading the cache list into memory.." << std::flush; - size_t num_cached_nodes = node_list.size(); - - // Allocate space for neighborhood cache - _nhood_cache_buf = new uint32_t[num_cached_nodes * (_max_degree + 1)]; - memset(_nhood_cache_buf, 0, num_cached_nodes * (_max_degree + 1)); - - // Allocate space for coordinate cache - size_t coord_cache_buf_len = num_cached_nodes * _aligned_dim; - diskann::alloc_aligned((void **)&_coord_cache_buf, - coord_cache_buf_len * sizeof(T), 8 * sizeof(T)); - memset(_coord_cache_buf, 0, coord_cache_buf_len * sizeof(T)); - - size_t BLOCK_SIZE = 8; - size_t num_blocks = DIV_ROUND_UP(num_cached_nodes, BLOCK_SIZE); - for (size_t block = 0; block < num_blocks; block++) { - size_t start_idx = block * BLOCK_SIZE; - size_t end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE); - - // Copy offset into buffers to read into - std::vector nodes_to_read; - std::vector coord_buffers; - std::vector> nbr_buffers; - for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++) { - nodes_to_read.push_back(node_list[node_idx]); - coord_buffers.push_back(_coord_cache_buf + node_idx * _aligned_dim); - nbr_buffers.emplace_back(0, - _nhood_cache_buf + node_idx * (_max_degree + 1)); - } +template void PQFlashIndex::load_cache_list(std::vector &node_list) +{ + diskann::cout << "Loading the cache list into memory.." << std::flush; + size_t num_cached_nodes = node_list.size(); - // issue the reads - auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); + // Allocate space for neighborhood cache + _nhood_cache_buf = new uint32_t[num_cached_nodes * (_max_degree + 1)]; + memset(_nhood_cache_buf, 0, num_cached_nodes * (_max_degree + 1)); - // check for success and insert into the cache. - for (size_t i = 0; i < read_status.size(); i++) { - if (read_status[i] == true) { - _coord_cache.insert(std::make_pair(nodes_to_read[i], coord_buffers[i])); - _nhood_cache.insert(std::make_pair(nodes_to_read[i], nbr_buffers[i])); - } + // Allocate space for coordinate cache + size_t coord_cache_buf_len = num_cached_nodes * _aligned_dim; + diskann::alloc_aligned((void **)&_coord_cache_buf, coord_cache_buf_len * sizeof(T), 8 * sizeof(T)); + memset(_coord_cache_buf, 0, coord_cache_buf_len * sizeof(T)); + + size_t BLOCK_SIZE = 8; + size_t num_blocks = DIV_ROUND_UP(num_cached_nodes, BLOCK_SIZE); + for (size_t block = 0; block < num_blocks; block++) + { + size_t start_idx = block * BLOCK_SIZE; + size_t end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE); + + // Copy offset into buffers to read into + std::vector nodes_to_read; + std::vector coord_buffers; + std::vector> nbr_buffers; + for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++) + { + nodes_to_read.push_back(node_list[node_idx]); + coord_buffers.push_back(_coord_cache_buf + node_idx * _aligned_dim); + nbr_buffers.emplace_back(0, _nhood_cache_buf + node_idx * (_max_degree + 1)); + } + + // issue the reads + auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); + + // check for success and insert into the cache. + for (size_t i = 0; i < read_status.size(); i++) + { + if (read_status[i] == true) + { + _coord_cache.insert(std::make_pair(nodes_to_read[i], coord_buffers[i])); + _nhood_cache.insert(std::make_pair(nodes_to_read[i], nbr_buffers[i])); + } + } } - } - diskann::cout << "..done." << std::endl; + diskann::cout << "..done." << std::endl; } #ifdef EXEC_ENV_OLS template -void PQFlashIndex::generate_cache_list_from_sample_queries( - MemoryMappedFiles &files, std::string sample_bin, uint64_t l_search, - uint64_t beamwidth, uint64_t num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list) { +void PQFlashIndex::generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin, + uint64_t l_search, uint64_t beamwidth, + uint64_t num_nodes_to_cache, uint32_t nthreads, + std::vector &node_list) +{ #else template -void PQFlashIndex::generate_cache_list_from_sample_queries( - std::string sample_bin, uint64_t l_search, uint64_t beamwidth, - uint64_t num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list) { +void PQFlashIndex::generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search, + uint64_t beamwidth, uint64_t num_nodes_to_cache, + uint32_t nthreads, + std::vector &node_list) +{ #endif - if (num_nodes_to_cache >= this->_num_points) { - // for small num_points and big num_nodes_to_cache, use below way to get - // the node_list quickly - node_list.resize(this->_num_points); - for (uint32_t i = 0; i < this->_num_points; ++i) { - node_list[i] = i; + if (num_nodes_to_cache >= this->_num_points) + { + // for small num_points and big num_nodes_to_cache, use below way to get + // the node_list quickly + node_list.resize(this->_num_points); + for (uint32_t i = 0; i < this->_num_points; ++i) + { + node_list[i] = i; + } + return; } - return; - } - this->_count_visited_nodes = true; - this->_node_visit_counter.clear(); - this->_node_visit_counter.resize(this->_num_points); - for (uint32_t i = 0; i < _node_visit_counter.size(); i++) { - this->_node_visit_counter[i].first = i; - this->_node_visit_counter[i].second = 0; - } + this->_count_visited_nodes = true; + this->_node_visit_counter.clear(); + this->_node_visit_counter.resize(this->_num_points); + for (uint32_t i = 0; i < _node_visit_counter.size(); i++) + { + this->_node_visit_counter[i].first = i; + this->_node_visit_counter[i].second = 0; + } - uint64_t sample_num, sample_dim, sample_aligned_dim; - T *samples; + uint64_t sample_num, sample_dim, sample_aligned_dim; + T *samples; #ifdef EXEC_ENV_OLS - if (files.fileExists(sample_bin)) { - diskann::load_aligned_bin(files, sample_bin, samples, sample_num, - sample_dim, sample_aligned_dim); - } + if (files.fileExists(sample_bin)) + { + diskann::load_aligned_bin(files, sample_bin, samples, sample_num, sample_dim, sample_aligned_dim); + } #else - if (file_exists(sample_bin)) { - diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, - sample_aligned_dim); - } + if (file_exists(sample_bin)) + { + diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, sample_aligned_dim); + } #endif - else { - diskann::cerr << "Sample bin file not found. Not generating cache." - << std::endl; - return; - } + else + { + diskann::cerr << "Sample bin file not found. Not generating cache." << std::endl; + return; + } - std::vector tmp_result_ids_64(sample_num, 0); - std::vector tmp_result_dists(sample_num, 0); + std::vector tmp_result_ids_64(sample_num, 0); + std::vector tmp_result_dists(sample_num, 0); - bool filtered_search = false; - std::vector random_query_filters(sample_num); - if (this->_filter_index) { - filtered_search = true; - _filter_store->generate_random_labels(random_query_filters, - (uint32_t)sample_num, nthreads); - } + bool filtered_search = false; + std::vector random_query_filters(sample_num); + if (this->_filter_index) + { + filtered_search = true; + _filter_store->generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); + } #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < (int64_t)sample_num; i++) { - auto &label_for_search = random_query_filters[i]; - // run a search on the sample query with a random label (sampled from base - // label distribution), and it will concurrently update the - // node_visit_counter to track most visited nodes. The last false is to - // not use the "use_reorder_data" option which enables a final reranking - // if the disk index itself contains only PQ data. - cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, - tmp_result_ids_64.data() + i, - tmp_result_dists.data() + i, beamwidth, filtered_search, - label_for_search, false); - } - - std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(), - [](std::pair &left, - std::pair &right) { - return left.second > right.second; - }); - node_list.clear(); - node_list.shrink_to_fit(); - num_nodes_to_cache = - std::min(num_nodes_to_cache, this->_node_visit_counter.size()); - node_list.reserve(num_nodes_to_cache); - for (uint64_t i = 0; i < num_nodes_to_cache; i++) { - node_list.push_back(this->_node_visit_counter[i].first); - } - this->_count_visited_nodes = false; - - diskann::aligned_free(samples); + for (int64_t i = 0; i < (int64_t)sample_num; i++) + { + auto &label_for_search = random_query_filters[i]; + // run a search on the sample query with a random label (sampled from base + // label distribution), and it will concurrently update the + // node_visit_counter to track most visited nodes. The last false is to + // not use the "use_reorder_data" option which enables a final reranking + // if the disk index itself contains only PQ data. + cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, + tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); + } + + std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(), + [](std::pair &left, std::pair &right) { + return left.second > right.second; + }); + node_list.clear(); + node_list.shrink_to_fit(); + num_nodes_to_cache = std::min(num_nodes_to_cache, this->_node_visit_counter.size()); + node_list.reserve(num_nodes_to_cache); + for (uint64_t i = 0; i < num_nodes_to_cache; i++) + { + node_list.push_back(this->_node_visit_counter[i].first); + } + this->_count_visited_nodes = false; + + diskann::aligned_free(samples); } template -void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, - std::vector &node_list, - const bool shuffle) { - std::random_device rng; - std::mt19937 urng(rng()); - - tsl::robin_set node_set; - - // Do not cache more than 10% of the nodes in the index - uint64_t tenp_nodes = (uint64_t)(std::round(this->_num_points * 0.1)); - if (num_nodes_to_cache > tenp_nodes) { - diskann::cout << "Reducing nodes to cache from: " << num_nodes_to_cache - << " to: " << tenp_nodes - << "(10 percent of total nodes:" << this->_num_points << ")" - << std::endl; - num_nodes_to_cache = tenp_nodes == 0 ? 1 : tenp_nodes; - } - diskann::cout << "Caching " << num_nodes_to_cache << "..." << std::endl; - - std::unique_ptr> cur_level, prev_level; - cur_level = std::make_unique>(); - prev_level = std::make_unique>(); - - for (uint64_t miter = 0; - miter < _num_medoids && cur_level->size() < num_nodes_to_cache; - miter++) { - cur_level->insert(_medoids[miter]); - } - - auto filter_to_medoid_ids = _filter_store->get_label_to_medoids(); - if ((filter_to_medoid_ids.size() > 0) && - (cur_level->size() < num_nodes_to_cache)) { - for (auto &x : filter_to_medoid_ids) { - for (auto &y : x.second) { - cur_level->insert(y); - if (cur_level->size() == num_nodes_to_cache) - break; - } - if (cur_level->size() == num_nodes_to_cache) - break; +void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector &node_list, + const bool shuffle) +{ + std::random_device rng; + std::mt19937 urng(rng()); + + tsl::robin_set node_set; + + // Do not cache more than 10% of the nodes in the index + uint64_t tenp_nodes = (uint64_t)(std::round(this->_num_points * 0.1)); + if (num_nodes_to_cache > tenp_nodes) + { + diskann::cout << "Reducing nodes to cache from: " << num_nodes_to_cache << " to: " << tenp_nodes + << "(10 percent of total nodes:" << this->_num_points << ")" << std::endl; + num_nodes_to_cache = tenp_nodes == 0 ? 1 : tenp_nodes; } - } - - uint64_t lvl = 1; - uint64_t prev_node_set_size = 0; - while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && - cur_level->size() != 0) { - // swap prev_level and cur_level - std::swap(prev_level, cur_level); - // clear cur_level - cur_level->clear(); - - std::vector nodes_to_expand; - - for (const uint32_t &id : *prev_level) { - if (node_set.find(id) != node_set.end()) { - continue; - } - node_set.insert(id); - nodes_to_expand.push_back(id); + diskann::cout << "Caching " << num_nodes_to_cache << "..." << std::endl; + + std::unique_ptr> cur_level, prev_level; + cur_level = std::make_unique>(); + prev_level = std::make_unique>(); + + for (uint64_t miter = 0; miter < _num_medoids && cur_level->size() < num_nodes_to_cache; miter++) + { + cur_level->insert(_medoids[miter]); } - if (shuffle) - std::shuffle(nodes_to_expand.begin(), nodes_to_expand.end(), urng); - else - std::sort(nodes_to_expand.begin(), nodes_to_expand.end()); + auto filter_to_medoid_ids = _filter_store->get_label_to_medoids(); + if ((filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) + { + for (auto &x : filter_to_medoid_ids) + { + for (auto &y : x.second) + { + cur_level->insert(y); + if (cur_level->size() == num_nodes_to_cache) + break; + } + if (cur_level->size() == num_nodes_to_cache) + break; + } + } - diskann::cout << "Level: " << lvl << std::flush; - bool finish_flag = false; - - uint64_t BLOCK_SIZE = 1024; - uint64_t nblocks = DIV_ROUND_UP(nodes_to_expand.size(), BLOCK_SIZE); - for (size_t block = 0; block < nblocks && !finish_flag; block++) { - diskann::cout << "." << std::flush; - size_t start = block * BLOCK_SIZE; - size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); - - std::vector nodes_to_read; - std::vector coord_buffers(end - start, nullptr); - std::vector> nbr_buffers; - - for (size_t cur_pt = start; cur_pt < end; cur_pt++) { - nodes_to_read.push_back(nodes_to_expand[cur_pt]); - nbr_buffers.emplace_back(0, new uint32_t[_max_degree + 1]); - } - - // issue read requests - auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); - - // process each nhood buf - for (uint32_t i = 0; i < read_status.size(); i++) { - if (read_status[i] == false) { - continue; - } else { - uint32_t nnbrs = nbr_buffers[i].first; - uint32_t *nbrs = nbr_buffers[i].second; - - // explore next level - for (uint32_t j = 0; j < nnbrs && !finish_flag; j++) { - if (node_set.find(nbrs[j]) == node_set.end()) { - cur_level->insert(nbrs[j]); + uint64_t lvl = 1; + uint64_t prev_node_set_size = 0; + while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && cur_level->size() != 0) + { + // swap prev_level and cur_level + std::swap(prev_level, cur_level); + // clear cur_level + cur_level->clear(); + + std::vector nodes_to_expand; + + for (const uint32_t &id : *prev_level) + { + if (node_set.find(id) != node_set.end()) + { + continue; + } + node_set.insert(id); + nodes_to_expand.push_back(id); + } + + if (shuffle) + std::shuffle(nodes_to_expand.begin(), nodes_to_expand.end(), urng); + else + std::sort(nodes_to_expand.begin(), nodes_to_expand.end()); + + diskann::cout << "Level: " << lvl << std::flush; + bool finish_flag = false; + + uint64_t BLOCK_SIZE = 1024; + uint64_t nblocks = DIV_ROUND_UP(nodes_to_expand.size(), BLOCK_SIZE); + for (size_t block = 0; block < nblocks && !finish_flag; block++) + { + diskann::cout << "." << std::flush; + size_t start = block * BLOCK_SIZE; + size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); + + std::vector nodes_to_read; + std::vector coord_buffers(end - start, nullptr); + std::vector> nbr_buffers; + + for (size_t cur_pt = start; cur_pt < end; cur_pt++) + { + nodes_to_read.push_back(nodes_to_expand[cur_pt]); + nbr_buffers.emplace_back(0, new uint32_t[_max_degree + 1]); } - if (cur_level->size() + node_set.size() >= num_nodes_to_cache) { - finish_flag = true; + + // issue read requests + auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); + + // process each nhood buf + for (uint32_t i = 0; i < read_status.size(); i++) + { + if (read_status[i] == false) + { + continue; + } + else + { + uint32_t nnbrs = nbr_buffers[i].first; + uint32_t *nbrs = nbr_buffers[i].second; + + // explore next level + for (uint32_t j = 0; j < nnbrs && !finish_flag; j++) + { + if (node_set.find(nbrs[j]) == node_set.end()) + { + cur_level->insert(nbrs[j]); + } + if (cur_level->size() + node_set.size() >= num_nodes_to_cache) + { + finish_flag = true; + } + } + } + delete[] nbr_buffers[i].second; } - } } - delete[] nbr_buffers[i].second; - } + + diskann::cout << ". #nodes: " << node_set.size() - prev_node_set_size + << ", #nodes thus far: " << node_set.size() << std::endl; + prev_node_set_size = node_set.size(); + lvl++; } - diskann::cout << ". #nodes: " << node_set.size() - prev_node_set_size - << ", #nodes thus far: " << node_set.size() << std::endl; - prev_node_set_size = node_set.size(); - lvl++; - } - - assert(node_set.size() + cur_level->size() == num_nodes_to_cache || - cur_level->size() == 0); - - node_list.clear(); - node_list.reserve(node_set.size() + cur_level->size()); - for (auto node : node_set) - node_list.push_back(node); - for (auto node : *cur_level) - node_list.push_back(node); - - diskann::cout << "Level: " << lvl << std::flush; - diskann::cout << ". #nodes: " << node_list.size() - prev_node_set_size - << ", #nodes thus far: " << node_list.size() << std::endl; - diskann::cout << "done" << std::endl; + assert(node_set.size() + cur_level->size() == num_nodes_to_cache || cur_level->size() == 0); + + node_list.clear(); + node_list.reserve(node_set.size() + cur_level->size()); + for (auto node : node_set) + node_list.push_back(node); + for (auto node : *cur_level) + node_list.push_back(node); + + diskann::cout << "Level: " << lvl << std::flush; + diskann::cout << ". #nodes: " << node_list.size() - prev_node_set_size << ", #nodes thus far: " << node_list.size() + << std::endl; + diskann::cout << "done" << std::endl; } -template -void PQFlashIndex::use_medoids_data_as_centroids() { - if (_centroid_data != nullptr) - aligned_free(_centroid_data); - alloc_aligned(((void **)&_centroid_data), - _num_medoids * _aligned_dim * sizeof(float), 32); - std::memset(_centroid_data, 0, _num_medoids * _aligned_dim * sizeof(float)); - - diskann::cout << "Loading centroid data from medoids vector data of " - << _num_medoids << " medoid(s)" << std::endl; - - std::vector nodes_to_read; - std::vector medoid_bufs; - std::vector> nbr_bufs; - - for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { - nodes_to_read.push_back(_medoids[cur_m]); - medoid_bufs.push_back(new T[_data_dim]); - nbr_bufs.emplace_back(0, nullptr); - } - - auto read_status = read_nodes(nodes_to_read, medoid_bufs, nbr_bufs); - - for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { - if (read_status[cur_m] == true) { - if (!_use_disk_index_pq) { - for (uint32_t i = 0; i < _data_dim; i++) - _centroid_data[cur_m * _aligned_dim + i] = medoid_bufs[cur_m][i]; - } else { - _disk_pq_table.inflate_vector((uint8_t *)medoid_bufs[cur_m], - (_centroid_data + cur_m * _aligned_dim)); - } - } else { - throw ANNException("Unable to read a medoid", -1, __FUNCSIG__, __FILE__, - __LINE__); +template void PQFlashIndex::use_medoids_data_as_centroids() +{ + if (_centroid_data != nullptr) + aligned_free(_centroid_data); + alloc_aligned(((void **)&_centroid_data), _num_medoids * _aligned_dim * sizeof(float), 32); + std::memset(_centroid_data, 0, _num_medoids * _aligned_dim * sizeof(float)); + + diskann::cout << "Loading centroid data from medoids vector data of " << _num_medoids << " medoid(s)" << std::endl; + + std::vector nodes_to_read; + std::vector medoid_bufs; + std::vector> nbr_bufs; + + for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) + { + nodes_to_read.push_back(_medoids[cur_m]); + medoid_bufs.push_back(new T[_data_dim]); + nbr_bufs.emplace_back(0, nullptr); + } + + auto read_status = read_nodes(nodes_to_read, medoid_bufs, nbr_bufs); + + for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) + { + if (read_status[cur_m] == true) + { + if (!_use_disk_index_pq) + { + for (uint32_t i = 0; i < _data_dim; i++) + _centroid_data[cur_m * _aligned_dim + i] = medoid_bufs[cur_m][i]; + } + else + { + _disk_pq_table.inflate_vector((uint8_t *)medoid_bufs[cur_m], (_centroid_data + cur_m * _aligned_dim)); + } + } + else + { + throw ANNException("Unable to read a medoid", -1, __FUNCSIG__, __FILE__, __LINE__); + } + delete[] medoid_bufs[cur_m]; } - delete[] medoid_bufs[cur_m]; - } } #ifdef EXEC_ENV_OLS template -int PQFlashIndex::load(MemoryMappedFiles &files, - uint32_t num_threads, - const char *index_prefix) { +int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) +{ #else -template -int PQFlashIndex::load(uint32_t num_threads, - const char *index_prefix) { +template int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) +{ #endif - std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; - std::string pq_compressed_vectors = - std::string(index_prefix) + "_pq_compressed.bin"; - std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; + std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; + std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; + std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; #ifdef EXEC_ENV_OLS - return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), - pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); + return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); #else - return load_from_separate_paths(num_threads, _disk_index_file.c_str(), - pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); + return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); #endif } #ifdef EXEC_ENV_OLS template -int PQFlashIndex::load_from_separate_paths( - diskann::MemoryMappedFiles &files, uint32_t num_threads, - const char *index_filepath, const char *pivots_filepath, - const char *compressed_filepath) { +int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, + const char *index_filepath, const char *pivots_filepath, + const char *compressed_filepath) +{ #else template -int PQFlashIndex::load_from_separate_paths( - uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, const char *compressed_filepath) { +int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, + const char *pivots_filepath, const char *compressed_filepath) +{ #endif - std::string pq_table_bin = pivots_filepath; - std::string pq_compressed_vectors = compressed_filepath; - std::string _disk_index_file = index_filepath; - std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; - std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; + std::string pq_table_bin = pivots_filepath; + std::string pq_compressed_vectors = compressed_filepath; + std::string _disk_index_file = index_filepath; + std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; + std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; - size_t pq_file_dim, pq_file_num_centroids; + size_t pq_file_dim, pq_file_num_centroids; #ifdef EXEC_ENV_OLS - get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, - METADATA_SIZE); + get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); #else - get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, - METADATA_SIZE); + get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); #endif - this->_disk_index_file = _disk_index_file; + this->_disk_index_file = _disk_index_file; - if (pq_file_num_centroids != 256) { - diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." - << std::endl; - return -1; - } + if (pq_file_num_centroids != 256) + { + diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl; + return -1; + } - this->_data_dim = pq_file_dim; - // will change later if we use PQ on disk or if we are using - // inner product without PQ - this->_disk_bytes_per_point = this->_data_dim * sizeof(T); - this->_aligned_dim = ROUND_UP(pq_file_dim, 8); + this->_data_dim = pq_file_dim; + // will change later if we use PQ on disk or if we are using + // inner product without PQ + this->_disk_bytes_per_point = this->_data_dim * sizeof(T); + this->_aligned_dim = ROUND_UP(pq_file_dim, 8); - size_t npts_u64, nchunks_u64; + size_t npts_u64, nchunks_u64; #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, - nchunks_u64); + diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #else - diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, - nchunks_u64); + diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #endif - this->_num_points = npts_u64; - this->_n_chunks = nchunks_u64; - - _filter_store = std::make_unique>(); - try { - _filter_index = _filter_store->load(_disk_index_file); - if (_filter_index) { - diskann::cout << "Index has filter support. " << std::endl; - } else { - diskann::cout << "Index does not have filter support." << std::endl; + this->_num_points = npts_u64; + this->_n_chunks = nchunks_u64; + + _filter_store = std::make_unique>(); + try + { + _filter_index = _filter_store->load(_disk_index_file); + if (_filter_index) + { + diskann::cout << "Index has filter support. " << std::endl; + } + else + { + diskann::cout << "Index does not have filter support." << std::endl; + } + } + catch (diskann::ANNException &ex) + { + // If filter_store=>load() returns false, it means filters are not + // enabled. If it throws, it means there was an error in processing a + // filter index. + diskann::cerr << "Filter index load failed because: " << ex.what() << std::endl; + return false; } - } catch (diskann::ANNException &ex) { - // If filter_store=>load() returns false, it means filters are not - // enabled. If it throws, it means there was an error in processing a - // filter index. - diskann::cerr << "Filter index load failed because: " << ex.what() - << std::endl; - return false; - } #ifdef EXEC_ENV_OLS - _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); + _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); #else - _pq_table.load_pq_centroid_bin(pq_table_bin.c_str(), nchunks_u64); + _pq_table.load_pq_centroid_bin(pq_table_bin.c_str(), nchunks_u64); #endif - diskann::cout - << "Loaded PQ centroids and in-memory compressed vectors. #points: " - << _num_points << " #dim: " << _data_dim - << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks - << std::endl; - - if (_n_chunks > MAX_PQ_CHUNKS) { - std::stringstream stream; - stream << "Error loading index. Ensure that max PQ bytes for in-memory " - "PQ data does not exceed " - << MAX_PQ_CHUNKS << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin"; + diskann::cout << "Loaded PQ centroids and in-memory compressed vectors. #points: " << _num_points + << " #dim: " << _data_dim << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks + << std::endl; + + if (_n_chunks > MAX_PQ_CHUNKS) + { + std::stringstream stream; + stream << "Error loading index. Ensure that max PQ bytes for in-memory " + "PQ data does not exceed " + << MAX_PQ_CHUNKS << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin"; #ifdef EXEC_ENV_OLS - if (files.fileExists(disk_pq_pivots_path)) { - _use_disk_index_pq = true; - // giving 0 chunks to make the _pq_table infer from the - // chunk_offsets file the correct value - _disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0); + if (files.fileExists(disk_pq_pivots_path)) + { + _use_disk_index_pq = true; + // giving 0 chunks to make the _pq_table infer from the + // chunk_offsets file the correct value + _disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0); #else - if (file_exists(disk_pq_pivots_path)) { - _use_disk_index_pq = true; - // giving 0 chunks to make the _pq_table infer from the - // chunk_offsets file the correct value - _disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); + if (file_exists(disk_pq_pivots_path)) + { + _use_disk_index_pq = true; + // giving 0 chunks to make the _pq_table infer from the + // chunk_offsets file the correct value + _disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); #endif - _disk_pq_n_chunks = _disk_pq_table.get_num_chunks(); - _disk_bytes_per_point = - _disk_pq_n_chunks * - sizeof(uint8_t); // revising disk_bytes_per_point since DISK PQ is used. - diskann::cout << "Disk index uses PQ data compressed down to " - << _disk_pq_n_chunks << " bytes per point." << std::endl; - } + _disk_pq_n_chunks = _disk_pq_table.get_num_chunks(); + _disk_bytes_per_point = + _disk_pq_n_chunks * sizeof(uint8_t); // revising disk_bytes_per_point since DISK PQ is used. + diskann::cout << "Disk index uses PQ data compressed down to " << _disk_pq_n_chunks << " bytes per point." + << std::endl; + } // read index metadata #ifdef EXEC_ENV_OLS - // This is a bit tricky. We have to read the header from the - // disk_index_file. But this is now exclusively a preserve of the - // DiskPriorityIO class. So, we need to estimate how many - // bytes are needed to store the header and read in that many using our - // 'standard' aligned file reader approach. - reader->open(_disk_index_file); - this->setup_thread_data(num_threads); - this->_max_nthreads = num_threads; - - char *bytes = getHeaderBytes(); - ContentBuf buf(bytes, HEADER_SIZE); - std::basic_istream index_metadata(&buf); + // This is a bit tricky. We have to read the header from the + // disk_index_file. But this is now exclusively a preserve of the + // DiskPriorityIO class. So, we need to estimate how many + // bytes are needed to store the header and read in that many using our + // 'standard' aligned file reader approach. + reader->open(_disk_index_file); + this->setup_thread_data(num_threads); + this->_max_nthreads = num_threads; + + char *bytes = getHeaderBytes(); + ContentBuf buf(bytes, HEADER_SIZE); + std::basic_istream index_metadata(&buf); #else - std::ifstream index_metadata(_disk_index_file, std::ios::binary); + std::ifstream index_metadata(_disk_index_file, std::ios::binary); #endif - uint32_t nr, nc; // metadata itself is stored as bin format (nr is number - // of metadata, nc should be 1) - READ_U32(index_metadata, nr); - READ_U32(index_metadata, nc); - - uint64_t disk_nnodes; - uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true - READ_U64(index_metadata, disk_nnodes); - READ_U64(index_metadata, disk_ndims); - - if (disk_nnodes != _num_points) { - diskann::cout << "Mismatch in #points for compressed data file and disk " - "index file: " - << disk_nnodes << " vs " << _num_points << std::endl; - return -1; - } - - size_t medoid_id_on_file; - READ_U64(index_metadata, medoid_id_on_file); - READ_U64(index_metadata, _max_node_len); - READ_U64(index_metadata, _nnodes_per_sector); - _max_degree = - ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1; - - if (_max_degree > defaults::MAX_GRAPH_DEGREE) { - std::stringstream stream; - stream << "Error loading index. Ensure that max graph degree (R) does " - "not exceed " - << defaults::MAX_GRAPH_DEGREE << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } - - // setting up concept of frozen points in disk index for streaming-DiskANN - READ_U64(index_metadata, this->_num_frozen_points); - uint64_t file_frozen_id; - READ_U64(index_metadata, file_frozen_id); - if (this->_num_frozen_points == 1) - this->_frozen_location = file_frozen_id; - if (this->_num_frozen_points == 1) { - diskann::cout << " Detected frozen point in index at location " - << this->_frozen_location - << ". Will not output it at search time." << std::endl; - } - - READ_U64(index_metadata, this->_reorder_data_exists); - if (this->_reorder_data_exists) { - if (this->_use_disk_index_pq == false) { - throw ANNException("Reordering is designed for used with disk PQ " - "compression option", - -1, __FUNCSIG__, __FILE__, __LINE__); + uint32_t nr, nc; // metadata itself is stored as bin format (nr is number + // of metadata, nc should be 1) + READ_U32(index_metadata, nr); + READ_U32(index_metadata, nc); + + uint64_t disk_nnodes; + uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true + READ_U64(index_metadata, disk_nnodes); + READ_U64(index_metadata, disk_ndims); + + if (disk_nnodes != _num_points) + { + diskann::cout << "Mismatch in #points for compressed data file and disk " + "index file: " + << disk_nnodes << " vs " << _num_points << std::endl; + return -1; + } + + size_t medoid_id_on_file; + READ_U64(index_metadata, medoid_id_on_file); + READ_U64(index_metadata, _max_node_len); + READ_U64(index_metadata, _nnodes_per_sector); + _max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1; + + if (_max_degree > defaults::MAX_GRAPH_DEGREE) + { + std::stringstream stream; + stream << "Error loading index. Ensure that max graph degree (R) does " + "not exceed " + << defaults::MAX_GRAPH_DEGREE << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // setting up concept of frozen points in disk index for streaming-DiskANN + READ_U64(index_metadata, this->_num_frozen_points); + uint64_t file_frozen_id; + READ_U64(index_metadata, file_frozen_id); + if (this->_num_frozen_points == 1) + this->_frozen_location = file_frozen_id; + if (this->_num_frozen_points == 1) + { + diskann::cout << " Detected frozen point in index at location " << this->_frozen_location + << ". Will not output it at search time." << std::endl; + } + + READ_U64(index_metadata, this->_reorder_data_exists); + if (this->_reorder_data_exists) + { + if (this->_use_disk_index_pq == false) + { + throw ANNException("Reordering is designed for used with disk PQ " + "compression option", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + READ_U64(index_metadata, this->_reorder_data_start_sector); + READ_U64(index_metadata, this->_ndims_reorder_vecs); + READ_U64(index_metadata, this->_nvecs_per_sector); } - READ_U64(index_metadata, this->_reorder_data_start_sector); - READ_U64(index_metadata, this->_ndims_reorder_vecs); - READ_U64(index_metadata, this->_nvecs_per_sector); - } - diskann::cout << "Disk-Index File Meta-data: "; - diskann::cout << "# nodes per sector: " << _nnodes_per_sector; - diskann::cout << ", max node len (bytes): " << _max_node_len; - diskann::cout << ", max node degree: " << _max_degree << std::endl; + diskann::cout << "Disk-Index File Meta-data: "; + diskann::cout << "# nodes per sector: " << _nnodes_per_sector; + diskann::cout << ", max node len (bytes): " << _max_node_len; + diskann::cout << ", max node degree: " << _max_degree << std::endl; #ifdef EXEC_ENV_OLS - delete[] bytes; + delete[] bytes; #else - index_metadata.close(); + index_metadata.close(); #endif #ifndef EXEC_ENV_OLS - // open AlignedFileReader handle to index_file - std::string index_fname(_disk_index_file); - reader->open(index_fname); - this->setup_thread_data(num_threads); - this->_max_nthreads = num_threads; + // open AlignedFileReader handle to index_file + std::string index_fname(_disk_index_file); + reader->open(index_fname); + this->setup_thread_data(num_threads); + this->_max_nthreads = num_threads; #endif #ifdef EXEC_ENV_OLS - if (files.fileExists(medoids_file)) { - size_t tmp_dim; - diskann::load_bin(files, norm_file, medoids_file, _medoids, - _num_medoids, tmp_dim); + if (files.fileExists(medoids_file)) + { + size_t tmp_dim; + diskann::load_bin(files, norm_file, medoids_file, _medoids, _num_medoids, tmp_dim); #else - if (file_exists(medoids_file)) { - size_t tmp_dim; - diskann::load_bin(medoids_file, _medoids, _num_medoids, tmp_dim); + if (file_exists(medoids_file)) + { + size_t tmp_dim; + diskann::load_bin(medoids_file, _medoids, _num_medoids, tmp_dim); #endif - if (tmp_dim != 1) { - std::stringstream stream; - stream << "Error loading medoids file. Expected bin format of m times " - "1 vector of uint32_t." - << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + if (tmp_dim != 1) + { + std::stringstream stream; + stream << "Error loading medoids file. Expected bin format of m times " + "1 vector of uint32_t." + << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } #ifdef EXEC_ENV_OLS - if (!files.fileExists(centroids_file)) { + if (!files.fileExists(centroids_file)) + { #else - if (!file_exists(centroids_file)) { + if (!file_exists(centroids_file)) + { #endif - diskann::cout - << "Centroid data file not found. Using corresponding vectors " - "for the medoids " - << std::endl; - use_medoids_data_as_centroids(); - } else { - size_t num_centroids, aligned_tmp_dim; + diskann::cout << "Centroid data file not found. Using corresponding vectors " + "for the medoids " + << std::endl; + use_medoids_data_as_centroids(); + } + else + { + size_t num_centroids, aligned_tmp_dim; #ifdef EXEC_ENV_OLS - diskann::load_aligned_bin(files, centroids_file, _centroid_data, - num_centroids, tmp_dim, aligned_tmp_dim); + diskann::load_aligned_bin(files, centroids_file, _centroid_data, num_centroids, tmp_dim, + aligned_tmp_dim); #else - diskann::load_aligned_bin(centroids_file, _centroid_data, - num_centroids, tmp_dim, aligned_tmp_dim); + diskann::load_aligned_bin(centroids_file, _centroid_data, num_centroids, tmp_dim, aligned_tmp_dim); #endif - if (aligned_tmp_dim != _aligned_dim || num_centroids != _num_medoids) { - std::stringstream stream; - stream << "Error loading centroids data file. Expected bin format " - "of " - "m times data_dim vector of float, where m is number of " - "medoids " - "in medoids file."; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); - } + if (aligned_tmp_dim != _aligned_dim || num_centroids != _num_medoids) + { + std::stringstream stream; + stream << "Error loading centroids data file. Expected bin format " + "of " + "m times data_dim vector of float, where m is number of " + "medoids " + "in medoids file."; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + } + } + else + { + _num_medoids = 1; + _medoids = new uint32_t[1]; + _medoids[0] = (uint32_t)(medoid_id_on_file); + use_medoids_data_as_centroids(); } - } else { - _num_medoids = 1; - _medoids = new uint32_t[1]; - _medoids[0] = (uint32_t)(medoid_id_on_file); - use_medoids_data_as_centroids(); - } - std::string norm_file = std::string(_disk_index_file) + "_max_base_norm.bin"; + std::string norm_file = std::string(_disk_index_file) + "_max_base_norm.bin"; #ifdef EXEC_ENV_OLS - if (files.fileExists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) { - uint64_t dumr, dumc; - float *norm_val; - diskann::load_bin(files, norm_val, dumr, dumc); + if (files.fileExists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) + { + uint64_t dumr, dumc; + float *norm_val; + diskann::load_bin(files, norm_val, dumr, dumc); #else - if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) { - uint64_t dumr, dumc; - float *norm_val; - diskann::load_bin(norm_file, norm_val, dumr, dumc); + if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) + { + uint64_t dumr, dumc; + float *norm_val; + diskann::load_bin(norm_file, norm_val, dumr, dumc); #endif - this->_max_base_norm = norm_val[0]; - diskann::cout << "Setting re-scaling factor of base vectors to " - << this->_max_base_norm << std::endl; - delete[] norm_val; - } - diskann::cout << "done.." << std::endl; - return 0; + this->_max_base_norm = norm_val[0]; + diskann::cout << "Setting re-scaling factor of base vectors to " << this->_max_base_norm << std::endl; + delete[] norm_val; + } + diskann::cout << "done.." << std::endl; + return 0; } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(std::shared_ptr &reader, - IOContext &ctx, size_t size, int &completedIndex) { - if ((*ctx.m_pRequests)[0].m_callback) { - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do { - for (int i = 0; i < size; i++) { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) { - completedIndex = i; - return true; - } else if (ithStatus == IOContext::Status::READ_WAIT) { - waitsRemaining = true; - } - } - - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) { - WaitOnAddress(&ctx.m_completeCount, &completeCount, - sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? - } - } while (waitsRemaining); - - completedIndex = -1; - return false; - } else { - reader->wait(ctx, completedIndex); - return completedIndex != -1; - } +bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, + int &completedIndex) +{ + if ((*ctx.m_pRequests)[0].m_callback) + { + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do + { + for (int i = 0; i < size; i++) + { + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) + { + completedIndex = i; + return true; + } + else if (ithStatus == IOContext::Status::READ_WAIT) + { + waitsRemaining = true; + } + } + + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) + { + WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? + } + } while (waitsRemaining); + + completedIndex = -1; + return false; + } + else + { + reader->wait(ctx, completedIndex); + return completedIndex != -1; + } } #endif template -void PQFlashIndex::cached_beam_search( - const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, - std::numeric_limits::max(), use_reorder_data, - stats); +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_reorder_data, QueryStats *stats) +{ + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), + use_reorder_data, stats); } template -void PQFlashIndex::cached_beam_search( - const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, - const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, - use_filter, filter_label, - std::numeric_limits::max(), use_reorder_data, - stats); +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const bool use_reorder_data, QueryStats *stats) +{ + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, + std::numeric_limits::max(), use_reorder_data, stats); } template -void PQFlashIndex::cached_beam_search( - const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, - false, dummy_filter, io_limit, use_reorder_data, stats); +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const uint32_t io_limit, const bool use_reorder_data, + QueryStats *stats) +{ + LabelT dummy_filter = 0; + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + use_reorder_data, stats); } template -void PQFlashIndex::cached_beam_search( - const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, const uint32_t io_limit, - const bool use_reorder_data, QueryStats *stats) { - uint64_t num_sector_per_nodes = - DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); - if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) - throw ANNException( - "Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, - __FUNCSIG__, __FILE__, __LINE__); - - ScratchStoreManager> manager(this->_thread_data); - auto data = manager.scratch_space(); - IOContext &ctx = data->ctx; - auto query_scratch = &(data->scratch); - auto pq_query_scratch = query_scratch->pq_scratch(); - - // reset query scratch - query_scratch->reset(); - - // copy query to thread specific aligned and allocated memory (for distance - // calculations we need aligned data) - float query_norm = 0; - T *aligned_query_T = query_scratch->aligned_query_T(); - float *query_float = pq_query_scratch->aligned_query_float; - float *query_rotated = pq_query_scratch->rotated_query; - - // normalization step. for cosine, we simply normalize the query - // for mips, we normalize the first d-1 dims, and add a 0 for last dim, - // since an extra coordinate was used to convert MIPS to L2 search - if (metric == diskann::Metric::INNER_PRODUCT || - metric == diskann::Metric::COSINE) { - uint64_t inherent_dim = (metric == diskann::Metric::COSINE) - ? this->_data_dim - : (uint64_t)(this->_data_dim - 1); - for (size_t i = 0; i < inherent_dim; i++) { - aligned_query_T[i] = query1[i]; - query_norm += query1[i] * query1[i]; - } - if (metric == diskann::Metric::INNER_PRODUCT) - aligned_query_T[this->_data_dim - 1] = 0; +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const uint32_t io_limit, const bool use_reorder_data, + QueryStats *stats) +{ + uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); + if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) + throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, + __LINE__); - query_norm = std::sqrt(query_norm); + ScratchStoreManager> manager(this->_thread_data); + auto data = manager.scratch_space(); + IOContext &ctx = data->ctx; + auto query_scratch = &(data->scratch); + auto pq_query_scratch = query_scratch->pq_scratch(); + + // reset query scratch + query_scratch->reset(); + + // copy query to thread specific aligned and allocated memory (for distance + // calculations we need aligned data) + float query_norm = 0; + T *aligned_query_T = query_scratch->aligned_query_T(); + float *query_float = pq_query_scratch->aligned_query_float; + float *query_rotated = pq_query_scratch->rotated_query; + + // normalization step. for cosine, we simply normalize the query + // for mips, we normalize the first d-1 dims, and add a 0 for last dim, + // since an extra coordinate was used to convert MIPS to L2 search + if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE) + { + uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1); + for (size_t i = 0; i < inherent_dim; i++) + { + aligned_query_T[i] = query1[i]; + query_norm += query1[i] * query1[i]; + } + if (metric == diskann::Metric::INNER_PRODUCT) + aligned_query_T[this->_data_dim - 1] = 0; - for (size_t i = 0; i < inherent_dim; i++) { - aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm); - } - pq_query_scratch->initialize(this->_data_dim, aligned_query_T); - } else { - for (size_t i = 0; i < this->_data_dim; i++) { - aligned_query_T[i] = query1[i]; + query_norm = std::sqrt(query_norm); + + for (size_t i = 0; i < inherent_dim; i++) + { + aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm); + } + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } - pq_query_scratch->initialize(this->_data_dim, aligned_query_T); - } - - // pointers to buffers for data - T *data_buf = query_scratch->coord_scratch; - _mm_prefetch((char *)data_buf, _MM_HINT_T1); - - // sector scratch - char *sector_scratch = query_scratch->sector_scratch; - uint64_t §or_scratch_idx = query_scratch->sector_idx; - const uint64_t num_sectors_per_node = - _nnodes_per_sector > 0 - ? 1 - : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); - - // query <-> PQ chunk centers distances - _pq_table.preprocess_query(query_rotated); // center the query and rotate - // if we have a rotation matrix - float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; - _pq_table.populate_chunk_distances(query_rotated, pq_dists); - - // query <-> neighbor list - float *dist_scratch = pq_query_scratch->aligned_dist_scratch; - uint8_t *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; - - // lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, pq_coord_scratch, pq_dists](const uint32_t *ids, - const uint64_t n_ids, - float *dists_out) { - diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, - pq_coord_scratch); - diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, pq_dists, - dists_out); - }; - Timer query_timer, io_timer, cpu_timer; - - tsl::robin_set &visited = query_scratch->visited; - NeighborPriorityQueue &retset = query_scratch->retset; - retset.reserve(l_search); - std::vector &full_retset = query_scratch->full_retset; - - uint32_t best_medoid = 0; - float best_dist = (std::numeric_limits::max)(); - if (!use_filter) { - for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { - float cur_expanded_dist = _dist_cmp_float->compare( - query_float, _centroid_data + _aligned_dim * cur_m, - (uint32_t)_aligned_dim); - if (cur_expanded_dist < best_dist) { - best_medoid = _medoids[cur_m]; - best_dist = cur_expanded_dist; - } + else + { + for (size_t i = 0; i < this->_data_dim; i++) + { + aligned_query_T[i] = query1[i]; + } + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } - } else { - const auto &medoid_ids = _filter_store->get_medoids_of_label(filter_label); - if (medoid_ids.size() > 0) - // if (_filter_to_medoid_ids.find(filter_label) != - // _filter_to_medoid_ids.end()) + + // pointers to buffers for data + T *data_buf = query_scratch->coord_scratch; + _mm_prefetch((char *)data_buf, _MM_HINT_T1); + + // sector scratch + char *sector_scratch = query_scratch->sector_scratch; + uint64_t §or_scratch_idx = query_scratch->sector_idx; + const uint64_t num_sectors_per_node = + _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); + + // query <-> PQ chunk centers distances + _pq_table.preprocess_query(query_rotated); // center the query and rotate + // if we have a rotation matrix + float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; + _pq_table.populate_chunk_distances(query_rotated, pq_dists); + + // query <-> neighbor list + float *dist_scratch = pq_query_scratch->aligned_dist_scratch; + uint8_t *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; + + // lambda to batch compute query<-> node distances in PQ space + auto compute_dists = [this, pq_coord_scratch, pq_dists](const uint32_t *ids, const uint64_t n_ids, + float *dists_out) { + diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, pq_coord_scratch); + diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, pq_dists, dists_out); + }; + Timer query_timer, io_timer, cpu_timer; + + tsl::robin_set &visited = query_scratch->visited; + NeighborPriorityQueue &retset = query_scratch->retset; + retset.reserve(l_search); + std::vector &full_retset = query_scratch->full_retset; + + uint32_t best_medoid = 0; + float best_dist = (std::numeric_limits::max)(); + if (!use_filter) { - // const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; - - for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) { - // for filtered index, we dont store global centroid data as for - // unfiltered index, so we use PQ distance as approximation to decide - // closest medoid matching the query filter. - compute_dists(&medoid_ids[cur_m], 1, dist_scratch); - float cur_expanded_dist = dist_scratch[0]; - if (cur_expanded_dist < best_dist) { - best_medoid = medoid_ids[cur_m]; - best_dist = cur_expanded_dist; + for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) + { + float cur_expanded_dist = + _dist_cmp_float->compare(query_float, _centroid_data + _aligned_dim * cur_m, (uint32_t)_aligned_dim); + if (cur_expanded_dist < best_dist) + { + best_medoid = _medoids[cur_m]; + best_dist = cur_expanded_dist; + } } - } - } else { - throw ANNException("Cannot find medoid for specified filter.", -1, - __FUNCSIG__, __FILE__, __LINE__); } - } - - compute_dists(&best_medoid, 1, dist_scratch); - retset.insert(Neighbor(best_medoid, dist_scratch[0])); - visited.insert(best_medoid); - - uint32_t cmps = 0; - uint32_t hops = 0; - uint32_t num_ios = 0; - - // cleared every iteration - std::vector frontier; - frontier.reserve(2 * beam_width); - std::vector> frontier_nhoods; - frontier_nhoods.reserve(2 * beam_width); - std::vector frontier_read_reqs; - frontier_read_reqs.reserve(2 * beam_width); - std::vector>> - cached_nhoods; - cached_nhoods.reserve(2 * beam_width); - - while (retset.has_unexpanded_node() && num_ios < io_limit) { - // clear iteration state - frontier.clear(); - frontier_nhoods.clear(); - frontier_read_reqs.clear(); - cached_nhoods.clear(); - sector_scratch_idx = 0; - // find new beam - uint32_t num_seen = 0; - while (retset.has_unexpanded_node() && frontier.size() < beam_width && - num_seen < beam_width) { - auto nbr = retset.closest_unexpanded(); - num_seen++; - auto iter = _nhood_cache.find(nbr.id); - if (iter != _nhood_cache.end()) { - cached_nhoods.push_back(std::make_pair(nbr.id, iter->second)); - if (stats != nullptr) { - stats->n_cache_hits++; + else + { + const auto &medoid_ids = _filter_store->get_medoids_of_label(filter_label); + if (medoid_ids.size() > 0) + // if (_filter_to_medoid_ids.find(filter_label) != + // _filter_to_medoid_ids.end()) + { + // const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + + for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) + { + // for filtered index, we dont store global centroid data as for + // unfiltered index, so we use PQ distance as approximation to decide + // closest medoid matching the query filter. + compute_dists(&medoid_ids[cur_m], 1, dist_scratch); + float cur_expanded_dist = dist_scratch[0]; + if (cur_expanded_dist < best_dist) + { + best_medoid = medoid_ids[cur_m]; + best_dist = cur_expanded_dist; + } + } + } + else + { + throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); } - } else { - frontier.push_back(nbr.id); - } - if (this->_count_visited_nodes) { - reinterpret_cast &>( - this->_node_visit_counter[nbr.id].second) - .fetch_add(1); - } } - // read nhoods of frontier ids - if (!frontier.empty()) { - if (stats != nullptr) - stats->n_hops++; - for (uint64_t i = 0; i < frontier.size(); i++) { - auto id = frontier[i]; - std::pair fnhood; - fnhood.first = id; - fnhood.second = sector_scratch + num_sectors_per_node * - sector_scratch_idx * - defaults::SECTOR_LEN; - sector_scratch_idx++; - frontier_nhoods.push_back(fnhood); - frontier_read_reqs.emplace_back( - get_node_sector((size_t)id) * defaults::SECTOR_LEN, - num_sectors_per_node * defaults::SECTOR_LEN, fnhood.second); - if (stats != nullptr) { - stats->n_4k++; - stats->n_ios++; + compute_dists(&best_medoid, 1, dist_scratch); + retset.insert(Neighbor(best_medoid, dist_scratch[0])); + visited.insert(best_medoid); + + uint32_t cmps = 0; + uint32_t hops = 0; + uint32_t num_ios = 0; + + // cleared every iteration + std::vector frontier; + frontier.reserve(2 * beam_width); + std::vector> frontier_nhoods; + frontier_nhoods.reserve(2 * beam_width); + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width); + std::vector>> cached_nhoods; + cached_nhoods.reserve(2 * beam_width); + + while (retset.has_unexpanded_node() && num_ios < io_limit) + { + // clear iteration state + frontier.clear(); + frontier_nhoods.clear(); + frontier_read_reqs.clear(); + cached_nhoods.clear(); + sector_scratch_idx = 0; + // find new beam + uint32_t num_seen = 0; + while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) + { + auto nbr = retset.closest_unexpanded(); + num_seen++; + auto iter = _nhood_cache.find(nbr.id); + if (iter != _nhood_cache.end()) + { + cached_nhoods.push_back(std::make_pair(nbr.id, iter->second)); + if (stats != nullptr) + { + stats->n_cache_hits++; + } + } + else + { + frontier.push_back(nbr.id); + } + if (this->_count_visited_nodes) + { + reinterpret_cast &>(this->_node_visit_counter[nbr.id].second).fetch_add(1); + } } - num_ios++; - } - io_timer.reset(); + + // read nhoods of frontier ids + if (!frontier.empty()) + { + if (stats != nullptr) + stats->n_hops++; + for (uint64_t i = 0; i < frontier.size(); i++) + { + auto id = frontier[i]; + std::pair fnhood; + fnhood.first = id; + fnhood.second = sector_scratch + num_sectors_per_node * sector_scratch_idx * defaults::SECTOR_LEN; + sector_scratch_idx++; + frontier_nhoods.push_back(fnhood); + frontier_read_reqs.emplace_back(get_node_sector((size_t)id) * defaults::SECTOR_LEN, + num_sectors_per_node * defaults::SECTOR_LEN, fnhood.second); + if (stats != nullptr) + { + stats->n_4k++; + stats->n_ios++; + } + num_ios++; + } + io_timer.reset(); #ifdef USE_BING_INFRA - reader->read(frontier_read_reqs, ctx, - true); // asynhronous reader for Bing. + reader->read(frontier_read_reqs, ctx, + true); // asynhronous reader for Bing. #else - reader->read(frontier_read_reqs, ctx); // synchronous IO linux + reader->read(frontier_read_reqs, ctx); // synchronous IO linux #endif - if (stats != nullptr) { - stats->io_us += (float)io_timer.elapsed(); - } - } + if (stats != nullptr) + { + stats->io_us += (float)io_timer.elapsed(); + } + } - // process cached nhoods - for (auto &cached_nhood : cached_nhoods) { - auto global_cache_iter = _coord_cache.find(cached_nhood.first); - T *node_fp_coords_copy = global_cache_iter->second; - float cur_expanded_dist; - if (!_use_disk_index_pq) { - cur_expanded_dist = _dist_cmp->compare( - aligned_query_T, node_fp_coords_copy, (uint32_t)_aligned_dim); - } else { - if (metric == diskann::Metric::INNER_PRODUCT) - cur_expanded_dist = _disk_pq_table.inner_product( - query_float, (uint8_t *)node_fp_coords_copy); - else - cur_expanded_dist = - _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet - query_float, (uint8_t *)node_fp_coords_copy); - } - full_retset.push_back( - Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist)); - - uint64_t nnbrs = cached_nhood.second.first; - uint32_t *node_nbrs = cached_nhood.second.second; - - // compute node_nbrs <-> query dists in PQ space - cpu_timer.reset(); - compute_dists(node_nbrs, nnbrs, dist_scratch); - if (stats != nullptr) { - stats->n_cmps += (uint32_t)nnbrs; - stats->cpu_us += (float)cpu_timer.elapsed(); - } - - // process prefetched nhood - for (uint64_t m = 0; m < nnbrs; ++m) { - uint32_t id = node_nbrs[m]; - if (visited.insert(id).second) { - // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) - // unfiltered search, but filtered index! - if (!use_filter && _filter_store->is_dummy_point(id)) - continue; + // process cached nhoods + for (auto &cached_nhood : cached_nhoods) + { + auto global_cache_iter = _coord_cache.find(cached_nhood.first); + T *node_fp_coords_copy = global_cache_iter->second; + float cur_expanded_dist; + if (!_use_disk_index_pq) + { + cur_expanded_dist = _dist_cmp->compare(aligned_query_T, node_fp_coords_copy, (uint32_t)_aligned_dim); + } + else + { + if (metric == diskann::Metric::INNER_PRODUCT) + cur_expanded_dist = _disk_pq_table.inner_product(query_float, (uint8_t *)node_fp_coords_copy); + else + cur_expanded_dist = _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet + query_float, (uint8_t *)node_fp_coords_copy); + } + full_retset.push_back(Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist)); + + uint64_t nnbrs = cached_nhood.second.first; + uint32_t *node_nbrs = cached_nhood.second.second; + + // compute node_nbrs <-> query dists in PQ space + cpu_timer.reset(); + compute_dists(node_nbrs, nnbrs, dist_scratch); + if (stats != nullptr) + { + stats->n_cmps += (uint32_t)nnbrs; + stats->cpu_us += (float)cpu_timer.elapsed(); + } - // if (use_filter && !(point_has_label(id, filter_label)) && - // (!_use_universal_label || !point_has_label(id, - // _universal_filter_label))) - if (use_filter && !_filter_store->point_has_label_or_universal_label( - id, filter_label)) - continue; - cmps++; - float dist = dist_scratch[m]; - Neighbor nn(id, dist); - retset.insert(nn); + // process prefetched nhood + for (uint64_t m = 0; m < nnbrs; ++m) + { + uint32_t id = node_nbrs[m]; + if (visited.insert(id).second) + { + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // unfiltered search, but filtered index! + if (!use_filter && _filter_store->is_dummy_point(id)) + continue; + + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) + continue; + cmps++; + float dist = dist_scratch[m]; + Neighbor nn(id, dist); + retset.insert(nn); + } + } } - } - } #ifdef USE_BING_INFRA - // process each frontier nhood - compute distances to unvisited nodes - int completedIndex = -1; - long requestCount = static_cast(frontier_read_reqs.size()); - // If we issued read requests and if a read is complete or there are - // reads in wait state, then enter the while loop. - while (requestCount > 0 && - getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) { - assert(completedIndex >= 0); - auto &frontier_nhood = frontier_nhoods[completedIndex]; - (*ctx.m_pRequestsStatus)[completedIndex] = IOContext::PROCESS_COMPLETE; + // process each frontier nhood - compute distances to unvisited nodes + int completedIndex = -1; + long requestCount = static_cast(frontier_read_reqs.size()); + // If we issued read requests and if a read is complete or there are + // reads in wait state, then enter the while loop. + while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) + { + assert(completedIndex >= 0); + auto &frontier_nhood = frontier_nhoods[completedIndex]; + (*ctx.m_pRequestsStatus)[completedIndex] = IOContext::PROCESS_COMPLETE; #else - for (auto &frontier_nhood : frontier_nhoods) { + for (auto &frontier_nhood : frontier_nhoods) + { #endif - char *node_disk_buf = - offset_to_node(frontier_nhood.second, frontier_nhood.first); - uint32_t *node_buf = offset_to_node_nhood(node_disk_buf); - uint64_t nnbrs = (uint64_t)(*node_buf); - T *node_fp_coords = offset_to_node_coords(node_disk_buf); - memcpy(data_buf, node_fp_coords, _disk_bytes_per_point); - float cur_expanded_dist; - if (!_use_disk_index_pq) { - cur_expanded_dist = _dist_cmp->compare(aligned_query_T, data_buf, - (uint32_t)_aligned_dim); - } else { - if (metric == diskann::Metric::INNER_PRODUCT) - cur_expanded_dist = - _disk_pq_table.inner_product(query_float, (uint8_t *)data_buf); - else - cur_expanded_dist = - _disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf); - } - full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); - uint32_t *node_nbrs = (node_buf + 1); - // compute node_nbrs <-> query dist in PQ space - cpu_timer.reset(); - compute_dists(node_nbrs, nnbrs, dist_scratch); - if (stats != nullptr) { - stats->n_cmps += (uint32_t)nnbrs; - stats->cpu_us += (float)cpu_timer.elapsed(); - } - - cpu_timer.reset(); - // process prefetch-ed nhood - for (uint64_t m = 0; m < nnbrs; ++m) { - uint32_t id = node_nbrs[m]; - if (visited.insert(id).second) { - // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) - if (!use_filter && _filter_store->is_dummy_point(id)) - continue; + char *node_disk_buf = offset_to_node(frontier_nhood.second, frontier_nhood.first); + uint32_t *node_buf = offset_to_node_nhood(node_disk_buf); + uint64_t nnbrs = (uint64_t)(*node_buf); + T *node_fp_coords = offset_to_node_coords(node_disk_buf); + memcpy(data_buf, node_fp_coords, _disk_bytes_per_point); + float cur_expanded_dist; + if (!_use_disk_index_pq) + { + cur_expanded_dist = _dist_cmp->compare(aligned_query_T, data_buf, (uint32_t)_aligned_dim); + } + else + { + if (metric == diskann::Metric::INNER_PRODUCT) + cur_expanded_dist = _disk_pq_table.inner_product(query_float, (uint8_t *)data_buf); + else + cur_expanded_dist = _disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf); + } + full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); + uint32_t *node_nbrs = (node_buf + 1); + // compute node_nbrs <-> query dist in PQ space + cpu_timer.reset(); + compute_dists(node_nbrs, nnbrs, dist_scratch); + if (stats != nullptr) + { + stats->n_cmps += (uint32_t)nnbrs; + stats->cpu_us += (float)cpu_timer.elapsed(); + } - // if (use_filter && !(point_has_label(id, filter_label)) && - // (!_use_universal_label || !point_has_label(id, - // _universal_filter_label))) - if (use_filter && !_filter_store->point_has_label_or_universal_label( - id, filter_label)) - continue; - cmps++; - float dist = dist_scratch[m]; - if (stats != nullptr) { - stats->n_cmps++; - } - - Neighbor nn(id, dist); - retset.insert(nn); + cpu_timer.reset(); + // process prefetch-ed nhood + for (uint64_t m = 0; m < nnbrs; ++m) + { + uint32_t id = node_nbrs[m]; + if (visited.insert(id).second) + { + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filter && _filter_store->is_dummy_point(id)) + continue; + + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) + continue; + cmps++; + float dist = dist_scratch[m]; + if (stats != nullptr) + { + stats->n_cmps++; + } + + Neighbor nn(id, dist); + retset.insert(nn); + } + } + + if (stats != nullptr) + { + stats->cpu_us += (float)cpu_timer.elapsed(); + } } - } - if (stats != nullptr) { - stats->cpu_us += (float)cpu_timer.elapsed(); - } + hops++; } - hops++; - } - - // re-sort by distance - std::sort(full_retset.begin(), full_retset.end()); + // re-sort by distance + std::sort(full_retset.begin(), full_retset.end()); - if (use_reorder_data) { - if (!(this->_reorder_data_exists)) { - throw ANNException("Requested use of reordering data which does " - "not exist in index " - "file", - -1, __FUNCSIG__, __FILE__, __LINE__); - } + if (use_reorder_data) + { + if (!(this->_reorder_data_exists)) + { + throw ANNException("Requested use of reordering data which does " + "not exist in index " + "file", + -1, __FUNCSIG__, __FILE__, __LINE__); + } - std::vector vec_read_reqs; + std::vector vec_read_reqs; - if (full_retset.size() > k_search * FULL_PRECISION_REORDER_MULTIPLIER) - full_retset.erase(full_retset.begin() + - k_search * FULL_PRECISION_REORDER_MULTIPLIER, - full_retset.end()); + if (full_retset.size() > k_search * FULL_PRECISION_REORDER_MULTIPLIER) + full_retset.erase(full_retset.begin() + k_search * FULL_PRECISION_REORDER_MULTIPLIER, full_retset.end()); - for (size_t i = 0; i < full_retset.size(); ++i) { - // MULTISECTORFIX - vec_read_reqs.emplace_back( - VECTOR_SECTOR_NO(((size_t)full_retset[i].id)) * defaults::SECTOR_LEN, - defaults::SECTOR_LEN, sector_scratch + i * defaults::SECTOR_LEN); + for (size_t i = 0; i < full_retset.size(); ++i) + { + // MULTISECTORFIX + vec_read_reqs.emplace_back(VECTOR_SECTOR_NO(((size_t)full_retset[i].id)) * defaults::SECTOR_LEN, + defaults::SECTOR_LEN, sector_scratch + i * defaults::SECTOR_LEN); - if (stats != nullptr) { - stats->n_4k++; - stats->n_ios++; - } - } + if (stats != nullptr) + { + stats->n_4k++; + stats->n_ios++; + } + } - io_timer.reset(); + io_timer.reset(); #ifdef USE_BING_INFRA - reader->read(vec_read_reqs, ctx, true); // async reader windows. + reader->read(vec_read_reqs, ctx, true); // async reader windows. #else - reader->read(vec_read_reqs, ctx); // synchronous IO linux + reader->read(vec_read_reqs, ctx); // synchronous IO linux #endif - if (stats != nullptr) { - stats->io_us += io_timer.elapsed(); - } + if (stats != nullptr) + { + stats->io_us += io_timer.elapsed(); + } - for (size_t i = 0; i < full_retset.size(); ++i) { - auto id = full_retset[i].id; - // MULTISECTORFIX - auto location = (sector_scratch + i * defaults::SECTOR_LEN) + - VECTOR_SECTOR_OFFSET(id); - full_retset[i].distance = _dist_cmp->compare( - aligned_query_T, (T *)location, (uint32_t)this->_data_dim); - } + for (size_t i = 0; i < full_retset.size(); ++i) + { + auto id = full_retset[i].id; + // MULTISECTORFIX + auto location = (sector_scratch + i * defaults::SECTOR_LEN) + VECTOR_SECTOR_OFFSET(id); + full_retset[i].distance = _dist_cmp->compare(aligned_query_T, (T *)location, (uint32_t)this->_data_dim); + } - std::sort(full_retset.begin(), full_retset.end()); - } - - // copy k_search values - for (uint64_t i = 0; i < k_search; i++) { - indices[i] = full_retset[i].id; - auto key = (uint32_t)indices[i]; - if (_filter_store->is_dummy_point(key)) { - indices[i] = _filter_store->get_real_point_for_dummy(key); + std::sort(full_retset.begin(), full_retset.end()); } - if (distances != nullptr) { - distances[i] = full_retset[i].distance; - if (metric == diskann::Metric::INNER_PRODUCT) { - // flip the sign to convert min to max - distances[i] = (-distances[i]); - // rescale to revert back to original norms (cancelling the - // effect of base and query pre-processing) - if (_max_base_norm != 0) - distances[i] *= (_max_base_norm * query_norm); - } + // copy k_search values + for (uint64_t i = 0; i < k_search; i++) + { + indices[i] = full_retset[i].id; + auto key = (uint32_t)indices[i]; + if (_filter_store->is_dummy_point(key)) + { + indices[i] = _filter_store->get_real_point_for_dummy(key); + } + + if (distances != nullptr) + { + distances[i] = full_retset[i].distance; + if (metric == diskann::Metric::INNER_PRODUCT) + { + // flip the sign to convert min to max + distances[i] = (-distances[i]); + // rescale to revert back to original norms (cancelling the + // effect of base and query pre-processing) + if (_max_base_norm != 0) + distances[i] *= (_max_base_norm * query_norm); + } + } } - } #ifdef USE_BING_INFRA - ctx.m_completeCount = 0; + ctx.m_completeCount = 0; #endif - if (stats != nullptr) { - stats->total_us = (float)query_timer.elapsed(); - } + if (stats != nullptr) + { + stats->total_us = (float)query_timer.elapsed(); + } } // range search returns results of all neighbors within distance of range. // indices and distances need to be pre-allocated of size l_search and the // return value is the number of matching hits. template -uint32_t PQFlashIndex::range_search( - const T *query1, const double range, const uint64_t min_l_search, - const uint64_t max_l_search, std::vector &indices, - std::vector &distances, const uint64_t min_beam_width, - QueryStats *stats) { - uint32_t res_count = 0; - - bool stop_flag = false; - - uint32_t l_search = - (uint32_t)min_l_search; // starting size of the candidate list - while (!stop_flag) { - indices.resize(l_search); - distances.resize(l_search); - uint64_t cur_bw = - min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; - cur_bw = (cur_bw > 100) ? 100 : cur_bw; - for (auto &x : distances) - x = std::numeric_limits::max(); - this->cached_beam_search(query1, l_search, l_search, indices.data(), - distances.data(), cur_bw, false, stats); - for (uint32_t i = 0; i < l_search; i++) { - if (distances[i] > (float)range) { - res_count = i; - break; - } else if (i == l_search - 1) - res_count = l_search; +uint32_t PQFlashIndex::range_search(const T *query1, const double range, const uint64_t min_l_search, + const uint64_t max_l_search, std::vector &indices, + std::vector &distances, const uint64_t min_beam_width, + QueryStats *stats) +{ + uint32_t res_count = 0; + + bool stop_flag = false; + + uint32_t l_search = (uint32_t)min_l_search; // starting size of the candidate list + while (!stop_flag) + { + indices.resize(l_search); + distances.resize(l_search); + uint64_t cur_bw = min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; + cur_bw = (cur_bw > 100) ? 100 : cur_bw; + for (auto &x : distances) + x = std::numeric_limits::max(); + this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, false, stats); + for (uint32_t i = 0; i < l_search; i++) + { + if (distances[i] > (float)range) + { + res_count = i; + break; + } + else if (i == l_search - 1) + res_count = l_search; + } + if (res_count < (uint32_t)(l_search / 2.0)) + stop_flag = true; + l_search = l_search * 2; + if (l_search > max_l_search) + stop_flag = true; } - if (res_count < (uint32_t)(l_search / 2.0)) - stop_flag = true; - l_search = l_search * 2; - if (l_search > max_l_search) - stop_flag = true; - } - indices.resize(res_count); - distances.resize(res_count); - return res_count; + indices.resize(res_count); + distances.resize(res_count); + return res_count; } -template -uint64_t PQFlashIndex::get_data_dim() { - return _data_dim; +template uint64_t PQFlashIndex::get_data_dim() +{ + return _data_dim; } -template -diskann::Metric PQFlashIndex::get_metric() { - return this->metric; +template diskann::Metric PQFlashIndex::get_metric() +{ + return this->metric; } template -LabelT -PQFlashIndex::get_converted_label(const std::string &filter_label) { - return _filter_store->get_converted_label(filter_label); +LabelT PQFlashIndex::get_converted_label(const std::string &filter_label) +{ + return _filter_store->get_converted_label(filter_label); } #ifdef EXEC_ENV_OLS -template -char *PQFlashIndex::getHeaderBytes() { - IOContext &ctx = reader->get_ctx(); - AlignedRead readReq; - readReq.buf = new char[PQFlashIndex::HEADER_SIZE]; - readReq.len = PQFlashIndex::HEADER_SIZE; - readReq.offset = 0; +template char *PQFlashIndex::getHeaderBytes() +{ + IOContext &ctx = reader->get_ctx(); + AlignedRead readReq; + readReq.buf = new char[PQFlashIndex::HEADER_SIZE]; + readReq.len = PQFlashIndex::HEADER_SIZE; + readReq.offset = 0; - std::vector readReqs; - readReqs.push_back(readReq); + std::vector readReqs; + readReqs.push_back(readReq); - reader->read(readReqs, ctx, false); + reader->read(readReqs, ctx, false); - return (char *)readReq.buf; + return (char *)readReq.buf; } #endif template -std::vector -PQFlashIndex::get_pq_vector(std::uint64_t vid) { - std::uint8_t *pqVec = &this->data[vid * this->_n_chunks]; - return std::vector(pqVec, pqVec + this->_n_chunks); +std::vector PQFlashIndex::get_pq_vector(std::uint64_t vid) +{ + std::uint8_t *pqVec = &this->data[vid * this->_n_chunks]; + return std::vector(pqVec, pqVec + this->_n_chunks); } -template -std::uint64_t PQFlashIndex::get_num_points() { - return _num_points; +template std::uint64_t PQFlashIndex::get_num_points() +{ + return _num_points; } // instantiations diff --git a/src/pq_l2_distance.cpp b/src/pq_l2_distance.cpp index fc5c3d4e4..9168d26be 100644 --- a/src/pq_l2_distance.cpp +++ b/src/pq_l2_distance.cpp @@ -6,271 +6,275 @@ // block size for reading/processing large files and matrices in blocks #define BLOCK_SIZE 5000000 -namespace diskann { +namespace diskann +{ template -PQL2Distance::PQL2Distance(uint32_t num_chunks, bool use_opq) - : _num_chunks(num_chunks), _is_opq(use_opq) {} +PQL2Distance::PQL2Distance(uint32_t num_chunks, bool use_opq) : _num_chunks(num_chunks), _is_opq(use_opq) +{ +} -template PQL2Distance::~PQL2Distance() { +template PQL2Distance::~PQL2Distance() +{ #ifndef EXEC_ENV_OLS - if (_tables != nullptr) - delete[] _tables; - if (_chunk_offsets != nullptr) - delete[] _chunk_offsets; - if (_centroid != nullptr) - delete[] _centroid; - if (_rotmat_tr != nullptr) - delete[] _rotmat_tr; + if (_tables != nullptr) + delete[] _tables; + if (_chunk_offsets != nullptr) + delete[] _chunk_offsets; + if (_centroid != nullptr) + delete[] _centroid; + if (_rotmat_tr != nullptr) + delete[] _rotmat_tr; #endif - if (_tables_tr != nullptr) - delete[] _tables_tr; + if (_tables_tr != nullptr) + delete[] _tables_tr; } -template bool PQL2Distance::is_opq() const { - return this->_is_opq; +template bool PQL2Distance::is_opq() const +{ + return this->_is_opq; } template -std::string PQL2Distance::get_quantized_vectors_filename( - const std::string &prefix) const { - if (_num_chunks == 0) { - throw diskann::ANNException( - "Must set num_chunks before calling get_quantized_vectors_filename", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - return diskann::get_quantized_vectors_filename(prefix, _is_opq, - (uint32_t)_num_chunks); +std::string PQL2Distance::get_quantized_vectors_filename(const std::string &prefix) const +{ + if (_num_chunks == 0) + { + throw diskann::ANNException("Must set num_chunks before calling get_quantized_vectors_filename", -1, + __FUNCSIG__, __FILE__, __LINE__); + } + return diskann::get_quantized_vectors_filename(prefix, _is_opq, (uint32_t)_num_chunks); } -template -std::string -PQL2Distance::get_pivot_data_filename(const std::string &prefix) const { - if (_num_chunks == 0) { - throw diskann::ANNException( - "Must set num_chunks before calling get_pivot_data_filename", -1, - __FUNCSIG__, __FILE__, __LINE__); - } - return diskann::get_pivot_data_filename(prefix, _is_opq, - (uint32_t)_num_chunks); +template std::string PQL2Distance::get_pivot_data_filename(const std::string &prefix) const +{ + if (_num_chunks == 0) + { + throw diskann::ANNException("Must set num_chunks before calling get_pivot_data_filename", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + return diskann::get_pivot_data_filename(prefix, _is_opq, (uint32_t)_num_chunks); } template -std::string PQL2Distance::get_rotation_matrix_suffix( - const std::string &pq_pivots_filename) const { - return diskann::get_rotation_matrix_suffix(pq_pivots_filename); +std::string PQL2Distance::get_rotation_matrix_suffix(const std::string &pq_pivots_filename) const +{ + return diskann::get_rotation_matrix_suffix(pq_pivots_filename); } #ifdef EXEC_ENV_OLS template -void PQL2Distance::load_pivot_data(MemoryMappedFiles &files, - const std::string &pq_table_file, - size_t num_chunks) { +void PQL2Distance::load_pivot_data(MemoryMappedFiles &files, const std::string &pq_table_file, + size_t num_chunks) +{ #else template -void PQL2Distance::load_pivot_data(const std::string &pq_table_file, - size_t num_chunks) { +void PQL2Distance::load_pivot_data(const std::string &pq_table_file, size_t num_chunks) +{ #endif - uint64_t nr, nc; - // std::string rotmat_file = get_opq_rot_matrix_filename(pq_table_file, - // false); + uint64_t nr, nc; + // std::string rotmat_file = get_opq_rot_matrix_filename(pq_table_file, + // false); #ifdef EXEC_ENV_OLS - size_t *file_offset_data; // since load_bin only sets the pointer, no need - // to delete. - diskann::load_bin(files, pq_table_file, file_offset_data, nr, nc); + size_t *file_offset_data; // since load_bin only sets the pointer, no need + // to delete. + diskann::load_bin(files, pq_table_file, file_offset_data, nr, nc); #else - std::unique_ptr file_offset_data; - diskann::load_bin(pq_table_file, file_offset_data, nr, nc); + std::unique_ptr file_offset_data; + diskann::load_bin(pq_table_file, file_offset_data, nr, nc); #endif - bool use_old_filetype = false; + bool use_old_filetype = false; - if (nr != 4 && nr != 5) { - diskann::cout << "Error reading pq_pivots file " << pq_table_file - << ". Offsets dont contain correct metadata, # offsets = " - << nr << ", but expecting " << 4 << " or " << 5; - throw diskann::ANNException("Error reading pq_pivots file at offsets data.", - -1, __FUNCSIG__, __FILE__, __LINE__); - } + if (nr != 4 && nr != 5) + { + diskann::cout << "Error reading pq_pivots file " << pq_table_file + << ". Offsets dont contain correct metadata, # offsets = " << nr << ", but expecting " << 4 + << " or " << 5; + throw diskann::ANNException("Error reading pq_pivots file at offsets data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - if (nr == 4) { - diskann::cout << "Offsets: " << file_offset_data[0] << " " - << file_offset_data[1] << " " << file_offset_data[2] << " " - << file_offset_data[3] << std::endl; - } else if (nr == 5) { - use_old_filetype = true; - diskann::cout << "Offsets: " << file_offset_data[0] << " " - << file_offset_data[1] << " " << file_offset_data[2] << " " - << file_offset_data[3] << file_offset_data[4] << std::endl; - } else { - throw diskann::ANNException("Wrong number of offsets in pq_pivots", -1, - __FUNCSIG__, __FILE__, __LINE__); - } + if (nr == 4) + { + diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2] + << " " << file_offset_data[3] << std::endl; + } + else if (nr == 5) + { + use_old_filetype = true; + diskann::cout << "Offsets: " << file_offset_data[0] << " " << file_offset_data[1] << " " << file_offset_data[2] + << " " << file_offset_data[3] << file_offset_data[4] << std::endl; + } + else + { + throw diskann::ANNException("Wrong number of offsets in pq_pivots", -1, __FUNCSIG__, __FILE__, __LINE__); + } #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_table_file, tables, nr, nc, - file_offset_data[0]); + diskann::load_bin(files, pq_table_file, tables, nr, nc, file_offset_data[0]); #else - diskann::load_bin(pq_table_file, _tables, nr, nc, file_offset_data[0]); + diskann::load_bin(pq_table_file, _tables, nr, nc, file_offset_data[0]); #endif - if ((nr != NUM_PQ_CENTROIDS)) { - diskann::cout << "Error reading pq_pivots file " << pq_table_file - << ". file_num_centers = " << nr << " but expecting " - << NUM_PQ_CENTROIDS << " centers"; - throw diskann::ANNException("Error reading pq_pivots file at pivots data.", - -1, __FUNCSIG__, __FILE__, __LINE__); - } + if ((nr != NUM_PQ_CENTROIDS)) + { + diskann::cout << "Error reading pq_pivots file " << pq_table_file << ". file_num_centers = " << nr + << " but expecting " << NUM_PQ_CENTROIDS << " centers"; + throw diskann::ANNException("Error reading pq_pivots file at pivots data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - this->_ndims = nc; + this->_ndims = nc; #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_table_file, centroid, nr, nc, - file_offset_data[1]); + diskann::load_bin(files, pq_table_file, centroid, nr, nc, file_offset_data[1]); #else - diskann::load_bin(pq_table_file, _centroid, nr, nc, - file_offset_data[1]); + diskann::load_bin(pq_table_file, _centroid, nr, nc, file_offset_data[1]); #endif - if ((nr != this->_ndims) || (nc != 1)) { - diskann::cerr << "Error reading centroids from pq_pivots file " - << pq_table_file << ". file_dim = " << nr - << ", file_cols = " << nc << " but expecting " << this->_ndims - << " entries in 1 dimension."; - throw diskann::ANNException( - "Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } + if ((nr != this->_ndims) || (nc != 1)) + { + diskann::cerr << "Error reading centroids from pq_pivots file " << pq_table_file << ". file_dim = " << nr + << ", file_cols = " << nc << " but expecting " << this->_ndims << " entries in 1 dimension."; + throw diskann::ANNException("Error reading pq_pivots file at centroid data.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } - int chunk_offsets_index = 2; - if (use_old_filetype) { - chunk_offsets_index = 3; - } + int chunk_offsets_index = 2; + if (use_old_filetype) + { + chunk_offsets_index = 3; + } #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_table_file, chunk_offsets, nr, nc, - file_offset_data[chunk_offsets_index]); + diskann::load_bin(files, pq_table_file, chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]); #else - diskann::load_bin(pq_table_file, _chunk_offsets, nr, nc, - file_offset_data[chunk_offsets_index]); + diskann::load_bin(pq_table_file, _chunk_offsets, nr, nc, file_offset_data[chunk_offsets_index]); #endif - if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0)) { - diskann::cerr << "Error loading chunk offsets file. numc: " << nc - << " (should be 1). numr: " << nr << " (should be " - << num_chunks + 1 << " or 0 if we need to infer)" - << std::endl; - throw diskann::ANNException("Error loading chunk offsets file", -1, - __FUNCSIG__, __FILE__, __LINE__); - } + if (nc != 1 || (nr != num_chunks + 1 && num_chunks != 0)) + { + diskann::cerr << "Error loading chunk offsets file. numc: " << nc << " (should be 1). numr: " << nr + << " (should be " << num_chunks + 1 << " or 0 if we need to infer)" << std::endl; + throw diskann::ANNException("Error loading chunk offsets file", -1, __FUNCSIG__, __FILE__, __LINE__); + } - this->_num_chunks = nr - 1; - diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS - << ", #dims: " << this->_ndims - << ", #chunks: " << this->_num_chunks << std::endl; + this->_num_chunks = nr - 1; + diskann::cout << "Loaded PQ Pivots: #ctrs: " << NUM_PQ_CENTROIDS << ", #dims: " << this->_ndims + << ", #chunks: " << this->_num_chunks << std::endl; - // For OPQ there will be a rotation matrix to load. - if (this->_is_opq) { - std::string rotmat_file = get_rotation_matrix_suffix(pq_table_file); + // For OPQ there will be a rotation matrix to load. + if (this->_is_opq) + { + std::string rotmat_file = get_rotation_matrix_suffix(pq_table_file); #ifdef EXEC_ENV_OLS - diskann::load_bin(files, rotmat_file, (float *&)rotmat_tr, nr, nc); + diskann::load_bin(files, rotmat_file, (float *&)rotmat_tr, nr, nc); #else - diskann::load_bin(rotmat_file, _rotmat_tr, nr, nc); + diskann::load_bin(rotmat_file, _rotmat_tr, nr, nc); #endif - if (nr != this->_ndims || nc != this->_ndims) { - diskann::cerr << "Error loading rotation matrix file" << std::endl; - throw diskann::ANNException("Error loading rotation matrix file", -1, - __FUNCSIG__, __FILE__, __LINE__); + if (nr != this->_ndims || nc != this->_ndims) + { + diskann::cerr << "Error loading rotation matrix file" << std::endl; + throw diskann::ANNException("Error loading rotation matrix file", -1, __FUNCSIG__, __FILE__, __LINE__); + } } - } - // alloc and compute transpose - _tables_tr = new float[256 * this->_ndims]; - for (size_t i = 0; i < 256; i++) { - for (size_t j = 0; j < this->_ndims; j++) { - _tables_tr[j * 256 + i] = _tables[i * this->_ndims + j]; + // alloc and compute transpose + _tables_tr = new float[256 * this->_ndims]; + for (size_t i = 0; i < 256; i++) + { + for (size_t j = 0; j < this->_ndims; j++) + { + _tables_tr[j * 256 + i] = _tables[i * this->_ndims + j]; + } } - } } -template -uint32_t PQL2Distance::get_num_chunks() const { - return static_cast(_num_chunks); +template uint32_t PQL2Distance::get_num_chunks() const +{ + return static_cast(_num_chunks); } // REFACTOR: Instead of doing half the work in the caller and half in this // function, we let this function // do all of the work, making it easier for the caller. template -void PQL2Distance::preprocess_query(const data_t *aligned_query, - uint32_t dim, - PQScratch &scratch) { - // Copy query vector to float and then to "rotated" query - for (size_t d = 0; d < dim; d++) { - scratch.aligned_query_float[d] = (float)aligned_query[d]; - } - scratch.initialize(dim, aligned_query); +void PQL2Distance::preprocess_query(const data_t *aligned_query, uint32_t dim, PQScratch &scratch) +{ + // Copy query vector to float and then to "rotated" query + for (size_t d = 0; d < dim; d++) + { + scratch.aligned_query_float[d] = (float)aligned_query[d]; + } + scratch.initialize(dim, aligned_query); - for (uint32_t d = 0; d < _ndims; d++) { - scratch.rotated_query[d] -= _centroid[d]; - } - std::vector tmp(_ndims, 0); - if (_is_opq) { - for (uint32_t d = 0; d < _ndims; d++) { - for (uint32_t d1 = 0; d1 < _ndims; d1++) { - tmp[d] += scratch.rotated_query[d1] * _rotmat_tr[d1 * _ndims + d]; - } + for (uint32_t d = 0; d < _ndims; d++) + { + scratch.rotated_query[d] -= _centroid[d]; } - std::memcpy(scratch.rotated_query, tmp.data(), _ndims * sizeof(float)); - } - this->prepopulate_chunkwise_distances(scratch.rotated_query, - scratch.aligned_pqtable_dist_scratch); + std::vector tmp(_ndims, 0); + if (_is_opq) + { + for (uint32_t d = 0; d < _ndims; d++) + { + for (uint32_t d1 = 0; d1 < _ndims; d1++) + { + tmp[d] += scratch.rotated_query[d1] * _rotmat_tr[d1 * _ndims + d]; + } + } + std::memcpy(scratch.rotated_query, tmp.data(), _ndims * sizeof(float)); + } + this->prepopulate_chunkwise_distances(scratch.rotated_query, scratch.aligned_pqtable_dist_scratch); } template -void PQL2Distance::preprocessed_distance(PQScratch &pq_scratch, - const uint32_t n_ids, - float *dists_out) { - pq_dist_lookup(pq_scratch.aligned_pq_coord_scratch, n_ids, _num_chunks, - pq_scratch.aligned_pqtable_dist_scratch, dists_out); +void PQL2Distance::preprocessed_distance(PQScratch &pq_scratch, const uint32_t n_ids, float *dists_out) +{ + pq_dist_lookup(pq_scratch.aligned_pq_coord_scratch, n_ids, _num_chunks, pq_scratch.aligned_pqtable_dist_scratch, + dists_out); } template -void PQL2Distance::preprocessed_distance( - PQScratch &pq_scratch, const uint32_t n_ids, - std::vector &dists_out) { - pq_dist_lookup(pq_scratch.aligned_pq_coord_scratch, n_ids, _num_chunks, - pq_scratch.aligned_pqtable_dist_scratch, dists_out); +void PQL2Distance::preprocessed_distance(PQScratch &pq_scratch, const uint32_t n_ids, + std::vector &dists_out) +{ + pq_dist_lookup(pq_scratch.aligned_pq_coord_scratch, n_ids, _num_chunks, pq_scratch.aligned_pqtable_dist_scratch, + dists_out); } -template -float PQL2Distance::brute_force_distance(const float *query_vec, - uint8_t *base_vec) { - float res = 0; - for (size_t chunk = 0; chunk < _num_chunks; chunk++) { - for (size_t j = _chunk_offsets[chunk]; j < _chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = _tables_tr + (256 * j); - float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]); - res += diff * diff; +template float PQL2Distance::brute_force_distance(const float *query_vec, uint8_t *base_vec) +{ + float res = 0; + for (size_t chunk = 0; chunk < _num_chunks; chunk++) + { + for (size_t j = _chunk_offsets[chunk]; j < _chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = _tables_tr + (256 * j); + float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]); + res += diff * diff; + } } - } - return res; + return res; } template -void PQL2Distance::prepopulate_chunkwise_distances( - const float *query_vec, float *dist_vec) { - memset(dist_vec, 0, 256 * _num_chunks * sizeof(float)); - // chunk wise distance computation - for (size_t chunk = 0; chunk < _num_chunks; chunk++) { - // sum (q-c)^2 for the dimensions associated with this chunk - float *chunk_dists = dist_vec + (256 * chunk); - for (size_t j = _chunk_offsets[chunk]; j < _chunk_offsets[chunk + 1]; j++) { - const float *centers_dim_vec = _tables_tr + (256 * j); - for (size_t idx = 0; idx < 256; idx++) { - double diff = centers_dim_vec[idx] - (query_vec[j]); - chunk_dists[idx] += (float)(diff * diff); - } +void PQL2Distance::prepopulate_chunkwise_distances(const float *query_vec, float *dist_vec) +{ + memset(dist_vec, 0, 256 * _num_chunks * sizeof(float)); + // chunk wise distance computation + for (size_t chunk = 0; chunk < _num_chunks; chunk++) + { + // sum (q-c)^2 for the dimensions associated with this chunk + float *chunk_dists = dist_vec + (256 * chunk); + for (size_t j = _chunk_offsets[chunk]; j < _chunk_offsets[chunk + 1]; j++) + { + const float *centers_dim_vec = _tables_tr + (256 * j); + for (size_t idx = 0; idx < 256; idx++) + { + double diff = centers_dim_vec[idx] - (query_vec[j]); + chunk_dists[idx] += (float)(diff * diff); + } + } } - } } template DISKANN_DLLEXPORT class PQL2Distance; diff --git a/src/scratch.cpp b/src/scratch.cpp index 14ee81832..2203dcbc3 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -7,162 +7,160 @@ #include "pq_scratch.h" #include "scratch.h" -namespace diskann { +namespace diskann +{ // // Functions to manage scratch space for in-memory index based search // template -InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, - uint32_t r, uint32_t maxc, size_t dim, - size_t aligned_dim, - size_t alignment_factor, - bool init_pq_scratch) - : _L(0), _R(r), _maxc(maxc) { - if (search_l == 0 || indexing_l == 0 || r == 0 || dim == 0) { - std::stringstream ss; - ss << "In InMemQueryScratch, one of search_l = " << search_l - << ", indexing_l = " << indexing_l << ", dim = " << dim - << " or r = " << r << " is zero." << std::endl; - throw diskann::ANNException(ss.str(), -1); - } - - alloc_aligned(((void **)&this->_aligned_query_T), aligned_dim * sizeof(T), - alignment_factor * sizeof(T)); - memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); - - if (init_pq_scratch) - this->_pq_scratch = - new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); - else - this->_pq_scratch = nullptr; - - _occlude_factor.reserve(maxc); - _inserted_into_pool_bs = new boost::dynamic_bitset<>(); - _id_scratch.reserve( - (size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R)); - _dist_scratch.reserve( - (size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R)); - - resize_for_new_L(std::max(search_l, indexing_l)); +InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, + size_t aligned_dim, size_t alignment_factor, bool init_pq_scratch) + : _L(0), _R(r), _maxc(maxc) +{ + if (search_l == 0 || indexing_l == 0 || r == 0 || dim == 0) + { + std::stringstream ss; + ss << "In InMemQueryScratch, one of search_l = " << search_l << ", indexing_l = " << indexing_l + << ", dim = " << dim << " or r = " << r << " is zero." << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + alloc_aligned(((void **)&this->_aligned_query_T), aligned_dim * sizeof(T), alignment_factor * sizeof(T)); + memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); + + if (init_pq_scratch) + this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); + else + this->_pq_scratch = nullptr; + + _occlude_factor.reserve(maxc); + _inserted_into_pool_bs = new boost::dynamic_bitset<>(); + _id_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R)); + _dist_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R)); + + resize_for_new_L(std::max(search_l, indexing_l)); } -template void InMemQueryScratch::clear() { - _pool.clear(); - _best_l_nodes.clear(); - _occlude_factor.clear(); +template void InMemQueryScratch::clear() +{ + _pool.clear(); + _best_l_nodes.clear(); + _occlude_factor.clear(); - _inserted_into_pool_rs.clear(); - _inserted_into_pool_bs->reset(); + _inserted_into_pool_rs.clear(); + _inserted_into_pool_bs->reset(); - _id_scratch.clear(); - _dist_scratch.clear(); + _id_scratch.clear(); + _dist_scratch.clear(); - _expanded_nodes_set.clear(); - _expanded_nghrs_vec.clear(); - _occlude_list_output.clear(); + _expanded_nodes_set.clear(); + _expanded_nghrs_vec.clear(); + _occlude_list_output.clear(); } -template -void InMemQueryScratch::resize_for_new_L(uint32_t new_l) { - if (new_l > _L) { - _L = new_l; - _pool.reserve(3 * _L + _R); - _best_l_nodes.reserve(_L); - - _inserted_into_pool_rs.reserve(20 * _L); - } +template void InMemQueryScratch::resize_for_new_L(uint32_t new_l) +{ + if (new_l > _L) + { + _L = new_l; + _pool.reserve(3 * _L + _R); + _best_l_nodes.reserve(_L); + + _inserted_into_pool_rs.reserve(20 * _L); + } } -template InMemQueryScratch::~InMemQueryScratch() { - if (this->_aligned_query_T != nullptr) { - aligned_free(this->_aligned_query_T); - this->_aligned_query_T = nullptr; - } +template InMemQueryScratch::~InMemQueryScratch() +{ + if (this->_aligned_query_T != nullptr) + { + aligned_free(this->_aligned_query_T); + this->_aligned_query_T = nullptr; + } - delete this->_pq_scratch; - delete _inserted_into_pool_bs; + delete this->_pq_scratch; + delete _inserted_into_pool_bs; } // // Functions to manage scratch space for SSD based search // -template void SSDQueryScratch::reset() { - sector_idx = 0; - visited.clear(); - retset.clear(); - full_retset.clear(); +template void SSDQueryScratch::reset() +{ + sector_idx = 0; + visited.clear(); + retset.clear(); + full_retset.clear(); } -template -SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, - size_t visited_reserve) { - size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); +template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) +{ + size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); - diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); - diskann::alloc_aligned((void **)§or_scratch, - defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN, - defaults::SECTOR_LEN); - diskann::alloc_aligned((void **)&this->_aligned_query_T, - aligned_dim * sizeof(T), 8 * sizeof(T)); + diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); + diskann::alloc_aligned((void **)§or_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN, + defaults::SECTOR_LEN); + diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); - this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); + this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); - memset(coord_scratch, 0, coord_alloc_size); - memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); + memset(coord_scratch, 0, coord_alloc_size); + memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); - visited.reserve(visited_reserve); - full_retset.reserve(visited_reserve); + visited.reserve(visited_reserve); + full_retset.reserve(visited_reserve); } -template SSDQueryScratch::~SSDQueryScratch() { - diskann::aligned_free((void *)coord_scratch); - diskann::aligned_free((void *)sector_scratch); - diskann::aligned_free((void *)this->_aligned_query_T); +template SSDQueryScratch::~SSDQueryScratch() +{ + diskann::aligned_free((void *)coord_scratch); + diskann::aligned_free((void *)sector_scratch); + diskann::aligned_free((void *)this->_aligned_query_T); - delete this->_pq_scratch; + delete this->_pq_scratch; } template -SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) - : scratch(aligned_dim, visited_reserve) {} +SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve) +{ +} -template void SSDThreadData::clear() { scratch.reset(); } +template void SSDThreadData::clear() +{ + scratch.reset(); +} -template -PQScratch::PQScratch(size_t graph_degree, size_t aligned_dim) { - diskann::alloc_aligned( - (void **)&aligned_pq_coord_scratch, - (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); - diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, - 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), 256); - diskann::alloc_aligned((void **)&aligned_dist_scratch, - (size_t)graph_degree * sizeof(float), 256); - diskann::alloc_aligned((void **)&aligned_query_float, - aligned_dim * sizeof(float), 8 * sizeof(float)); - diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), - 8 * sizeof(float)); - - memset(aligned_query_float, 0, aligned_dim * sizeof(float)); - memset(rotated_query, 0, aligned_dim * sizeof(float)); +template PQScratch::PQScratch(size_t graph_degree, size_t aligned_dim) +{ + diskann::alloc_aligned((void **)&aligned_pq_coord_scratch, + (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); + diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), 256); + diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256); + diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float)); + diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float)); + + memset(aligned_query_float, 0, aligned_dim * sizeof(float)); + memset(rotated_query, 0, aligned_dim * sizeof(float)); } -template PQScratch::~PQScratch() { - diskann::aligned_free((void *)aligned_pq_coord_scratch); - diskann::aligned_free((void *)aligned_pqtable_dist_scratch); - diskann::aligned_free((void *)aligned_dist_scratch); - diskann::aligned_free((void *)aligned_query_float); - diskann::aligned_free((void *)rotated_query); +template PQScratch::~PQScratch() +{ + diskann::aligned_free((void *)aligned_pq_coord_scratch); + diskann::aligned_free((void *)aligned_pqtable_dist_scratch); + diskann::aligned_free((void *)aligned_dist_scratch); + diskann::aligned_free((void *)aligned_query_float); + diskann::aligned_free((void *)rotated_query); } -template -void PQScratch::initialize(size_t dim, const T *query, const float norm) { - for (size_t d = 0; d < dim; ++d) { - if (norm != 1.0f) - rotated_query[d] = aligned_query_float[d] = - static_cast(query[d]) / norm; - else - rotated_query[d] = aligned_query_float[d] = static_cast(query[d]); - } +template void PQScratch::initialize(size_t dim, const T *query, const float norm) +{ + for (size_t d = 0; d < dim; ++d) + { + if (norm != 1.0f) + rotated_query[d] = aligned_query_float[d] = static_cast(query[d]) / norm; + else + rotated_query[d] = aligned_query_float[d] = static_cast(query[d]); + } } template DISKANN_DLLEXPORT class InMemQueryScratch; diff --git a/src/utils.cpp b/src/utils.cpp index eb3fe8db1..3773cda22 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -17,44 +17,48 @@ const uint32_t MAX_SIMULTANEOUS_READ_REQUESTS = 128; // Taken from: // https://insufficientlycomplicated.wordpress.com/2011/11/07/detecting-intel-advanced-vector-extensions-avx-in-visual-studio/ -bool cpuHasAvxSupport() { - bool avxSupported = false; - - // Checking for AVX requires 3 things: - // 1) CPUID indicates that the OS uses XSAVE and XRSTORE - // instructions (allowing saving YMM registers on context - // switch) - // 2) CPUID indicates support for AVX - // 3) XGETBV indicates the AVX registers will be saved and - // restored on context switch - // - // Note that XGETBV is only available on 686 or later CPUs, so - // the instruction needs to be conditionally run. - int cpuInfo[4]; - __cpuid(cpuInfo, 1); - - bool osUsesXSAVE_XRSTORE = cpuInfo[2] & (1 << 27) || false; - bool cpuAVXSuport = cpuInfo[2] & (1 << 28) || false; - - if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { - // Check if the OS will save the YMM registers - unsigned long long xcrFeatureMask = _xgetbv(_XCR_XFEATURE_ENABLED_MASK); - avxSupported = (xcrFeatureMask & 0x6) || false; - } - - return avxSupported; +bool cpuHasAvxSupport() +{ + bool avxSupported = false; + + // Checking for AVX requires 3 things: + // 1) CPUID indicates that the OS uses XSAVE and XRSTORE + // instructions (allowing saving YMM registers on context + // switch) + // 2) CPUID indicates support for AVX + // 3) XGETBV indicates the AVX registers will be saved and + // restored on context switch + // + // Note that XGETBV is only available on 686 or later CPUs, so + // the instruction needs to be conditionally run. + int cpuInfo[4]; + __cpuid(cpuInfo, 1); + + bool osUsesXSAVE_XRSTORE = cpuInfo[2] & (1 << 27) || false; + bool cpuAVXSuport = cpuInfo[2] & (1 << 28) || false; + + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) + { + // Check if the OS will save the YMM registers + unsigned long long xcrFeatureMask = _xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avxSupported = (xcrFeatureMask & 0x6) || false; + } + + return avxSupported; } -bool cpuHasAvx2Support() { - int cpuInfo[4]; - __cpuid(cpuInfo, 0); - int n = cpuInfo[0]; - if (n >= 7) { - __cpuidex(cpuInfo, 7, 0); - static int avx2Mask = 0x20; - return (cpuInfo[1] & avx2Mask) > 0; - } - return false; +bool cpuHasAvx2Support() +{ + int cpuInfo[4]; + __cpuid(cpuInfo, 0); + int n = cpuInfo[0]; + if (n >= 7) + { + __cpuidex(cpuInfo, 7, 0); + static int avx2Mask = 0x20; + return (cpuInfo[1] & avx2Mask) > 0; + } + return false; } bool AvxSupportedCPU = cpuHasAvxSupport(); @@ -66,422 +70,407 @@ bool Avx2SupportedCPU = true; bool AvxSupportedCPU = false; #endif -namespace diskann { +namespace diskann +{ -void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, - size_t npts, size_t ndims) { - readr.read((char *)read_buf, npts * ndims * sizeof(float)); - uint32_t ndims_u32 = (uint32_t)ndims; +void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, size_t npts, size_t ndims) +{ + readr.read((char *)read_buf, npts * ndims * sizeof(float)); + uint32_t ndims_u32 = (uint32_t)ndims; #pragma omp parallel for - for (int64_t i = 0; i < (int64_t)npts; i++) { - float norm_pt = std::numeric_limits::epsilon(); - for (uint32_t dim = 0; dim < ndims_u32; dim++) { - norm_pt += *(read_buf + i * ndims + dim) * *(read_buf + i * ndims + dim); - } - norm_pt = std::sqrt(norm_pt); - for (uint32_t dim = 0; dim < ndims_u32; dim++) { - *(read_buf + i * ndims + dim) = *(read_buf + i * ndims + dim) / norm_pt; + for (int64_t i = 0; i < (int64_t)npts; i++) + { + float norm_pt = std::numeric_limits::epsilon(); + for (uint32_t dim = 0; dim < ndims_u32; dim++) + { + norm_pt += *(read_buf + i * ndims + dim) * *(read_buf + i * ndims + dim); + } + norm_pt = std::sqrt(norm_pt); + for (uint32_t dim = 0; dim < ndims_u32; dim++) + { + *(read_buf + i * ndims + dim) = *(read_buf + i * ndims + dim) / norm_pt; + } } - } - writr.write((char *)read_buf, npts * ndims * sizeof(float)); + writr.write((char *)read_buf, npts * ndims * sizeof(float)); } -void normalize_data_file(const std::string &inFileName, - const std::string &outFileName) { - std::ifstream readr(inFileName, std::ios::binary); - std::ofstream writr(outFileName, std::ios::binary); - - int npts_s32, ndims_s32; - readr.read((char *)&npts_s32, sizeof(int32_t)); - readr.read((char *)&ndims_s32, sizeof(int32_t)); - - writr.write((char *)&npts_s32, sizeof(int32_t)); - writr.write((char *)&ndims_s32, sizeof(int32_t)); - - size_t npts = (size_t)npts_s32; - size_t ndims = (size_t)ndims_s32; - diskann::cout << "Normalizing FLOAT vectors in file: " << inFileName - << std::endl; - diskann::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims - << std::endl; - - size_t blk_size = 131072; - size_t nblks = ROUND_UP(npts, blk_size) / blk_size; - diskann::cout << "# blks: " << nblks << std::endl; - - float *read_buf = new float[npts * ndims]; - for (size_t i = 0; i < nblks; i++) { - size_t cblk_size = std::min(npts - i * blk_size, blk_size); - block_convert(writr, readr, read_buf, cblk_size, ndims); - } - delete[] read_buf; - - diskann::cout << "Wrote normalized points to file: " << outFileName - << std::endl; -} +void normalize_data_file(const std::string &inFileName, const std::string &outFileName) +{ + std::ifstream readr(inFileName, std::ios::binary); + std::ofstream writr(outFileName, std::ios::binary); -double calculate_recall(uint32_t num_queries, uint32_t *gold_std, - float *gs_dist, uint32_t dim_gs, uint32_t *our_results, - uint32_t dim_or, uint32_t recall_at) { - double total_recall = 0; - std::set gt, res; - - for (size_t i = 0; i < num_queries; i++) { - gt.clear(); - res.clear(); - uint32_t *gt_vec = gold_std + dim_gs * i; - uint32_t *res_vec = our_results + dim_or * i; - size_t tie_breaker = recall_at; - if (gs_dist != nullptr) { - tie_breaker = recall_at - 1; - float *gt_dist_vec = gs_dist + dim_gs * i; - while (tie_breaker < dim_gs && - gt_dist_vec[tie_breaker] == gt_dist_vec[recall_at - 1]) - tie_breaker++; - } + int npts_s32, ndims_s32; + readr.read((char *)&npts_s32, sizeof(int32_t)); + readr.read((char *)&ndims_s32, sizeof(int32_t)); + + writr.write((char *)&npts_s32, sizeof(int32_t)); + writr.write((char *)&ndims_s32, sizeof(int32_t)); - gt.insert(gt_vec, gt_vec + tie_breaker); - res.insert(res_vec, - res_vec + recall_at); // change to recall_at for recall k@k - // or dim_or for k@dim_or - uint32_t cur_recall = 0; - for (auto &v : gt) { - if (res.find(v) != res.end()) { - cur_recall++; - } + size_t npts = (size_t)npts_s32; + size_t ndims = (size_t)ndims_s32; + diskann::cout << "Normalizing FLOAT vectors in file: " << inFileName << std::endl; + diskann::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; + + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; + diskann::cout << "# blks: " << nblks << std::endl; + + float *read_buf = new float[npts * ndims]; + for (size_t i = 0; i < nblks; i++) + { + size_t cblk_size = std::min(npts - i * blk_size, blk_size); + block_convert(writr, readr, read_buf, cblk_size, ndims); } - total_recall += cur_recall; - } - return total_recall / (num_queries) * (100.0 / recall_at); + delete[] read_buf; + + diskann::cout << "Wrote normalized points to file: " << outFileName << std::endl; } -double calculate_recall(uint32_t num_queries, uint32_t *gold_std, - float *gs_dist, uint32_t dim_gs, uint32_t *our_results, - uint32_t dim_or, uint32_t recall_at, - const tsl::robin_set &active_tags) { - double total_recall = 0; - std::set gt, res; - bool printed = false; - for (size_t i = 0; i < num_queries; i++) { - gt.clear(); - res.clear(); - uint32_t *gt_vec = gold_std + dim_gs * i; - uint32_t *res_vec = our_results + dim_or * i; - size_t tie_breaker = recall_at; - uint32_t active_points_count = 0; - uint32_t cur_counter = 0; - while (active_points_count < recall_at && cur_counter < dim_gs) { - if (active_tags.find(*(gt_vec + cur_counter)) != active_tags.end()) { - active_points_count++; - } - cur_counter++; - } - if (active_tags.empty()) - cur_counter = recall_at; - - if ((active_points_count < recall_at && !active_tags.empty()) && !printed) { - diskann::cout << "Warning: Couldn't find enough closest neighbors " - << active_points_count << "/" << recall_at - << " from " - "truthset for query # " - << i << ". Will result in under-reported value of recall." - << std::endl; - printed = true; - } - if (gs_dist != nullptr) { - tie_breaker = cur_counter - 1; - float *gt_dist_vec = gs_dist + dim_gs * i; - while (tie_breaker < dim_gs && - gt_dist_vec[tie_breaker] == gt_dist_vec[cur_counter - 1]) - tie_breaker++; +double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs, + uint32_t *our_results, uint32_t dim_or, uint32_t recall_at) +{ + double total_recall = 0; + std::set gt, res; + + for (size_t i = 0; i < num_queries; i++) + { + gt.clear(); + res.clear(); + uint32_t *gt_vec = gold_std + dim_gs * i; + uint32_t *res_vec = our_results + dim_or * i; + size_t tie_breaker = recall_at; + if (gs_dist != nullptr) + { + tie_breaker = recall_at - 1; + float *gt_dist_vec = gs_dist + dim_gs * i; + while (tie_breaker < dim_gs && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_at - 1]) + tie_breaker++; + } + + gt.insert(gt_vec, gt_vec + tie_breaker); + res.insert(res_vec, + res_vec + recall_at); // change to recall_at for recall k@k + // or dim_or for k@dim_or + uint32_t cur_recall = 0; + for (auto &v : gt) + { + if (res.find(v) != res.end()) + { + cur_recall++; + } + } + total_recall += cur_recall; } + return total_recall / (num_queries) * (100.0 / recall_at); +} - gt.insert(gt_vec, gt_vec + tie_breaker); - res.insert(res_vec, res_vec + recall_at); - uint32_t cur_recall = 0; - for (auto &v : res) { - if (gt.find(v) != gt.end()) { - cur_recall++; - } +double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs, + uint32_t *our_results, uint32_t dim_or, uint32_t recall_at, + const tsl::robin_set &active_tags) +{ + double total_recall = 0; + std::set gt, res; + bool printed = false; + for (size_t i = 0; i < num_queries; i++) + { + gt.clear(); + res.clear(); + uint32_t *gt_vec = gold_std + dim_gs * i; + uint32_t *res_vec = our_results + dim_or * i; + size_t tie_breaker = recall_at; + uint32_t active_points_count = 0; + uint32_t cur_counter = 0; + while (active_points_count < recall_at && cur_counter < dim_gs) + { + if (active_tags.find(*(gt_vec + cur_counter)) != active_tags.end()) + { + active_points_count++; + } + cur_counter++; + } + if (active_tags.empty()) + cur_counter = recall_at; + + if ((active_points_count < recall_at && !active_tags.empty()) && !printed) + { + diskann::cout << "Warning: Couldn't find enough closest neighbors " << active_points_count << "/" + << recall_at + << " from " + "truthset for query # " + << i << ". Will result in under-reported value of recall." << std::endl; + printed = true; + } + if (gs_dist != nullptr) + { + tie_breaker = cur_counter - 1; + float *gt_dist_vec = gs_dist + dim_gs * i; + while (tie_breaker < dim_gs && gt_dist_vec[tie_breaker] == gt_dist_vec[cur_counter - 1]) + tie_breaker++; + } + + gt.insert(gt_vec, gt_vec + tie_breaker); + res.insert(res_vec, res_vec + recall_at); + uint32_t cur_recall = 0; + for (auto &v : res) + { + if (gt.find(v) != gt.end()) + { + cur_recall++; + } + } + total_recall += cur_recall; } - total_recall += cur_recall; - } - return ((double)(total_recall / (num_queries))) * - ((double)(100.0 / recall_at)); + return ((double)(total_recall / (num_queries))) * ((double)(100.0 / recall_at)); } -double -calculate_range_search_recall(uint32_t num_queries, - std::vector> &groundtruth, - std::vector> &our_results) { - double total_recall = 0; - std::set gt, res; - - for (size_t i = 0; i < num_queries; i++) { - gt.clear(); - res.clear(); - - gt.insert(groundtruth[i].begin(), groundtruth[i].end()); - res.insert(our_results[i].begin(), our_results[i].end()); - uint32_t cur_recall = 0; - for (auto &v : gt) { - if (res.find(v) != res.end()) { - cur_recall++; - } +double calculate_range_search_recall(uint32_t num_queries, std::vector> &groundtruth, + std::vector> &our_results) +{ + double total_recall = 0; + std::set gt, res; + + for (size_t i = 0; i < num_queries; i++) + { + gt.clear(); + res.clear(); + + gt.insert(groundtruth[i].begin(), groundtruth[i].end()); + res.insert(our_results[i].begin(), our_results[i].end()); + uint32_t cur_recall = 0; + for (auto &v : gt) + { + if (res.find(v) != res.end()) + { + cur_recall++; + } + } + if (gt.size() != 0) + total_recall += ((100.0 * cur_recall) / gt.size()); + else + total_recall += 100; } - if (gt.size() != 0) - total_recall += ((100.0 * cur_recall) / gt.size()); - else - total_recall += 100; - } - return total_recall / (num_queries); + return total_recall / (num_queries); } #ifdef EXEC_ENV_OLS -void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, - size_t offset) { - std::vector readReqs; - AlignedRead readReq; - uint32_t buf[2]; // npts/ndim are uint32_ts. - - readReq.buf = buf; - readReq.offset = offset; - readReq.len = 2 * sizeof(uint32_t); - readReqs.push_back(readReq); - - IOContext &ctx = reader.get_ctx(); - reader.read(readReqs, ctx); // synchronous - if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) { - npts = buf[0]; - ndim = buf[1]; - diskann::cout << "File has: " << npts << " points, " << ndim - << " dimensions at offset: " << offset << std::endl; - } else { - std::stringstream str; - str << "Could not read binary metadata from index file at offset: " - << offset << std::endl; - throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } -} - -template -void load_bin(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim, - size_t offset) { - // Code assumes that the reader is already setup correctly. - get_bin_metadata(reader, npts, ndim, offset); - data = new T[npts * ndim]; - - size_t data_size = npts * ndim * sizeof(T); - size_t write_offset = 0; - size_t read_start = offset + 2 * sizeof(uint32_t); - - // BingAlignedFileReader can only read uint32_t bytes of data. So, - // we limit ourselves even more to reading 1GB at a time. - std::vector readReqs; - while (data_size > 0) { +void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, size_t offset) +{ + std::vector readReqs; AlignedRead readReq; - readReq.buf = data + write_offset; - readReq.offset = read_start + write_offset; - readReq.len = data_size > MAX_REQUEST_SIZE ? MAX_REQUEST_SIZE : data_size; + uint32_t buf[2]; // npts/ndim are uint32_ts. + + readReq.buf = buf; + readReq.offset = offset; + readReq.len = 2 * sizeof(uint32_t); readReqs.push_back(readReq); - // in the corner case, the loop will not execute - data_size -= readReq.len; - write_offset += readReq.len; - } - IOContext &ctx = reader.get_ctx(); - reader.read(readReqs, ctx); - for (int i = 0; i < readReqs.size(); i++) { - // Since we are making sync calls, no request will be in the - // READ_WAIT state. - if ((*(ctx.m_pRequestsStatus))[i] != IOContext::READ_SUCCESS) { - std::stringstream str; - str << "Could not read binary data from index file at offset: " - << readReqs[i].offset << std::endl; - throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); + + IOContext &ctx = reader.get_ctx(); + reader.read(readReqs, ctx); // synchronous + if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) + { + npts = buf[0]; + ndim = buf[1]; + diskann::cout << "File has: " << npts << " points, " << ndim << " dimensions at offset: " << offset + << std::endl; + } + else + { + std::stringstream str; + str << "Could not read binary metadata from index file at offset: " << offset << std::endl; + throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } +} + +template void load_bin(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim, size_t offset) +{ + // Code assumes that the reader is already setup correctly. + get_bin_metadata(reader, npts, ndim, offset); + data = new T[npts * ndim]; + + size_t data_size = npts * ndim * sizeof(T); + size_t write_offset = 0; + size_t read_start = offset + 2 * sizeof(uint32_t); + + // BingAlignedFileReader can only read uint32_t bytes of data. So, + // we limit ourselves even more to reading 1GB at a time. + std::vector readReqs; + while (data_size > 0) + { + AlignedRead readReq; + readReq.buf = data + write_offset; + readReq.offset = read_start + write_offset; + readReq.len = data_size > MAX_REQUEST_SIZE ? MAX_REQUEST_SIZE : data_size; + readReqs.push_back(readReq); + // in the corner case, the loop will not execute + data_size -= readReq.len; + write_offset += readReq.len; + } + IOContext &ctx = reader.get_ctx(); + reader.read(readReqs, ctx); + for (int i = 0; i < readReqs.size(); i++) + { + // Since we are making sync calls, no request will be in the + // READ_WAIT state. + if ((*(ctx.m_pRequestsStatus))[i] != IOContext::READ_SUCCESS) + { + std::stringstream str; + str << "Could not read binary data from index file at offset: " << readReqs[i].offset << std::endl; + throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } } - } } template -void load_bin(AlignedFileReader &reader, std::unique_ptr &data, - size_t &npts, size_t &ndim, size_t offset) { - T *ptr = nullptr; - load_bin(reader, ptr, npts, ndim, offset); - data.reset(ptr); +void load_bin(AlignedFileReader &reader, std::unique_ptr &data, size_t &npts, size_t &ndim, size_t offset) +{ + T *ptr = nullptr; + load_bin(reader, ptr, npts, ndim, offset); + data.reset(ptr); } template -void copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, - size_t &npts, size_t &ndim, - const size_t &rounded_dim, size_t offset) { - if (data == nullptr) { - diskann::cerr << "Memory was not allocated for " << data - << " before calling the load function. Exiting..." - << std::endl; - throw diskann::ANNException( - "Null pointer passed to copy_aligned_data_from_file()", -1, __FUNCSIG__, - __FILE__, __LINE__); - } - - size_t pts, dim; - get_bin_metadata(reader, pts, dim, offset); - - if (ndim != dim || npts != pts) { - std::stringstream ss; - ss << "Either file dimension: " << dim - << " is != passed dimension: " << ndim << " or file #pts: " << pts - << " is != passed #pts: " << npts << std::endl; - throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - - // Instead of reading one point of ndim size and setting (rounded_dim - dim) - // values to zero We'll set everything to zero and read in chunks of data at - // the appropriate locations. - size_t read_offset = offset + 2 * sizeof(uint32_t); - memset(data, 0, npts * rounded_dim * sizeof(T)); - int i = 0; - std::vector read_requests; - - while (i < npts) { - int j = 0; - read_requests.clear(); - while (j < MAX_SIMULTANEOUS_READ_REQUESTS && i < npts) { - AlignedRead read_req; - read_req.buf = data + i * rounded_dim; - read_req.len = dim * sizeof(T); - read_req.offset = read_offset + i * dim * sizeof(T); - read_requests.push_back(read_req); - i++; - j++; +void copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, size_t &npts, size_t &ndim, + const size_t &rounded_dim, size_t offset) +{ + if (data == nullptr) + { + diskann::cerr << "Memory was not allocated for " << data << " before calling the load function. Exiting..." + << std::endl; + throw diskann::ANNException("Null pointer passed to copy_aligned_data_from_file()", -1, __FUNCSIG__, __FILE__, + __LINE__); } - IOContext &ctx = reader.get_ctx(); - reader.read(read_requests, ctx); - for (int k = 0; k < read_requests.size(); k++) { - if ((*ctx.m_pRequestsStatus)[k] != IOContext::READ_SUCCESS) { - throw diskann::ANNException( - "Load data from file using AlignedReader failed.", -1, __FUNCSIG__, - __FILE__, __LINE__); - } + + size_t pts, dim; + get_bin_metadata(reader, pts, dim, offset); + + if (ndim != dim || npts != pts) + { + std::stringstream ss; + ss << "Either file dimension: " << dim << " is != passed dimension: " << ndim << " or file #pts: " << pts + << " is != passed #pts: " << npts << std::endl; + throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // Instead of reading one point of ndim size and setting (rounded_dim - dim) + // values to zero We'll set everything to zero and read in chunks of data at + // the appropriate locations. + size_t read_offset = offset + 2 * sizeof(uint32_t); + memset(data, 0, npts * rounded_dim * sizeof(T)); + int i = 0; + std::vector read_requests; + + while (i < npts) + { + int j = 0; + read_requests.clear(); + while (j < MAX_SIMULTANEOUS_READ_REQUESTS && i < npts) + { + AlignedRead read_req; + read_req.buf = data + i * rounded_dim; + read_req.len = dim * sizeof(T); + read_req.offset = read_offset + i * dim * sizeof(T); + read_requests.push_back(read_req); + i++; + j++; + } + IOContext &ctx = reader.get_ctx(); + reader.read(read_requests, ctx); + for (int k = 0; k < read_requests.size(); k++) + { + if ((*ctx.m_pRequestsStatus)[k] != IOContext::READ_SUCCESS) + { + throw diskann::ANNException("Load data from file using AlignedReader failed.", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + } } - } } // Unlike load_bin, assumes that data is already allocated 'size' entries -template -void read_array(AlignedFileReader &reader, T *data, size_t size, - size_t offset) { - if (data == nullptr) { - throw diskann::ANNException("read_array requires an allocated buffer.", -1); - } - - if (size * sizeof(T) > MAX_REQUEST_SIZE) { - std::stringstream ss; - ss << "Cannot read more than " << MAX_REQUEST_SIZE - << " bytes. Current request size: " << std::to_string(size) - << " sizeof(T): " << sizeof(T) << std::endl; - throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - std::vector read_requests; - AlignedRead read_req; - read_req.buf = data; - read_req.len = size * sizeof(T); - read_req.offset = offset; - read_requests.push_back(read_req); - IOContext &ctx = reader.get_ctx(); - reader.read(read_requests, ctx); - - if ((*(ctx.m_pRequestsStatus))[0] != IOContext::READ_SUCCESS) { - std::stringstream ss; - ss << "Failed to read_array() of size: " << size * sizeof(T) - << " at offset: " << offset << " from reader. " << std::endl; - throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } +template void read_array(AlignedFileReader &reader, T *data, size_t size, size_t offset) +{ + if (data == nullptr) + { + throw diskann::ANNException("read_array requires an allocated buffer.", -1); + } + + if (size * sizeof(T) > MAX_REQUEST_SIZE) + { + std::stringstream ss; + ss << "Cannot read more than " << MAX_REQUEST_SIZE << " bytes. Current request size: " << std::to_string(size) + << " sizeof(T): " << sizeof(T) << std::endl; + throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + std::vector read_requests; + AlignedRead read_req; + read_req.buf = data; + read_req.len = size * sizeof(T); + read_req.offset = offset; + read_requests.push_back(read_req); + IOContext &ctx = reader.get_ctx(); + reader.read(read_requests, ctx); + + if ((*(ctx.m_pRequestsStatus))[0] != IOContext::READ_SUCCESS) + { + std::stringstream ss; + ss << "Failed to read_array() of size: " << size * sizeof(T) << " at offset: " << offset << " from reader. " + << std::endl; + throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } } -template -void read_value(AlignedFileReader &reader, T &value, size_t offset) { - read_array(reader, &value, 1, offset); +template void read_value(AlignedFileReader &reader, T &value, size_t offset) +{ + read_array(reader, &value, 1, offset); } -template DISKANN_DLLEXPORT void -load_bin(AlignedFileReader &reader, std::unique_ptr &data, - size_t &npts, size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void -load_bin(AlignedFileReader &reader, std::unique_ptr &data, - size_t &npts, size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void -load_bin(AlignedFileReader &reader, std::unique_ptr &data, - size_t &npts, size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void -load_bin(AlignedFileReader &reader, std::unique_ptr &data, - size_t &npts, size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void -load_bin(AlignedFileReader &reader, std::unique_ptr &data, - size_t &npts, size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - std::unique_ptr &data, - size_t &npts, size_t &ndim, - size_t offset); - -template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - uint8_t *&data, size_t &npts, - size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - int64_t *&data, size_t &npts, - size_t &ndim, size_t offset); -template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - uint64_t *&data, - size_t &npts, size_t &ndim, - size_t offset); -template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - uint32_t *&data, - size_t &npts, size_t &ndim, - size_t offset); -template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, - int32_t *&data, size_t &npts, - size_t &ndim, size_t offset); - -template DISKANN_DLLEXPORT void -copy_aligned_data_from_file(AlignedFileReader &reader, uint8_t *&data, - size_t &npts, size_t &dim, - const size_t &rounded_dim, size_t offset); -template DISKANN_DLLEXPORT void -copy_aligned_data_from_file(AlignedFileReader &reader, int8_t *&data, - size_t &npts, size_t &dim, - const size_t &rounded_dim, size_t offset); -template DISKANN_DLLEXPORT void -copy_aligned_data_from_file(AlignedFileReader &reader, float *&data, - size_t &npts, size_t &dim, - const size_t &rounded_dim, size_t offset); - -template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, - char *data, size_t size, - size_t offset); - -template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, - uint8_t *data, size_t size, +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, + size_t &npts, size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, + size_t &npts, size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, + size_t &npts, size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, + size_t &npts, size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, + size_t &npts, size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, std::unique_ptr &data, size_t &npts, + size_t &ndim, size_t offset); + +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, uint8_t *&data, size_t &npts, size_t &ndim, + size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, int64_t *&data, size_t &npts, size_t &ndim, + size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, uint64_t *&data, size_t &npts, + size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, uint32_t *&data, size_t &npts, + size_t &ndim, size_t offset); +template DISKANN_DLLEXPORT void load_bin(AlignedFileReader &reader, int32_t *&data, size_t &npts, size_t &ndim, + size_t offset); + +template DISKANN_DLLEXPORT void copy_aligned_data_from_file(AlignedFileReader &reader, uint8_t *&data, + size_t &npts, size_t &dim, + const size_t &rounded_dim, size_t offset); +template DISKANN_DLLEXPORT void copy_aligned_data_from_file(AlignedFileReader &reader, int8_t *&data, + size_t &npts, size_t &dim, + const size_t &rounded_dim, size_t offset); +template DISKANN_DLLEXPORT void copy_aligned_data_from_file(AlignedFileReader &reader, float *&data, + size_t &npts, size_t &dim, const size_t &rounded_dim, + size_t offset); + +template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, char *data, size_t size, size_t offset); + +template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, uint8_t *data, size_t size, size_t offset); -template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, - int8_t *data, size_t size, - size_t offset); -template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, - uint32_t *data, - size_t size, +template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, int8_t *data, size_t size, size_t offset); +template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, uint32_t *data, size_t size, size_t offset); -template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, - float *data, size_t size, - size_t offset); +template DISKANN_DLLEXPORT void read_array(AlignedFileReader &reader, float *data, size_t size, size_t offset); -template DISKANN_DLLEXPORT void -read_value(AlignedFileReader &reader, uint8_t &value, size_t offset); -template DISKANN_DLLEXPORT void -read_value(AlignedFileReader &reader, int8_t &value, size_t offset); -template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, - float &value, size_t offset); -template DISKANN_DLLEXPORT void -read_value(AlignedFileReader &reader, uint32_t &value, size_t offset); -template DISKANN_DLLEXPORT void -read_value(AlignedFileReader &reader, uint64_t &value, size_t offset); +template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, uint8_t &value, size_t offset); +template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, int8_t &value, size_t offset); +template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, float &value, size_t offset); +template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, uint32_t &value, size_t offset); +template DISKANN_DLLEXPORT void read_value(AlignedFileReader &reader, uint64_t &value, size_t offset); #endif diff --git a/src/windows_aligned_file_reader.cpp b/src/windows_aligned_file_reader.cpp index e62a5fed7..4ddd50902 100644 --- a/src/windows_aligned_file_reader.cpp +++ b/src/windows_aligned_file_reader.cpp @@ -10,166 +10,180 @@ #define SECTOR_LEN 4096 -void WindowsAlignedFileReader::open(const std::string &fname) { +void WindowsAlignedFileReader::open(const std::string &fname) +{ #ifdef UNICODE - m_filename = std::wstring(fname.begin(), fname.end()); + m_filename = std::wstring(fname.begin(), fname.end()); #else - m_filename = fname; + m_filename = fname; #endif - this->register_thread(); + this->register_thread(); } -void WindowsAlignedFileReader::close() { - for (auto &k_v : ctx_map) { - IOContext ctx = ctx_map[k_v.first]; - CloseHandle(ctx.fhandle); - } +void WindowsAlignedFileReader::close() +{ + for (auto &k_v : ctx_map) + { + IOContext ctx = ctx_map[k_v.first]; + CloseHandle(ctx.fhandle); + } } -void WindowsAlignedFileReader::register_thread() { - std::unique_lock lk(this->ctx_mut); - if (this->ctx_map.find(std::this_thread::get_id()) != ctx_map.end()) { - diskann::cout << "Warning:: Duplicate registration for thread_id : " - << std::this_thread::get_id() << std::endl; - } - - IOContext ctx; - ctx.fhandle = CreateFile(m_filename.c_str(), GENERIC_READ, FILE_SHARE_READ, - NULL, OPEN_EXISTING, - FILE_ATTRIBUTE_READONLY | FILE_FLAG_NO_BUFFERING | - FILE_FLAG_OVERLAPPED | FILE_FLAG_RANDOM_ACCESS, - NULL); - if (ctx.fhandle == INVALID_HANDLE_VALUE) { - const size_t c_max_filepath_len = 256; - size_t actual_len = 0; - char filePath[c_max_filepath_len]; - if (wcstombs_s(&actual_len, filePath, c_max_filepath_len, - m_filename.c_str(), m_filename.length()) == 0) { - diskann::cout << "Error opening " << filePath - << " -- error=" << GetLastError() << std::endl; - } else { - diskann::cout << "Error converting wchar to char -- error=" - << GetLastError() << std::endl; +void WindowsAlignedFileReader::register_thread() +{ + std::unique_lock lk(this->ctx_mut); + if (this->ctx_map.find(std::this_thread::get_id()) != ctx_map.end()) + { + diskann::cout << "Warning:: Duplicate registration for thread_id : " << std::this_thread::get_id() << std::endl; + } + + IOContext ctx; + ctx.fhandle = CreateFile( + m_filename.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, + FILE_ATTRIBUTE_READONLY | FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED | FILE_FLAG_RANDOM_ACCESS, NULL); + if (ctx.fhandle == INVALID_HANDLE_VALUE) + { + const size_t c_max_filepath_len = 256; + size_t actual_len = 0; + char filePath[c_max_filepath_len]; + if (wcstombs_s(&actual_len, filePath, c_max_filepath_len, m_filename.c_str(), m_filename.length()) == 0) + { + diskann::cout << "Error opening " << filePath << " -- error=" << GetLastError() << std::endl; + } + else + { + diskann::cout << "Error converting wchar to char -- error=" << GetLastError() << std::endl; + } } - } - - // create IOCompletionPort - ctx.iocp = CreateIoCompletionPort(ctx.fhandle, ctx.iocp, 0, 0); - - // create MAX_DEPTH # of reqs - for (uint64_t i = 0; i < MAX_IO_DEPTH; i++) { - OVERLAPPED os; - memset(&os, 0, sizeof(OVERLAPPED)); - // os.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); - ctx.reqs.push_back(os); - } - this->ctx_map.insert(std::make_pair(std::this_thread::get_id(), ctx)); + + // create IOCompletionPort + ctx.iocp = CreateIoCompletionPort(ctx.fhandle, ctx.iocp, 0, 0); + + // create MAX_DEPTH # of reqs + for (uint64_t i = 0; i < MAX_IO_DEPTH; i++) + { + OVERLAPPED os; + memset(&os, 0, sizeof(OVERLAPPED)); + // os.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); + ctx.reqs.push_back(os); + } + this->ctx_map.insert(std::make_pair(std::this_thread::get_id(), ctx)); } -IOContext &WindowsAlignedFileReader::get_ctx() { - std::unique_lock lk(this->ctx_mut); - if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) { - std::stringstream stream; - stream << "unable to find IOContext for thread_id : " - << std::this_thread::get_id() << "\n"; - throw diskann::ANNException(stream.str(), -2, __FUNCSIG__, __FILE__, - __LINE__); - } - IOContext &ctx = ctx_map[std::this_thread::get_id()]; - lk.unlock(); - return ctx; +IOContext &WindowsAlignedFileReader::get_ctx() +{ + std::unique_lock lk(this->ctx_mut); + if (ctx_map.find(std::this_thread::get_id()) == ctx_map.end()) + { + std::stringstream stream; + stream << "unable to find IOContext for thread_id : " << std::this_thread::get_id() << "\n"; + throw diskann::ANNException(stream.str(), -2, __FUNCSIG__, __FILE__, __LINE__); + } + IOContext &ctx = ctx_map[std::this_thread::get_id()]; + lk.unlock(); + return ctx; } -void WindowsAlignedFileReader::read(std::vector &read_reqs, - IOContext &ctx, bool async) { - using namespace std::chrono_literals; - // execute each request sequentially - size_t n_reqs = read_reqs.size(); - uint64_t n_batches = ROUND_UP(n_reqs, MAX_IO_DEPTH) / MAX_IO_DEPTH; - for (uint64_t i = 0; i < n_batches; i++) { - // reset all OVERLAPPED objects - for (auto &os : ctx.reqs) { - // HANDLE evt = os.hEvent; - memset(&os, 0, sizeof(os)); - // os.hEvent = evt; - - /* - if (ResetEvent(os.hEvent) == 0) { - diskann::cerr << "ResetEvent failed" << std::endl; - exit(-3); +void WindowsAlignedFileReader::read(std::vector &read_reqs, IOContext &ctx, bool async) +{ + using namespace std::chrono_literals; + // execute each request sequentially + size_t n_reqs = read_reqs.size(); + uint64_t n_batches = ROUND_UP(n_reqs, MAX_IO_DEPTH) / MAX_IO_DEPTH; + for (uint64_t i = 0; i < n_batches; i++) + { + // reset all OVERLAPPED objects + for (auto &os : ctx.reqs) + { + // HANDLE evt = os.hEvent; + memset(&os, 0, sizeof(os)); + // os.hEvent = evt; + + /* + if (ResetEvent(os.hEvent) == 0) { + diskann::cerr << "ResetEvent failed" << std::endl; + exit(-3); + } + */ } - */ - } - // batch start/end - uint64_t batch_start = MAX_IO_DEPTH * i; - uint64_t batch_size = - std::min((uint64_t)(n_reqs - batch_start), (uint64_t)MAX_IO_DEPTH); - - // fill OVERLAPPED and issue them - for (uint64_t j = 0; j < batch_size; j++) { - AlignedRead &req = read_reqs[batch_start + j]; - OVERLAPPED &os = ctx.reqs[j]; - - uint64_t offset = req.offset; - uint64_t nbytes = req.len; - char *read_buf = (char *)req.buf; - assert(IS_ALIGNED(read_buf, SECTOR_LEN)); - assert(IS_ALIGNED(offset, SECTOR_LEN)); - assert(IS_ALIGNED(nbytes, SECTOR_LEN)); - - // fill in OVERLAPPED struct - os.Offset = offset & 0xffffffff; - os.OffsetHigh = (offset >> 32); - - BOOL ret = ReadFile(ctx.fhandle, read_buf, (DWORD)nbytes, NULL, &os); - if (ret == FALSE) { - auto error = GetLastError(); - if (error != ERROR_IO_PENDING) { - diskann::cerr << "Error queuing IO -- " << error << "\n"; + // batch start/end + uint64_t batch_start = MAX_IO_DEPTH * i; + uint64_t batch_size = std::min((uint64_t)(n_reqs - batch_start), (uint64_t)MAX_IO_DEPTH); + + // fill OVERLAPPED and issue them + for (uint64_t j = 0; j < batch_size; j++) + { + AlignedRead &req = read_reqs[batch_start + j]; + OVERLAPPED &os = ctx.reqs[j]; + + uint64_t offset = req.offset; + uint64_t nbytes = req.len; + char *read_buf = (char *)req.buf; + assert(IS_ALIGNED(read_buf, SECTOR_LEN)); + assert(IS_ALIGNED(offset, SECTOR_LEN)); + assert(IS_ALIGNED(nbytes, SECTOR_LEN)); + + // fill in OVERLAPPED struct + os.Offset = offset & 0xffffffff; + os.OffsetHigh = (offset >> 32); + + BOOL ret = ReadFile(ctx.fhandle, read_buf, (DWORD)nbytes, NULL, &os); + if (ret == FALSE) + { + auto error = GetLastError(); + if (error != ERROR_IO_PENDING) + { + diskann::cerr << "Error queuing IO -- " << error << "\n"; + } + } + else + { + diskann::cerr << "Error queueing IO -- ReadFile returned TRUE" << std::endl; + } } - } else { - diskann::cerr << "Error queueing IO -- ReadFile returned TRUE" - << std::endl; - } - } - DWORD n_read = 0; - uint64_t n_complete = 0; - ULONG_PTR completion_key = 0; - OVERLAPPED *lp_os; - while (n_complete < batch_size) { - if (GetQueuedCompletionStatus(ctx.iocp, &n_read, &completion_key, &lp_os, - INFINITE) != 0) { - // successfully dequeued a completed I/O - n_complete++; - } else { - // failed to dequeue OR dequeued failed I/O - if (lp_os == NULL) { - DWORD error = GetLastError(); - if (error != WAIT_TIMEOUT) { - diskann::cerr << "GetQueuedCompletionStatus() failed " - "with error = " - << error << std::endl; - throw diskann::ANNException( - "GetQueuedCompletionStatus failed with error: ", error, - __FUNCSIG__, __FILE__, __LINE__); - } - // no completion packet dequeued ==> sleep for 5us and try - // again - std::this_thread::sleep_for(5us); - } else { - // completion packet for failed IO dequeued - auto op_idx = lp_os - ctx.reqs.data(); - std::stringstream stream; - stream << "I/O failed , offset: " << read_reqs[op_idx].offset - << "with error code: " << GetLastError() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, - __LINE__); + DWORD n_read = 0; + uint64_t n_complete = 0; + ULONG_PTR completion_key = 0; + OVERLAPPED *lp_os; + while (n_complete < batch_size) + { + if (GetQueuedCompletionStatus(ctx.iocp, &n_read, &completion_key, &lp_os, INFINITE) != 0) + { + // successfully dequeued a completed I/O + n_complete++; + } + else + { + // failed to dequeue OR dequeued failed I/O + if (lp_os == NULL) + { + DWORD error = GetLastError(); + if (error != WAIT_TIMEOUT) + { + diskann::cerr << "GetQueuedCompletionStatus() failed " + "with error = " + << error << std::endl; + throw diskann::ANNException("GetQueuedCompletionStatus failed with error: ", error, __FUNCSIG__, + __FILE__, __LINE__); + } + // no completion packet dequeued ==> sleep for 5us and try + // again + std::this_thread::sleep_for(5us); + } + else + { + // completion packet for failed IO dequeued + auto op_idx = lp_os - ctx.reqs.data(); + std::stringstream stream; + stream << "I/O failed , offset: " << read_reqs[op_idx].offset + << "with error code: " << GetLastError() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + } } - } } - } } #endif #endif From 505572f4f8a405b5c209c952cedb50368002c897 Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 12 Nov 2024 23:00:59 -0800 Subject: [PATCH 6/7] added new .clang-format file --- .clang-format | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..ad3192fd6 --- /dev/null +++ b/.clang-format @@ -0,0 +1,6 @@ +--- +BasedOnStyle: Microsoft +--- +Language: Cpp +SortIncludes: false +... From 7ac2891545b9aeae7ce4b5e6a1b055846c6406ba Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 12 Nov 2024 23:17:15 -0800 Subject: [PATCH 7/7] clang-formatted restapi files --- include/restapi/common.h | 12 +- include/restapi/search_wrapper.h | 203 +++++++++++++++++-------------- include/restapi/server.h | 74 ++++++----- 3 files changed, 154 insertions(+), 135 deletions(-) diff --git a/include/restapi/common.h b/include/restapi/common.h index ec321ec9a..b8339635a 100644 --- a/include/restapi/common.h +++ b/include/restapi/common.h @@ -6,14 +6,12 @@ #include #include -namespace diskann { +namespace diskann +{ // Constants -static const std::string VECTOR_KEY = "query", K_KEY = "k", - INDICES_KEY = "indices", DISTANCES_KEY = "distances", - TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", - ERROR_MESSAGE_KEY = "error", L_KEY = "Ls", - TIME_TAKEN_KEY = "time_taken_in_us", - PARTITION_KEY = "partition", +static const std::string VECTOR_KEY = "query", K_KEY = "k", INDICES_KEY = "indices", DISTANCES_KEY = "distances", + TAGS_KEY = "tags", QUERY_ID_KEY = "query_id", ERROR_MESSAGE_KEY = "error", L_KEY = "Ls", + TIME_TAKEN_KEY = "time_taken_in_us", PARTITION_KEY = "partition", UNKNOWN_ERROR = "unknown_error"; const unsigned int DEFAULT_L = 100; diff --git a/include/restapi/search_wrapper.h b/include/restapi/search_wrapper.h index e7ed1725e..d41b2b7cd 100644 --- a/include/restapi/search_wrapper.h +++ b/include/restapi/search_wrapper.h @@ -10,106 +10,131 @@ #include #include -namespace diskann { -class SearchResult { -public: - SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, - const unsigned *const indices, const float *const distances, - const std::string *const tags = nullptr, - const unsigned *const partitions = nullptr); - - const std::vector &get_indices() const { return _indices; } - const std::vector &get_distances() const { return _distances; } - bool tags_enabled() const { return _tags_enabled; } - const std::vector &get_tags() const { return _tags; } - bool partitions_enabled() const { return _partitions_enabled; } - const std::vector &get_partitions() const { return _partitions; } - unsigned get_time() const { return _search_time_in_ms; } - -private: - unsigned int _K; - unsigned int _search_time_in_ms; - std::vector _indices; - std::vector _distances; - - bool _tags_enabled; - std::vector _tags; - - bool _partitions_enabled; - std::vector _partitions; +namespace diskann +{ +class SearchResult +{ + public: + SearchResult(unsigned int K, unsigned int elapsed_time_in_ms, const unsigned *const indices, + const float *const distances, const std::string *const tags = nullptr, + const unsigned *const partitions = nullptr); + + const std::vector &get_indices() const + { + return _indices; + } + const std::vector &get_distances() const + { + return _distances; + } + bool tags_enabled() const + { + return _tags_enabled; + } + const std::vector &get_tags() const + { + return _tags; + } + bool partitions_enabled() const + { + return _partitions_enabled; + } + const std::vector &get_partitions() const + { + return _partitions; + } + unsigned get_time() const + { + return _search_time_in_ms; + } + + private: + unsigned int _K; + unsigned int _search_time_in_ms; + std::vector _indices; + std::vector _distances; + + bool _tags_enabled; + std::vector _tags; + + bool _partitions_enabled; + std::vector _partitions; }; -class SearchNotImplementedException : public std::logic_error { -private: - std::string _errormsg; - -public: - SearchNotImplementedException(const char *type) - : std::logic_error("Not Implemented") { - _errormsg = "Search with data type "; - _errormsg += std::string(type); - _errormsg += " not implemented : "; - _errormsg += __FUNCTION__; - } - - virtual const char *what() const throw() { return _errormsg.c_str(); } +class SearchNotImplementedException : public std::logic_error +{ + private: + std::string _errormsg; + + public: + SearchNotImplementedException(const char *type) : std::logic_error("Not Implemented") + { + _errormsg = "Search with data type "; + _errormsg += std::string(type); + _errormsg += " not implemented : "; + _errormsg += __FUNCTION__; + } + + virtual const char *what() const throw() + { + return _errormsg.c_str(); + } }; -class BaseSearch { -public: - BaseSearch(const std::string &tagsFile = nullptr); - virtual SearchResult search(const float *query, const unsigned int dimensions, - const unsigned int K, const unsigned int Ls) { - throw SearchNotImplementedException("float"); - } - virtual SearchResult search(const int8_t *query, - const unsigned int dimensions, - const unsigned int K, const unsigned int Ls) { - throw SearchNotImplementedException("int8_t"); - } - - virtual SearchResult search(const uint8_t *query, - const unsigned int dimensions, - const unsigned int K, const unsigned int Ls) { - throw SearchNotImplementedException("uint8_t"); - } - - void lookup_tags(const unsigned K, const unsigned *indices, - std::string *ret_tags); - -protected: - bool _tags_enabled; - std::vector _tags_str; +class BaseSearch +{ + public: + BaseSearch(const std::string &tagsFile = nullptr); + virtual SearchResult search(const float *query, const unsigned int dimensions, const unsigned int K, + const unsigned int Ls) + { + throw SearchNotImplementedException("float"); + } + virtual SearchResult search(const int8_t *query, const unsigned int dimensions, const unsigned int K, + const unsigned int Ls) + { + throw SearchNotImplementedException("int8_t"); + } + + virtual SearchResult search(const uint8_t *query, const unsigned int dimensions, const unsigned int K, + const unsigned int Ls) + { + throw SearchNotImplementedException("uint8_t"); + } + + void lookup_tags(const unsigned K, const unsigned *indices, std::string *ret_tags); + + protected: + bool _tags_enabled; + std::vector _tags_str; }; -template class InMemorySearch : public BaseSearch { -public: - InMemorySearch(const std::string &baseFile, const std::string &indexFile, - const std::string &tagsFile, Metric m, uint32_t num_threads, - uint32_t search_l); - virtual ~InMemorySearch(); +template class InMemorySearch : public BaseSearch +{ + public: + InMemorySearch(const std::string &baseFile, const std::string &indexFile, const std::string &tagsFile, Metric m, + uint32_t num_threads, uint32_t search_l); + virtual ~InMemorySearch(); - SearchResult search(const T *query, const unsigned int dimensions, - const unsigned int K, const unsigned int Ls); + SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls); -private: - unsigned int _dimensions, _numPoints; - std::unique_ptr> _index; + private: + unsigned int _dimensions, _numPoints; + std::unique_ptr> _index; }; -template class PQFlashSearch : public BaseSearch { -public: - PQFlashSearch(const std::string &indexPrefix, - const unsigned num_nodes_to_cache, const unsigned num_threads, - const std::string &tagsFile, Metric m); - virtual ~PQFlashSearch(); +template class PQFlashSearch : public BaseSearch +{ + public: + PQFlashSearch(const std::string &indexPrefix, const unsigned num_nodes_to_cache, const unsigned num_threads, + const std::string &tagsFile, Metric m); + virtual ~PQFlashSearch(); - SearchResult search(const T *query, const unsigned int dimensions, - const unsigned int K, const unsigned int Ls); + SearchResult search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls); -private: - unsigned int _dimensions, _numPoints; - std::unique_ptr> _index; - std::shared_ptr reader; + private: + unsigned int _dimensions, _numPoints; + std::unique_ptr> _index; + std::shared_ptr reader; }; } // namespace diskann diff --git a/include/restapi/server.h b/include/restapi/server.h index 9cb9449da..ddb19d17a 100644 --- a/include/restapi/server.h +++ b/include/restapi/server.h @@ -6,44 +6,40 @@ #include #include -namespace diskann { -class Server { -public: - Server(web::uri &url, - std::vector> &multi_searcher, - const std::string &typestring); - virtual ~Server(); - - pplx::task open(); - pplx::task close(); - -protected: - template void handle_post(web::http::http_request message); - - template - web::json::value - toJsonArray(const std::vector &v, - std::function valConverter); - web::json::value prepareResponse(const int64_t &queryId, const int k); - - template - void parseJson(const utility::string_t &body, unsigned int &k, - int64_t &queryId, T *&queryVector, unsigned int &dimensions, - unsigned &Ls); - - web::json::value idsToJsonArray(const diskann::SearchResult &result); - web::json::value distancesToJsonArray(const diskann::SearchResult &result); - web::json::value tagsToJsonArray(const diskann::SearchResult &result); - web::json::value partitionsToJsonArray(const diskann::SearchResult &result); - - SearchResult - aggregate_results(const unsigned K, - const std::vector &results); - -private: - bool _isDebug; - std::unique_ptr _listener; - const bool _multi_search; - std::vector> _multi_searcher; +namespace diskann +{ +class Server +{ + public: + Server(web::uri &url, std::vector> &multi_searcher, + const std::string &typestring); + virtual ~Server(); + + pplx::task open(); + pplx::task close(); + + protected: + template void handle_post(web::http::http_request message); + + template + web::json::value toJsonArray(const std::vector &v, std::function valConverter); + web::json::value prepareResponse(const int64_t &queryId, const int k); + + template + void parseJson(const utility::string_t &body, unsigned int &k, int64_t &queryId, T *&queryVector, + unsigned int &dimensions, unsigned &Ls); + + web::json::value idsToJsonArray(const diskann::SearchResult &result); + web::json::value distancesToJsonArray(const diskann::SearchResult &result); + web::json::value tagsToJsonArray(const diskann::SearchResult &result); + web::json::value partitionsToJsonArray(const diskann::SearchResult &result); + + SearchResult aggregate_results(const unsigned K, const std::vector &results); + + private: + bool _isDebug; + std::unique_ptr _listener; + const bool _multi_search; + std::vector> _multi_searcher; }; } // namespace diskann