diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 410c4a8d..48b93dae 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union +from typing import Any, Iterable, Optional, Sequence, Type, Union import numpy as np @@ -168,16 +168,15 @@ "model_file": "onnx/model.onnx", }, { - "model": "akshayballal/colpali-v1.2-merged", - "dim": 128, - "description": "", - "license": "mit", - "size_in_GB": 6.08, + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.55, "sources": { - "hf": "akshayballal/colpali-v1.2-merged-onnx", + "hf": "jinaai/jina-clip-v1", }, - "additional_files": ["model.onnx_data"], - "model_file": "model.onnx", + "model_file": "onnx/text_model.onnx", }, ] @@ -186,12 +185,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]): """Implementation of the Flag Embedding model.""" @classmethod - def list_supported_models(cls) -> List[Dict[str, Any]]: + def list_supported_models(cls) -> list[dict[str, Any]]: """ Lists the supported models. Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. + list[dict[str, Any]]: A list of dictionaries containing the model information. """ return supported_onnx_models @@ -202,7 +201,7 @@ def __init__( threads: Optional[int] = None, providers: Optional[Sequence[OnnxProvider]] = None, cuda: bool = False, - device_ids: Optional[List[int]] = None, + device_ids: Optional[list[int]] = None, lazy_load: bool = False, device_id: Optional[int] = None, **kwargs, @@ -218,7 +217,7 @@ def __init__( Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` Defaults to False. - device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in + device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. lazy_load (bool, optional): Whether to load the model during class initialization or on demand. Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. @@ -291,8 +290,8 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]: return OnnxTextEmbeddingWorker def _preprocess_onnx_input( - self, onnx_input: Dict[str, np.ndarray], **kwargs - ) -> Dict[str, np.ndarray]: + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: """ Preprocess the onnx input. """ @@ -300,7 +299,13 @@ def _preprocess_onnx_input( def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: embeddings = output.model_output - return normalize(embeddings[:, 0]).astype(np.float32) + if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim) + processed_embeddings = embeddings[:, 0] + elif embeddings.ndim == 2: # (batch_size, embedding_dim) + processed_embeddings = embeddings + else: + raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") + return normalize(processed_embeddings).astype(np.float32) def load_onnx_model(self) -> None: self._load_onnx_model(