Skip to content

Commit 5100775

Browse files
committed
[ENH] add query config on collection configuration
1 parent dcb9b49 commit 5100775

File tree

5 files changed

+134
-61
lines changed

5 files changed

+134
-61
lines changed

chromadb/api/collection_configuration.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def load_collection_configuration_from_json(
9999
raise ValueError(
100100
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101101
)
102-
103102
else:
104103
ef = None
105104

@@ -148,11 +147,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148147
if ef is None:
149148
ef = None
150149
ef_config = {"type": "legacy"}
151-
return {
152-
"hnsw": hnsw_config,
153-
"spann": spann_config,
154-
"embedding_function": ef_config,
155-
}
156150

157151
if ef is not None:
158152
try:
@@ -260,16 +254,6 @@ class CreateCollectionConfiguration(TypedDict, total=False):
260254
embedding_function: Optional[EmbeddingFunction] # type: ignore
261255

262256

263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
271-
272-
273257
def create_collection_configuration_from_legacy_collection_metadata(
274258
metadata: CollectionMetadata,
275259
) -> CreateCollectionConfiguration:
@@ -301,13 +285,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301285
return CreateCollectionConfiguration(hnsw=hnsw_config)
302286

303287

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311288
# TODO: make warnings prettier and add link to migration docs
312289
def load_create_collection_configuration_from_json(
313290
json_map: Dict[str, Any]

chromadb/api/models/CollectionCommon.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,27 +550,41 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562-
return self._embedding_function(input=input)
571+
if is_query:
572+
return self._embedding_function.embed_query(input=input)
573+
else:
574+
return self._embedding_function(input=input)
575+
563576
config_ef = self.configuration.get("embedding_function")
564577
if config_ef is not None:
565-
return config_ef(input=input)
578+
if is_query:
579+
return config_ef.embed_query(input=input)
580+
else:
581+
return config_ef(input=input)
566582
if self._embedding_function is None:
567583
raise ValueError(
568584
"You must provide an embedding function to compute embeddings."
569585
"https://docs.trychroma.com/guides/embeddings"
570586
)
571-
return self._embedding_function(input=input)
587+
if is_query:
588+
return self._embedding_function.embed_query(input=input)
589+
else:
590+
return self._embedding_function(input=input)

chromadb/api/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,13 @@ class EmbeddingFunction(Protocol[D]):
576576
def __call__(self, input: D) -> Embeddings:
577577
...
578578

579+
def embed_query(self, input: D) -> Embeddings:
580+
"""
581+
Get the embeddings for a query input.
582+
This method is optional, and if not implemented, the default behavior is to call __call__.
583+
"""
584+
return self.__call__(input)
585+
579586
def __init_subclass__(cls) -> None:
580587
super().__init_subclass__()
581588
# Raise an exception if __call__ is not defined since it is expected to be defined

chromadb/utils/embedding_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from chromadb.utils.embedding_functions.jina_embedding_function import (
3434
JinaEmbeddingFunction,
35+
JinaQueryConfig,
3536
)
3637
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
3738
VoyageAIEmbeddingFunction,
@@ -237,6 +238,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
237238
"OllamaEmbeddingFunction",
238239
"InstructorEmbeddingFunction",
239240
"JinaEmbeddingFunction",
241+
"JinaQueryConfig",
240242
"MistralEmbeddingFunction",
241243
"MorphEmbeddingFunction",
242244
"VoyageAIEmbeddingFunction",

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1-
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
1+
from chromadb.api.types import (
2+
Embeddings,
3+
EmbeddingFunction,
4+
Space,
5+
Embeddable,
6+
is_image,
7+
is_document,
8+
)
29
from chromadb.utils.embedding_functions.schemas import validate_config_schema
3-
from typing import List, Dict, Any, Union, Optional
10+
from typing import List, Dict, Any, Union, Optional, TypedDict
411
import os
512
import numpy as np
613
import warnings
14+
import importlib
15+
import base64
16+
import io
717

818

9-
class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
19+
class JinaQueryConfig(TypedDict):
20+
task: str
21+
22+
23+
class JinaEmbeddingFunction(EmbeddingFunction[Embeddable]):
1024
"""
1125
This class is used to get embeddings for a list of texts using the Jina AI API.
1226
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
@@ -23,6 +37,7 @@ def __init__(
2337
dimensions: Optional[int] = None,
2438
embedding_type: Optional[str] = None,
2539
normalized: Optional[bool] = None,
40+
query_config: Optional[JinaQueryConfig] = None,
2641
):
2742
"""
2843
Initialize the JinaEmbeddingFunction.
@@ -52,6 +67,12 @@ def __init__(
5267
raise ValueError(
5368
"The httpx python package is not installed. Please install it with `pip install httpx`"
5469
)
70+
try:
71+
self._PILImage = importlib.import_module("PIL.Image")
72+
except ImportError:
73+
raise ValueError(
74+
"The PIL python package is not installed. Please install it with `pip install pillow`"
75+
)
5576

5677
if api_key is not None:
5778
warnings.warn(
@@ -74,57 +95,71 @@ def __init__(
7495
self.dimensions = dimensions
7596
self.embedding_type = embedding_type
7697
self.normalized = normalized
98+
self.query_config = query_config
7799

78100
self._api_url = "https://api.jina.ai/v1/embeddings"
79101
self._session = httpx.Client()
80102
self._session.headers.update(
81103
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
82104
)
83105

84-
def __call__(self, input: Documents) -> Embeddings:
85-
"""
86-
Get the embeddings for a list of texts.
87-
88-
Args:
89-
input (Documents): A list of texts to get embeddings for.
90-
91-
Returns:
92-
Embeddings: The embeddings for the texts.
93-
94-
Example:
95-
>>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
96-
>>> input = ["Hello, world!", "How are you?"]
97-
"""
98-
# Jina AI only works with text documents
99-
if not all(isinstance(item, str) for item in input):
100-
raise ValueError("Jina AI only supports text documents, not images")
101-
106+
def _build_payload(self, input: Embeddable, is_query: bool) -> Dict[str, Any]:
102107
payload: Dict[str, Any] = {
103-
"input": input,
108+
"input": [],
104109
"model": self.model_name,
105110
}
111+
if all(is_document(item) for item in input):
112+
payload["input"] = input
113+
else:
114+
for item in input:
115+
if is_document(item):
116+
payload["input"].append({"text": item})
117+
elif is_image(item):
118+
try:
119+
pil_image = self._PILImage.fromarray(item)
120+
121+
buffer = io.BytesIO()
122+
pil_image.save(buffer, format="PNG")
123+
img_bytes = buffer.getvalue()
124+
125+
# Encode bytes to base64 string
126+
base64_string = base64.b64encode(img_bytes).decode("utf-8")
127+
128+
except Exception as e:
129+
raise ValueError(
130+
f"Failed to convert image numpy array to base64 data URI: {e}"
131+
) from e
132+
payload["input"].append({"image": base64_string})
106133

107134
if self.task is not None:
108135
payload["task"] = self.task
109-
110136
if self.late_chunking is not None:
111137
payload["late_chunking"] = self.late_chunking
112-
113138
if self.truncate is not None:
114139
payload["truncate"] = self.truncate
115-
116140
if self.dimensions is not None:
117141
payload["dimensions"] = self.dimensions
118-
119142
if self.embedding_type is not None:
120143
payload["embedding_type"] = self.embedding_type
121-
122144
if self.normalized is not None:
123145
payload["normalized"] = self.normalized
124146

125-
# Call Jina AI Embedding API
126-
resp = self._session.post(self._api_url, json=payload).json()
147+
if is_query and self.query_config is not None:
148+
for key, value in self.query_config.items():
149+
payload[key] = value
150+
151+
return payload
127152

153+
def _convert_resp(self, resp: Any, is_query: bool = False) -> Embeddings:
154+
"""
155+
Convert the response from the Jina AI API to a list of numpy arrays.
156+
157+
Args:
158+
resp (Any): The response from the Jina AI API.
159+
160+
Returns:
161+
Embeddings: A list of numpy arrays representing the embeddings.
162+
"""
128163
if "data" not in resp:
129164
raise RuntimeError(resp.get("detail", "Unknown error"))
130165

@@ -139,6 +174,36 @@ def __call__(self, input: Documents) -> Embeddings:
139174
for result in sorted_embeddings
140175
]
141176

177+
def __call__(self, input: Embeddable) -> Embeddings:
178+
"""
179+
Get the embeddings for a list of texts.
180+
181+
Args:
182+
input (Embeddable): A list of texts and/or images to get embeddings for.
183+
184+
Returns:
185+
Embeddings: The embeddings for the texts.
186+
187+
Example:
188+
>>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
189+
>>> input = ["Hello, world!", "How are you?"]
190+
"""
191+
192+
payload = self._build_payload(input, is_query=False)
193+
194+
# Call Jina AI Embedding API
195+
resp = self._session.post(self._api_url, json=payload, timeout=60).json()
196+
197+
return self._convert_resp(resp)
198+
199+
def embed_query(self, input: Embeddable) -> Embeddings:
200+
payload = self._build_payload(input, is_query=True)
201+
202+
# Call Jina AI Embedding API
203+
resp = self._session.post(self._api_url, json=payload, timeout=60).json()
204+
205+
return self._convert_resp(resp, is_query=True)
206+
142207
@staticmethod
143208
def name() -> str:
144209
return "jina"
@@ -150,7 +215,7 @@ def supported_spaces(self) -> List[Space]:
150215
return ["cosine", "l2", "ip"]
151216

152217
@staticmethod
153-
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
218+
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
154219
api_key_env_var = config.get("api_key_env_var")
155220
model_name = config.get("model_name")
156221
task = config.get("task")
@@ -159,6 +224,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
159224
dimensions = config.get("dimensions")
160225
embedding_type = config.get("embedding_type")
161226
normalized = config.get("normalized")
227+
query_config = config.get("query_config")
162228

163229
if api_key_env_var is None or model_name is None:
164230
assert False, "This code should not be reached" # this is for type checking
@@ -172,6 +238,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
172238
dimensions=dimensions,
173239
embedding_type=embedding_type,
174240
normalized=normalized,
241+
query_config=query_config,
175242
)
176243

177244
def get_config(self) -> Dict[str, Any]:
@@ -184,6 +251,7 @@ def get_config(self) -> Dict[str, Any]:
184251
"dimensions": self.dimensions,
185252
"embedding_type": self.embedding_type,
186253
"normalized": self.normalized,
254+
"query_config": self.query_config,
187255
}
188256

189257
def validate_config_update(

0 commit comments

Comments
 (0)