Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: Added jina embedding v3 #428

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8b4eb26
new: Added jina embedding v3
hh-space-invader Dec 19, 2024
64127fc
refactor: Changed dim to int value
hh-space-invader Dec 23, 2024
e48f647
new: Updated notice
hh-space-invader Dec 23, 2024
eb475d5
new: Extended text embedding with query embed and passage embed
hh-space-invader Dec 23, 2024
1650252
fix: Fix lazy load in query and passage embed
hh-space-invader Dec 23, 2024
197b381
tests: Added test for multitask embeddings
hh-space-invader Dec 23, 2024
c917201
nit: Remove cache dir from tests
hh-space-invader Dec 23, 2024
1ed62e9
tests: Updated tests
hh-space-invader Dec 23, 2024
eda3bae
improve: Improve task selection
hh-space-invader Dec 30, 2024
a31461b
fix: Fix ci
hh-space-invader Dec 30, 2024
2f8290d
fix: Update fastembed/text/multitask_embedding.py
hh-space-invader Jan 3, 2025
38ad796
Update fastembed/text/multitask_embedding.py
hh-space-invader Jan 3, 2025
b33301d
fix: Pass task id using kwargs to parallel processor
hh-space-invader Jan 3, 2025
89cf732
tests: Added test for task assignment
hh-space-invader Jan 3, 2025
91afca7
prefer enums over ints
joein Jan 3, 2025
3bee8c3
tests: Added test for parallel
hh-space-invader Jan 5, 2025
a4c9499
improve: Updated model description
hh-space-invader Jan 5, 2025
2fdee75
fix: Fix ci
hh-space-invader Jan 5, 2025
dd6111c
fix: Fix ci
hh-space-invader Jan 5, 2025
16aebc0
refactor: Refactor query_embed and passage_embed
hh-space-invader Jan 8, 2025
a7c6582
tests: Added task propagation to parallel
hh-space-invader Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This distribution includes the following Jina AI models, each with its respectiv
- License: cc-by-nc-4.0
- jinaai/jina-reranker-v2-base-multilingual
- License: cc-by-nc-4.0
- jinaai/jina-embeddings-v3
- License: cc-by-nc-4.0

These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms.

Expand Down
112 changes: 112 additions & 0 deletions fastembed/text/multitask_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from enum import Enum
from typing import Any, Type, Iterable, Union, Optional

import numpy as np

from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
from fastembed.text.onnx_text_model import TextEmbeddingWorker

supported_multitask_models = [
{
"model": "jinaai/jina-embeddings-v3",
"dim": 1024,
"tasks": {
"retrieval.query": 0,
"retrieval.passage": 1,
"separation": 2,
"classification": 3,
"text-matching": 4,
},
"description": "Multi-task, multi-lingual embedding model with Matryoshka architecture",
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved
"license": "cc-by-nc-4.0",
"size_in_GB": 2.29,
"sources": {
"hf": "jinaai/jina-embeddings-v3",
},
"model_file": "onnx/model.onnx",
"additional_files": ["onnx/model.onnx_data"],
},
]


class Task(int, Enum):
RETRIEVAL_QUERY = 0
RETRIEVAL_PASSAGE = 1
SEPARATION = 2
CLASSIFICATION = 3
TEXT_MATCHING = 4


class JinaEmbeddingV3(PooledNormalizedEmbedding):
DEFAULT_TASK = Task.TEXT_MATCHING
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it might be inconvenient to have a default task different from passage embed task, could we please make it the same?

PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
QUERY_TASK = Task.RETRIEVAL_QUERY

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current_task_id = self.DEFAULT_TASK

@classmethod
def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
return JinaEmbeddingV3Worker

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
return supported_multitask_models

