Skip to content

Commit

Permalink
feat: update JinaEmbedding for v3 release (#15971)
Browse files Browse the repository at this point in the history
  • Loading branch information
DresAaron authored Sep 18, 2024
1 parent 4ef2cdc commit bbad1ba
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 24 deletions.
71 changes: 64 additions & 7 deletions docs/docs/examples/embeddings/jinaai_embeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,48 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You can encode your text and your queries using the JinaEmbedding class"
"You can encode your text and your queries using the JinaEmbedding class. Jina offers a range of models adaptable to various use cases.\n",
"\n",
"| Model | Dimension | Language | MRL (matryoshka) | Context |\n",
"|:----------------------:|:---------:|:---------:|:-----------:|:---------:|\n",
"| jina-embeddings-v3 | 1024 | Multilingual (89 languages) | Yes | 8192 |\n",
"| jina-embeddings-v2-base-en | 768 | English | No | 8192 | \n",
"| jina-embeddings-v2-base-de | 768 | German & English | No | 8192 | \n",
"| jina-embeddings-v2-base-es | 768 | Spanish & English | No | 8192 | \n",
"| jina-embeddings-v2-base-zh | 768 | Chinese & English | No | 8192 | \n",
"\n",
"**Recommended Model: jina-embeddings-v3 :**\n",
"\n",
"We recommend `jina-embeddings-v3` as the latest and most performant embedding model from Jina AI. This model features 5 task-specific adapters trained on top of its backbone, optimizing various embedding use cases.\n",
"\n",
"By default `JinaEmbedding` class uses `jina-embeddings-v3`. On top of the backbone, `jina-embeddings-v3` has been trained with 5 task-specific adapters for different embedding uses.\n",
"\n",
"**Task-Specific Adapters:**\n",
"\n",
"Include `task` in your request to optimize your downstream application:\n",
"\n",
"+ **retrieval.query**: Used to encode user queries or questions in retrieval tasks.\n",
"+ **retrieval.passage**: Used to encode large documents in retrieval tasks at indexing time.\n",
"+ **classification**: Used to encode text for text classification tasks.\n",
"+ **text-matching**: Used to encode text for similarity matching, such as measuring similarity between two sentences.\n",
"+ **separation**: Used for clustering or reranking tasks.\n",
"\n",
"\n",
"**Matryoshka Representation Learning**:\n",
"\n",
"`jina-embeddings-v3` supports Matryoshka Representation Learning, allowing users to control the embedding dimension with minimal performance loss. \n",
"Include `dimensions` in your request to select the desired dimension. \n",
"By default, **dimensions** is set to 1024, and a number between 256 and 1024 is recommended. \n",
"You can reference the table below for hints on dimension vs. performance:\n",
"\n",
"\n",
"| Dimension | 32 | 64 | 128 | 256 | 512 | 768 | 1024 | \n",
"|:----------------------:|:---------:|:---------:|:-----------:|:---------:|:----------:|:---------:|:---------:|\n",
"| Average Retrieval Performance (nDCG@10) | 52.54 | 58.54 | 61.64 | 62.72 | 63.16 | 63.3 | 63.35 | \n",
"\n",
"**Late Chunking in Long-Context Embedding Models**\n",
"\n",
"`jina-embeddings-v3` supports [Late Chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models/), the technique to leverage the model's long-context capabilities for generating contextual chunk embeddings. Include `late_chunking=True` in your request to enable contextual chunked representation. When set to true, Jina AI API will concatenate all sentences in the input field and feed them as a single string to the model. Internally, the model embeds this long concatenated string and then performs late chunking, returning a list of embeddings that matches the size of the input list. "
]
},
{
Expand All @@ -101,16 +142,30 @@
"source": [
"from llama_index.embeddings.jinaai import JinaEmbedding\n",
"\n",
"embed_model = JinaEmbedding(\n",
"text_embed_model = JinaEmbedding(\n",
" api_key=jinaai_api_key,\n",
" model=\"jina-embeddings-v2-base-en\",\n",
" model=\"jina-embeddings-v3\",\n",
" # choose `retrieval.passage` to get passage embeddings\n",
" task=\"retrieval.passage\",\n",
")\n",
"\n",
"embeddings = embed_model.get_text_embedding(\"This is the text to embed\")\n",
"embeddings = text_embed_model.get_text_embedding(\"This is the text to embed\")\n",
"print(\"Text dim:\", len(embeddings))\n",
"print(\"Text embed:\", embeddings[:5])\n",
"\n",
"embeddings = embed_model.get_query_embedding(\"This is the query to embed\")\n",
"query_embed_model = JinaEmbedding(\n",
" api_key=jinaai_api_key,\n",
" model=\"jina-embeddings-v3\",\n",
" # choose `retrieval.query` to get query embeddings, or choose your desired task type\n",
" task=\"retrieval.query\",\n",
" # `dimensions` allows users to control the embedding dimension with minimal performance loss. by default it is 1024.\n",
" # A number between 256 and 1024 is recommended.\n",
" dimensions=512,\n",
")\n",
"\n",
"embeddings = query_embed_model.get_query_embedding(\n",
" \"This is the query to embed\"\n",
")\n",
"print(\"Query dim:\", len(embeddings))\n",
"print(\"Query embed:\", embeddings[:5])"
]
Expand Down Expand Up @@ -190,8 +245,9 @@
"source": [
"embed_model = JinaEmbedding(\n",
" api_key=jinaai_api_key,\n",
" model=\"jina-embeddings-v2-base-en\",\n",
" model=\"jina-embeddings-v3\",\n",
" embed_batch_size=16,\n",
" task=\"retrieval.passage\",\n",
")\n",
"\n",
"embeddings = embed_model.get_text_embedding_batch(\n",
Expand Down Expand Up @@ -290,8 +346,9 @@
"llm = OpenAI(api_key=your_openai_key)\n",
"embed_model = JinaEmbedding(\n",
" api_key=jinaai_api_key,\n",
" model=\"jina-embeddings-v2-base-en\",\n",
" model=\"jina-embeddings-v3\",\n",
" embed_batch_size=16,\n",
" task=\"retrieval.passage\",\n",
")\n",
"\n",
"index = VectorStoreIndex.from_documents(\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class _JinaAPICaller:
def __init__(
self,
model: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v3",
base_url: str = DEFAULT_JINA_AI_API_URL,
api_key: Optional[str] = None,
**kwargs: Any,
Expand All @@ -37,12 +37,31 @@ def __init__(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)

def get_embeddings(self, input, encoding_type: str = "float") -> List[List[float]]:
def get_embeddings(
self,
input,
encoding_type: str = "float",
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
) -> List[List[float]]:
"""Get embeddings."""
# Call Jina AI Embedding API
input_json = {
"input": input,
"model": self.model,
"encoding_type": encoding_type,
}
if task is not None:
input_json["task"] = task
if dimensions is not None:
input_json["dimensions"] = dimensions
if late_chunking is not None:
input_json["late_chunking"] = late_chunking

resp = self._session.post( # type: ignore
self.api_url,
json={"input": input, "model": self.model, "encoding_type": encoding_type},
json=input_json,
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
Expand All @@ -68,7 +87,12 @@ def get_embeddings(self, input, encoding_type: str = "float") -> List[List[float
return [result["embedding"] for result in sorted_embeddings]

async def aget_embeddings(
self, input, encoding_type: str = "float"
self,
input,
encoding_type: str = "float",
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
) -> List[List[float]]:
"""Asynchronously get text embeddings."""
import aiohttp
Expand All @@ -78,13 +102,21 @@ async def aget_embeddings(
"Authorization": f"Bearer {self.api_key}",
"Accept-Encoding": "identity",
}
input_json = {
"input": input,
"model": self.model,
"encoding_type": encoding_type,
}
if task is not None:
input_json["task"] = task
if dimensions is not None:
input_json["dimensions"] = dimensions
if late_chunking is not None:
input_json["late_chunking"] = late_chunking

async with session.post(
self.api_url,
json={
"input": input,
"model": self.model,
"encoding_type": encoding_type,
},
json=input_json,
headers=headers,
) as response:
resp = await response.json()
Expand Down Expand Up @@ -130,27 +162,31 @@ class JinaEmbedding(MultiModalEmbedding):
Args:
model (str): Model for embedding.
Defaults to `jina-embeddings-v2-base-en`
Defaults to `jina-embeddings-v3`
"""

api_key: Optional[str] = Field(default=None, description="The JinaAI API key.")
model: str = Field(
default="jina-embeddings-v2-base-en",
default="jina-embeddings-v3",
description="The model to use when calling Jina AI API",
)

_encoding_queries: str = PrivateAttr()
_encoding_documents: str = PrivateAttr()
_task: str = PrivateAttr()
_api: Any = PrivateAttr()

def __init__(
self,
model: str = "jina-embeddings-v2-base-en",
model: str = "jina-embeddings-v3",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
api_key: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
encoding_queries: Optional[str] = None,
encoding_documents: Optional[str] = None,
task: Optional[str] = None,
dimensions: Optional[int] = None,
late_chunking: Optional[bool] = None,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -162,6 +198,9 @@ def __init__(
)
self._encoding_queries = encoding_queries or "float"
self._encoding_documents = encoding_documents or "float"
self._task = task
self._dimensions = dimensions
self._late_chunking = late_chunking

assert (
self._encoding_documents in VALID_ENCODING
Expand All @@ -179,13 +218,21 @@ def class_name(cls) -> str:
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._api.get_embeddings(
input=[query], encoding_type=self._encoding_queries
input=[query],
encoding_type=self._encoding_queries,
task=self._task,
dimensions=self._dimensions,
late_chunking=self._late_chunking,
)[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self._api.aget_embeddings(
input=[query], encoding_type=self._encoding_queries
input=[query],
encoding_type=self._encoding_queries,
task=self._task,
dimensions=self._dimensions,
late_chunking=self._late_chunking,
)
return result[0]

Expand All @@ -200,15 +247,23 @@ async def _aget_text_embedding(self, text: str) -> List[float]:

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._api.get_embeddings(
input=texts, encoding_type=self._encoding_documents
input=texts,
encoding_type=self._encoding_documents,
task=self._task,
dimensions=self._dimensions,
late_chunking=self._late_chunking,
)

async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self._api.aget_embeddings(
input=texts, encoding_type=self._encoding_documents
input=texts,
encoding_type=self._encoding_documents,
task=self._task,
dimensions=self._dimensions,
late_chunking=self._late_chunking,
)

def _get_image_embedding(self, img_file_path: ImageType) -> List[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-jinaai"
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit bbad1ba

Please sign in to comment.