Skip to content

Commit 36a2766

Browse files
committed
[ENH] add query config on collection configuration
1 parent 9bda3dc commit 36a2766

File tree

7 files changed

+131
-80
lines changed

7 files changed

+131
-80
lines changed

chromadb/api/collection_configuration.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def load_collection_configuration_from_json(
9999
raise ValueError(
100100
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101101
)
102-
103102
else:
104103
ef = None
105104

@@ -148,11 +147,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148147
if ef is None:
149148
ef = None
150149
ef_config = {"type": "legacy"}
151-
return {
152-
"hnsw": hnsw_config,
153-
"spann": spann_config,
154-
"embedding_function": ef_config,
155-
}
156150

157151
if ef is not None:
158152
try:
@@ -260,16 +254,6 @@ class CreateCollectionConfiguration(TypedDict, total=False):
260254
embedding_function: Optional[EmbeddingFunction] # type: ignore
261255

262256

263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
271-
272-
273257
def create_collection_configuration_from_legacy_collection_metadata(
274258
metadata: CollectionMetadata,
275259
) -> CreateCollectionConfiguration:
@@ -301,13 +285,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301285
return CreateCollectionConfiguration(hnsw=hnsw_config)
302286

303287

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311288
# TODO: make warnings prettier and add link to migration docs
312289
def load_create_collection_configuration_from_json(
313290
json_map: Dict[str, Any]

chromadb/api/models/CollectionCommon.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,27 +550,41 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562-
return self._embedding_function(input=input)
571+
if is_query:
572+
return self._embedding_function.embed_query(input=input)
573+
else:
574+
return self._embedding_function(input=input)
575+
563576
config_ef = self.configuration.get("embedding_function")
564577
if config_ef is not None:
565-
return config_ef(input=input)
578+
if is_query:
579+
return config_ef.embed_query(input=input)
580+
else:
581+
return config_ef(input=input)
566582
if self._embedding_function is None:
567583
raise ValueError(
568584
"You must provide an embedding function to compute embeddings."
569585
"https://docs.trychroma.com/guides/embeddings"
570586
)
571-
return self._embedding_function(input=input)
587+
if is_query:
588+
return self._embedding_function.embed_query(input=input)
589+
else:
590+
return self._embedding_function(input=input)

chromadb/api/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,13 @@ class EmbeddingFunction(Protocol[D]):
545545
def __call__(self, input: D) -> Embeddings:
546546
...
547547

548+
def embed_query(self, input: D) -> Embeddings:
549+
"""
550+
Get the embeddings for a query input.
551+
This method is optional, and if not implemented, the default behavior is to call __call__.
552+
"""
553+
return self.__call__(input)
554+
548555
def __init_subclass__(cls) -> None:
549556
super().__init_subclass__()
550557
# Raise an exception if __call__ is not defined since it is expected to be defined

chromadb/utils/embedding_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from chromadb.utils.embedding_functions.jina_embedding_function import (
3434
JinaEmbeddingFunction,
35+
JinaQueryConfig,
3536
)
3637
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
3738
VoyageAIEmbeddingFunction,
@@ -232,6 +233,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
232233
"OllamaEmbeddingFunction",
233234
"InstructorEmbeddingFunction",
234235
"JinaEmbeddingFunction",
236+
"JinaQueryConfig",
235237
"MistralEmbeddingFunction",
236238
"VoyageAIEmbeddingFunction",
237239
"ONNXMiniLM_L6_V2",

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1-
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
1+
from chromadb.api.types import (
2+
Embeddings,
3+
Documents,
4+
EmbeddingFunction,
5+
Space,
6+
)
27
from chromadb.utils.embedding_functions.schemas import validate_config_schema
3-
from typing import List, Dict, Any, Union, Optional
8+
from typing import List, Dict, Any, Union, Optional, TypedDict
49
import os
510
import numpy as np
611
import warnings
712

813

14+
class JinaQueryConfig(TypedDict):
15+
task: str
16+
17+
918
class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
1019
"""
1120
This class is used to get embeddings for a list of texts using the Jina AI API.
@@ -23,6 +32,7 @@ def __init__(
2332
dimensions: Optional[int] = None,
2433
embedding_type: Optional[str] = None,
2534
normalized: Optional[bool] = None,
35+
query_config: Optional[JinaQueryConfig] = None,
2636
):
2737
"""
2838
Initialize the JinaEmbeddingFunction.
@@ -74,57 +84,49 @@ def __init__(
7484
self.dimensions = dimensions
7585
self.embedding_type = embedding_type
7686
self.normalized = normalized
87+
self.query_config = query_config
7788

7889
self._api_url = "https://api.jina.ai/v1/embeddings"
7990
self._session = httpx.Client()
8091
self._session.headers.update(
8192
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
8293
)
8394

84-
def __call__(self, input: Documents) -> Embeddings:
85-
"""
86-
Get the embeddings for a list of texts.
87-
88-
Args:
89-
input (Documents): A list of texts to get embeddings for.
90-
91-
Returns:
92-
Embeddings: The embeddings for the texts.
93-
94-
Example:
95-
>>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
96-
>>> input = ["Hello, world!", "How are you?"]
97-
"""
98-
# Jina AI only works with text documents
99-
if not all(isinstance(item, str) for item in input):
100-
raise ValueError("Jina AI only supports text documents, not images")
101-
95+
def _build_payload(self, input: Documents, is_query: bool) -> Dict[str, Any]:
10296
payload: Dict[str, Any] = {
10397
"input": input,
10498
"model": self.model_name,
10599
}
106100

107101
if self.task is not None:
108102
payload["task"] = self.task
109-
110103
if self.late_chunking is not None:
111104
payload["late_chunking"] = self.late_chunking
112-
113105
if self.truncate is not None:
114106
payload["truncate"] = self.truncate
115-
116107
if self.dimensions is not None:
117108
payload["dimensions"] = self.dimensions
118-
119109
if self.embedding_type is not None:
120110
payload["embedding_type"] = self.embedding_type
121-
122111
if self.normalized is not None:
123112
payload["normalized"] = self.normalized
124113

125-
# Call Jina AI Embedding API
126-
resp = self._session.post(self._api_url, json=payload).json()
114+
if is_query and self.query_config is not None:
115+
for key, value in self.query_config.items():
116+
payload[key] = value
117+
118+
return payload
119+
120+
def _convert_resp(self, resp: Any) -> Embeddings:
121+
"""
122+
Convert the response from the Jina AI API to a list of numpy arrays.
123+
124+
Args:
125+
resp (Any): The response from the Jina AI API.
127126
127+
Returns:
128+
Embeddings: A list of numpy arrays representing the embeddings.
129+
"""
128130
if "data" not in resp:
129131
raise RuntimeError(resp.get("detail", "Unknown error"))
130132

@@ -139,6 +141,43 @@ def __call__(self, input: Documents) -> Embeddings:
139141
for result in sorted_embeddings
140142
]
141143

144+
def __call__(self, input: Documents) -> Embeddings:
145+
"""
146+
Get the embeddings for a list of texts.
147+
148+
Args:
149+
input (Documents): A list of texts to get embeddings for.
150+
151+
Returns:
152+
Embeddings: The embeddings for the texts.
153+
154+
Example:
155+
>>> jina_ai_fn = JinaEmbeddingFunction(api_key_env_var="CHROMA_JINA_API_KEY")
156+
>>> input = ["Hello, world!", "How are you?"]
157+
"""
158+
# Jina AI only works with text documents
159+
if not all(isinstance(item, str) for item in input):
160+
raise ValueError("Jina AI only supports text documents, not images")
161+
162+
payload = self._build_payload(input, is_query=False)
163+
164+
# Call Jina AI Embedding API
165+
resp = self._session.post(self._api_url, json=payload).json()
166+
167+
return self._convert_resp(resp)
168+
169+
def embed_query(self, input: Documents) -> Embeddings:
170+
# Jina AI only works with text documents
171+
if not all(isinstance(item, str) for item in input):
172+
raise ValueError("Jina AI only supports text documents, not images")
173+
174+
payload = self._build_payload(input, is_query=True)
175+
176+
# Call Jina AI Embedding API
177+
resp = self._session.post(self._api_url, json=payload).json()
178+
179+
return self._convert_resp(resp)
180+
142181
@staticmethod
143182
def name() -> str:
144183
return "jina"
@@ -159,6 +198,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
159198
dimensions = config.get("dimensions")
160199
embedding_type = config.get("embedding_type")
161200
normalized = config.get("normalized")
201+
query_config = config.get("query_config")
162202

163203
if api_key_env_var is None or model_name is None:
164204
assert False, "This code should not be reached" # this is for type checking
@@ -172,6 +212,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
172212
dimensions=dimensions,
173213
embedding_type=embedding_type,
174214
normalized=normalized,
215+
query_config=query_config,
175216
)
176217

177218
def get_config(self) -> Dict[str, Any]:
@@ -184,6 +225,7 @@ def get_config(self) -> Dict[str, Any]:
184225
"dimensions": self.dimensions,
185226
"embedding_type": self.embedding_type,
186227
"normalized": self.normalized,
228+
"query_config": self.query_config,
187229
}
188230

189231
def validate_config_update(

0 commit comments

Comments
 (0)