diff --git a/faiss/IVFlib.cpp b/faiss/IVFlib.cpp index 83812f6abe..11900f4b09 100644 --- a/faiss/IVFlib.cpp +++ b/faiss/IVFlib.cpp @@ -16,7 +16,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -519,5 +521,146 @@ void ivf_residual_add_from_flat_codes( index->ntotal += nb; } +int64_t DefaultShardingFunction::operator()(int64_t i, int64_t shard_count) { + return i % shard_count; +} + +void handle_ivf( + faiss::IndexIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + std::vector sharded_indexes(shard_count); + auto clone = static_cast(faiss::clone_index(index)); + clone->quantizer->reset(); + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i] = + static_cast(faiss::clone_index(clone)); + } + + // assign centroids to each sharded Index based on sharding_function, and + // add them to the quantizer of each sharded index + std::vector> sharded_centroids(shard_count); + for (int64_t i = 0; i < index->quantizer->ntotal; i++) { + int64_t shard_id = (*sharding_function)(i, shard_count); + float* reconstructed = new float[index->quantizer->d]; + index->quantizer->reconstruct(i, reconstructed); + sharded_centroids[shard_id].insert( + sharded_centroids[shard_id].end(), + &reconstructed[0], + &reconstructed[index->quantizer->d]); + delete[] reconstructed; + } + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i]->quantizer->add( + sharded_centroids[i].size() / index->quantizer->d, + sharded_centroids[i].data()); + } + + for (int64_t i = 0; i < shard_count; i++) { + char fname[256]; + snprintf(fname, 256, filename_template.c_str(), i); + faiss::write_index(sharded_indexes[i], fname); + } + + for (int64_t i = 0; i < shard_count; i++) { + delete sharded_indexes[i]; + } +} + +void handle_binary_ivf( + faiss::IndexBinaryIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + std::vector sharded_indexes(shard_count); + + auto clone = static_cast( + faiss::clone_binary_index(index)); + clone->quantizer->reset(); + + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i] = static_cast( + faiss::clone_binary_index(clone)); + } + + // assign centroids to each sharded Index based on sharding_function, and + // add them to the quantizer of each sharded index + int64_t reconstruction_size = index->quantizer->d / 8; + std::vector> sharded_centroids(shard_count); + for (int64_t i = 0; i < index->quantizer->ntotal; i++) { + int64_t shard_id = (*sharding_function)(i, shard_count); + uint8_t* reconstructed = new uint8_t[reconstruction_size]; + index->quantizer->reconstruct(i, reconstructed); + sharded_centroids[shard_id].insert( + sharded_centroids[shard_id].end(), + &reconstructed[0], + &reconstructed[reconstruction_size]); + delete[] reconstructed; + } + for (int64_t i = 0; i < shard_count; i++) { + sharded_indexes[i]->quantizer->add( + sharded_centroids[i].size() / reconstruction_size, + sharded_centroids[i].data()); + } + + for (int64_t i = 0; i < shard_count; i++) { + char fname[256]; + snprintf(fname, 256, filename_template.c_str(), i); + faiss::write_index_binary(sharded_indexes[i], fname); + } + + for (int64_t i = 0; i < shard_count; i++) { + delete sharded_indexes[i]; + } +} + +template +void sharding_helper( + IndexType* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + FAISS_THROW_IF_MSG(index->quantizer->ntotal == 0, "No centroids to shard."); + FAISS_THROW_IF_MSG( + filename_template.find("%d") == std::string::npos, + "Invalid filename_template. Must contain format specifier for shard count."); + + DefaultShardingFunction default_sharding_function; + if (sharding_function == nullptr) { + sharding_function = &default_sharding_function; + } + + if (typeid(IndexType) == typeid(faiss::IndexIVF)) { + handle_ivf( + dynamic_cast(index), + shard_count, + filename_template, + sharding_function); + } else if (typeid(IndexType) == typeid(faiss::IndexBinaryIVF)) { + handle_binary_ivf( + dynamic_cast(index), + shard_count, + filename_template, + sharding_function); + } +} + +void shard_ivf_index_centroids( + faiss::IndexIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + sharding_helper(index, shard_count, filename_template, sharding_function); +} + +void shard_binary_ivf_index_centroids( + faiss::IndexBinaryIVF* index, + int64_t shard_count, + const std::string& filename_template, + ShardingFunction* sharding_function) { + sharding_helper(index, shard_count, filename_template, sharding_function); +} + } // namespace ivflib } // namespace faiss diff --git a/faiss/IVFlib.h b/faiss/IVFlib.h index 6f6a590c72..8a83ba515e 100644 --- a/faiss/IVFlib.h +++ b/faiss/IVFlib.h @@ -14,6 +14,7 @@ * IndexIVFs embedded within an IndexPreTransform. */ +#include #include #include @@ -167,6 +168,43 @@ void ivf_residual_add_from_flat_codes( const uint8_t* codes, int64_t code_size = -1); +struct ShardingFunction { + virtual int64_t operator()(int64_t i, int64_t shard_count) = 0; + virtual ~ShardingFunction() = default; + ShardingFunction() {} + ShardingFunction(const ShardingFunction&) = default; + ShardingFunction(ShardingFunction&&) = default; + ShardingFunction& operator=(const ShardingFunction&) = default; + ShardingFunction& operator=(ShardingFunction&&) = default; +}; +struct DefaultShardingFunction : ShardingFunction { + int64_t operator()(int64_t i, int64_t shard_count) override; +}; + +/** + * Shards an IVF index centroids by the given sharding function, and writes + * the index to the path given by filename_generator. The centroids must already + * be added to the index quantizer. + * + * @param index The IVF index containing centroids to shard. + * @param shard_count Number of shards. + * @param filename_template Template for shard filenames. + * @param sharding_function The function to shard by. The default is ith vector + * mod shard_count. + * @return The number of shards written. + */ +void shard_ivf_index_centroids( + IndexIVF* index, + int64_t shard_count = 20, + const std::string& filename_template = "shard.%d.index", + ShardingFunction* sharding_function = nullptr); + +void shard_binary_ivf_index_centroids( + faiss::IndexBinaryIVF* index, + int64_t shard_count = 20, + const std::string& filename_template = "shard.%d.index", + ShardingFunction* sharding_function = nullptr); + } // namespace ivflib } // namespace faiss diff --git a/faiss/clone_index.cpp b/faiss/clone_index.cpp index 7174cd6ae0..bc08283740 100644 --- a/faiss/clone_index.cpp +++ b/faiss/clone_index.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include #include #include @@ -107,6 +109,11 @@ IndexIVF* Cloner::clone_IndexIVF(const IndexIVF* ivf) { return nullptr; } +IndexBinaryIVF* clone_IndexBinaryIVF(const IndexBinaryIVF* ivf) { + TRYCLONE(IndexBinaryIVF, ivf) + return nullptr; +} + IndexRefine* clone_IndexRefine(const IndexRefine* ir) { TRYCLONE(IndexRefineFlat, ir) TRYCLONE(IndexRefine, ir) { @@ -131,6 +138,11 @@ IndexHNSW* clone_IndexHNSW(const IndexHNSW* ihnsw) { } } +IndexBinaryHNSW* clone_IndexBinaryHNSW(const IndexBinaryHNSW* ihnsw) { + TRYCLONE(IndexBinaryHNSW, ihnsw) + return nullptr; +} + IndexNNDescent* clone_IndexNNDescent(const IndexNNDescent* innd) { TRYCLONE(IndexNNDescentFlat, innd) TRYCLONE(IndexNNDescent, innd) { @@ -385,6 +397,28 @@ Quantizer* clone_Quantizer(const Quantizer* quant) { IndexBinary* clone_binary_index(const IndexBinary* index) { if (auto ii = dynamic_cast(index)) { return new IndexBinaryFlat(*ii); + } else if ( + const IndexBinaryIVF* ivf = + dynamic_cast(index)) { + IndexBinaryIVF* res = clone_IndexBinaryIVF(ivf); + if (ivf->invlists == nullptr) { + res->invlists = nullptr; + } else { + res->invlists = clone_InvertedLists(ivf->invlists); + res->own_invlists = true; + } + + res->own_fields = true; + res->quantizer = clone_binary_index(ivf->quantizer); + + return res; + } else if ( + const IndexBinaryHNSW* ihnsw = + dynamic_cast(index)) { + IndexBinaryHNSW* res = clone_IndexBinaryHNSW(ihnsw); + res->own_fields = true; + res->storage = clone_binary_index(ihnsw->storage); + return res; } else { FAISS_THROW_MSG("cannot clone this type of index"); } diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 9d956ebe71..7266da71f3 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -53,6 +53,7 @@ class_wrappers.handle_Linear(Linear) class_wrappers.handle_QINCo(QINCo) class_wrappers.handle_QINCoStep(QINCoStep) +shard_ivf_index_centroids = class_wrappers.handle_shard_ivf_index_centroids(shard_ivf_index_centroids) this_module = sys.modules[__name__] @@ -170,7 +171,7 @@ def replacement_function(*args): add_ref_in_constructor(GpuIndexIVFPQ, 1) add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1) except NameError as e: - logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0]) + logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss." % e.args[0]) add_ref_in_constructor(IndexIVFFlat, 0) add_ref_in_constructor(IndexIVFFlatDedup, 0) diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 607fdd6d29..46f8b0195f 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -1395,3 +1395,12 @@ def from_torch(self, qinco): the_class.__init__ = replacement_init the_class.from_torch = from_torch + + +def handle_shard_ivf_index_centroids(func): + def wrapper(*args, **kwargs): + args = list(args) + if len(args) > 3 and args[3] is not None: + args[3] = faiss.PyCallbackShardingFunction(args[3]) + return func(*args, **kwargs) + return wrapper diff --git a/faiss/python/python_callbacks.cpp b/faiss/python/python_callbacks.cpp index ce36bed437..8b78bf1e43 100644 --- a/faiss/python/python_callbacks.cpp +++ b/faiss/python/python_callbacks.cpp @@ -134,3 +134,27 @@ PyCallbackIDSelector::~PyCallbackIDSelector() { PyThreadLock gil; Py_DECREF(callback); } + +/*********************************************************** + * Callbacks for IVF index sharding + ***********************************************************/ + +PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback) + : callback(callback) { + PyThreadLock gil; + Py_INCREF(callback); +} + +int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) { + PyThreadLock gil; + PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count); + if (shard_id == nullptr) { + FAISS_THROW_MSG("propagate py error"); + } + return PyLong_AsLongLong(shard_id); +} + +PyCallbackShardingFunction::~PyCallbackShardingFunction() { + PyThreadLock gil; + Py_DECREF(callback); +} diff --git a/faiss/python/python_callbacks.h b/faiss/python/python_callbacks.h index fa8ebaf53c..072e69f91f 100644 --- a/faiss/python/python_callbacks.h +++ b/faiss/python/python_callbacks.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -58,3 +59,24 @@ struct PyCallbackIDSelector : faiss::IDSelector { ~PyCallbackIDSelector() override; }; + +/*********************************************************** + * Callbacks for IVF index sharding + ***********************************************************/ + +struct PyCallbackShardingFunction : faiss::ivflib::ShardingFunction { + PyObject* callback; + + explicit PyCallbackShardingFunction(PyObject* callback); + + int64_t operator()(int64_t i, int64_t shard_count) override; + + ~PyCallbackShardingFunction() override; + + PyCallbackShardingFunction(const PyCallbackShardingFunction&) = delete; + PyCallbackShardingFunction(PyCallbackShardingFunction&&) noexcept = default; + PyCallbackShardingFunction& operator=(const PyCallbackShardingFunction&) = + default; + PyCallbackShardingFunction& operator=(PyCallbackShardingFunction&&) = + default; +}; diff --git a/tests/test_ivflib.py b/tests/test_ivflib.py index d905f3d486..4121304689 100644 --- a/tests/test_ivflib.py +++ b/tests/test_ivflib.py @@ -8,6 +8,9 @@ import unittest import faiss import numpy as np +import os +import random + class TestIVFlib(unittest.TestCase): @@ -180,3 +183,140 @@ def test_small_data(self): assert np.all(lims == ref_lims) assert np.all(D == ref_D) assert np.all(I == ref_I) + + +class TestIvfSharding(unittest.TestCase): + d = 32 + nlist = 100 + nb = 1000 + + def custom_sharding_function(self, i, _): + return 1 if i % 2 == 0 else 7 + + # Mimics the default in DefaultShardingFunction. + # This impl is just used for verification. + def default_sharding_function(self, i, shard_count): + return i % shard_count + + def verify_sharded_ivf_indexes( + self, template, xb, shard_count, sharding_function): + sharded_indexes_counters = [0] * shard_count + sharded_indexes = [] + for i in range(shard_count): + if xb[0].dtype.name == 'uint8': + index = faiss.read_index_binary(template % i) + else: + index = faiss.read_index(template % i) + sharded_indexes.append(index) + # Reconstruct and verify each centroid + nb = len(xb) + for i in range(nb): + shard_id = sharding_function(i, shard_count) + reconstructed = sharded_indexes[shard_id].quantizer.reconstruct( + sharded_indexes_counters[shard_id]) + sharded_indexes_counters[shard_id] += 1 + print(f"reconstructed: {reconstructed} xb[i]: {xb[i]}") + np.testing.assert_array_equal(reconstructed, xb[i]) + # Clean up + for i in range(shard_count): + os.remove(template % i) + + def test_save_index_shards_by_centroids_no_op(self): + quantizer = faiss.IndexFlatL2(self.d) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + with self.assertRaises(RuntimeError): + faiss.shard_ivf_index_centroids( + index, + 10, + "shard.%d.index", + None + ) + + def test_save_index_shards_by_centroids_flat_quantizer_default_sharding( + self): + xb = np.random.rand(self.nb, self.d).astype('float32') + quantizer = faiss.IndexFlatL2(self.d) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + shard_count = 3 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_ivf_index_centroids( + index, + shard_count, + template + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function) + + def test_save_index_shards_by_centroids_flat_quantizer_custom_sharding( + self): + xb = np.random.rand(self.nb, self.d).astype('float32') + quantizer = faiss.IndexFlatL2(self.d) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + shard_count = 20 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_ivf_index_centroids( + index, + shard_count, + template, + self.custom_sharding_function + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.custom_sharding_function) + + def test_save_index_shards_by_centroids_hnsw_quantizer(self): + xb = np.random.rand(self.nb, self.d).astype('float32') + quantizer = faiss.IndexHNSWFlat(self.d, 32) + index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist) + shard_count = 17 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_ivf_index_centroids( + index, + shard_count, + template, + None + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function) + + def test_save_index_shards_by_centroids_binary_flat_quantizer(self): + xb = np.random.randint(256, size=(self.nb, int(self.d / 8))).astype('uint8') + quantizer = faiss.IndexBinaryFlat(self.d) + index = faiss.IndexBinaryIVF(quantizer, self.d, self.nlist) + shard_count = 11 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_binary_ivf_index_centroids( + index, + shard_count, + template + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function) + + def test_save_index_shards_by_centroids_binary_hnsw_quantizer(self): + xb = np.random.randint(256, size=(self.nb, int(self.d / 8))).astype('uint8') + quantizer = faiss.IndexBinaryHNSW(self.d, 32) + index = faiss.IndexBinaryIVF(quantizer, self.d, self.nlist) + shard_count = 13 + + index.quantizer.add(xb) + + template = str(random.randint(0, 100000)) + "shard.%d.index" + faiss.shard_binary_ivf_index_centroids( + index, + shard_count, + template + ) + self.verify_sharded_ivf_indexes( + template, xb, shard_count, self.default_sharding_function)