Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merging multifilter for bann #543

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/abstract_scratch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
namespace diskann
{

template <typename data_t> class PQScratch;

// By somewhat more than a coincidence, it seems that both InMemQueryScratch
// and SSDQueryScratch have the aligned query and PQScratch objects. So we
// can put them in a neat hierarchy and keep PQScratch as a standalone class.
template <typename data_t> class AbstractScratch
{
public:
AbstractScratch() = default;
// This class does not take any responsibilty for memory management of
// its members. It is the responsibility of the derived classes to do so.
virtual ~AbstractScratch() = default;

// Scratch objects should not be copied
AbstractScratch(const AbstractScratch &) = delete;
AbstractScratch &operator=(const AbstractScratch &) = delete;

data_t *aligned_query_T()
{
return _aligned_query_T;
}
PQScratch<data_t> *pq_scratch()
{
return _pq_scratch;
}

protected:
data_t *_aligned_query_T = nullptr;
PQScratch<data_t> *_pq_scratch = nullptr;
};
} // namespace diskann
2 changes: 2 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "in_mem_data_store.h"
#include "in_mem_graph_store.h"
#include "abstract_index.h"
#include "pq_scratch.h"
#include "pq.h"

#define OVERHEAD_FACTOR 1.1
#ifdef EXEC_ENV_OLS
Expand Down
5 changes: 5 additions & 0 deletions include/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ class NeighborPriorityQueue
return _cur < _size;
}

void sort()
{
std::sort(_data.begin(), _data.begin() + _size);
}

size_t size() const
{
return _size;
Expand Down
50 changes: 9 additions & 41 deletions include/pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
#pragma once

#include "utils.h"

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512
#include "pq_common.h"

namespace diskann
{
Expand Down Expand Up @@ -53,40 +47,6 @@ class FixedChunkPQTable
void populate_chunk_inner_products(const float *query_vec, float *dist_vec);
};

template <typename T> struct PQScratch
{
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim)
{
diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
(size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float),
256);
diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));

memset(aligned_query_float, 0, aligned_dim * sizeof(float));
memset(rotated_query, 0, aligned_dim * sizeof(float));
}

void set(size_t dim, T *query, const float norm = 1.0f)
{
for (size_t d = 0; d < dim; ++d)
{
if (norm != 1.0f)
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
else
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
}
}
};

void aggregate_coords(const std::vector<unsigned> &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out);

void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
Expand All @@ -107,11 +67,19 @@ DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_tr
unsigned num_pq_chunks, std::string opq_pivots_path,
bool make_zero_mean = false);

DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim,
size_t num_pq_chunks, std::vector<float> &pivot_data_vector);

template <typename T>
int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks,
const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path,
bool use_opq = false);

DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num,
const float *pivot_data, const size_t pivots_num,
const size_t dim, const size_t num_pq_chunks,
std::vector<uint8_t> &pq);

template <typename T>
void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path,
const std::string &disk_pq_compressed_vectors_path,
Expand Down
30 changes: 30 additions & 0 deletions include/pq_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <string>
#include <sstream>

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512

namespace diskann
{
inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin";
}

inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin";
}

inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename)
{
return pivot_data_filename + "_rotation_matrix.bin";
}

} // namespace diskann
18 changes: 15 additions & 3 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

#pragma once
#include <unordered_map>
#include "common_includes.h"

#include "aligned_file_reader.h"
Expand Down Expand Up @@ -35,6 +36,15 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
#endif

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void load_labels(MemoryMappedFiles &files, const std::string &disk_index_file);
#else
DISKANN_DLLEXPORT void load_labels(const std::string& disk_index_filepath);
#endif
DISKANN_DLLEXPORT void load_label_medoid_map(
const std::string &labels_to_medoids_filepath, std::istream &medoid_stream);
DISKANN_DLLEXPORT void load_dummy_map(const std::string& dummy_map_filepath, std::istream &dummy_map_stream);

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
Expand Down Expand Up @@ -77,7 +87,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const bool use_filter, const std::vector<LabelT> &filter_labels,
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr);

Expand Down Expand Up @@ -116,9 +126,11 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT inline bool point_has_any_label(uint32_t point_id, const std::vector<LabelT> &label_ids);
void load_label_map(std::basic_istream<char> &map_reader,
std::unordered_map<std::string, LabelT> &string_to_int_map);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream<char> &infile, uint32_t &num_pts,
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);
Expand Down
22 changes: 22 additions & 0 deletions include/pq_scratch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once
#include <cstdint>
#include "pq_common.h"
#include "utils.h"

namespace diskann
{

template <typename T> class PQScratch
{
public:
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim);
void initialize(size_t dim, const T *query, const float norm = 1.0f);
};

} // namespace diskann
28 changes: 10 additions & 18 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
#include "tsl/sparse_map.h"

#include "aligned_file_reader.h"
#include "concurrent_queue.h"
#include "defaults.h"
#include "abstract_scratch.h"
#include "neighbor.h"
#include "pq.h"
#include "defaults.h"
#include "concurrent_queue.h"

namespace diskann
{
template <typename T> class PQScratch;

//
// Scratch space for in-memory index based search
// AbstractScratch space for in-memory index based search
//
template <typename T> class InMemQueryScratch
template <typename T> class InMemQueryScratch : public AbstractScratch<T>
{
public:
~InMemQueryScratch();
// REFACTOR TODO: move all parameters to a new class.
InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim,
size_t alignment_factor, bool init_pq_scratch = false);
void resize_for_new_L(uint32_t new_search_l);
Expand All @@ -47,11 +47,11 @@ template <typename T> class InMemQueryScratch
}
inline T *aligned_query()
{
return _aligned_query;
return this->_aligned_query_T;
}
inline PQScratch<T> *pq_scratch()
{
return _pq_scratch;
return this->_pq_scratch;
}
inline std::vector<Neighbor> &pool()
{
Expand Down Expand Up @@ -99,10 +99,6 @@ template <typename T> class InMemQueryScratch
uint32_t _R;
uint32_t _maxc;

T *_aligned_query = nullptr;

PQScratch<T> *_pq_scratch = nullptr;

// _pool stores all neighbors explored from best_L_nodes.
// Usually around L+R, but could be higher.
// Initialized to 3L+R for some slack, expands as needed.
Expand Down Expand Up @@ -139,21 +135,17 @@ template <typename T> class InMemQueryScratch
};

//
// Scratch space for SSD index based search
// AbstractScratch space for SSD index based search
//

template <typename T> class SSDQueryScratch
template <typename T> class SSDQueryScratch : public AbstractScratch<T>
{
public:
T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim]

char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN]
size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use

T *aligned_query_T = nullptr;

PQScratch<T> *_pq_scratch;

tsl::robin_set<size_t> visited;
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;
Expand Down
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
{
query_float[d] = (float)aligned_query[d];
}
pq_query_scratch->set(_dim, aligned_query);
pq_query_scratch->initialize(_dim, aligned_query);

// center the query and rotate if we have a rotation matrix
_pq_table.preprocess_query(query_rotated);
Expand Down
Loading
Loading