diff --git a/apis/python/src/tiledb/vector_search/embeddings/object_embedding.py b/apis/python/src/tiledb/vector_search/embeddings/object_embedding.py index f29b711ca..d19be70ae 100644 --- a/apis/python/src/tiledb/vector_search/embeddings/object_embedding.py +++ b/apis/python/src/tiledb/vector_search/embeddings/object_embedding.py @@ -1,6 +1,6 @@ from abc import ABC from abc import abstractmethod -from typing import Dict, OrderedDict +from typing import Dict, OrderedDict, Tuple, Union import numpy as np @@ -43,10 +43,12 @@ def load(self) -> None: raise NotImplementedError @abstractmethod - def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: + def embed( + self, objects: OrderedDict, metadata: OrderedDict + ) -> Union[np.ndarray, Tuple[np.ndarray, np.array]]: """ Creates embedding vectors for objects. Returns a numpy array of embedding vectors. - There is no enforced restriction on the object format. ObjectReaders and ObjectEmbeddings should use comatible object and metadata formats. + There is no enforced restriction on the object format. ObjectReaders and ObjectEmbeddings should use compatible object and metadata formats. Parameters ---------- diff --git a/apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py b/apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py index 9491e2b1d..4964cca3b 100644 --- a/apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py +++ b/apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Tuple from tiledb.cloud.dag import Mode @@ -298,7 +298,11 @@ def instantiate_object(code, class_name, **kwargs): logger.debug("Embedding objects...") embeddings = object_embedding.embed(objects, metadata) - + if isinstance(embeddings, Tuple): + external_ids = embeddings[1] + embeddings = embeddings[0] + else: + external_ids = objects["external_id"].astype(np.uint64) logger.debug("Write embeddings partition_id: %d", partition_id) if use_updates_array: vectors = np.empty(embeddings.shape[0], dtype="O") @@ -306,7 +310,7 @@ def instantiate_object(code, class_name, **kwargs): vectors[i] = embeddings[i].astype(vector_type) obj_index.index.update_batch( vectors=vectors, - external_ids=objects["external_id"].astype(np.uint64), + external_ids=external_ids.astype(np.uint64), ) else: embeddings_flattened = np.empty(1, dtype="O") @@ -317,16 +321,16 @@ def instantiate_object(code, class_name, **kwargs): embeddings_shape[0] = np.array( embeddings.shape, dtype=np.uint32 ) - external_ids = np.empty(1, dtype="O") - external_ids[0] = objects["external_id"].astype(np.uint64) + write_external_ids = np.empty(1, dtype="O") + write_external_ids[0] = external_ids.astype(np.uint64) embeddings_array[partition_id] = { "vectors": embeddings_flattened, "vectors_shape": embeddings_shape, - "external_ids": external_ids, + "external_ids": write_external_ids, } if metadata_array_uri is not None: - external_ids = metadata.pop("external_id", None) - metadata_array[external_ids] = metadata + metadata_external_ids = metadata.pop("external_id", None) + metadata_array[metadata_external_ids] = metadata if not use_updates_array: embeddings_array.close() diff --git a/apis/python/src/tiledb/vector_search/object_api/object_index.py b/apis/python/src/tiledb/vector_search/object_api/object_index.py index dbcfa1531..4a3ec1d15 100644 --- a/apis/python/src/tiledb/vector_search/object_api/object_index.py +++ b/apis/python/src/tiledb/vector_search/object_api/object_index.py @@ -1,8 +1,9 @@ import json +import operator import random import string from collections import OrderedDict -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple import numpy as np @@ -13,6 +14,7 @@ from tiledb.vector_search import IVFFlatIndex from tiledb.vector_search import IVFPQIndex from tiledb.vector_search import VamanaIndex +from tiledb.vector_search import _tiledbvspy as vspy from tiledb.vector_search import flat_index from tiledb.vector_search import ivf_flat_index from tiledb.vector_search import ivf_pq_index @@ -21,6 +23,8 @@ from tiledb.vector_search.object_readers import ObjectReader from tiledb.vector_search.storage_formats import STORAGE_VERSION from tiledb.vector_search.storage_formats import storage_formats +from tiledb.vector_search.utils import MAX_FLOAT32 +from tiledb.vector_search.utils import MAX_UINT64 from tiledb.vector_search.utils import add_to_group TILEDB_CLOUD_PROTOCOL = 4 @@ -288,6 +292,10 @@ def query( driver_resources: Optional[Mapping[str, Any]] = None, extra_driver_modules: Optional[List[str]] = None, driver_access_credentials_name: Optional[str] = None, + merge_results_result_pos_as_score: bool = True, + merge_results_reverse_dist: Optional[bool] = None, + merge_results_per_query_embedding_group_fn: Callable = max, + merge_results_per_query_group_fn: Callable = operator.add, **kwargs, ): """ @@ -360,6 +368,19 @@ def query( A list of extra Python modules to install on the driver node. driver_access_credentials_name: Optional[str] If `driver_mode` was not `None`, the access credentials name to use for the driver execution. + merge_results_result_pos_as_score: bool + Applies only when there are multiple query embeddings per query. + If True, each result score is based on the position of the result for the query embedding. + merge_results_reverse_dist: Optional[bool] + Applies only when there are multiple query embeddings per query. + If True, the distances are reversed based on their reciprocal, (1 / dist). + merge_results_per_query_embedding_group_fn: Callable + Applies only when there are multiple query embeddings per query. + Group function used to group together object scores per query embedding (i.e max, min, etc.). + merge_results_per_query_group_fn: Callable + Applies only when there are multiple query embeddings per query. + Group function used to group together object scores per query (i.e add). This is applied after + `merge_results_per_query_embedding_group_fn` **kwargs Keyword arguments to pass to the index query method. @@ -412,9 +433,21 @@ def query( if not self.embedding_loaded: self.embedding.load() self.embedding_loaded = True + + num_queries = len(query_objects[list(query_objects.keys())[0]]) + if query_metadata is None: + query_metadata = {} + if "external_id" not in query_metadata: + query_metadata["external_id"] = np.arange(num_queries).astype(np.uint64) query_embeddings = self.embedding.embed( objects=query_objects, metadata=query_metadata ) + if isinstance(query_embeddings, Tuple): + query_ids = query_embeddings[1].astype(np.uint64) + query_embeddings = query_embeddings[0] + else: + query_ids = query_metadata["external_id"].astype(np.uint64) + fetch_k = k if metadata_array_cond is not None or metadata_df_filter_fn is not None: fetch_k = min(50 * k, self.index.size) @@ -422,6 +455,28 @@ def query( distances, object_ids = self.index.query( queries=query_embeddings, k=fetch_k, **kwargs ) + + # Post-process query results for multiple embeddings per query object + if merge_results_reverse_dist is None: + merge_results_reverse_dist = ( + False + if self.index.distance_metric == vspy.DistanceMetric.INNER_PRODUCT + else True + ) + + if query_embeddings.shape[0] > num_queries: + distances, object_ids = self._merge_results_per_query( + distances=distances, + object_ids=object_ids, + query_ids=query_ids, + num_queries=num_queries, + k=fetch_k, + result_pos_as_score=merge_results_result_pos_as_score, + reverse_dist=merge_results_reverse_dist, + per_query_embedding_group_fn=merge_results_per_query_embedding_group_fn, + per_query_group_fn=merge_results_per_query_group_fn, + ) + unique_ids, idx = np.unique(object_ids, return_inverse=True) idx = np.reshape(idx, object_ids.shape) if metadata_array_cond is not None or metadata_df_filter_fn is not None: @@ -448,13 +503,9 @@ def query( filtered_unique_ids = unique_ids_metadata_df[ self.object_metadata_external_id_dim ].to_numpy() - filtered_distances = np.zeros((query_embeddings.shape[0], k)).astype( - object_ids.dtype - ) - filtered_object_ids = np.zeros((query_embeddings.shape[0], k)).astype( - object_ids.dtype - ) - for query_id in range(query_embeddings.shape[0]): + filtered_distances = np.zeros((num_queries, k)).astype(object_ids.dtype) + filtered_object_ids = np.zeros((num_queries, k)).astype(object_ids.dtype) + for query_id in range(num_queries): write_id = 0 for result_id in range(fetch_k): if object_ids[query_id, result_id] in filtered_unique_ids: @@ -507,6 +558,82 @@ def query( elif not return_objects and not return_metadata: return distances, object_ids + def _merge_results_per_query( + self, + distances, + object_ids, + query_ids, + num_queries, + k, + result_pos_as_score=True, + reverse_dist=True, + per_query_embedding_group_fn=max, + per_query_group_fn=operator.add, + ): + """ + Post-process query results for multiple embeddings per query object. + - Computes score per original result + - If `result_pos_as_score=True` the score is based on the position of the result for the query embedding. + - Else, the distance is used as score + - If `reverse_dist=True` uses as score the reciprocal of the distance: (1 / distance) + - Applies `per_query_embedding_group_fn` to group object results per query embedding. + - Applies `per_query_group_fn` to group object results per query. + """ + + def get_reciprocal(dist): + if dist == 0: + return MAX_FLOAT32 + return 1 / dist + + # Apply `per_query_embedding_group_fn` for each query embedding + q_emb_to_obj_score = [] + for q_emb_id in range(distances.shape[0]): + q_emb_score = {} + for result_id in range(distances.shape[1]): + obj_id = object_ids[q_emb_id][result_id] + if result_pos_as_score: + score = 1 - result_id / k + else: + score = distances[q_emb_id][result_id] + if reverse_dist: + score = get_reciprocal(score) + if obj_id not in q_emb_score: + q_emb_score[obj_id] = score + else: + q_emb_score[obj_id] = per_query_embedding_group_fn( + q_emb_score[obj_id], score + ) + q_emb_to_obj_score.append(q_emb_score) + + # Apply `per_query_group_fn` for each query + q_to_obj_score = [] + for q_id in range(num_queries): + q_to_obj_score.append({}) + + for q_emb_id in range(distances.shape[0]): + q_id = query_ids[q_emb_id] + for obj_id, score in q_emb_to_obj_score[q_emb_id].items(): + if obj_id not in q_to_obj_score[q_id]: + q_to_obj_score[q_id][obj_id] = score + else: + q_to_obj_score[q_id][obj_id] = per_query_group_fn( + q_to_obj_score[q_id][obj_id], score + ) + + merged_distances = MAX_FLOAT32 * np.zeros((num_queries, k), dtype=np.float32) + merged_object_ids = MAX_UINT64 * np.zeros((num_queries, k), dtype=np.uint64) + for q_id in range(num_queries): + pos_id = 0 + for obj_id, score in sorted( + q_to_obj_score[q_id].items(), key=lambda item: item[1], reverse=True + ): + if pos_id >= k: + break + merged_distances[q_id, pos_id] = score + merged_object_ids[q_id, pos_id] = obj_id + pos_id += 1 + return merged_distances, merged_object_ids + def update_object_reader( self, object_reader: ObjectReader, @@ -626,6 +753,7 @@ def update_index( config: Optional[Mapping[str, Any]] = None, namespace: Optional[str] = None, environment_variables: Dict = {}, + use_updates_array: bool = True, **kwargs, ): """Updates the index with new data. @@ -703,14 +831,14 @@ def update_index( Keyword arguments to pass to the ingestion function. """ with tiledb.scope_ctx(ctx_or_config=config): - use_updates_array = True embeddings_array_uri = None if self.index.size == 0: + use_updates_array = False + if not use_updates_array: ( temp_dir_name, embeddings_array_uri, ) = self._create_embeddings_partitioned_array() - use_updates_array = False storage_formats[self.index.storage_version]["EXTERNAL_IDS_ARRAY_NAME"] metadata_array_uri = None @@ -790,6 +918,7 @@ def create( embedding: ObjectEmbedding, config: Optional[Mapping[str, Any]] = None, storage_version: str = STORAGE_VERSION, + metadata_tile_size: int = 10000, **kwargs, ) -> ObjectIndex: """Creates a new ObjectIndex. @@ -827,6 +956,7 @@ def create( group_exists=False, config=config, storage_version=storage_version, + **kwargs, ) elif index_type == "IVF_FLAT": index = ivf_flat_index.create( @@ -836,6 +966,7 @@ def create( group_exists=False, config=config, storage_version=storage_version, + **kwargs, ) elif index_type == "VAMANA": index = vamana_index.create( @@ -844,22 +975,20 @@ def create( vector_type=vector_type, config=config, storage_version=storage_version, + **kwargs, ) elif index_type == "IVF_PQ": if "num_subspaces" not in kwargs: raise ValueError( "num_subspaces must be provided when creating an IVF_PQ index" ) - num_subspaces = kwargs["num_subspaces"] - partitions = kwargs.get("partitions", None) index = ivf_pq_index.create( uri=uri, dimensions=dimensions, vector_type=vector_type, config=config, storage_version=storage_version, - partitions=partitions, - num_subspaces=num_subspaces, + **kwargs, ) else: raise ValueError(f"Unsupported index type {index_type}") @@ -883,8 +1012,8 @@ def create( object_metadata_array_uri = f"{uri}/{metadata_array_name}" external_ids_dim = tiledb.Dim( name="external_id", - domain=(0, np.iinfo(np.dtype("uint64")).max - 10000), - tile=10000, + domain=(0, np.iinfo(np.dtype("uint64")).max - metadata_tile_size), + tile=metadata_tile_size, dtype=np.dtype(np.uint64), ) external_ids_dom = tiledb.Domain(external_ids_dim) diff --git a/apis/python/test/test_object_index.py b/apis/python/test/test_object_index.py index c03324f85..3f235dfe4 100644 --- a/apis/python/test/test_object_index.py +++ b/apis/python/test/test_object_index.py @@ -43,6 +43,43 @@ def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray: return embeddings +class TestMultipleEmbeddingsPerObject(ObjectEmbedding): + def __init__( + self, + ): + self.model = None + + def init_kwargs(self) -> Dict: + return {} + + def dimensions(self) -> int: + return EMBED_DIM + + def vector_type(self) -> np.dtype: + return np.float32 + + def load(self) -> None: + pass + + def embed( + self, objects: OrderedDict, metadata: OrderedDict + ) -> Tuple[np.ndarray, np.array]: + embeddings_per_object = 10 + num_embeddings = len(objects["object"]) * embeddings_per_object + embeddings = np.zeros((num_embeddings, EMBED_DIM), dtype=self.vector_type()) + external_ids = np.zeros((num_embeddings)) + emb_id = 0 + for obj_id in range(len(objects["object"])): + for eid in range(embeddings_per_object): + for dim_id in range(EMBED_DIM): + embeddings[emb_id, dim_id] = ( + objects["object"][obj_id][0] + 100000 * eid + ) + external_ids[emb_id] = metadata["external_id"][obj_id] + emb_id += 1 + return embeddings, external_ids + + class TestPartition(ObjectPartition): def __init__( self, @@ -377,6 +414,65 @@ def test_object_index(tmp_path): ) +def test_object_index_multiple_embeddings_per_object(tmp_path): + from common import INDEXES + + for index_type in INDEXES: + index_uri = os.path.join(tmp_path, f"object_index_{index_type}") + reader = TestReader( + object_id_start=0, + object_id_end=1000, + vector_dim_offset=0, + ) + embedding = TestMultipleEmbeddingsPerObject() + + index = object_index.create( + uri=index_uri, + index_type=index_type, + object_reader=reader, + embedding=embedding, + num_subspaces=4, + ) + + # Check initial ingestion + index.update_index(partitions=10) + evaluate_query( + index_type=index_type, + index_uri=index_uri, + query_kwargs={"nprobe": 10, "l_search": 250}, + dim_id=42, + vector_dim_offset=0, + ) + + # Check that updating the same data doesn't create duplicates + index = object_index.ObjectIndex(uri=index_uri) + index.update_index(partitions=10, use_updates_array=False) + evaluate_query( + index_type=index_type, + index_uri=index_uri, + query_kwargs={"nprobe": 10, "l_search": 500}, + dim_id=42, + vector_dim_offset=0, + ) + + # Add new data with a new reader + reader = TestReader( + object_id_start=1000, + object_id_end=2000, + vector_dim_offset=0, + ) + index = object_index.ObjectIndex(uri=index_uri) + index.update_object_reader(reader) + index.update_index(partitions=10, use_updates_array=False) + evaluate_query( + index_type=index_type, + index_uri=index_uri, + query_kwargs={"nprobe": 10, "l_search": 500}, + dim_id=1042, + vector_dim_offset=0, + ) + + def test_object_index_ivf_flat_cloud(tmp_path): from common import create_cloud_uri from common import delete_uri