1- from chromadb .api .types import Embeddings , Documents , EmbeddingFunction , Space
1+ from chromadb .api .types import (
2+ Embeddings ,
3+ EmbeddingFunction ,
4+ Space ,
5+ Embeddable ,
6+ is_image ,
7+ is_document ,
8+ )
29from chromadb .utils .embedding_functions .schemas import validate_config_schema
3- from typing import List , Dict , Any , Union , Optional
10+ from typing import List , Dict , Any , Union , Optional , TypedDict
411import os
512import numpy as np
613import warnings
14+ import importlib
15+ import base64
16+ import io
717
818
9- class JinaEmbeddingFunction (EmbeddingFunction [Documents ]):
19+ class JinaQueryConfig (TypedDict ):
20+ task : str
21+
22+
23+ class JinaEmbeddingFunction (EmbeddingFunction [Embeddable ]):
1024 """
1125 This class is used to get embeddings for a list of texts using the Jina AI API.
1226 It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
@@ -23,6 +37,7 @@ def __init__(
2337 dimensions : Optional [int ] = None ,
2438 embedding_type : Optional [str ] = None ,
2539 normalized : Optional [bool ] = None ,
40+ query_config : Optional [JinaQueryConfig ] = None ,
2641 ):
2742 """
2843 Initialize the JinaEmbeddingFunction.
@@ -52,6 +67,12 @@ def __init__(
5267 raise ValueError (
5368 "The httpx python package is not installed. Please install it with `pip install httpx`"
5469 )
70+ try :
71+ self ._PILImage = importlib .import_module ("PIL.Image" )
72+ except ImportError :
73+ raise ValueError (
74+ "The PIL python package is not installed. Please install it with `pip install pillow`"
75+ )
5576
5677 if api_key is not None :
5778 warnings .warn (
@@ -74,57 +95,71 @@ def __init__(
7495 self .dimensions = dimensions
7596 self .embedding_type = embedding_type
7697 self .normalized = normalized
98+ self .query_config = query_config
7799
78100 self ._api_url = "https://api.jina.ai/v1/embeddings"
79101 self ._session = httpx .Client ()
80102 self ._session .headers .update (
81103 {"Authorization" : f"Bearer { self .api_key } " , "Accept-Encoding" : "identity" }
82104 )
83105
84- def __call__ (self , input : Documents ) -> Embeddings :
85- """
86- Get the embeddings for a list of texts.
87-
88- Args:
89- input (Documents): A list of texts to get embeddings for.
90-
91- Returns:
92- Embeddings: The embeddings for the texts.
93-
94- Example:
95- >>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
96- >>> input = ["Hello, world!", "How are you?"]
97- """
98- # Jina AI only works with text documents
99- if not all (isinstance (item , str ) for item in input ):
100- raise ValueError ("Jina AI only supports text documents, not images" )
101-
106+ def _build_payload (self , input : Embeddable , is_query : bool ) -> Dict [str , Any ]:
102107 payload : Dict [str , Any ] = {
103- "input" : input ,
108+ "input" : [] ,
104109 "model" : self .model_name ,
105110 }
111+ if all (is_document (item ) for item in input ):
112+ payload ["input" ] = input
113+ else :
114+ for item in input :
115+ if is_document (item ):
116+ payload ["input" ].append ({"text" : item })
117+ elif is_image (item ):
118+ try :
119+ pil_image = self ._PILImage .fromarray (item )
120+
121+ buffer = io .BytesIO ()
122+ pil_image .save (buffer , format = "PNG" )
123+ img_bytes = buffer .getvalue ()
124+
125+ # Encode bytes to base64 string
126+ base64_string = base64 .b64encode (img_bytes ).decode ("utf-8" )
127+
128+ except Exception as e :
129+ raise ValueError (
130+ f"Failed to convert image numpy array to base64 data URI: { e } "
131+ ) from e
132+ payload ["input" ].append ({"image" : base64_string })
106133
107134 if self .task is not None :
108135 payload ["task" ] = self .task
109-
110136 if self .late_chunking is not None :
111137 payload ["late_chunking" ] = self .late_chunking
112-
113138 if self .truncate is not None :
114139 payload ["truncate" ] = self .truncate
115-
116140 if self .dimensions is not None :
117141 payload ["dimensions" ] = self .dimensions
118-
119142 if self .embedding_type is not None :
120143 payload ["embedding_type" ] = self .embedding_type
121-
122144 if self .normalized is not None :
123145 payload ["normalized" ] = self .normalized
124146
125- # Call Jina AI Embedding API
126- resp = self ._session .post (self ._api_url , json = payload ).json ()
147+ if is_query and self .query_config is not None :
148+ for key , value in self .query_config .items ():
149+ payload [key ] = value
150+
151+ return payload
127152
153+ def _convert_resp (self , resp : Any , is_query : bool = False ) -> Embeddings :
154+ """
155+ Convert the response from the Jina AI API to a list of numpy arrays.
156+
157+ Args:
158+ resp (Any): The response from the Jina AI API.
159+
160+ Returns:
161+ Embeddings: A list of numpy arrays representing the embeddings.
162+ """
128163 if "data" not in resp :
129164 raise RuntimeError (resp .get ("detail" , "Unknown error" ))
130165
@@ -139,6 +174,36 @@ def __call__(self, input: Documents) -> Embeddings:
139174 for result in sorted_embeddings
140175 ]
141176
177+ def __call__ (self , input : Embeddable ) -> Embeddings :
178+ """
179+ Get the embeddings for a list of texts.
180+
181+ Args:
182+ input (Embeddable): A list of texts and/or images to get embeddings for.
183+
184+ Returns:
185+ Embeddings: The embeddings for the texts.
186+
187+ Example:
188+ >>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
189+ >>> input = ["Hello, world!", "How are you?"]
190+ """
191+
192+ payload = self ._build_payload (input , is_query = False )
193+
194+ # Call Jina AI Embedding API
195+ resp = self ._session .post (self ._api_url , json = payload , timeout = 60 ).json ()
196+
197+ return self ._convert_resp (resp )
198+
199+ def embed_query (self , input : Embeddable ) -> Embeddings :
200+ payload = self ._build_payload (input , is_query = True )
201+
202+ # Call Jina AI Embedding API
203+ resp = self ._session .post (self ._api_url , json = payload , timeout = 60 ).json ()
204+
205+ return self ._convert_resp (resp , is_query = True )
206+
142207 @staticmethod
143208 def name () -> str :
144209 return "jina"
@@ -150,7 +215,7 @@ def supported_spaces(self) -> List[Space]:
150215 return ["cosine" , "l2" , "ip" ]
151216
152217 @staticmethod
153- def build_from_config (config : Dict [str , Any ]) -> "EmbeddingFunction[Documents ]" :
218+ def build_from_config (config : Dict [str , Any ]) -> "EmbeddingFunction[Embeddable ]" :
154219 api_key_env_var = config .get ("api_key_env_var" )
155220 model_name = config .get ("model_name" )
156221 task = config .get ("task" )
@@ -159,6 +224,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
159224 dimensions = config .get ("dimensions" )
160225 embedding_type = config .get ("embedding_type" )
161226 normalized = config .get ("normalized" )
227+ query_config = config .get ("query_config" )
162228
163229 if api_key_env_var is None or model_name is None :
164230 assert False , "This code should not be reached" # this is for type checking
@@ -172,6 +238,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
172238 dimensions = dimensions ,
173239 embedding_type = embedding_type ,
174240 normalized = normalized ,
241+ query_config = query_config ,
175242 )
176243
177244 def get_config (self ) -> Dict [str , Any ]:
@@ -184,6 +251,7 @@ def get_config(self) -> Dict[str, Any]:
184251 "dimensions" : self .dimensions ,
185252 "embedding_type" : self .embedding_type ,
186253 "normalized" : self .normalized ,
254+ "query_config" : self .query_config ,
187255 }
188256
189257 def validate_config_update (
0 commit comments