11from typing import TypedDict , Dict , Any , Optional , cast , get_args
22import json
3+ import copy
34from chromadb .api .types import (
45 Space ,
56 CollectionMetadata ,
67 UpdateMetadata ,
78 EmbeddingFunction ,
9+ QueryConfig ,
810)
911from chromadb .utils .embedding_functions import (
1012 known_embedding_functions ,
@@ -41,6 +43,7 @@ class CollectionConfiguration(TypedDict, total=True):
4143 hnsw : Optional [HNSWConfiguration ]
4244 spann : Optional [SpannConfiguration ]
4345 embedding_function : Optional [EmbeddingFunction ] # type: ignore
46+ query_embedding_function : Optional [EmbeddingFunction ] # type: ignore
4447
4548
4649def load_collection_configuration_from_json_str (
@@ -64,6 +67,8 @@ def load_collection_configuration_from_json(
6467 spann_config = None
6568 ef_config = None
6669
70+ query_ef = None
71+
6772 # Process vector index configuration (HNSW or SPANN)
6873 if config_json_map .get ("hnsw" ) is not None :
6974 hnsw_config = cast (HNSWConfiguration , config_json_map ["hnsw" ])
@@ -100,13 +105,27 @@ def load_collection_configuration_from_json(
100105 f"Could not build embedding function { ef_config ['name' ]} from config { ef_config ['config' ]} : { e } "
101106 )
102107
108+ if config_json_map .get ("query_config" ) is not None :
109+ query_config = config_json_map ["query_config" ]
110+ query_ef_config = copy .deepcopy (ef_config )
111+ query_ef = known_embedding_functions [ef_name ]
112+ for k , v in query_config .items ():
113+ query_ef_config ["config" ][k ] = v
114+
115+ try :
116+ query_ef = query_ef .build_from_config (query_ef_config ["config" ]) # type: ignore
117+ except Exception as e :
118+ raise ValueError (
119+ f"Could not build query embedding function { query_ef_config ['name' ]} from config { query_ef_config ['config' ]} : { e } "
120+ )
103121 else :
104122 ef = None
105123
106124 return CollectionConfiguration (
107125 hnsw = hnsw_config ,
108126 spann = spann_config ,
109127 embedding_function = ef , # type: ignore
128+ query_embedding_function = query_ef , # type: ignore
110129 )
111130
112131
@@ -119,6 +138,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
119138 hnsw_config = config .get ("hnsw" )
120139 spann_config = config .get ("spann" )
121140 ef = config .get ("embedding_function" )
141+ query_ef = config .get ("query_embedding_function" )
122142 else :
123143 try :
124144 hnsw_config = config .get_parameter ("hnsw" ).value
@@ -148,11 +168,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148168 if ef is None :
149169 ef = None
150170 ef_config = {"type" : "legacy" }
151- return {
152- "hnsw" : hnsw_config ,
153- "spann" : spann_config ,
154- "embedding_function" : ef_config ,
155- }
156171
157172 if ef is not None :
158173 try :
@@ -174,10 +189,28 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
174189 ef = None
175190 ef_config = {"type" : "legacy" }
176191
192+ query_ef_config : Dict [str , Any ] | None = None
193+ if query_ef is not None :
194+ try :
195+ query_ef_config = {
196+ "name" : query_ef .name (),
197+ "type" : "known" ,
198+ "config" : query_ef .get_config (),
199+ }
200+ except Exception as e :
201+ warnings .warn (
202+ f"legacy query embedding function config: { e } " ,
203+ DeprecationWarning ,
204+ stacklevel = 2 ,
205+ )
206+ query_ef = None
207+ query_ef_config = {"type" : "legacy" }
208+
177209 return {
178210 "hnsw" : hnsw_config ,
179211 "spann" : spann_config ,
180212 "embedding_function" : ef_config ,
213+ "query_embedding_function" : query_ef_config ,
181214 }
182215
183216
@@ -258,16 +291,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
258291 hnsw : Optional [CreateHNSWConfiguration ]
259292 spann : Optional [CreateSpannConfiguration ]
260293 embedding_function : Optional [EmbeddingFunction ] # type: ignore
261-
262-
263- def load_collection_configuration_from_create_collection_configuration (
264- config : CreateCollectionConfiguration ,
265- ) -> CollectionConfiguration :
266- return CollectionConfiguration (
267- hnsw = config .get ("hnsw" ),
268- spann = config .get ("spann" ),
269- embedding_function = config .get ("embedding_function" ),
270- )
294+ query_config : Optional [QueryConfig ]
271295
272296
273297def create_collection_configuration_from_legacy_collection_metadata (
@@ -301,13 +325,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301325 return CreateCollectionConfiguration (hnsw = hnsw_config )
302326
303327
304- def load_create_collection_configuration_from_json_str (
305- json_str : str ,
306- ) -> CreateCollectionConfiguration :
307- json_map = json .loads (json_str )
308- return load_create_collection_configuration_from_json (json_map )
309-
310-
311328# TODO: make warnings prettier and add link to migration docs
312329def load_create_collection_configuration_from_json (
313330 json_map : Dict [str , Any ]
@@ -353,6 +370,7 @@ def create_collection_configuration_to_json(
353370) -> Dict [str , Any ]:
354371 """Convert a CreateCollection configuration to a JSON-serializable dict"""
355372 ef_config : Dict [str , Any ] | None = None
373+ query_config : Dict [str , Any ] | None = None
356374 hnsw_config = config .get ("hnsw" )
357375 spann_config = config .get ("spann" )
358376 if hnsw_config is not None :
@@ -389,6 +407,15 @@ def create_collection_configuration_to_json(
389407 "config" : ef .get_config (),
390408 }
391409 register_embedding_function (type (ef )) # type: ignore
410+
411+ q = config .get ("query_config" )
412+ if q is not None :
413+ if q .name () == ef .name ():
414+ query_config = q .get_config ()
415+ else :
416+ raise ValueError (
417+ f"query config name { q .name ()} does not match embedding function name { ef .name ()} "
418+ )
392419 except Exception as e :
393420 warnings .warn (
394421 f"legacy embedding function config: { e } " ,
@@ -402,6 +429,7 @@ def create_collection_configuration_to_json(
402429 "hnsw" : hnsw_config ,
403430 "spann" : spann_config ,
404431 "embedding_function" : ef_config ,
432+ "query_config" : query_config ,
405433 }
406434
407435
@@ -473,6 +501,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
473501 hnsw : Optional [UpdateHNSWConfiguration ]
474502 spann : Optional [UpdateSpannConfiguration ]
475503 embedding_function : Optional [EmbeddingFunction ] # type: ignore
504+ query_config : Optional [QueryConfig ]
476505
477506
478507def update_collection_configuration_from_legacy_collection_metadata (
@@ -528,7 +557,9 @@ def update_collection_configuration_to_json(
528557 hnsw_config = config .get ("hnsw" )
529558 spann_config = config .get ("spann" )
530559 ef = config .get ("embedding_function" )
531- if hnsw_config is None and spann_config is None and ef is None :
560+ q = config .get ("query_config" )
561+ query_config : Dict [str , Any ] | None = None
562+ if hnsw_config is None and spann_config is None and ef is None and q is None :
532563 return {}
533564
534565 if hnsw_config is not None :
@@ -555,13 +586,21 @@ def update_collection_configuration_to_json(
555586 "config" : ef .get_config (),
556587 }
557588 register_embedding_function (type (ef )) # type: ignore
589+ if q is not None :
590+ if q .name () == ef .name ():
591+ query_config = q .get_config ()
592+ else :
593+ raise ValueError (
594+ f"query config name { q .name ()} does not match embedding function name { ef .name ()} "
595+ )
558596 else :
559597 ef_config = None
560598
561599 return {
562600 "hnsw" : hnsw_config ,
563601 "spann" : spann_config ,
564602 "embedding_function" : ef_config ,
603+ "query_config" : query_config ,
565604 }
566605
567606
@@ -710,10 +749,26 @@ def overwrite_collection_configuration(
710749 else :
711750 updated_embedding_function = update_ef
712751
752+ query_ef = None
753+ if updated_embedding_function is not None :
754+ q = update_config .get ("query_config" )
755+ if q is not None :
756+ if q .name () != updated_embedding_function .name ():
757+ raise ValueError (
758+ f"query config name { q .name ()} does not match embedding function name { updated_embedding_function .name ()} "
759+ )
760+ else :
761+ ef_config = copy .deepcopy (updated_embedding_function .get_config ())
762+ query_config = q .get_config ()
763+ for k , v in query_config .items ():
764+ ef_config [k ] = v
765+ query_ef = updated_embedding_function .build_from_config (ef_config )
766+
713767 return CollectionConfiguration (
714768 hnsw = updated_hnsw_config ,
715769 spann = updated_spann_config ,
716770 embedding_function = updated_embedding_function ,
771+ query_embedding_function = query_ef ,
717772 )
718773
719774
0 commit comments