Skip to content

Commit

Permalink
Refactored loads from pq_flash_index into in_mem_filter_store
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Sep 30, 2024
1 parent fd93a04 commit 34f723f
Show file tree
Hide file tree
Showing 9 changed files with 695 additions and 309 deletions.
6 changes: 0 additions & 6 deletions .clang-format

This file was deleted.

34 changes: 34 additions & 0 deletions include/abstract_filter_store.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once
#include <vector>
#include "types.h"
#include "multi_filter/abstract_predicate.h"
namespace diskann
{
template<typename LabelT>
class AbstractFilterStore
{
public:
/// <summary>
/// Returns the filters for a data point. Only valid for base points
/// </summary>
/// <param name="point">base point id</param>
/// <returns>list of filters of the base point</returns>
virtual const std::vector<LabelT> &get_filters_for_point(location_t point) const = 0;

/// <summary>
/// Adds filters for a point.
/// </summary>
/// <param name="point"></param>
/// <param name="filters"></param>
virtual void add_filters_for_point(location_t point, const std::vector<LabelT> &filters) = 0;

/// <summary>
/// Returns a score between [0,1] indicating how many points in the dataset
/// matched the predicate
/// </summary>
/// <param name="pred">Predicate to match</param>
/// <returns>Score between [0,1] indicate %age of points matching pred</returns>
virtual float get_predicate_selectivity(const AbstractPredicate &pred) const = 0;
};

}
90 changes: 90 additions & 0 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#pragma once

#include <vector>
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"
#include <abstract_filter_store.h>

namespace diskann
{
template<typename LabelT>
class InMemFilterStore : public AbstractFilterStore<LabelT>
{
public:
/// <summary>
/// Returns the filters for a data point. Only valid for base points
/// </summary>
/// <param name="point">base point id</param>
/// <returns>list of filters of the base point</returns>
virtual const std::vector<LabelT> &get_filters_for_point(location_t point) const override;

/// <summary>
/// Adds filters for a point.
/// </summary>
/// <param name="point"></param>
/// <param name="filters"></param>
virtual void add_filters_for_point(location_t point, const std::vector<LabelT> &filters) override;

/// <summary>
/// Returns a score between [0,1] indicating how many points in the dataset
/// matched the predicate
/// </summary>
/// <param name="pred">Predicate to match</param>
/// <returns>Score between [0,1] indicate %age of points matching pred</returns>
virtual float get_predicate_selectivity(const AbstractPredicate &pred) const override;


virtual const std::unordered_map<LabelT, std::vector<location_t>>& get_label_to_medoids() const;

virtual const std::vector<location_t> &get_medoids_of_label(const LabelT label) const;

virtual void set_universal_label(const LabelT univ_label);

inline bool point_has_label(location_t point_id, const LabelT label_id) const;

inline bool is_dummy_point(location_t id) const;

inline bool point_has_label_or_universal_label(location_t point_id, const LabelT label_id) const;

inline LabelT get_converted_label(const std::string &filter_label) const;

//Returns true if the index is filter-enabled and all files were loaded correctly.
//false otherwise. Note that "false" can mean that the index does not have filter support,
//or that some index files do not exist, or that they exist and could not be opened.
bool load(const std::string& disk_index_file);

private:

// Load functions for search START
void load_label_file(const std::string_view& file_content);
void load_label_map(std::basic_istream<char> &map_reader);
void load_labels_to_medoids(std::basic_istream<char> &reader);
void load_dummy_map(std::basic_istream<char> &dummy_map_stream);

bool load_file_and_parse(
const std::string &filename,
void (*parse_fn)(const std::string_view &content));

bool load_file_and_parse(
const std::string &filename,
void (*parse_fn)(std::basic_istream<char> &stream))


// Load functions for search END

// filter support
uint32_t *_pts_to_label_offsets = nullptr;
uint32_t *_pts_to_label_counts = nullptr;
LabelT *_pts_to_labels = nullptr;
std::unordered_map<LabelT, std::vector<location_t>> _filter_to_medoid_ids;
bool _use_universal_label = false;
LabelT _universal_filter_label;
tsl::robin_set<uint32_t> _dummy_pts;
tsl::robin_set<uint32_t> _has_dummy_pts;
tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;

};

}
14 changes: 14 additions & 0 deletions include/multi_filter/abstract_predicate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <vector>

namespace diskann
{

class AbstractPredicate
{
public:
virtual ~AbstractPredicate() = 0;

};

} // namespace diskann
9 changes: 9 additions & 0 deletions include/multi_filter/filter_matcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once
namespace diskann
{
class AbstractFilterMatcher
{
public:
uint32_t get_approximate_match_count(const AbstractFilter& filter)
};
}
44 changes: 44 additions & 0 deletions include/multi_filter/simple_boolean_predicate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once
#include <vector>

namespace diskann {

enum BooleanOperator
{
AND, OR
};

/// <summary>
/// Represents a simple boolean filter condition with only
/// one kind of operator. The operator can be either AND or
/// OR. The NOT operator is not supported. The predicates
/// are expected to be integers representing predicates
/// provided by the user.
/// </summary>
/// <typeparam name="T"></typeparam>
template <typename T>
class SimpleBooleanPredicate : public AbstractPredicate
{
public:
SimpleBooleanPredicate(BooleanOperator op)
{
_op = op;
}
void add_predicate(const T &predicate)
{
_predicates.push_back(predicate);
}
const std::vector<T> &get_predicates() const
{
return _predicates;
}
const BooleanOperator get_op() const
{
return _op;
}

private:
BooleanOperator _op;
std::vector<T> _predicates;
};
}
18 changes: 6 additions & 12 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"

#include "in_mem_filter_store.h"

#define FULL_PRECISION_REORDER_MULTIPLIER 3

namespace diskann
Expand Down Expand Up @@ -221,18 +223,10 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
bool _reorder_data_exists = false;
uint64_t _reoreder_data_offset = 0;

// filter support
uint32_t *_pts_to_label_offsets = nullptr;
uint32_t *_pts_to_label_counts = nullptr;
LabelT *_pts_to_labels = nullptr;
std::unordered_map<LabelT, std::vector<uint32_t>> _filter_to_medoid_ids;
bool _use_universal_label = false;
LabelT _universal_filter_label;
tsl::robin_set<uint32_t> _dummy_pts;
tsl::robin_set<uint32_t> _has_dummy_pts;
tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;
//Moved filter-specific data structures to in_mem_filter_store.
//TODO: Make this a unique pointer
InMemFilterStore<LabelT>* _filter_store;


#ifdef EXEC_ENV_OLS
// Set to a larger value than the actual header to accommodate
Expand Down
Loading

0 comments on commit 34f723f

Please sign in to comment.