From bbad1ba7b540d3482ce6829e2391fa89cb556078 Mon Sep 17 00:00:00 2001 From: Aaron Ji <127167174+DresAaron@users.noreply.github.com> Date: Thu, 19 Sep 2024 01:39:08 +0800 Subject: [PATCH] feat: update JinaEmbedding for v3 release (#15971) --- .../embeddings/jinaai_embeddings.ipynb | 71 +++++++++++++-- .../llama_index/embeddings/jinaai/base.py | 87 +++++++++++++++---- .../pyproject.toml | 2 +- 3 files changed, 136 insertions(+), 24 deletions(-) diff --git a/docs/docs/examples/embeddings/jinaai_embeddings.ipynb b/docs/docs/examples/embeddings/jinaai_embeddings.ipynb index 7bd4c21984fd1..6403f3a3956b1 100644 --- a/docs/docs/examples/embeddings/jinaai_embeddings.ipynb +++ b/docs/docs/examples/embeddings/jinaai_embeddings.ipynb @@ -90,7 +90,48 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can encode your text and your queries using the JinaEmbedding class" + "You can encode your text and your queries using the JinaEmbedding class. Jina offers a range of models adaptable to various use cases.\n", + "\n", + "| Model | Dimension | Language | MRL (matryoshka) | Context |\n", + "|:----------------------:|:---------:|:---------:|:-----------:|:---------:|\n", + "| jina-embeddings-v3 | 1024 | Multilingual (89 languages) | Yes | 8192 |\n", + "| jina-embeddings-v2-base-en | 768 | English | No | 8192 | \n", + "| jina-embeddings-v2-base-de | 768 | German & English | No | 8192 | \n", + "| jina-embeddings-v2-base-es | 768 | Spanish & English | No | 8192 | \n", + "| jina-embeddings-v2-base-zh | 768 | Chinese & English | No | 8192 | \n", + "\n", + "**Recommended Model: jina-embeddings-v3 :**\n", + "\n", + "We recommend `jina-embeddings-v3` as the latest and most performant embedding model from Jina AI. This model features 5 task-specific adapters trained on top of its backbone, optimizing various embedding use cases.\n", + "\n", + "By default `JinaEmbedding` class uses `jina-embeddings-v3`. On top of the backbone, `jina-embeddings-v3` has been trained with 5 task-specific adapters for different embedding uses.\n", + "\n", + "**Task-Specific Adapters:**\n", + "\n", + "Include `task` in your request to optimize your downstream application:\n", + "\n", + "+ **retrieval.query**: Used to encode user queries or questions in retrieval tasks.\n", + "+ **retrieval.passage**: Used to encode large documents in retrieval tasks at indexing time.\n", + "+ **classification**: Used to encode text for text classification tasks.\n", + "+ **text-matching**: Used to encode text for similarity matching, such as measuring similarity between two sentences.\n", + "+ **separation**: Used for clustering or reranking tasks.\n", + "\n", + "\n", + "**Matryoshka Representation Learning**:\n", + "\n", + "`jina-embeddings-v3` supports Matryoshka Representation Learning, allowing users to control the embedding dimension with minimal performance loss. \n", + "Include `dimensions` in your request to select the desired dimension. \n", + "By default, **dimensions** is set to 1024, and a number between 256 and 1024 is recommended. \n", + "You can reference the table below for hints on dimension vs. performance:\n", + "\n", + "\n", + "| Dimension | 32 | 64 | 128 | 256 | 512 | 768 | 1024 | \n", + "|:----------------------:|:---------:|:---------:|:-----------:|:---------:|:----------:|:---------:|:---------:|\n", + "| Average Retrieval Performance (nDCG@10) | 52.54 | 58.54 | 61.64 | 62.72 | 63.16 | 63.3 | 63.35 | \n", + "\n", + "**Late Chunking in Long-Context Embedding Models**\n", + "\n", + "`jina-embeddings-v3` supports [Late Chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models/), the technique to leverage the model's long-context capabilities for generating contextual chunk embeddings. Include `late_chunking=True` in your request to enable contextual chunked representation. When set to true, Jina AI API will concatenate all sentences in the input field and feed them as a single string to the model. Internally, the model embeds this long concatenated string and then performs late chunking, returning a list of embeddings that matches the size of the input list. " ] }, { @@ -101,16 +142,30 @@ "source": [ "from llama_index.embeddings.jinaai import JinaEmbedding\n", "\n", - "embed_model = JinaEmbedding(\n", + "text_embed_model = JinaEmbedding(\n", " api_key=jinaai_api_key,\n", - " model=\"jina-embeddings-v2-base-en\",\n", + " model=\"jina-embeddings-v3\",\n", + " # choose `retrieval.passage` to get passage embeddings\n", + " task=\"retrieval.passage\",\n", ")\n", "\n", - "embeddings = embed_model.get_text_embedding(\"This is the text to embed\")\n", + "embeddings = text_embed_model.get_text_embedding(\"This is the text to embed\")\n", "print(\"Text dim:\", len(embeddings))\n", "print(\"Text embed:\", embeddings[:5])\n", "\n", - "embeddings = embed_model.get_query_embedding(\"This is the query to embed\")\n", + "query_embed_model = JinaEmbedding(\n", + " api_key=jinaai_api_key,\n", + " model=\"jina-embeddings-v3\",\n", + " # choose `retrieval.query` to get query embeddings, or choose your desired task type\n", + " task=\"retrieval.query\",\n", + " # `dimensions` allows users to control the embedding dimension with minimal performance loss. by default it is 1024.\n", + " # A number between 256 and 1024 is recommended.\n", + " dimensions=512,\n", + ")\n", + "\n", + "embeddings = query_embed_model.get_query_embedding(\n", + " \"This is the query to embed\"\n", + ")\n", "print(\"Query dim:\", len(embeddings))\n", "print(\"Query embed:\", embeddings[:5])" ] @@ -190,8 +245,9 @@ "source": [ "embed_model = JinaEmbedding(\n", " api_key=jinaai_api_key,\n", - " model=\"jina-embeddings-v2-base-en\",\n", + " model=\"jina-embeddings-v3\",\n", " embed_batch_size=16,\n", + " task=\"retrieval.passage\",\n", ")\n", "\n", "embeddings = embed_model.get_text_embedding_batch(\n", @@ -290,8 +346,9 @@ "llm = OpenAI(api_key=your_openai_key)\n", "embed_model = JinaEmbedding(\n", " api_key=jinaai_api_key,\n", - " model=\"jina-embeddings-v2-base-en\",\n", + " model=\"jina-embeddings-v3\",\n", " embed_batch_size=16,\n", + " task=\"retrieval.passage\",\n", ")\n", "\n", "index = VectorStoreIndex.from_documents(\n", diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/llama_index/embeddings/jinaai/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/llama_index/embeddings/jinaai/base.py index 89fd6fe83476f..dad5ad21121c0 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/llama_index/embeddings/jinaai/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/llama_index/embeddings/jinaai/base.py @@ -24,7 +24,7 @@ class _JinaAPICaller: def __init__( self, - model: str = "jina-embeddings-v2-base-en", + model: str = "jina-embeddings-v3", base_url: str = DEFAULT_JINA_AI_API_URL, api_key: Optional[str] = None, **kwargs: Any, @@ -37,12 +37,31 @@ def __init__( {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} ) - def get_embeddings(self, input, encoding_type: str = "float") -> List[List[float]]: + def get_embeddings( + self, + input, + encoding_type: str = "float", + task: Optional[str] = None, + dimensions: Optional[int] = None, + late_chunking: Optional[bool] = None, + ) -> List[List[float]]: """Get embeddings.""" # Call Jina AI Embedding API + input_json = { + "input": input, + "model": self.model, + "encoding_type": encoding_type, + } + if task is not None: + input_json["task"] = task + if dimensions is not None: + input_json["dimensions"] = dimensions + if late_chunking is not None: + input_json["late_chunking"] = late_chunking + resp = self._session.post( # type: ignore self.api_url, - json={"input": input, "model": self.model, "encoding_type": encoding_type}, + json=input_json, ).json() if "data" not in resp: raise RuntimeError(resp["detail"]) @@ -68,7 +87,12 @@ def get_embeddings(self, input, encoding_type: str = "float") -> List[List[float return [result["embedding"] for result in sorted_embeddings] async def aget_embeddings( - self, input, encoding_type: str = "float" + self, + input, + encoding_type: str = "float", + task: Optional[str] = None, + dimensions: Optional[int] = None, + late_chunking: Optional[bool] = None, ) -> List[List[float]]: """Asynchronously get text embeddings.""" import aiohttp @@ -78,13 +102,21 @@ async def aget_embeddings( "Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity", } + input_json = { + "input": input, + "model": self.model, + "encoding_type": encoding_type, + } + if task is not None: + input_json["task"] = task + if dimensions is not None: + input_json["dimensions"] = dimensions + if late_chunking is not None: + input_json["late_chunking"] = late_chunking + async with session.post( self.api_url, - json={ - "input": input, - "model": self.model, - "encoding_type": encoding_type, - }, + json=input_json, headers=headers, ) as response: resp = await response.json() @@ -130,27 +162,31 @@ class JinaEmbedding(MultiModalEmbedding): Args: model (str): Model for embedding. - Defaults to `jina-embeddings-v2-base-en` + Defaults to `jina-embeddings-v3` """ api_key: Optional[str] = Field(default=None, description="The JinaAI API key.") model: str = Field( - default="jina-embeddings-v2-base-en", + default="jina-embeddings-v3", description="The model to use when calling Jina AI API", ) _encoding_queries: str = PrivateAttr() _encoding_documents: str = PrivateAttr() + _task: str = PrivateAttr() _api: Any = PrivateAttr() def __init__( self, - model: str = "jina-embeddings-v2-base-en", + model: str = "jina-embeddings-v3", embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, api_key: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, encoding_queries: Optional[str] = None, encoding_documents: Optional[str] = None, + task: Optional[str] = None, + dimensions: Optional[int] = None, + late_chunking: Optional[bool] = None, **kwargs: Any, ) -> None: super().__init__( @@ -162,6 +198,9 @@ def __init__( ) self._encoding_queries = encoding_queries or "float" self._encoding_documents = encoding_documents or "float" + self._task = task + self._dimensions = dimensions + self._late_chunking = late_chunking assert ( self._encoding_documents in VALID_ENCODING @@ -179,13 +218,21 @@ def class_name(cls) -> str: def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" return self._api.get_embeddings( - input=[query], encoding_type=self._encoding_queries + input=[query], + encoding_type=self._encoding_queries, + task=self._task, + dimensions=self._dimensions, + late_chunking=self._late_chunking, )[0] async def _aget_query_embedding(self, query: str) -> List[float]: """The asynchronous version of _get_query_embedding.""" result = await self._api.aget_embeddings( - input=[query], encoding_type=self._encoding_queries + input=[query], + encoding_type=self._encoding_queries, + task=self._task, + dimensions=self._dimensions, + late_chunking=self._late_chunking, ) return result[0] @@ -200,7 +247,11 @@ async def _aget_text_embedding(self, text: str) -> List[float]: def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: return self._api.get_embeddings( - input=texts, encoding_type=self._encoding_documents + input=texts, + encoding_type=self._encoding_documents, + task=self._task, + dimensions=self._dimensions, + late_chunking=self._late_chunking, ) async def _aget_text_embeddings( @@ -208,7 +259,11 @@ async def _aget_text_embeddings( texts: List[str], ) -> List[List[float]]: return await self._api.aget_embeddings( - input=texts, encoding_type=self._encoding_documents + input=texts, + encoding_type=self._encoding_documents, + task=self._task, + dimensions=self._dimensions, + late_chunking=self._late_chunking, ) def _get_image_embedding(self, img_file_path: ImageType) -> List[float]: diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/pyproject.toml index bd1b165f24d65..1256be42ba531 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-jinaai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-jinaai" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"