diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 48b93dae..410c4a8d 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional, Sequence, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union import numpy as np @@ -168,15 +168,16 @@ "model_file": "onnx/model.onnx", }, { - "model": "jinaai/jina-clip-v1", - "dim": 768, - "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", - "license": "apache-2.0", - "size_in_GB": 0.55, + "model": "akshayballal/colpali-v1.2-merged", + "dim": 128, + "description": "", + "license": "mit", + "size_in_GB": 6.08, "sources": { - "hf": "jinaai/jina-clip-v1", + "hf": "akshayballal/colpali-v1.2-merged-onnx", }, - "model_file": "onnx/text_model.onnx", + "additional_files": ["model.onnx_data"], + "model_file": "model.onnx", }, ] @@ -185,12 +186,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]): """Implementation of the Flag Embedding model.""" @classmethod - def list_supported_models(cls) -> list[dict[str, Any]]: + def list_supported_models(cls) -> List[Dict[str, Any]]: """ Lists the supported models. Returns: - list[dict[str, Any]]: A list of dictionaries containing the model information. + List[Dict[str, Any]]: A list of dictionaries containing the model information. """ return supported_onnx_models @@ -201,7 +202,7 @@ def __init__( threads: Optional[int] = None, providers: Optional[Sequence[OnnxProvider]] = None, cuda: bool = False, - device_ids: Optional[list[int]] = None, + device_ids: Optional[List[int]] = None, lazy_load: bool = False, device_id: Optional[int] = None, **kwargs, @@ -217,7 +218,7 @@ def __init__( Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` Defaults to False. - device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in + device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. lazy_load (bool, optional): Whether to load the model during class initialization or on demand. Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. @@ -290,8 +291,8 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]: return OnnxTextEmbeddingWorker def _preprocess_onnx_input( - self, onnx_input: dict[str, np.ndarray], **kwargs - ) -> dict[str, np.ndarray]: + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: """ Preprocess the onnx input. """ @@ -299,13 +300,7 @@ def _preprocess_onnx_input( def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: embeddings = output.model_output - if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim) - processed_embeddings = embeddings[:, 0] - elif embeddings.ndim == 2: # (batch_size, embedding_dim) - processed_embeddings = embeddings - else: - raise ValueError(f"Unsupported embedding shape: {embeddings.shape}") - return normalize(processed_embeddings).astype(np.float32) + return normalize(embeddings[:, 0]).astype(np.float32) def load_onnx_model(self) -> None: self._load_onnx_model(