From 8b4eb26d4fc790d00796a0074f2dd9f026ef81a7 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Thu, 19 Dec 2024 12:57:14 +0200 Subject: [PATCH 01/21] new: Added jina embedding v3 --- fastembed/text/multitask_embedding.py | 83 +++++++++++++++++++++++++++ fastembed/text/text_embedding.py | 2 + 2 files changed, 85 insertions(+) create mode 100644 fastembed/text/multitask_embedding.py diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py new file mode 100644 index 00000000..f910b282 --- /dev/null +++ b/fastembed/text/multitask_embedding.py @@ -0,0 +1,83 @@ +from typing import Any, Type, Iterable, Union, Optional + +import numpy as np + +from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding +from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker +from fastembed.text.onnx_text_model import TextEmbeddingWorker + +supported_multitask_models = [ + { + "model": "jinaai/jina-embeddings-v3", + "dim": [32, 64, 128, 256, 512, 768, 1024], + "tasks": { + "retrieval.query": 0, + "retrieval.passage": 1, + "separation": 2, + "classification": 3, + "text-matching": 4, + }, + "description": "Multi-task, multi-lingual embedding model with Matryoshka architecture", + "license": "cc-by-nc-4.0", + "size_in_GB": 2.29, + "sources": { + "hf": "jinaai/jina-embeddings-v3", + }, + "model_file": "onnx/model.onnx", + "additional_files": ["onnx/model.onnx_data"], + }, +] + + +class JinaEmbeddingV3(PooledNormalizedEmbedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._current_task_id = 4 + + @classmethod + def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]: + return JinaEmbeddingV3Worker + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + return supported_multitask_models + + def _preprocess_onnx_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + onnx_input["task_id"] = np.array(self._current_task_id, dtype=np.int64) + return onnx_input + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + task_id: int = 4, + **kwargs, + ) -> Iterable[np.ndarray]: + self._current_task_id = task_id + yield from super().embed(documents, batch_size, parallel, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + self._current_task_id = 0 + yield from super().query_embed(query, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: + self._current_task_id = 1 + yield from super().passage_embed(texts, **kwargs) + + +class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs, + ) -> JinaEmbeddingV3: + return JinaEmbeddingV3( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 960d68f7..f7e44775 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -7,6 +7,7 @@ from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding +from fastembed.text.multitask_embedding import JinaEmbeddingV3 from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase @@ -18,6 +19,7 @@ class TextEmbedding(TextEmbeddingBase): CLIPOnnxEmbedding, PooledNormalizedEmbedding, PooledEmbedding, + JinaEmbeddingV3, ] @classmethod From 64127fcf3ae8bd0cfc663de0b4dcd6709a3ee39c Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 08:27:52 +0200 Subject: [PATCH 02/21] refactor: Changed dim to int value --- fastembed/text/multitask_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index f910b282..8a9148cf 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -9,7 +9,7 @@ supported_multitask_models = [ { "model": "jinaai/jina-embeddings-v3", - "dim": [32, 64, 128, 256, 512, 768, 1024], + "dim": 1024, "tasks": { "retrieval.query": 0, "retrieval.passage": 1, From e48f64731629e59fce969722d8c70833895a577e Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 08:28:08 +0200 Subject: [PATCH 03/21] new: Updated notice --- NOTICE | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NOTICE b/NOTICE index bfa9618d..caa664b7 100644 --- a/NOTICE +++ b/NOTICE @@ -7,6 +7,8 @@ This distribution includes the following Jina AI models, each with its respectiv - License: cc-by-nc-4.0 - jinaai/jina-reranker-v2-base-multilingual - License: cc-by-nc-4.0 +- jinaai/jina-embeddings-v3 + - License: cc-by-nc-4.0 These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms. From eb475d5474c66644020602b234ff05f3619a8ab6 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 08:28:41 +0200 Subject: [PATCH 04/21] new: Extended text embedding with query embed and passage embed --- fastembed/text/text_embedding.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index f7e44775..a8def42e 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -107,3 +107,30 @@ def embed( List of embeddings, one per document """ yield from self.model.embed(documents, batch_size, parallel, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds queries + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[np.ndarray]: The embeddings. + """ + # This is model-specific, so that different models can have specialized implementations + yield from self.model.query_embed(query, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds a list of text passages into a list of embeddings. + + Args: + texts (Iterable[str]): The list of texts to embed. + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[SparseEmbedding]: The sparse embeddings. + """ + # This is model-specific, so that different models can have specialized implementations + yield from self.model.passage_embed(texts, **kwargs) From 1650252e07f6abc998bcf7259a3d8793ee7aa13b Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:37:30 +0200 Subject: [PATCH 05/21] fix: Fix lazy load in query and passage embed --- fastembed/text/multitask_embedding.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 8a9148cf..3ae31a82 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -61,11 +61,24 @@ def embed( def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: self._current_task_id = 0 - yield from super().query_embed(query, **kwargs) + + if isinstance(query, str): + query = [query] + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for text in query: + yield from self._post_process_onnx_output(self.onnx_embed([text])) def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: self._current_task_id = 1 - yield from super().passage_embed(texts, **kwargs) + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for text in texts: + yield from self._post_process_onnx_output(self.onnx_embed([text])) class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): From 197b381e8f19b52fb9f272d038c5a5ae4bfd66ac Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:38:04 +0200 Subject: [PATCH 06/21] tests: Added test for multitask embeddings --- tests/test_text_multitask_embeddings.py | 231 ++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/test_text_multitask_embeddings.py diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py new file mode 100644 index 00000000..606b4138 --- /dev/null +++ b/tests/test_text_multitask_embeddings.py @@ -0,0 +1,231 @@ +import os + +import numpy as np +import pytest + +from fastembed import TextEmbedding +from tests.utils import delete_model_cache + + +CANONICAL_VECTOR_VALUES = { + "jinaai/jina-embeddings-v3": [ + { + "task_id": 0, + "vectors": np.array( + [ + [0.0623, -0.0402, 0.1706, -0.0143, 0.0617], + [-0.1064, -0.0733, 0.0353, 0.0096, 0.0667], + ] + ), + }, + { + "task_id": 1, + "vectors": np.array( + [ + [0.0513, -0.0247, 0.1751, -0.0075, 0.0679], + [-0.0987, -0.0786, 0.09, 0.0087, 0.0577], + ] + ), + }, + { + "task_id": 2, + "vectors": np.array( + [ + [0.094, -0.1065, 0.1305, 0.0547, 0.0556], + [0.0315, -0.1468, 0.065, 0.0568, 0.0546], + ] + ), + }, + { + "task_id": 3, + "vectors": np.array( + [ + [0.0606, -0.0877, 0.1384, 0.0065, 0.0722], + [-0.0502, -0.119, 0.032, 0.0514, 0.0689], + ] + ), + }, + { + "task_id": 4, + "vectors": np.array( + [ + [0.0911, -0.0341, 0.1305, -0.026, 0.0576], + [-0.1432, -0.05, 0.0133, 0.0464, 0.0789], + ] + ), + }, + ] +} +docs = ["Hello World", "Follow the white rabbit."] + + +def test_batch_embedding(): + is_ci = os.getenv("CI") + docs_to_embed = docs * 10 + default_task = 4 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} default task") + + embeddings = list(model.embed(documents=docs_to_embed, batch_size=6)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs_to_embed), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding(): + is_ci = os.getenv("CI") + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + for task in CANONICAL_VECTOR_VALUES[model_name]: + print(f"evaluating {model_name} task_id: {task['task_id']}") + + embeddings = list(model.embed(documents=docs, task_id=task["task_id"])) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = task["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_query(): + is_ci = os.getenv("CI") + task_id = 0 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} query_embed task_id: {task_id}") + + embeddings = list(model.query_embed(query=docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_passage(): + is_ci = os.getenv("CI") + task_id = 1 + + for model_desc in TextEmbedding.list_supported_models(): + # if not is_ci and model_desc["size_in_GB"] > 1: + # continue + + model_name = model_desc["model"] + dim = model_desc["dim"] + + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + print(f"evaluating {model_name} passage_embed task_id: {task_id}") + + embeddings = list(model.passage_embed(texts=docs)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (len(docs), dim) + + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose( + embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4 + ), model_desc["model"] + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_parallel_processing(): + is_ci = os.getenv("CI") + + docs = ["Hello World", "Follow the white rabbit."] * 100 + + model_name = "jinaai/jina-embeddings-v3" + dim = 1024 + + model = TextEmbedding(model_name=model_name, cache_dir="models") + + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) + + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim + assert np.allclose(embeddings, embeddings_2, atol=1e-4) + assert np.allclose(embeddings, embeddings_3, atol=1e-4) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize( + "model_name", + ["jinaai/jina-embeddings-v3"], +) +def test_lazy_load(model_name): + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models") + assert not hasattr(model.model, "model") + + list(model.embed(docs)) + assert hasattr(model.model, "model") + + if is_ci: + delete_model_cache(model.model._model_dir) From c9172015298247342bd9c2f545756d9e9aedd00a Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:51:53 +0200 Subject: [PATCH 07/21] nit: Remove cache dir from tests --- tests/test_text_multitask_embeddings.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 606b4138..67964f62 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -74,7 +74,7 @@ def test_batch_embedding(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} default task") @@ -105,7 +105,7 @@ def test_single_embedding(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) for task in CANONICAL_VECTOR_VALUES[model_name]: print(f"evaluating {model_name} task_id: {task['task_id']}") @@ -138,7 +138,7 @@ def test_single_embedding_query(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} query_embed task_id: {task_id}") @@ -170,7 +170,7 @@ def test_single_embedding_passage(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} passage_embed task_id: {task_id}") @@ -196,7 +196,7 @@ def test_parallel_processing(): model_name = "jinaai/jina-embeddings-v3" dim = 1024 - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) embeddings = list(model.embed(docs, batch_size=10, parallel=2)) embeddings = np.stack(embeddings, axis=0) @@ -221,7 +221,7 @@ def test_parallel_processing(): ) def test_lazy_load(model_name): is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models") + model = TextEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") list(model.embed(docs)) From 1ed62e96922bb8c01662548e14c6a86fae0a28fc Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 13:35:06 +0200 Subject: [PATCH 08/21] tests: Updated tests --- tests/test_text_multitask_embeddings.py | 38 ++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 67964f62..5e2d241e 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -65,8 +65,8 @@ def test_batch_embedding(): default_task = 4 for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -96,8 +96,8 @@ def test_single_embedding(): is_ci = os.getenv("CI") for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -129,8 +129,8 @@ def test_single_embedding_query(): task_id = 0 for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -161,8 +161,8 @@ def test_single_embedding_passage(): task_id = 1 for model_desc in TextEmbedding.list_supported_models(): - # if not is_ci and model_desc["size_in_GB"] > 1: - # continue + if not is_ci and model_desc["size_in_GB"] > 1: + continue model_name = model_desc["model"] dim = model_desc["dim"] @@ -196,22 +196,22 @@ def test_parallel_processing(): model_name = "jinaai/jina-embeddings-v3" dim = 1024 - model = TextEmbedding(model_name=model_name) + if is_ci: + model = TextEmbedding(model_name=model_name) - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings = np.stack(embeddings, axis=0) + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - embeddings_2 = np.stack(embeddings_2, axis=0) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) - embeddings_3 = np.stack(embeddings_3, axis=0) + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) - assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim - assert np.allclose(embeddings, embeddings_2, atol=1e-4) - assert np.allclose(embeddings, embeddings_3, atol=1e-4) + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim + assert np.allclose(embeddings, embeddings_2, atol=1e-4) + assert np.allclose(embeddings, embeddings_3, atol=1e-4) - if is_ci: delete_model_cache(model.model._model_dir) From eda3baee0ffeb9168a1529c71cd2f92dc600e108 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 30 Dec 2024 18:27:42 +0200 Subject: [PATCH 09/21] improve: Improve task selection --- fastembed/text/multitask_embedding.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 3ae31a82..051fb028 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Type, Iterable, Union, Optional import numpy as np @@ -29,10 +30,22 @@ ] +class Task(str, Enum): + RETRIEVAL_QUERY = 0 + RETRIEVAL_PASSAGE = 1 + SEPARATION = 2 + CLASSIFICATION = 3 + TEXT_MATCHING = 4 + + class JinaEmbeddingV3(PooledNormalizedEmbedding): + DEFAULT_TASK = Task.TEXT_MATCHING + PASSAGE_TASK = Task.RETRIEVAL_PASSAGE + QUERY_TASK = Task.RETRIEVAL_QUERY + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._current_task_id = 4 + self._current_task_id = self.DEFAULT_TASK @classmethod def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]: @@ -60,7 +73,7 @@ def embed( yield from super().embed(documents, batch_size, parallel, **kwargs) def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: - self._current_task_id = 0 + self._current_task_id = self.QUERY_TASK if isinstance(query, str): query = [query] @@ -72,7 +85,7 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np yield from self._post_process_onnx_output(self.onnx_embed([text])) def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: - self._current_task_id = 1 + self._current_task_id = self.PASSAGE_TASK if not hasattr(self, "model") or self.model is None: self.load_onnx_model() From a31461b54dd3e714becd3c0e688f20f772ce54bf Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 30 Dec 2024 18:29:49 +0200 Subject: [PATCH 10/21] fix: Fix ci --- tests/test_text_multitask_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 5e2d241e..d908a91c 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -197,7 +197,7 @@ def test_parallel_processing(): dim = 1024 if is_ci: - model = TextEmbedding(model_name=model_name) + model = TextEmbedding(model_name=model_name, lazy_load=True) embeddings = list(model.embed(docs, batch_size=10, parallel=2)) embeddings = np.stack(embeddings, axis=0) From 2f8290d10bc54745c28ed6f02d0011d1cb03cabb Mon Sep 17 00:00:00 2001 From: Hossam Hagag <90828745+hh-space-invader@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:46:03 +0200 Subject: [PATCH 11/21] fix: Update fastembed/text/multitask_embedding.py Co-authored-by: George --- fastembed/text/multitask_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 051fb028..05189dc8 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -30,7 +30,7 @@ ] -class Task(str, Enum): +class Task(int, Enum): RETRIEVAL_QUERY = 0 RETRIEVAL_PASSAGE = 1 SEPARATION = 2 From 38ad796d646ee373762eb8861b9b97217fca8802 Mon Sep 17 00:00:00 2001 From: Hossam Hagag <90828745+hh-space-invader@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:46:13 +0200 Subject: [PATCH 12/21] Update fastembed/text/multitask_embedding.py Co-authored-by: George --- fastembed/text/multitask_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 05189dc8..e0aac0e2 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -66,7 +66,7 @@ def embed( documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: Optional[int] = None, - task_id: int = 4, + task_id: int = DEFAULT_TASK, **kwargs, ) -> Iterable[np.ndarray]: self._current_task_id = task_id From b33301d15cf86f8001b75d74471daf5c90ad1c6c Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Fri, 3 Jan 2025 10:47:46 +0200 Subject: [PATCH 13/21] fix: Pass task id using kwargs to parallel processor --- fastembed/text/multitask_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index e0aac0e2..31c774bc 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -70,6 +70,7 @@ def embed( **kwargs, ) -> Iterable[np.ndarray]: self._current_task_id = task_id + kwargs["task_id"] = task_id yield from super().embed(documents, batch_size, parallel, **kwargs) def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: @@ -101,9 +102,11 @@ def init_embedding( cache_dir: str, **kwargs, ) -> JinaEmbeddingV3: - return JinaEmbeddingV3( + model = JinaEmbeddingV3( model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs, ) + model._current_task_id = kwargs["task_id"] + return model From 89cf73252affe7fbf2413f8f2f0f986d5926c1bc Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Fri, 3 Jan 2025 11:04:30 +0200 Subject: [PATCH 14/21] tests: Added test for task assignment --- tests/test_text_multitask_embeddings.py | 37 ++++++++++++++----------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index d908a91c..53878576 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -4,6 +4,7 @@ import pytest from fastembed import TextEmbedding +from fastembed.text.multitask_embedding import Task from tests.utils import delete_model_cache @@ -188,31 +189,35 @@ def test_single_embedding_passage(): delete_model_cache(model.model._model_dir) -def test_parallel_processing(): +def test_task_assignment(): is_ci = os.getenv("CI") - docs = ["Hello World", "Follow the white rabbit."] * 100 + for model_desc in TextEmbedding.list_supported_models(): + if not is_ci and model_desc["size_in_GB"] > 1: + continue - model_name = "jinaai/jina-embeddings-v3" - dim = 1024 + model_name = model_desc["model"] + if model_name not in CANONICAL_VECTOR_VALUES.keys(): + continue - if is_ci: - model = TextEmbedding(model_name=model_name, lazy_load=True) + model = TextEmbedding(model_name=model_name) - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings = np.stack(embeddings, axis=0) + _ = list(model.embed(documents=docs, batch_size=1, task_id=2)) + assert model.model._current_task_id == Task.SEPARATION - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - embeddings_2 = np.stack(embeddings_2, axis=0) + _ = list( + model.embed(documents=docs, batch_size=1, parallel=1, task_id=Task.CLASSIFICATION) + ) + assert model.model._current_task_id == 3 - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) - embeddings_3 = np.stack(embeddings_3, axis=0) + _ = list(model.query_embed(query=docs)) + assert model.model._current_task_id == Task.RETRIEVAL_QUERY - assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == dim - assert np.allclose(embeddings, embeddings_2, atol=1e-4) - assert np.allclose(embeddings, embeddings_3, atol=1e-4) + _ = list(model.passage_embed(texts=docs)) + assert model.model._current_task_id == Task.RETRIEVAL_PASSAGE - delete_model_cache(model.model._model_dir) + if is_ci: + delete_model_cache(model.model._model_dir) @pytest.mark.parametrize( From 91afca785988daceddd3d33b913a8fc767d4deb5 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Fri, 3 Jan 2025 11:55:20 +0100 Subject: [PATCH 15/21] prefer enums over ints --- tests/test_text_multitask_embeddings.py | 32 +++++++++---------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 53878576..3cf0b1cf 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -11,7 +11,7 @@ CANONICAL_VECTOR_VALUES = { "jinaai/jina-embeddings-v3": [ { - "task_id": 0, + "task_id": Task.RETRIEVAL_QUERY, "vectors": np.array( [ [0.0623, -0.0402, 0.1706, -0.0143, 0.0617], @@ -20,7 +20,7 @@ ), }, { - "task_id": 1, + "task_id": Task.RETRIEVAL_PASSAGE, "vectors": np.array( [ [0.0513, -0.0247, 0.1751, -0.0075, 0.0679], @@ -29,7 +29,7 @@ ), }, { - "task_id": 2, + "task_id": Task.SEPARATION, "vectors": np.array( [ [0.094, -0.1065, 0.1305, 0.0547, 0.0556], @@ -38,7 +38,7 @@ ), }, { - "task_id": 3, + "task_id": Task.CLASSIFICATION, "vectors": np.array( [ [0.0606, -0.0877, 0.1384, 0.0065, 0.0722], @@ -47,7 +47,7 @@ ), }, { - "task_id": 4, + "task_id": Task.TEXT_MATCHING, "vectors": np.array( [ [0.0911, -0.0341, 0.1305, -0.026, 0.0576], @@ -63,7 +63,7 @@ def test_batch_embedding(): is_ci = os.getenv("CI") docs_to_embed = docs * 10 - default_task = 4 + default_task = Task.TEXT_MATCHING for model_desc in TextEmbedding.list_supported_models(): if not is_ci and model_desc["size_in_GB"] > 1: @@ -127,7 +127,7 @@ def test_single_embedding(): def test_single_embedding_query(): is_ci = os.getenv("CI") - task_id = 0 + task_id = Task.RETRIEVAL_QUERY for model_desc in TextEmbedding.list_supported_models(): if not is_ci and model_desc["size_in_GB"] > 1: @@ -159,7 +159,7 @@ def test_single_embedding_query(): def test_single_embedding_passage(): is_ci = os.getenv("CI") - task_id = 1 + task_id = Task.RETRIEVAL_PASSAGE for model_desc in TextEmbedding.list_supported_models(): if not is_ci and model_desc["size_in_GB"] > 1: @@ -202,19 +202,9 @@ def test_task_assignment(): model = TextEmbedding(model_name=model_name) - _ = list(model.embed(documents=docs, batch_size=1, task_id=2)) - assert model.model._current_task_id == Task.SEPARATION - - _ = list( - model.embed(documents=docs, batch_size=1, parallel=1, task_id=Task.CLASSIFICATION) - ) - assert model.model._current_task_id == 3 - - _ = list(model.query_embed(query=docs)) - assert model.model._current_task_id == Task.RETRIEVAL_QUERY - - _ = list(model.passage_embed(texts=docs)) - assert model.model._current_task_id == Task.RETRIEVAL_PASSAGE + for i, task_id in enumerate(Task): + _ = list(model.embed(documents=docs, batch_size=1, task_id=i)) + assert model.model._current_task_id == task_id if is_ci: delete_model_cache(model.model._model_dir) From 3bee8c3bdb2b611a81de10a858301339f5524f06 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Sun, 5 Jan 2025 22:17:08 +0200 Subject: [PATCH 16/21] tests: Added test for parallel --- tests/test_text_multitask_embeddings.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 3cf0b1cf..e83b9e77 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -189,6 +189,29 @@ def test_single_embedding_passage(): delete_model_cache(model.model._model_dir) +def test_parallel_processing(): + is_ci = os.getenv("CI") + + docs = ["Hello World", "Follow the white rabbit."] * 10 + + model_name = "jinaai/jina-embeddings-v3" + dim = 1024 + + model = TextEmbedding(model_name=model_name) + + embeddings_1 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_1 = np.stack(embeddings_1, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=1)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + assert embeddings_1.shape[0] == len(docs) and embeddings_1.shape[-1] == dim + assert np.allclose(embeddings_1, embeddings_2, atol=1e-4) + + if is_ci: + delete_model_cache(model.model._model_dir) + + def test_task_assignment(): is_ci = os.getenv("CI") From a4c94990ae9255a35136426f323086608aab7a17 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Sun, 5 Jan 2025 22:23:18 +0200 Subject: [PATCH 17/21] improve: Updated model description --- fastembed/text/multitask_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index 31c774bc..bc436461 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -18,7 +18,7 @@ "classification": 3, "text-matching": 4, }, - "description": "Multi-task, multi-lingual embedding model with Matryoshka architecture", + "description": "Multi-task unimodal (text) embedding model, multi-lingual (~100), 1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year.", "license": "cc-by-nc-4.0", "size_in_GB": 2.29, "sources": { From 2fdee75fb061d64ee327caa8d0f90a45dc939c5e Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Sun, 5 Jan 2025 23:25:49 +0200 Subject: [PATCH 18/21] fix: Fix ci --- tests/test_text_onnx_embeddings.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index ac2c41a3..24801229 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -68,12 +68,18 @@ "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), } +MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] + def test_embedding(): is_ci = os.getenv("CI") for model_desc in TextEmbedding.list_supported_models(): - if not is_ci and model_desc["size_in_GB"] > 1: + if ( + not is_ci + and model_desc["size_in_GB"] > 1 + and model_desc["model"] not in MULTI_TASK_MODELS + ): continue dim = model_desc["dim"] From dd6111c9927ac97fc7359ec26a3cd4e774695157 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Sun, 5 Jan 2025 23:26:58 +0200 Subject: [PATCH 19/21] fix: Fix ci --- tests/test_text_onnx_embeddings.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 24801229..26b3f7f7 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -75,11 +75,9 @@ def test_embedding(): is_ci = os.getenv("CI") for model_desc in TextEmbedding.list_supported_models(): - if ( - not is_ci - and model_desc["size_in_GB"] > 1 - and model_desc["model"] not in MULTI_TASK_MODELS - ): + if (not is_ci and model_desc["size_in_GB"] > 1) or model_desc[ + "model" + ] in MULTI_TASK_MODELS: continue dim = model_desc["dim"] From 16aebc010294368e475f90f9f8fb97e7c1fb6abe Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 8 Jan 2025 08:20:45 +0200 Subject: [PATCH 20/21] refactor: Refactor query_embed and passage_embed --- fastembed/text/multitask_embedding.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/fastembed/text/multitask_embedding.py b/fastembed/text/multitask_embedding.py index bc436461..a22c9a42 100644 --- a/fastembed/text/multitask_embedding.py +++ b/fastembed/text/multitask_embedding.py @@ -75,24 +75,11 @@ def embed( def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: self._current_task_id = self.QUERY_TASK - - if isinstance(query, str): - query = [query] - - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() - - for text in query: - yield from self._post_process_onnx_output(self.onnx_embed([text])) + yield from super().embed(query, **kwargs) def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: self._current_task_id = self.PASSAGE_TASK - - if not hasattr(self, "model") or self.model is None: - self.load_onnx_model() - - for text in texts: - yield from self._post_process_onnx_output(self.onnx_embed([text])) + yield from super().embed(texts, **kwargs) class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): From a7c6582e1ff891d91a4e77587b2691ebde3d8831 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 8 Jan 2025 08:52:54 +0200 Subject: [PATCH 21/21] tests: Added task propagation to parallel --- tests/test_text_multitask_embeddings.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index e83b9e77..c86a9cdb 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -199,15 +199,19 @@ def test_parallel_processing(): model = TextEmbedding(model_name=model_name) - embeddings_1 = list(model.embed(docs, batch_size=10, parallel=None)) + task_id = Task.SEPARATION + embeddings_1 = list(model.embed(docs, batch_size=10, parallel=None, task_id=task_id)) embeddings_1 = np.stack(embeddings_1, axis=0) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=1)) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=1, task_id=task_id)) embeddings_2 = np.stack(embeddings_2, axis=0) assert embeddings_1.shape[0] == len(docs) and embeddings_1.shape[-1] == dim assert np.allclose(embeddings_1, embeddings_2, atol=1e-4) + canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"] + assert np.allclose(embeddings_2[:2, : canonical_vector.shape[1]], canonical_vector, atol=1e-4) + if is_ci: delete_model_cache(model.model._model_dir)