def _preprocess_onnx_input(
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
onnx_input["task_id"] = np.array(self._current_task_id, dtype=np.int64)
return onnx_input

def embed(
self,
documents: Union[str, Iterable[str]],
batch_size: int = 256,
parallel: Optional[int] = None,
task_id: int = DEFAULT_TASK,
**kwargs,
) -> Iterable[np.ndarray]:
self._current_task_id = task_id
kwargs["task_id"] = task_id
yield from super().embed(documents, batch_size, parallel, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
self._current_task_id = self.QUERY_TASK

if isinstance(query, str):
query = [query]

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for text in query:
yield from self._post_process_onnx_output(self.onnx_embed([text]))
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
self._current_task_id = self.PASSAGE_TASK

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for text in texts:
yield from self._post_process_onnx_output(self.onnx_embed([text]))


class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
def init_embedding(
self,
model_name: str,
cache_dir: str,
**kwargs,
) -> JinaEmbeddingV3:
model = JinaEmbeddingV3(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
model._current_task_id = kwargs["task_id"]
return model
29 changes: 29 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
from fastembed.text.pooled_embedding import PooledEmbedding
from fastembed.text.multitask_embedding import JinaEmbeddingV3
from fastembed.text.onnx_embedding import OnnxTextEmbedding
from fastembed.text.text_embedding_base import TextEmbeddingBase

Expand All @@ -18,6 +19,7 @@ class TextEmbedding(TextEmbeddingBase):
CLIPOnnxEmbedding,
PooledNormalizedEmbedding,
PooledEmbedding,
JinaEmbeddingV3,
]

@classmethod
Expand Down Expand Up @@ -105,3 +107,30 @@ def embed(
List of embeddings, one per document
"""
yield from self.model.embed(documents, batch_size, parallel, **kwargs)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
"""
Embeds queries

Args:
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.

Returns:
Iterable[np.ndarray]: The embeddings.
"""
# This is model-specific, so that different models can have specialized implementations
yield from self.model.query_embed(query, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
"""
Embeds a list of text passages into a list of embeddings.

Args:
texts (Iterable[str]): The list of texts to embed.
**kwargs: Additional keyword argument to pass to the embed method.

Yields:
Iterable[SparseEmbedding]: The sparse embeddings.
"""
# This is model-specific, so that different models can have specialized implementations
yield from self.model.passage_embed(texts, **kwargs)
226 changes: 226 additions & 0 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import os

import numpy as np
import pytest

from fastembed import TextEmbedding
from fastembed.text.multitask_embedding import Task
from tests.utils import delete_model_cache


CANONICAL_VECTOR_VALUES = {
"jinaai/jina-embeddings-v3": [
{
"task_id": Task.RETRIEVAL_QUERY,
"vectors": np.array(
[
[0.0623, -0.0402, 0.1706, -0.0143, 0.0617],
[-0.1064, -0.0733, 0.0353, 0.0096, 0.0667],
]
),
},
{
"task_id": Task.RETRIEVAL_PASSAGE,
"vectors": np.array(
[
[0.0513, -0.0247, 0.1751, -0.0075, 0.0679],
[-0.0987, -0.0786, 0.09, 0.0087, 0.0577],
]
),
},
{
"task_id": Task.SEPARATION,
"vectors": np.array(
[
[0.094, -0.1065, 0.1305, 0.0547, 0.0556],
[0.0315, -0.1468, 0.065, 0.0568, 0.0546],
]
),
},
{
"task_id": Task.CLASSIFICATION,
"vectors": np.array(
[
[0.0606, -0.0877, 0.1384, 0.0065, 0.0722],
[-0.0502, -0.119, 0.032, 0.0514, 0.0689],
]
),
},
{
"task_id": Task.TEXT_MATCHING,
"vectors": np.array(
[
[0.0911, -0.0341, 0.1305, -0.026, 0.0576],
[-0.1432, -0.05, 0.0133, 0.0464, 0.0789],
]
),
},
]
}
docs = ["Hello World", "Follow the white rabbit."]


def test_batch_embedding():
is_ci = os.getenv("CI")
docs_to_embed = docs * 10
default_task = Task.TEXT_MATCHING

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} default task")

embeddings = list(model.embed(documents=docs_to_embed, batch_size=6))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs_to_embed), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][default_task]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding():
is_ci = os.getenv("CI")

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

for task in CANONICAL_VECTOR_VALUES[model_name]:
print(f"evaluating {model_name} task_id: {task['task_id']}")

embeddings = list(model.embed(documents=docs, task_id=task["task_id"]))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs), dim)

canonical_vector = task["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)
hh-space-invader marked this conversation as resolved.
Show resolved Hide resolved


def test_single_embedding_query():
is_ci = os.getenv("CI")
task_id = Task.RETRIEVAL_QUERY

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} query_embed task_id: {task_id}")

embeddings = list(model.query_embed(query=docs))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_single_embedding_passage():
is_ci = os.getenv("CI")
task_id = Task.RETRIEVAL_PASSAGE

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
dim = model_desc["dim"]

if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} passage_embed task_id: {task_id}")

embeddings = list(model.passage_embed(texts=docs))
embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (len(docs), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
assert np.allclose(
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
), model_desc["model"]

if is_ci:
delete_model_cache(model.model._model_dir)


def test_task_assignment():
is_ci = os.getenv("CI")

for model_desc in TextEmbedding.list_supported_models():
if not is_ci and model_desc["size_in_GB"] > 1:
continue

model_name = model_desc["model"]
if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name)

for i, task_id in enumerate(Task):
_ = list(model.embed(documents=docs, batch_size=1, task_id=i))
assert model.model._current_task_id == task_id

if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["jinaai/jina-embeddings-v3"],
)
def test_lazy_load(model_name):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")

list(model.embed(docs))
assert hasattr(model.model, "model")

if is_ci:
delete_model_cache(model.model._model_dir)
Loading