1- from chromadb .api .types import Embeddings , Documents , EmbeddingFunction , Space
1+ from chromadb .api .types import (
2+ Embeddings ,
3+ Documents ,
4+ EmbeddingFunction ,
5+ Space ,
6+ )
27from chromadb .utils .embedding_functions .schemas import validate_config_schema
3- from typing import List , Dict , Any , Union , Optional
8+ from typing import List , Dict , Any , Union , Optional , TypedDict
49import os
510import numpy as np
611import warnings
712
813
14+ class JinaQueryConfig (TypedDict ):
15+ task : str
16+
17+
918class JinaEmbeddingFunction (EmbeddingFunction [Documents ]):
1019 """
1120 This class is used to get embeddings for a list of texts using the Jina AI API.
@@ -23,6 +32,7 @@ def __init__(
2332 dimensions : Optional [int ] = None ,
2433 embedding_type : Optional [str ] = None ,
2534 normalized : Optional [bool ] = None ,
35+ query_config : Optional [JinaQueryConfig ] = None ,
2636 ):
2737 """
2838 Initialize the JinaEmbeddingFunction.
@@ -74,57 +84,49 @@ def __init__(
7484 self .dimensions = dimensions
7585 self .embedding_type = embedding_type
7686 self .normalized = normalized
87+ self .query_config = query_config
7788
7889 self ._api_url = "https://api.jina.ai/v1/embeddings"
7990 self ._session = httpx .Client ()
8091 self ._session .headers .update (
8192 {"Authorization" : f"Bearer { self .api_key } " , "Accept-Encoding" : "identity" }
8293 )
8394
84- def __call__ (self , input : Documents ) -> Embeddings :
85- """
86- Get the embeddings for a list of texts.
87-
88- Args:
89- input (Documents): A list of texts to get embeddings for.
90-
91- Returns:
92- Embeddings: The embeddings for the texts.
93-
94- Example:
95- >>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
96- >>> input = ["Hello, world!", "How are you?"]
97- """
98- # Jina AI only works with text documents
99- if not all (isinstance (item , str ) for item in input ):
100- raise ValueError ("Jina AI only supports text documents, not images" )
101-
95+ def _build_payload (self , input : Documents , is_query : bool ) -> Dict [str , Any ]:
10296 payload : Dict [str , Any ] = {
10397 "input" : input ,
10498 "model" : self .model_name ,
10599 }
106100
107101 if self .task is not None :
108102 payload ["task" ] = self .task
109-
110103 if self .late_chunking is not None :
111104 payload ["late_chunking" ] = self .late_chunking
112-
113105 if self .truncate is not None :
114106 payload ["truncate" ] = self .truncate
115-
116107 if self .dimensions is not None :
117108 payload ["dimensions" ] = self .dimensions
118-
119109 if self .embedding_type is not None :
120110 payload ["embedding_type" ] = self .embedding_type
121-
122111 if self .normalized is not None :
123112 payload ["normalized" ] = self .normalized
124113
125- # Call Jina AI Embedding API
126- resp = self ._session .post (self ._api_url , json = payload ).json ()
114+ if is_query and self .query_config is not None :
115+ for key , value in self .query_config .items ():
116+ payload [key ] = value
117+
118+ return payload
119+
120+ def _convert_resp (self , resp : Any ) -> Embeddings :
121+ """
122+ Convert the response from the Jina AI API to a list of numpy arrays.
123+
124+ Args:
125+ resp (Any): The response from the Jina AI API.
127126
127+ Returns:
128+ Embeddings: A list of numpy arrays representing the embeddings.
129+ """
128130 if "data" not in resp :
129131 raise RuntimeError (resp .get ("detail" , "Unknown error" ))
130132
@@ -139,6 +141,43 @@ def __call__(self, input: Documents) -> Embeddings:
139141 for result in sorted_embeddings
140142 ]
141143
144+ def __call__ (self , input : Documents ) -> Embeddings :
145+ """
146+ Get the embeddings for a list of texts.
147+
148+ Args:
149+ input (Documents): A list of texts to get embeddings for.
150+
151+ Returns:
152+ Embeddings: The embeddings for the texts.
153+
154+ Example:
155+ >>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
156+ >>> input = ["Hello, world!", "How are you?"]
157+ """
158+ # Jina AI only works with text documents
159+ if not all (isinstance (item , str ) for item in input ):
160+ raise ValueError ("Jina AI only supports text documents, not images" )
161+
162+ payload = self ._build_payload (input , is_query = False )
163+
164+ # Call Jina AI Embedding API
165+ resp = self ._session .post (self ._api_url , json = payload ).json ()
166+
167+ return self ._convert_resp (resp )
168+
169+ def embed_query (self , input : Documents ) -> Embeddings :
170+ # Jina AI only works with text documents
171+ if not all (isinstance (item , str ) for item in input ):
172+ raise ValueError ("Jina AI only supports text documents, not images" )
173+
174+ payload = self ._build_payload (input , is_query = True )
175+
176+ # Call Jina AI Embedding API
177+ resp = self ._session .post (self ._api_url , json = payload ).json ()
178+
179+ return self ._convert_resp (resp )
180+
142181 @staticmethod
143182 def name () -> str :
144183 return "jina"
@@ -159,6 +198,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
159198 dimensions = config .get ("dimensions" )
160199 embedding_type = config .get ("embedding_type" )
161200 normalized = config .get ("normalized" )
201+ query_config = config .get ("query_config" )
162202
163203 if api_key_env_var is None or model_name is None :
164204 assert False , "This code should not be reached" # this is for type checking
@@ -172,6 +212,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
172212 dimensions = dimensions ,
173213 embedding_type = embedding_type ,
174214 normalized = normalized ,
215+ query_config = query_config ,
175216 )
176217
177218 def get_config (self ) -> Dict [str , Any ]:
@@ -184,6 +225,7 @@ def get_config(self) -> Dict[str, Any]:
184225 "dimensions" : self .dimensions ,
185226 "embedding_type" : self .embedding_type ,
186227 "normalized" : self .normalized ,
228+ "query_config" : self .query_config ,
187229 }
188230
189231 def validate_config_update (
0 commit comments