Skip to content

Commit

Permalink
reparate static search and streaming search
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanhaoji2 committed Jan 29, 2024
1 parent 12457b8 commit 237f2e4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 92 deletions.
6 changes: 4 additions & 2 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class AbstractIndex
// Initialize space for res_vectors before calling.
template <typename data_type, typename tag_type>
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
float *distances, std::vector<data_type *> &res_vectors);
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");

// Added search overload that takes L as parameter, so that we
// can customize L on a per-query basis without tampering with "Parameters"
Expand Down Expand Up @@ -120,7 +121,8 @@ class AbstractIndex
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors) = 0;
float *distances, DataVector &res_vectors, bool use_filters,
const std::string filter_label) = 0;
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
virtual void _set_universal_label(const LabelType universal_label) = 0;
};
Expand Down
6 changes: 4 additions & 2 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

// Initialize space for res_vectors before calling.
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
float *distances, std::vector<T *> &res_vectors);
float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");

// Filter support search
template <typename IndexType>
Expand Down Expand Up @@ -226,7 +227,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;

virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors) override;
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") override;

virtual void _set_universal_label(const LabelType universal_label) override;

Expand Down
77 changes: 35 additions & 42 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search(const data_type *query, cons

template <typename data_type, typename tag_type>
size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
float *distances, std::vector<data_type *> &res_vectors)
float *distances, std::vector<data_type *> &res_vectors, bool use_filters,
const std::string filter_label)
{
auto any_query = std::any(query);
auto any_tags = std::any(tags);
auto any_res_vectors = DataVector(res_vectors);
return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors);
return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label);
}

template <typename IndexType>
Expand Down Expand Up @@ -162,61 +163,53 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_w
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(const float *query, const uint64_t K,
const uint32_t L, int32_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(
const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t
AbstractIndex::search_with_tags<uint8_t, int32_t>(const uint8_t *query, const uint64_t K, const uint32_t L,
int32_t *tags, float *distances, std::vector<uint8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int32_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
int32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int32_t>(
const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(const float *query, const uint64_t K,
const uint32_t L, uint32_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint32_t>(
const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint32_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors);
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
uint32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint32_t>(
const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(const float *query, const uint64_t K,
const uint32_t L, int64_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int64_t>(
const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t
AbstractIndex::search_with_tags<uint8_t, int64_t>(const uint8_t *query, const uint64_t K, const uint32_t L,
int64_t *tags, float *distances, std::vector<uint8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, int64_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
int64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, int64_t>(
const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(const float *query, const uint64_t K,
const uint32_t L, uint64_t *tags,
float *distances,
std::vector<float *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, uint64_t>(
const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
std::vector<float *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<uint8_t, uint64_t>(
const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
std::vector<uint8_t *> &res_vectors);
std::vector<uint8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(const int8_t *query,
const uint64_t K, const uint32_t L,
uint64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors);
template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<int8_t, uint64_t>(
const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances,
std::vector<int8_t *> &res_vectors, bool use_filters, const std::string filter_label);

template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout<float>(const float *query, size_t K,
size_t L, uint32_t *indices);
Expand Down
71 changes: 25 additions & 46 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,10 +2023,14 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const
float *distances)
{
auto converted_label = this->get_converted_label(raw_label);

if (typeid(TagT *) == indices.type())
if (typeid(uint64_t *) == indices.type())
{
auto ptr = std::any_cast<uint64_t *>(indices);
return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances);
}
else if (typeid(uint32_t *) == indices.type())
{
auto ptr = std::any_cast<TagT *>(indices);
auto ptr = std::any_cast<uint32_t *>(indices);
return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances);
}
else
Expand Down Expand Up @@ -2090,24 +2094,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
{
if (best_L_nodes[i].id < _max_points)
{
// safe because Index uses uint32_t ids internally
// and IDType will be uint32_t or uint64_t
if (_enable_tags)
{
TagT tag;
if (_location_to_tag.try_get(best_L_nodes[i].id, tag))
{
indices[pos] = (IdType)tag;
}
else
{
continue;
}
}
else
{
indices[pos] = best_L_nodes[i].id;
}
indices[pos] = (IdType)best_L_nodes[i].id;

if (distances != nullptr)
{
Expand All @@ -2134,12 +2121,13 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const

template <typename T, typename TagT, typename LabelT>
size_t Index<T, TagT, LabelT>::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L,
const TagType &tags, float *distances, DataVector &res_vectors)
const TagType &tags, float *distances, DataVector &res_vectors,
bool use_filters, const std::string filter_label)
{
try
{
return this->search_with_tags(std::any_cast<const T *>(query), K, L, std::any_cast<TagT *>(tags), distances,
res_vectors.get<std::vector<T *>>());
res_vectors.get<std::vector<T *>>(), use_filters, filter_label);
}
catch (const std::bad_any_cast &e)
{
Expand All @@ -2153,7 +2141,8 @@ size_t Index<T, TagT, LabelT>::_search_with_tags(const DataType &query, const ui

template <typename T, typename TagT, typename LabelT>
size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
float *distances, std::vector<T *> &res_vectors)
float *distances, std::vector<T *> &res_vectors, bool use_filters,
const std::string filter_label)
{
if (K > (uint64_t)L)
{
Expand All @@ -2173,12 +2162,22 @@ size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K
std::shared_lock<std::shared_timed_mutex> ul(_update_lock);

const std::vector<uint32_t> init_ids = get_init_ids();
const std::vector<LabelT> unused_filter_label;

//_distance->preprocess_query(query, _data_store->get_dims(),
// scratch->aligned_query());
_data_store->preprocess_query(query, scratch);
iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true);
if (!use_filters)
{
const std::vector<LabelT> unused_filter_label;
iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true);
}
else
{
std::vector<LabelT> filter_vec;
auto converted_label = this->get_converted_label(filter_label);
filter_vec.push_back(converted_label);
iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
}

NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes();
assert(best_L_nodes.size() <= L);
Expand Down Expand Up @@ -3423,16 +3422,6 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
// TagT==uint128
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, tag_uint128, uint32_t>::search_with_filters<
tag_uint128>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, tag_uint128, uint32_t>::search_with_filters<
tag_uint128>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, tag_uint128, uint32_t>::search_with_filters<
tag_uint128>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint64_t>(
const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
Expand Down Expand Up @@ -3504,15 +3493,5 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
// TagT==uint128
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, tag_uint128, uint16_t>::search_with_filters<
tag_uint128>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, tag_uint128, uint16_t>::search_with_filters<
tag_uint128>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, tag_uint128, uint16_t>::search_with_filters<
tag_uint128>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);

} // namespace diskann

0 comments on commit 237f2e4

Please sign in to comment.