Skip to content

Commit

Permalink
Support KNN search for FAISS IVF indices (#13258)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #13258

Differential Revision: D67684898
  • Loading branch information
ltamasi authored and facebook-github-bot committed Dec 28, 2024
1 parent ac3cde3 commit bd905aa
Show file tree
Hide file tree
Showing 4 changed files with 422 additions and 13 deletions.
18 changes: 18 additions & 0 deletions include/rocksdb/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,24 @@ struct ReadOptions {
// Default: false
bool allow_unprepared_value = false;

// The maximum number of neighbors K to return when performing a
// K-nearest-neighbors vector similarity search. The number of neighbors
// returned can be smaller if there are not enough vectors in the inverted
// lists probed. Only applicable to FAISS IVF secondary indices, where it must
// be specified and positive. See also `SecondaryIndex::NewIterator` and
// `similarity_search_probes` below.
//
// Default: none
std::optional<size_t> similarity_search_neighbors;

// The number of inverted lists to probe when performing a K-nearest-neighbors
// vector similarity search. Only applicable to FAISS IVF secondary indices,
// where it must be specified and positive. See also
// `SecondaryIndex::NewIterator` and `similarity_search_neighbors` above.
//
// Default: none
std::optional<size_t> similarity_search_probes;

// *** END options only relevant to iterators or scans ***

// *** BEGIN options for RocksDB internal use only ***
Expand Down
250 changes: 239 additions & 11 deletions utilities/secondary_index/faiss_ivf_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,157 @@
#include "utilities/secondary_index/faiss_ivf_index.h"

#include <cassert>
#include <optional>
#include <utility>

#include "faiss/invlists/InvertedLists.h"
#include "util/autovector.h"
#include "util/coding.h"

namespace ROCKSDB_NAMESPACE {

class FaissIVFIndex::KNNIterator : public Iterator {
public:
KNNIterator(faiss::IndexIVF* index, Iterator* underlying_it, size_t k,
size_t probes)
: index_(index),
underlying_it_(underlying_it),
k_(k),
probes_(probes),
distances_(k, 0.0f),
labels_(k, -1),
pos_(0) {
assert(index_);
assert(underlying_it_);
assert(k_ > 0);
assert(probes_ > 0);
}

Iterator* GetUnderlyingIterator() const { return underlying_it_; }

faiss::idx_t AddKey(std::string&& key) {
keys_.emplace_back(std::move(key));

return static_cast<faiss::idx_t>(keys_.size()) - 1;
}

bool Valid() const override {
assert(labels_.size() == k_);
assert(distances_.size() == k_);

return status_.ok() && pos_ >= 0 && pos_ < k_ && labels_[pos_] >= 0;
}

void SeekToFirst() override {
status_ =
Status::NotSupported("SeekToFirst not supported for FaissIVFIndex");
}

void SeekToLast() override {
status_ =
Status::NotSupported("SeekToLast not supported for FaissIVFIndex");
}

void Seek(const Slice& target) override {
distances_.assign(k_, 0.0f);
labels_.assign(k_, -1);
status_ = Status::OK();
pos_ = 0;
keys_.clear();

faiss::SearchParametersIVF params;
params.nprobe = probes_;
params.inverted_list_context = this;

constexpr faiss::idx_t n = 1;

try {
index_->search(n, reinterpret_cast<const float*>(target.data()), k_,
distances_.data(), labels_.data(), &params);
} catch (const std::exception& e) {
status_ = Status::InvalidArgument(e.what());
}
}

void SeekForPrev(const Slice& /* target */) override {
status_ =
Status::NotSupported("SeekForPrev not supported for FaissIVFIndex");
}

void Next() override {
assert(Valid());

++pos_;
}

void Prev() override {
assert(Valid());

--pos_;
}

Status status() const override { return status_; }

Slice key() const override {
assert(Valid());
assert(labels_[pos_] >= 0);
assert(labels_[pos_] < keys_.size());

return keys_[labels_[pos_]];
}

Slice value() const override {
assert(Valid());

return Slice();
}

const WideColumns& columns() const override {
assert(Valid());

return kNoWideColumns;
}

Slice timestamp() const override {
assert(Valid());

return Slice();
}

Status GetProperty(std::string prop_name, std::string* prop) override {
if (!prop) {
return Status::InvalidArgument("No property pointer provided");
}

if (!Valid()) {
return Status::InvalidArgument("Iterator is not valid");
}

if (prop_name == kPropertyName_) {
*prop = std::to_string(distances_[pos_]);
return Status::OK();
}

return Iterator::GetProperty(std::move(prop_name), prop);
}

private:
faiss::IndexIVF* index_;
Iterator* underlying_it_;
size_t k_;
size_t probes_;
std::vector<float> distances_;
std::vector<faiss::idx_t> labels_;
Status status_;
faiss::idx_t pos_;
autovector<std::string> keys_;

static const std::string kPropertyName_;
};

const std::string FaissIVFIndex::KNNIterator::kPropertyName_ =
"rocksdb.faiss.ivf.index.distance";

class FaissIVFIndex::Adapter : public faiss::InvertedLists {
public:
Adapter(size_t num_lists, size_t code_size)
Expand All @@ -36,14 +181,13 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists {
return nullptr;
}

// Iterator-based read interface; not yet implemented
// Iterator-based read interface
faiss::InvertedListsIterator* get_iterator(
size_t /* list_no */,
void* /* inverted_list_context */ = nullptr) const override {
// TODO: implement this
size_t list_no, void* inverted_list_context = nullptr) const override {
KNNIterator* const it = static_cast<KNNIterator*>(inverted_list_context);
assert(it);

assert(false);
return nullptr;
return new IteratorAdapter(it, list_no, code_size);
}

// Write interface; only add_entry is implemented/required for now
Expand Down Expand Up @@ -80,6 +224,77 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists {
void resize(size_t /* list_no */, size_t /* new_size */) override {
assert(false);
}

private:
class IteratorAdapter : public faiss::InvertedListsIterator {
public:
IteratorAdapter(KNNIterator* it, size_t list_no, size_t code_size)
: it_(it),
underlying_it_(it->GetUnderlyingIterator()),
prefix_(FaissIVFIndex::SerializeLabel(list_no)),
prefix_slice_(prefix_),
code_size_(code_size) {
assert(it_);
assert(underlying_it_);

// FIXME: here we rely on the empty Slice being less than any other one,
// which is true for e.g. BytewiseComparator but not in general
underlying_it_->Seek(prefix_slice_);
Update();
}

bool is_available() const override { return id_and_codes_.has_value(); }

void next() override {
underlying_it_->Next();
Update();
}

std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override {
assert(is_available());

return *id_and_codes_;
}

private:
void Update() {
id_and_codes_.reset();

if (!underlying_it_->Valid()) {
return;
}

Slice key = underlying_it_->key();
if (!key.starts_with(prefix_slice_)) {
return;
}

if (!underlying_it_->PrepareValue()) {
throw std::runtime_error(
"Failed to prepare value during iteration in FaissIVFIndex");
}

const Slice& value = underlying_it_->value();
if (value.size() != code_size_) {
throw std::runtime_error(
"Code with unexpected size encountered during iteration in "
"FaissIVFIndex");
}

key.remove_prefix(prefix_slice_.size());

const faiss::idx_t id = it_->AddKey(key.ToString());

id_and_codes_.emplace(id, reinterpret_cast<const uint8_t*>(value.data()));
}

KNNIterator* it_;
Iterator* underlying_it_;
std::string prefix_;
Slice prefix_slice_;
size_t code_size_;
std::optional<std::pair<faiss::idx_t, const uint8_t*>> id_and_codes_;
};
};

std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) {
Expand All @@ -105,6 +320,7 @@ FaissIVFIndex::FaissIVFIndex(std::unique_ptr<faiss::IndexIVF>&& index,
assert(index_);
assert(index_->quantizer);

index_->parallel_mode = 0;
index_->replace_invlists(adapter_.get());
}

Expand Down Expand Up @@ -202,7 +418,7 @@ Status FaissIVFIndex::GetSecondaryValue(

if (code_str.size() != index_->code_size) {
return Status::InvalidArgument(
"Unexpected code returned by fine quantizer");
"Code with unexpected size returned by fine quantizer");
}

secondary_value->emplace(std::move(code_str));
Expand All @@ -211,10 +427,22 @@ Status FaissIVFIndex::GetSecondaryValue(
}

std::unique_ptr<Iterator> FaissIVFIndex::NewIterator(
const ReadOptions& /* read_options */,
Iterator* /* underlying_it */) const {
// TODO: implement this
return std::unique_ptr<Iterator>(NewErrorIterator(Status::NotSupported()));
const ReadOptions& read_options, Iterator* it) const {
if (!read_options.similarity_search_neighbors.has_value() ||
*read_options.similarity_search_neighbors == 0) {
return std::unique_ptr<Iterator>(NewErrorIterator(
Status::InvalidArgument("Invalid number of neighbors")));
}

if (!read_options.similarity_search_probes.has_value() ||
*read_options.similarity_search_probes == 0) {
return std::unique_ptr<Iterator>(
NewErrorIterator(Status::InvalidArgument("Invalid number of probes")));
}

return std::make_unique<KNNIterator>(
index_.get(), it, *read_options.similarity_search_neighbors,
*read_options.similarity_search_probes);
}

} // namespace ROCKSDB_NAMESPACE
1 change: 1 addition & 0 deletions utilities/secondary_index/faiss_ivf_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class FaissIVFIndex : public SecondaryIndex {
Iterator* underlying_it) const override;

private:
class KNNIterator;
class Adapter;

static std::string SerializeLabel(faiss::idx_t label);
Expand Down
Loading

0 comments on commit bd905aa

Please sign in to comment.