Skip to content

Commit

Permalink
Support multiple embeddings per object in Object API (#542)
Browse files Browse the repository at this point in the history
Support multiple embeddings per object in the Object API.

This adds:
- The ability for embedding functions to return multiple embedding per object and return both embeddings and their respective external_ids
- Query post processing to merge score of different query embedding per object and per query.
  • Loading branch information
NikolaosPapailiou authored Oct 14, 2024
1 parent c7609f7 commit f41bf58
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC
from abc import abstractmethod
from typing import Dict, OrderedDict
from typing import Dict, OrderedDict, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -43,10 +43,12 @@ def load(self) -> None:
raise NotImplementedError

@abstractmethod
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray:
def embed(
self, objects: OrderedDict, metadata: OrderedDict
) -> Union[np.ndarray, Tuple[np.ndarray, np.array]]:
"""
Creates embedding vectors for objects. Returns a numpy array of embedding vectors.
There is no enforced restriction on the object format. ObjectReaders and ObjectEmbeddings should use comatible object and metadata formats.
There is no enforced restriction on the object format. ObjectReaders and ObjectEmbeddings should use compatible object and metadata formats.
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Tuple

from tiledb.cloud.dag import Mode

Expand Down Expand Up @@ -298,15 +298,19 @@ def instantiate_object(code, class_name, **kwargs):

logger.debug("Embedding objects...")
embeddings = object_embedding.embed(objects, metadata)

if isinstance(embeddings, Tuple):
external_ids = embeddings[1]
embeddings = embeddings[0]
else:
external_ids = objects["external_id"].astype(np.uint64)
logger.debug("Write embeddings partition_id: %d", partition_id)
if use_updates_array:
vectors = np.empty(embeddings.shape[0], dtype="O")
for i in range(embeddings.shape[0]):
vectors[i] = embeddings[i].astype(vector_type)
obj_index.index.update_batch(
vectors=vectors,
external_ids=objects["external_id"].astype(np.uint64),
external_ids=external_ids.astype(np.uint64),
)
else:
embeddings_flattened = np.empty(1, dtype="O")
Expand All @@ -317,16 +321,16 @@ def instantiate_object(code, class_name, **kwargs):
embeddings_shape[0] = np.array(
embeddings.shape, dtype=np.uint32
)
external_ids = np.empty(1, dtype="O")
external_ids[0] = objects["external_id"].astype(np.uint64)
write_external_ids = np.empty(1, dtype="O")
write_external_ids[0] = external_ids.astype(np.uint64)
embeddings_array[partition_id] = {
"vectors": embeddings_flattened,
"vectors_shape": embeddings_shape,
"external_ids": external_ids,
"external_ids": write_external_ids,
}
if metadata_array_uri is not None:
external_ids = metadata.pop("external_id", None)
metadata_array[external_ids] = metadata
metadata_external_ids = metadata.pop("external_id", None)
metadata_array[metadata_external_ids] = metadata

if not use_updates_array:
embeddings_array.close()
Expand Down
161 changes: 145 additions & 16 deletions apis/python/src/tiledb/vector_search/object_api/object_index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import operator
import random
import string
from collections import OrderedDict
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import numpy as np

Expand All @@ -13,6 +14,7 @@
from tiledb.vector_search import IVFFlatIndex
from tiledb.vector_search import IVFPQIndex
from tiledb.vector_search import VamanaIndex
from tiledb.vector_search import _tiledbvspy as vspy
from tiledb.vector_search import flat_index
from tiledb.vector_search import ivf_flat_index
from tiledb.vector_search import ivf_pq_index
Expand All @@ -21,6 +23,8 @@
from tiledb.vector_search.object_readers import ObjectReader
from tiledb.vector_search.storage_formats import STORAGE_VERSION
from tiledb.vector_search.storage_formats import storage_formats
from tiledb.vector_search.utils import MAX_FLOAT32
from tiledb.vector_search.utils import MAX_UINT64
from tiledb.vector_search.utils import add_to_group

TILEDB_CLOUD_PROTOCOL = 4
Expand Down Expand Up @@ -288,6 +292,10 @@ def query(
driver_resources: Optional[Mapping[str, Any]] = None,
extra_driver_modules: Optional[List[str]] = None,
driver_access_credentials_name: Optional[str] = None,
merge_results_result_pos_as_score: bool = True,
merge_results_reverse_dist: Optional[bool] = None,
merge_results_per_query_embedding_group_fn: Callable = max,
merge_results_per_query_group_fn: Callable = operator.add,
**kwargs,
):
"""
Expand Down Expand Up @@ -360,6 +368,19 @@ def query(
A list of extra Python modules to install on the driver node.
driver_access_credentials_name: Optional[str]
If `driver_mode` was not `None`, the access credentials name to use for the driver execution.
merge_results_result_pos_as_score: bool
Applies only when there are multiple query embeddings per query.
If True, each result score is based on the position of the result for the query embedding.
merge_results_reverse_dist: Optional[bool]
Applies only when there are multiple query embeddings per query.
If True, the distances are reversed based on their reciprocal, (1 / dist).
merge_results_per_query_embedding_group_fn: Callable
Applies only when there are multiple query embeddings per query.
Group function used to group together object scores per query embedding (i.e max, min, etc.).
merge_results_per_query_group_fn: Callable
Applies only when there are multiple query embeddings per query.
Group function used to group together object scores per query (i.e add). This is applied after
`merge_results_per_query_embedding_group_fn`
**kwargs
Keyword arguments to pass to the index query method.
Expand Down Expand Up @@ -412,16 +433,50 @@ def query(
if not self.embedding_loaded:
self.embedding.load()
self.embedding_loaded = True

num_queries = len(query_objects[list(query_objects.keys())[0]])
if query_metadata is None:
query_metadata = {}
if "external_id" not in query_metadata:
query_metadata["external_id"] = np.arange(num_queries).astype(np.uint64)
query_embeddings = self.embedding.embed(
objects=query_objects, metadata=query_metadata
)
if isinstance(query_embeddings, Tuple):
query_ids = query_embeddings[1].astype(np.uint64)
query_embeddings = query_embeddings[0]
else:
query_ids = query_metadata["external_id"].astype(np.uint64)

fetch_k = k
if metadata_array_cond is not None or metadata_df_filter_fn is not None:
fetch_k = min(50 * k, self.index.size)

distances, object_ids = self.index.query(
queries=query_embeddings, k=fetch_k, **kwargs
)

# Post-process query results for multiple embeddings per query object
if merge_results_reverse_dist is None:
merge_results_reverse_dist = (
False
if self.index.distance_metric == vspy.DistanceMetric.INNER_PRODUCT
else True
)

if query_embeddings.shape[0] > num_queries:
distances, object_ids = self._merge_results_per_query(
distances=distances,
object_ids=object_ids,
query_ids=query_ids,
num_queries=num_queries,
k=fetch_k,
result_pos_as_score=merge_results_result_pos_as_score,
reverse_dist=merge_results_reverse_dist,
per_query_embedding_group_fn=merge_results_per_query_embedding_group_fn,
per_query_group_fn=merge_results_per_query_group_fn,
)

unique_ids, idx = np.unique(object_ids, return_inverse=True)
idx = np.reshape(idx, object_ids.shape)
if metadata_array_cond is not None or metadata_df_filter_fn is not None:
Expand All @@ -448,13 +503,9 @@ def query(
filtered_unique_ids = unique_ids_metadata_df[
self.object_metadata_external_id_dim
].to_numpy()
filtered_distances = np.zeros((query_embeddings.shape[0], k)).astype(
object_ids.dtype
)
filtered_object_ids = np.zeros((query_embeddings.shape[0], k)).astype(
object_ids.dtype
)
for query_id in range(query_embeddings.shape[0]):
filtered_distances = np.zeros((num_queries, k)).astype(object_ids.dtype)
filtered_object_ids = np.zeros((num_queries, k)).astype(object_ids.dtype)
for query_id in range(num_queries):
write_id = 0
for result_id in range(fetch_k):
if object_ids[query_id, result_id] in filtered_unique_ids:
Expand Down Expand Up @@ -507,6 +558,82 @@ def query(
elif not return_objects and not return_metadata:
return distances, object_ids

def _merge_results_per_query(
self,
distances,
object_ids,
query_ids,
num_queries,
k,
result_pos_as_score=True,
reverse_dist=True,
per_query_embedding_group_fn=max,
per_query_group_fn=operator.add,
):
"""
Post-process query results for multiple embeddings per query object.
- Computes score per original result
- If `result_pos_as_score=True` the score is based on the position of the result for the query embedding.
- Else, the distance is used as score
- If `reverse_dist=True` uses as score the reciprocal of the distance: (1 / distance)
- Applies `per_query_embedding_group_fn` to group object results per query embedding.
- Applies `per_query_group_fn` to group object results per query.
"""

def get_reciprocal(dist):
if dist == 0:
return MAX_FLOAT32
return 1 / dist

# Apply `per_query_embedding_group_fn` for each query embedding
q_emb_to_obj_score = []
for q_emb_id in range(distances.shape[0]):
q_emb_score = {}
for result_id in range(distances.shape[1]):
obj_id = object_ids[q_emb_id][result_id]
if result_pos_as_score:
score = 1 - result_id / k
else:
score = distances[q_emb_id][result_id]
if reverse_dist:
score = get_reciprocal(score)
if obj_id not in q_emb_score:
q_emb_score[obj_id] = score
else:
q_emb_score[obj_id] = per_query_embedding_group_fn(
q_emb_score[obj_id], score
)
q_emb_to_obj_score.append(q_emb_score)

# Apply `per_query_group_fn` for each query
q_to_obj_score = []
for q_id in range(num_queries):
q_to_obj_score.append({})

for q_emb_id in range(distances.shape[0]):
q_id = query_ids[q_emb_id]
for obj_id, score in q_emb_to_obj_score[q_emb_id].items():
if obj_id not in q_to_obj_score[q_id]:
q_to_obj_score[q_id][obj_id] = score
else:
q_to_obj_score[q_id][obj_id] = per_query_group_fn(
q_to_obj_score[q_id][obj_id], score
)

merged_distances = MAX_FLOAT32 * np.zeros((num_queries, k), dtype=np.float32)
merged_object_ids = MAX_UINT64 * np.zeros((num_queries, k), dtype=np.uint64)
for q_id in range(num_queries):
pos_id = 0
for obj_id, score in sorted(
q_to_obj_score[q_id].items(), key=lambda item: item[1], reverse=True
):
if pos_id >= k:
break
merged_distances[q_id, pos_id] = score
merged_object_ids[q_id, pos_id] = obj_id
pos_id += 1
return merged_distances, merged_object_ids

def update_object_reader(
self,
object_reader: ObjectReader,
Expand Down Expand Up @@ -626,6 +753,7 @@ def update_index(
config: Optional[Mapping[str, Any]] = None,
namespace: Optional[str] = None,
environment_variables: Dict = {},
use_updates_array: bool = True,
**kwargs,
):
"""Updates the index with new data.
Expand Down Expand Up @@ -703,14 +831,14 @@ def update_index(
Keyword arguments to pass to the ingestion function.
"""
with tiledb.scope_ctx(ctx_or_config=config):
use_updates_array = True
embeddings_array_uri = None
if self.index.size == 0:
use_updates_array = False
if not use_updates_array:
(
temp_dir_name,
embeddings_array_uri,
) = self._create_embeddings_partitioned_array()
use_updates_array = False

storage_formats[self.index.storage_version]["EXTERNAL_IDS_ARRAY_NAME"]
metadata_array_uri = None
Expand Down Expand Up @@ -790,6 +918,7 @@ def create(
embedding: ObjectEmbedding,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
metadata_tile_size: int = 10000,
**kwargs,
) -> ObjectIndex:
"""Creates a new ObjectIndex.
Expand Down Expand Up @@ -827,6 +956,7 @@ def create(
group_exists=False,
config=config,
storage_version=storage_version,
**kwargs,
)
elif index_type == "IVF_FLAT":
index = ivf_flat_index.create(
Expand All @@ -836,6 +966,7 @@ def create(
group_exists=False,
config=config,
storage_version=storage_version,
**kwargs,
)
elif index_type == "VAMANA":
index = vamana_index.create(
Expand All @@ -844,22 +975,20 @@ def create(
vector_type=vector_type,
config=config,
storage_version=storage_version,
**kwargs,
)
elif index_type == "IVF_PQ":
if "num_subspaces" not in kwargs:
raise ValueError(
"num_subspaces must be provided when creating an IVF_PQ index"
)
num_subspaces = kwargs["num_subspaces"]
partitions = kwargs.get("partitions", None)
index = ivf_pq_index.create(
uri=uri,
dimensions=dimensions,
vector_type=vector_type,
config=config,
storage_version=storage_version,
partitions=partitions,
num_subspaces=num_subspaces,
**kwargs,
)
else:
raise ValueError(f"Unsupported index type {index_type}")
Expand All @@ -883,8 +1012,8 @@ def create(
object_metadata_array_uri = f"{uri}/{metadata_array_name}"
external_ids_dim = tiledb.Dim(
name="external_id",
domain=(0, np.iinfo(np.dtype("uint64")).max - 10000),
tile=10000,
domain=(0, np.iinfo(np.dtype("uint64")).max - metadata_tile_size),
tile=metadata_tile_size,
dtype=np.dtype(np.uint64),
)
external_ids_dom = tiledb.Domain(external_ids_dim)
Expand Down
Loading

0 comments on commit f41bf58

Please sign in to comment.