11from typing import TypedDict , Dict , Any , Optional , cast , get_args
22import json
3+ import copy
34from chromadb .api .types import (
45 Space ,
56 CollectionMetadata ,
67 UpdateMetadata ,
78 EmbeddingFunction ,
9+ QueryConfig ,
810)
911from chromadb .utils .embedding_functions import (
1012 known_embedding_functions ,
@@ -41,6 +43,7 @@ class CollectionConfiguration(TypedDict, total=True):
4143 hnsw : Optional [HNSWConfiguration ]
4244 spann : Optional [SpannConfiguration ]
4345 embedding_function : Optional [EmbeddingFunction ] # type: ignore
46+ query_embedding_function : Optional [EmbeddingFunction ] # type: ignore
4447
4548
4649def load_collection_configuration_from_json_str (
@@ -64,6 +67,8 @@ def load_collection_configuration_from_json(
6467 spann_config = None
6568 ef_config = None
6669
70+ query_ef = None
71+
6772 # Process vector index configuration (HNSW or SPANN)
6873 if config_json_map .get ("hnsw" ) is not None :
6974 hnsw_config = cast (HNSWConfiguration , config_json_map ["hnsw" ])
@@ -100,13 +105,27 @@ def load_collection_configuration_from_json(
100105 f"Could not build embedding function { ef_config ['name' ]} from config { ef_config ['config' ]} : { e } "
101106 )
102107
108+ if config_json_map .get ("query_config" ) is not None :
109+ query_config = config_json_map ["query_config" ]
110+ query_ef_config = copy .deepcopy (ef_config )
111+ query_ef = known_embedding_functions [ef_name ]
112+ for k , v in query_config .items ():
113+ query_ef_config ["config" ][k ] = v
114+
115+ try :
116+ query_ef = query_ef .build_from_config (query_ef_config ["config" ]) # type: ignore
117+ except Exception as e :
118+ raise ValueError (
119+ f"Could not build query embedding function { query_ef_config ['name' ]} from config { query_ef_config ['config' ]} : { e } "
120+ )
103121 else :
104122 ef = None
105123
106124 return CollectionConfiguration (
107125 hnsw = hnsw_config ,
108126 spann = spann_config ,
109127 embedding_function = ef , # type: ignore
128+ query_embedding_function = query_ef , # type: ignore
110129 )
111130
112131
@@ -258,16 +277,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
258277 hnsw : Optional [CreateHNSWConfiguration ]
259278 spann : Optional [CreateSpannConfiguration ]
260279 embedding_function : Optional [EmbeddingFunction ] # type: ignore
261-
262-
263- def load_collection_configuration_from_create_collection_configuration (
264- config : CreateCollectionConfiguration ,
265- ) -> CollectionConfiguration :
266- return CollectionConfiguration (
267- hnsw = config .get ("hnsw" ),
268- spann = config .get ("spann" ),
269- embedding_function = config .get ("embedding_function" ),
270- )
280+ query_config : Optional [QueryConfig ]
271281
272282
273283def create_collection_configuration_from_legacy_collection_metadata (
@@ -301,13 +311,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301311 return CreateCollectionConfiguration (hnsw = hnsw_config )
302312
303313
304- def load_create_collection_configuration_from_json_str (
305- json_str : str ,
306- ) -> CreateCollectionConfiguration :
307- json_map = json .loads (json_str )
308- return load_create_collection_configuration_from_json (json_map )
309-
310-
311314# TODO: make warnings prettier and add link to migration docs
312315def load_create_collection_configuration_from_json (
313316 json_map : Dict [str , Any ]
@@ -353,6 +356,7 @@ def create_collection_configuration_to_json(
353356) -> Dict [str , Any ]:
354357 """Convert a CreateCollection configuration to a JSON-serializable dict"""
355358 ef_config : Dict [str , Any ] | None = None
359+ query_config : Dict [str , Any ] | None = None
356360 hnsw_config = config .get ("hnsw" )
357361 spann_config = config .get ("spann" )
358362 if hnsw_config is not None :
@@ -389,6 +393,15 @@ def create_collection_configuration_to_json(
389393 "config" : ef .get_config (),
390394 }
391395 register_embedding_function (type (ef )) # type: ignore
396+
397+ q = config .get ("query_config" )
398+ if q is not None :
399+ if q .name () == ef .name ():
400+ query_config = q .get_config ()
401+ else :
402+ raise ValueError (
403+ f"query config name { q .name ()} does not match embedding function name { ef .name ()} "
404+ )
392405 except Exception as e :
393406 warnings .warn (
394407 f"legacy embedding function config: { e } " ,
@@ -402,6 +415,7 @@ def create_collection_configuration_to_json(
402415 "hnsw" : hnsw_config ,
403416 "spann" : spann_config ,
404417 "embedding_function" : ef_config ,
418+ "query_config" : query_config ,
405419 }
406420
407421
@@ -473,6 +487,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
473487 hnsw : Optional [UpdateHNSWConfiguration ]
474488 spann : Optional [UpdateSpannConfiguration ]
475489 embedding_function : Optional [EmbeddingFunction ] # type: ignore
490+ query_config : Optional [QueryConfig ]
476491
477492
478493def update_collection_configuration_from_legacy_collection_metadata (
@@ -528,6 +543,7 @@ def update_collection_configuration_to_json(
528543 hnsw_config = config .get ("hnsw" )
529544 spann_config = config .get ("spann" )
530545 ef = config .get ("embedding_function" )
546+ query_config : Dict [str , Any ] | None = None
531547 if hnsw_config is None and spann_config is None and ef is None :
532548 return {}
533549
@@ -555,13 +571,22 @@ def update_collection_configuration_to_json(
555571 "config" : ef .get_config (),
556572 }
557573 register_embedding_function (type (ef )) # type: ignore
574+ q = config .get ("query_config" )
575+ if q is not None :
576+ if q .name () == ef .name ():
577+ query_config = q .get_config ()
578+ else :
579+ raise ValueError (
580+ f"query config name { q .name ()} does not match embedding function name { ef .name ()} "
581+ )
558582 else :
559583 ef_config = None
560584
561585 return {
562586 "hnsw" : hnsw_config ,
563587 "spann" : spann_config ,
564588 "embedding_function" : ef_config ,
589+ "query_config" : query_config ,
565590 }
566591
567592
@@ -710,10 +735,26 @@ def overwrite_collection_configuration(
710735 else :
711736 updated_embedding_function = update_ef
712737
738+ query_ef = None
739+ if updated_embedding_function is not None :
740+ q = update_config .get ("query_config" )
741+ if q is not None :
742+ if q .name () != updated_embedding_function .name ():
743+ raise ValueError (
744+ f"query config name { q .name ()} does not match embedding function name { updated_embedding_function .name ()} "
745+ )
746+ else :
747+ ef_config = copy .deepcopy (updated_embedding_function .get_config ())
748+ query_config = q .get_config ()
749+ for k , v in query_config .items ():
750+ ef_config [k ] = v
751+ query_ef = updated_embedding_function .build_from_config (ef_config )
752+
713753 return CollectionConfiguration (
714754 hnsw = updated_hnsw_config ,
715755 spann = updated_spann_config ,
716756 embedding_function = updated_embedding_function ,
757+ query_embedding_function = query_ef ,
717758 )
718759
719760
0 commit comments