diff --git a/apis/python/test/test_backwards_compatibility.py b/apis/python/test/test_backwards_compatibility.py index ade1cb6cc..9b8f3516d 100644 --- a/apis/python/test/test_backwards_compatibility.py +++ b/apis/python/test/test_backwards_compatibility.py @@ -2,6 +2,7 @@ from tiledb.vector_search.flat_index import FlatIndex from tiledb.vector_search.ivf_flat_index import IVFFlatIndex +from tiledb.vector_search.ivf_pq_index import IVFPQIndex from tiledb.vector_search.utils import load_fvecs from tiledb.vector_search.vamana_index import VamanaIndex @@ -63,6 +64,8 @@ def test_query_old_indices(): index = FlatIndex(uri=index_uri) elif "vamana" in index_name: index = VamanaIndex(uri=index_uri) + elif "ivf_pq" in index_name: + index = IVFPQIndex(uri=index_uri) else: assert False, f"Unknown index name: {index_name}" diff --git a/backwards-compatibility-data/generate_data.py b/backwards-compatibility-data/generate_data.py index 726fad58f..b6e286c2a 100644 --- a/backwards-compatibility-data/generate_data.py +++ b/backwards-compatibility-data/generate_data.py @@ -66,7 +66,7 @@ def generate_indexes(version): queries = base[indices] # Generate each index and query to make sure it works before we write it. - index_types = ["FLAT", "IVF_FLAT", "VAMANA"] + index_types = ["FLAT", "IVF_FLAT", "VAMANA", "IVF_PQ"] data_types = ["float32", "uint8"] for index_type in index_types: for data_type in data_types: @@ -75,6 +75,7 @@ def generate_indexes(version): index_type=index_type, index_uri=index_uri, input_vectors=base.astype(data_type), + num_subspaces=len(base[0]), ) result_d, result_i = index.query(queries, k=1) diff --git a/src/include/test/unit_backwards_compatibility.cc b/src/include/test/unit_backwards_compatibility.cc index 8a758af0e..4fa3c3b1f 100644 --- a/src/include/test/unit_backwards_compatibility.cc +++ b/src/include/test/unit_backwards_compatibility.cc @@ -35,9 +35,11 @@ #include #include "api/feature_vector_array.h" #include "api/ivf_flat_index.h" +#include "api/ivf_pq_index.h" #include "api/vamana_index.h" #include "detail/linalg/matrix.h" #include "index/ivf_flat_index.h" +#include "index/ivf_pq_index.h" #include "index/vamana_index.h" #include "mdspan/mdspan.hpp" #include "test/utils/array_defs.h" @@ -135,6 +137,27 @@ TEST_CASE("test_query_old_indices", "[backwards_compatibility]") { // Next check that we can load the metadata. auto metadata = vamana_index_metadata(); metadata.load_metadata(read_group); + } else if (index_uri.find("ivf_pq") != std::string::npos) { + // First check that we can query the index. + auto index = IndexIVFPQ(ctx, index_uri); + auto&& [scores, ids] = index.query(queries_feature_vector_array, 1, 10); + auto scores_span = + MatrixView{ + (siftsmall_feature_type*)scores.data(), + extents(scores)[0], + extents(scores)[1]}; + + auto ids_span = MatrixView{ + (siftsmall_ids_type*)ids.data(), extents(ids)[0], extents(ids)[1]}; + + for (size_t i = 0; i < query_indices.size(); ++i) { + CHECK(ids_span[0][i] == query_indices[i]); + CHECK(scores_span[0][i] == 0); + } + + // Next check that we can load the metadata. + auto metadata = ivf_pq_metadata(); + metadata.load_metadata(read_group); } else { REQUIRE(false); }