Skip to content

Commit

Permalink
Add sharding convenience function for IVF indexes (#4150)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4150

Creates a sharding convenience function for IVF indexes.
- The __**centroids on the quantizer**__ are sharded based on the given sharding function. (not the data, as data sharding by ids is already implemented by copy_subuset_to, https://github.com/facebookresearch/faiss/blob/main/faiss/IndexIVF.h#L408)
- The output is written to files based on the template filename generator param.
- The default sharding function is simply the ith vector mod the total shard count.

This would called by Laser here: https://www.internalfb.com/code/fbsource/[ce1f2e028e79]/fbcode/fblearner/flow/projects/laser/laser_sim_search/knn_trainer.py?lines=295-296. This convenience function will do the file writing, and return the created file names.

There's a few key required changes in FAISS:
1. Allow `std::vector<std::string>` to be used. Updates swigfaiss.swig and array_conversions.py to accommodate. These have to be numpy dtype of `object` instead of the more correct `unicode`, because unicode dtype is fixed length. I couldn't figure out how to create a numpy array with each of the output file names where they have different dtypes. (Say the file names are like file1, file11, file111. The dtype would need to be U5, U6, U7 respectively, as the dtype for unicode contains the length). I tried structured arrays : this does not work either, as numpy makes it into a matrix instead: the `file1 file11 file111` example with explicit setting of U5, U6, U7 turns into `[[file1 file1 file1], [file1 file11 file11], [file1 file11 file111]]`, which we do not want. If someone knows the right syntax, please yell at me
2. Create Python callbacks for sharding and template filename: `PyCallbackFilenameTemplateGenerator` and `PyCallbackShardingFunction`. Users of this function would inherit from the FilenameTemplateGenerator or ShardingFunction in C++ to pass to `shard_ivf_index_centroids`. See the other examples in python_callbacks.cpp. This is required because Python functions cannot be passed through SWIG to C++ (i.e. no std::function or function pointers), so we have to use this approach. This approach allows it to be called from both C++ and Python. test_sharding.py shows the Python calling, test_utils.cpp shows the C++ calling.

Reviewed By: asadoughi

Differential Revision: D68534991

fbshipit-source-id: b857e20c6cc4249a2ab7792db4c93dd4fb8403fd
  • Loading branch information
Michael Norris authored and facebook-github-bot committed Feb 7, 2025
1 parent 1d8f393 commit aff6bfc
Show file tree
Hide file tree
Showing 8 changed files with 412 additions and 1 deletion.
143 changes: 143 additions & 0 deletions faiss/IVFlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include <faiss/IndexPreTransform.h>
#include <faiss/IndexRefine.h>
#include <faiss/MetaIndexes.h>
#include <faiss/clone_index.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/index_io.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/utils.h>
Expand Down Expand Up @@ -519,5 +521,146 @@ void ivf_residual_add_from_flat_codes(
index->ntotal += nb;
}

int64_t DefaultShardingFunction::operator()(int64_t i, int64_t shard_count) {
return i % shard_count;
}

void handle_ivf(
faiss::IndexIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
std::vector<faiss::IndexIVF*> sharded_indexes(shard_count);
auto clone = static_cast<faiss::IndexIVF*>(faiss::clone_index(index));
clone->quantizer->reset();
for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i] =
static_cast<faiss::IndexIVF*>(faiss::clone_index(clone));
}

// assign centroids to each sharded Index based on sharding_function, and
// add them to the quantizer of each sharded index
std::vector<std::vector<float>> sharded_centroids(shard_count);
for (int64_t i = 0; i < index->quantizer->ntotal; i++) {
int64_t shard_id = (*sharding_function)(i, shard_count);
float* reconstructed = new float[index->quantizer->d];
index->quantizer->reconstruct(i, reconstructed);
sharded_centroids[shard_id].insert(
sharded_centroids[shard_id].end(),
&reconstructed[0],
&reconstructed[index->quantizer->d]);
delete[] reconstructed;
}
for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i]->quantizer->add(
sharded_centroids[i].size() / index->quantizer->d,
sharded_centroids[i].data());
}

for (int64_t i = 0; i < shard_count; i++) {
char fname[256];
snprintf(fname, 256, filename_template.c_str(), i);
faiss::write_index(sharded_indexes[i], fname);
}

for (int64_t i = 0; i < shard_count; i++) {
delete sharded_indexes[i];
}
}

void handle_binary_ivf(
faiss::IndexBinaryIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
std::vector<faiss::IndexBinaryIVF*> sharded_indexes(shard_count);

auto clone = static_cast<faiss::IndexBinaryIVF*>(
faiss::clone_binary_index(index));
clone->quantizer->reset();

for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i] = static_cast<faiss::IndexBinaryIVF*>(
faiss::clone_binary_index(clone));
}

