Skip to content

Commit 69696fd

Browse files
committed
[ENH] add query config on collection configuration
1 parent c4d0c93 commit 69696fd

File tree

11 files changed

+140
-23
lines changed

11 files changed

+140
-23
lines changed

chromadb/api/collection_configuration.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import TypedDict, Dict, Any, Optional, cast, get_args
22
import json
3+
import copy
34
from chromadb.api.types import (
45
Space,
56
CollectionMetadata,
67
UpdateMetadata,
78
EmbeddingFunction,
9+
QueryConfig,
810
)
911
from chromadb.utils.embedding_functions import (
1012
known_embedding_functions,
@@ -41,6 +43,7 @@ class CollectionConfiguration(TypedDict, total=True):
4143
hnsw: Optional[HNSWConfiguration]
4244
spann: Optional[SpannConfiguration]
4345
embedding_function: Optional[EmbeddingFunction] # type: ignore
46+
query_embedding_function: Optional[EmbeddingFunction] # type: ignore
4447

4548

4649
def load_collection_configuration_from_json_str(
@@ -64,6 +67,8 @@ def load_collection_configuration_from_json(
6467
spann_config = None
6568
ef_config = None
6669

70+
query_ef = None
71+
6772
# Process vector index configuration (HNSW or SPANN)
6873
if config_json_map.get("hnsw") is not None:
6974
hnsw_config = cast(HNSWConfiguration, config_json_map["hnsw"])
@@ -100,13 +105,27 @@ def load_collection_configuration_from_json(
100105
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101106
)
102107

108+
if config_json_map.get("query_config") is not None:
109+
query_config = config_json_map["query_config"]
110+
query_ef_config = copy.deepcopy(ef_config)
111+
query_ef = known_embedding_functions[ef_name]
112+
for k, v in query_config.items():
113+
query_ef_config["config"][k] = v
114+
115+
try:
116+
query_ef = query_ef.build_from_config(query_ef_config["config"]) # type: ignore
117+
except Exception as e:
118+
raise ValueError(
119+
f"Could not build query embedding function {query_ef_config['name']} from config {query_ef_config['config']}: {e}"
120+
)
103121
else:
104122
ef = None
105123

106124
return CollectionConfiguration(
107125
hnsw=hnsw_config,
108126
spann=spann_config,
109127
embedding_function=ef, # type: ignore
128+
query_embedding_function=query_ef, # type: ignore
110129
)
111130

112131

@@ -258,16 +277,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
258277
hnsw: Optional[CreateHNSWConfiguration]
259278
spann: Optional[CreateSpannConfiguration]
260279
embedding_function: Optional[EmbeddingFunction] # type: ignore
261-
262-
263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
280+
query_config: Optional[QueryConfig]
271281

272282

273283
def create_collection_configuration_from_legacy_collection_metadata(
@@ -301,13 +311,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301311
return CreateCollectionConfiguration(hnsw=hnsw_config)
302312

303313

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311314
# TODO: make warnings prettier and add link to migration docs
312315
def load_create_collection_configuration_from_json(
313316
json_map: Dict[str, Any]
@@ -353,6 +356,7 @@ def create_collection_configuration_to_json(
353356
) -> Dict[str, Any]:
354357
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
355358
ef_config: Dict[str, Any] | None = None
359+
query_config: Dict[str, Any] | None = None
356360
hnsw_config = config.get("hnsw")
357361
spann_config = config.get("spann")
358362
if hnsw_config is not None:
@@ -389,6 +393,15 @@ def create_collection_configuration_to_json(
389393
"config": ef.get_config(),
390394
}
391395
register_embedding_function(type(ef)) # type: ignore
396+
397+
q = config.get("query_config")
398+
if q is not None:
399+
if q.name() == ef.name():
400+
query_config = q.get_config()
401+
else:
402+
raise ValueError(
403+
f"query config name {q.name()} does not match embedding function name {ef.name()}"
404+
)
392405
except Exception as e:
393406
warnings.warn(
394407
f"legacy embedding function config: {e}",
@@ -402,6 +415,7 @@ def create_collection_configuration_to_json(
402415
"hnsw": hnsw_config,
403416
"spann": spann_config,
404417
"embedding_function": ef_config,
418+
"query_config": query_config,
405419
}
406420

407421

@@ -473,6 +487,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
473487
hnsw: Optional[UpdateHNSWConfiguration]
474488
spann: Optional[UpdateSpannConfiguration]
475489
embedding_function: Optional[EmbeddingFunction] # type: ignore
490+
query_config: Optional[QueryConfig]
476491

477492

478493
def update_collection_configuration_from_legacy_collection_metadata(
@@ -528,6 +543,7 @@ def update_collection_configuration_to_json(
528543
hnsw_config = config.get("hnsw")
529544
spann_config = config.get("spann")
530545
ef = config.get("embedding_function")
546+
query_config: Dict[str, Any] | None = None
531547
if hnsw_config is None and spann_config is None and ef is None:
532548
return {}
533549

@@ -555,13 +571,22 @@ def update_collection_configuration_to_json(
555571
"config": ef.get_config(),
556572
}
557573
register_embedding_function(type(ef)) # type: ignore
574+
q = config.get("query_config")
575+
if q is not None:
576+
if q.name() == ef.name():
577+
query_config = q.get_config()
578+
else:
579+
raise ValueError(
580+
f"query config name {q.name()} does not match embedding function name {ef.name()}"
581+
)
558582
else:
559583
ef_config = None
560584

561585
return {
562586
"hnsw": hnsw_config,
563587
"spann": spann_config,
564588
"embedding_function": ef_config,
589+
"query_config": query_config,
565590
}
566591

567592

@@ -710,10 +735,26 @@ def overwrite_collection_configuration(
710735
else:
711736
updated_embedding_function = update_ef
712737

738+
query_ef = None
739+
if updated_embedding_function is not None:
740+
q = update_config.get("query_config")
741+
if q is not None:
742+
if q.name() != updated_embedding_function.name():
743+
raise ValueError(
744+
f"query config name {q.name()} does not match embedding function name {updated_embedding_function.name()}"
745+
)
746+
else:
747+
ef_config = copy.deepcopy(updated_embedding_function.get_config())
748+
query_config = q.get_config()
749+
for k, v in query_config.items():
750+
ef_config[k] = v
751+
query_ef = updated_embedding_function.build_from_config(ef_config)
752+
713753
return CollectionConfiguration(
714754
hnsw=updated_hnsw_config,
715755
spann=updated_spann_config,
716756
embedding_function=updated_embedding_function,
757+
query_embedding_function=query_ef,
717758
)
718759

719760

chromadb/api/models/CollectionCommon.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,21 +550,30 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562571
return self._embedding_function(input=input)
572+
if is_query:
573+
config_ef = self.configuration.get("query_embedding_function")
574+
if config_ef is not None:
575+
return config_ef(input=input)
576+
563577
config_ef = self.configuration.get("embedding_function")
564578
if config_ef is not None:
565579
return config_ef(input=input)

chromadb/api/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,16 @@ def is_legacy(self) -> bool:
672672
return False
673673

674674

675+
class QueryConfig:
676+
@abstractmethod
677+
def name(self) -> str:
678+
return NotImplemented
679+
680+
@abstractmethod
681+
def get_config(self) -> Dict[str, Any]:
682+
return NotImplemented
683+
684+
675685
def validate_embedding_function(
676686
embedding_function: EmbeddingFunction[Embeddable],
677687
) -> None:

chromadb/utils/embedding_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from chromadb.utils.embedding_functions.jina_embedding_function import (
3434
JinaEmbeddingFunction,
35+
JinaQueryConfig,
3536
)
3637
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
3738
VoyageAIEmbeddingFunction,
@@ -232,6 +233,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
232233
"OllamaEmbeddingFunction",
233234
"InstructorEmbeddingFunction",
234235
"JinaEmbeddingFunction",
236+
"JinaQueryConfig",
235237
"MistralEmbeddingFunction",
236238
"VoyageAIEmbeddingFunction",
237239
"ONNXMiniLM_L6_V2",

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
1+
from chromadb.api.types import (
2+
Embeddings,
3+
Documents,
4+
EmbeddingFunction,
5+
Space,
6+
QueryConfig,
7+
)
28
from chromadb.utils.embedding_functions.schemas import validate_config_schema
39
from typing import List, Dict, Any, Union, Optional
10+
from typing_extensions import override
411
import os
512
import numpy as np
613
import warnings
@@ -206,3 +213,17 @@ def validate_config(config: Dict[str, Any]) -> None:
206213
ValidationError: If the configuration does not match the schema
207214
"""
208215
validate_config_schema(config, "jina")
216+
217+
218+
class JinaQueryConfig(QueryConfig):
219+
def __init__(self, task: Optional[str] = None):
220+
self.task = task
221+
222+
@override
223+
def name(self) -> str:
224+
return "jina"
225+
226+
def get_config(self) -> Dict[str, Any]:
227+
return {
228+
"task": self.task,
229+
}

go/pkg/sysdb/coordinator/model/collection_configuration.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type SpannConfiguration struct {
5656
type InternalCollectionConfiguration struct {
5757
VectorIndex *VectorIndexConfiguration `json:"vector_index"`
5858
EmbeddingFunction *EmbeddingFunctionConfiguration `json:"embedding_function,omitempty"`
59+
QueryConfig interface{} `json:"query_config,omitempty"`
5960
}
6061

6162
// DefaultHnswCollectionConfiguration returns a default configuration using HNSW
@@ -127,4 +128,5 @@ type UpdateVectorIndexConfiguration struct {
127128
type InternalUpdateCollectionConfiguration struct {
128129
VectorIndex *UpdateVectorIndexConfiguration `json:"vector_index,omitempty"`
129130
EmbeddingFunction *EmbeddingFunctionConfiguration `json:"embedding_function,omitempty"`
131+
QueryConfig interface{} `json:"query_config,omitempty"`
130132
}

go/pkg/sysdb/coordinator/table_catalog.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,10 @@ func (tc *Catalog) updateCollectionConfiguration(
854854
existingConfig.EmbeddingFunction = updateConfig.EmbeddingFunction
855855
}
856856

857+
if updateConfig.QueryConfig != nil {
858+
existingConfig.QueryConfig = updateConfig.QueryConfig
859+
}
860+
857861
// Serialize updated config back to JSON
858862
updatedConfigBytes, err := json.Marshal(existingConfig)
859863
if err != nil {

rust/python_bindings/src/bindings.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ impl Bindings {
280280
hnsw: None,
281281
spann: None,
282282
embedding_function: None,
283+
query_config: None,
283284
},
284285
self.frontend.get_default_knn_index(),
285286
)?),

rust/segment/src/distributed_hnsw.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ pub mod test {
420420
},
421421
),
422422
embedding_function: None,
423+
query_config: None,
423424
},
424425
..Default::default()
425426
};

rust/segment/src/distributed_spann.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ mod test {
629629
config: chroma_types::InternalCollectionConfiguration {
630630
vector_index: chroma_types::VectorIndexConfiguration::Spann(params),
631631
embedding_function: None,
632+
query_config: None,
632633
},
633634
metadata: None,
634635
dimension: None,
@@ -845,6 +846,7 @@ mod test {
845846
config: InternalCollectionConfiguration {
846847
vector_index: chroma_types::VectorIndexConfiguration::Spann(params),
847848
embedding_function: None,
849+
query_config: None,
848850
},
849851
..Default::default()
850852
};

0 commit comments

Comments
 (0)