// assign centroids to each sharded Index based on sharding_function, and
// add them to the quantizer of each sharded index
int64_t reconstruction_size = index->quantizer->d / 8;
std::vector<std::vector<uint8_t>> sharded_centroids(shard_count);
for (int64_t i = 0; i < index->quantizer->ntotal; i++) {
int64_t shard_id = (*sharding_function)(i, shard_count);
uint8_t* reconstructed = new uint8_t[reconstruction_size];
index->quantizer->reconstruct(i, reconstructed);
sharded_centroids[shard_id].insert(
sharded_centroids[shard_id].end(),
&reconstructed[0],
&reconstructed[reconstruction_size]);
delete[] reconstructed;
}
for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i]->quantizer->add(
sharded_centroids[i].size() / reconstruction_size,
sharded_centroids[i].data());
}

for (int64_t i = 0; i < shard_count; i++) {
char fname[256];
snprintf(fname, 256, filename_template.c_str(), i);
faiss::write_index_binary(sharded_indexes[i], fname);
}

for (int64_t i = 0; i < shard_count; i++) {
delete sharded_indexes[i];
}
}

template <typename IndexType>
void sharding_helper(
IndexType* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
FAISS_THROW_IF_MSG(index->quantizer->ntotal == 0, "No centroids to shard.");
FAISS_THROW_IF_MSG(
filename_template.find("%d") == std::string::npos,
"Invalid filename_template. Must contain format specifier for shard count.");

DefaultShardingFunction default_sharding_function;
if (sharding_function == nullptr) {
sharding_function = &default_sharding_function;
}

if (typeid(IndexType) == typeid(faiss::IndexIVF)) {
handle_ivf(
dynamic_cast<faiss::IndexIVF*>(index),
shard_count,
filename_template,
sharding_function);
} else if (typeid(IndexType) == typeid(faiss::IndexBinaryIVF)) {
handle_binary_ivf(
dynamic_cast<faiss::IndexBinaryIVF*>(index),
shard_count,
filename_template,
sharding_function);
}
}

void shard_ivf_index_centroids(
faiss::IndexIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
sharding_helper(index, shard_count, filename_template, sharding_function);
}

void shard_binary_ivf_index_centroids(
faiss::IndexBinaryIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
sharding_helper(index, shard_count, filename_template, sharding_function);
}

} // namespace ivflib
} // namespace faiss
38 changes: 38 additions & 0 deletions faiss/IVFlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* IndexIVFs embedded within an IndexPreTransform.
*/

#include <faiss/IndexBinaryIVF.h>
#include <faiss/IndexIVF.h>
#include <vector>

Expand Down Expand Up @@ -167,6 +168,43 @@ void ivf_residual_add_from_flat_codes(
const uint8_t* codes,
int64_t code_size = -1);

struct ShardingFunction {
virtual int64_t operator()(int64_t i, int64_t shard_count) = 0;
virtual ~ShardingFunction() = default;
ShardingFunction() {}
ShardingFunction(const ShardingFunction&) = default;
ShardingFunction(ShardingFunction&&) = default;
ShardingFunction& operator=(const ShardingFunction&) = default;
ShardingFunction& operator=(ShardingFunction&&) = default;
};
struct DefaultShardingFunction : ShardingFunction {
int64_t operator()(int64_t i, int64_t shard_count) override;
};

/**
* Shards an IVF index centroids by the given sharding function, and writes
* the index to the path given by filename_generator. The centroids must already
* be added to the index quantizer.
*
* @param index The IVF index containing centroids to shard.
* @param shard_count Number of shards.
* @param filename_template Template for shard filenames.
* @param sharding_function The function to shard by. The default is ith vector
* mod shard_count.
* @return The number of shards written.
*/
void shard_ivf_index_centroids(
IndexIVF* index,
int64_t shard_count = 20,
const std::string& filename_template = "shard.%d.index",
ShardingFunction* sharding_function = nullptr);

void shard_binary_ivf_index_centroids(
faiss::IndexBinaryIVF* index,
int64_t shard_count = 20,
const std::string& filename_template = "shard.%d.index",
ShardingFunction* sharding_function = nullptr);

} // namespace ivflib
} // namespace faiss

Expand Down
34 changes: 34 additions & 0 deletions faiss/clone_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <faiss/IndexAdditiveQuantizerFastScan.h>
#include <faiss/IndexBinary.h>
#include <faiss/IndexBinaryFlat.h>
#include <faiss/IndexBinaryHNSW.h>
#include <faiss/IndexBinaryIVF.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexHNSW.h>
#include <faiss/IndexIVF.h>
Expand Down Expand Up @@ -107,6 +109,11 @@ IndexIVF* Cloner::clone_IndexIVF(const IndexIVF* ivf) {
return nullptr;
}

IndexBinaryIVF* clone_IndexBinaryIVF(const IndexBinaryIVF* ivf) {
TRYCLONE(IndexBinaryIVF, ivf)
return nullptr;
}

IndexRefine* clone_IndexRefine(const IndexRefine* ir) {
TRYCLONE(IndexRefineFlat, ir)
TRYCLONE(IndexRefine, ir) {
Expand All @@ -131,6 +138,11 @@ IndexHNSW* clone_IndexHNSW(const IndexHNSW* ihnsw) {
}
}

IndexBinaryHNSW* clone_IndexBinaryHNSW(const IndexBinaryHNSW* ihnsw) {
TRYCLONE(IndexBinaryHNSW, ihnsw)
return nullptr;
}

IndexNNDescent* clone_IndexNNDescent(const IndexNNDescent* innd) {
TRYCLONE(IndexNNDescentFlat, innd)
TRYCLONE(IndexNNDescent, innd) {
Expand Down Expand Up @@ -385,6 +397,28 @@ Quantizer* clone_Quantizer(const Quantizer* quant) {
IndexBinary* clone_binary_index(const IndexBinary* index) {
if (auto ii = dynamic_cast<const IndexBinaryFlat*>(index)) {
return new IndexBinaryFlat(*ii);
} else if (
const IndexBinaryIVF* ivf =
dynamic_cast<const IndexBinaryIVF*>(index)) {
IndexBinaryIVF* res = clone_IndexBinaryIVF(ivf);
if (ivf->invlists == nullptr) {
res->invlists = nullptr;
} else {
res->invlists = clone_InvertedLists(ivf->invlists);
res->own_invlists = true;
}

res->own_fields = true;
res->quantizer = clone_binary_index(ivf->quantizer);

return res;
} else if (
const IndexBinaryHNSW* ihnsw =
dynamic_cast<const IndexBinaryHNSW*>(index)) {
IndexBinaryHNSW* res = clone_IndexBinaryHNSW(ihnsw);
res->own_fields = true;
res->storage = clone_binary_index(ihnsw->storage);
return res;
} else {
FAISS_THROW_MSG("cannot clone this type of index");
}
Expand Down
3 changes: 2 additions & 1 deletion faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
class_wrappers.handle_Linear(Linear)
class_wrappers.handle_QINCo(QINCo)
class_wrappers.handle_QINCoStep(QINCoStep)
shard_ivf_index_centroids = class_wrappers.handle_shard_ivf_index_centroids(shard_ivf_index_centroids)


this_module = sys.modules[__name__]
Expand Down Expand Up @@ -170,7 +171,7 @@ def replacement_function(*args):
add_ref_in_constructor(GpuIndexIVFPQ, 1)
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1)
except NameError as e:
logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0])
logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss." % e.args[0])

add_ref_in_constructor(IndexIVFFlat, 0)
add_ref_in_constructor(IndexIVFFlatDedup, 0)
Expand Down
9 changes: 9 additions & 0 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,3 +1395,12 @@ def from_torch(self, qinco):

the_class.__init__ = replacement_init
the_class.from_torch = from_torch


def handle_shard_ivf_index_centroids(func):
def wrapper(*args, **kwargs):
args = list(args)
if len(args) > 3 and args[3] is not None:
args[3] = faiss.PyCallbackShardingFunction(args[3])
return func(*args, **kwargs)
return wrapper
24 changes: 24 additions & 0 deletions faiss/python/python_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,27 @@ PyCallbackIDSelector::~PyCallbackIDSelector() {
PyThreadLock gil;
Py_DECREF(callback);
}

/***********************************************************
* Callbacks for IVF index sharding
***********************************************************/

PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback)
: callback(callback) {
PyThreadLock gil;
Py_INCREF(callback);
}

int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) {
PyThreadLock gil;
PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count);
if (shard_id == nullptr) {
FAISS_THROW_MSG("propagate py error");
}
return PyLong_AsLongLong(shard_id);
}

PyCallbackShardingFunction::~PyCallbackShardingFunction() {
PyThreadLock gil;
Py_DECREF(callback);
}
22 changes: 22 additions & 0 deletions faiss/python/python_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <faiss/IVFlib.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/io.h>
#include <faiss/invlists/InvertedLists.h>
Expand Down Expand Up @@ -58,3 +59,24 @@ struct PyCallbackIDSelector : faiss::IDSelector {

~PyCallbackIDSelector() override;
};

/***********************************************************
* Callbacks for IVF index sharding
***********************************************************/

struct PyCallbackShardingFunction : faiss::ivflib::ShardingFunction {
PyObject* callback;

explicit PyCallbackShardingFunction(PyObject* callback);

int64_t operator()(int64_t i, int64_t shard_count) override;

~PyCallbackShardingFunction() override;

PyCallbackShardingFunction(const PyCallbackShardingFunction&) = delete;
PyCallbackShardingFunction(PyCallbackShardingFunction&&) noexcept = default;
PyCallbackShardingFunction& operator=(const PyCallbackShardingFunction&) =
default;
PyCallbackShardingFunction& operator=(PyCallbackShardingFunction&&) =
default;
};
Loading

0 comments on commit aff6bfc

Please sign in to comment.