Skip to content

Commit 1907b4e

Browse files
committed
[ENH] Add schema support to collection configuration
1 parent 77473be commit 1907b4e

File tree

16 files changed

+345
-40
lines changed

16 files changed

+345
-40
lines changed

chromadb/api/collection_configuration.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
UpdateMetadata,
77
EmbeddingFunction,
88
)
9+
from chromadb.base_types import CollectionSchema, ValueType
910
from chromadb.utils.embedding_functions import (
1011
known_embedding_functions,
1112
register_embedding_function,
@@ -41,6 +42,7 @@ class CollectionConfiguration(TypedDict, total=True):
4142
hnsw: Optional[HNSWConfiguration]
4243
spann: Optional[SpannConfiguration]
4344
embedding_function: Optional[EmbeddingFunction] # type: ignore
45+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
4446

4547

4648
def load_collection_configuration_from_json_str(
@@ -106,6 +108,7 @@ def load_collection_configuration_from_json(
106108
hnsw=hnsw_config,
107109
spann=spann_config,
108110
embedding_function=ef, # type: ignore
111+
schema=config_json_map.get("schema"),
109112
)
110113

111114

@@ -118,6 +121,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
118121
hnsw_config = config.get("hnsw")
119122
spann_config = config.get("spann")
120123
ef = config.get("embedding_function")
124+
schema = config.get("schema")
121125
else:
122126
try:
123127
hnsw_config = config.get_parameter("hnsw").value
@@ -172,6 +176,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
172176
"hnsw": hnsw_config,
173177
"spann": spann_config,
174178
"embedding_function": ef_config,
179+
"schema": schema,
175180
}
176181

177182

@@ -252,6 +257,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
252257
hnsw: Optional[CreateHNSWConfiguration]
253258
spann: Optional[CreateSpannConfiguration]
254259
embedding_function: Optional[EmbeddingFunction] # type: ignore
260+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
255261

256262

257263
def create_collection_configuration_from_legacy_collection_metadata(
@@ -379,6 +385,7 @@ def create_collection_configuration_to_json(
379385
"hnsw": hnsw_config,
380386
"spann": spann_config,
381387
"embedding_function": ef_config,
388+
"schema": config.get("schema"),
382389
}
383390

384391

@@ -450,6 +457,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
450457
hnsw: Optional[UpdateHNSWConfiguration]
451458
spann: Optional[UpdateSpannConfiguration]
452459
embedding_function: Optional[EmbeddingFunction] # type: ignore
460+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
453461

454462

455463
def update_collection_configuration_from_legacy_collection_metadata(
@@ -504,8 +512,14 @@ def update_collection_configuration_to_json(
504512
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
505513
hnsw_config = config.get("hnsw")
506514
spann_config = config.get("spann")
515+
schema = config.get("schema")
507516
ef = config.get("embedding_function")
508-
if hnsw_config is None and spann_config is None and ef is None:
517+
if (
518+
hnsw_config is None
519+
and spann_config is None
520+
and ef is None
521+
and schema is None
522+
):
509523
return {}
510524

511525
if hnsw_config is not None:
@@ -539,6 +553,7 @@ def update_collection_configuration_to_json(
539553
"hnsw": hnsw_config,
540554
"spann": spann_config,
541555
"embedding_function": ef_config,
556+
"schema": schema,
542557
}
543558

544559

@@ -687,13 +702,40 @@ def overwrite_collection_configuration(
687702
else:
688703
updated_embedding_function = update_ef
689704

705+
706+
existing_schema = existing_config.get("schema")
707+
new_diff_schema = update_config.get("schema")
708+
updated_schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None
709+
if existing_schema is not None:
710+
if new_diff_schema is not None:
711+
updated_schema = overwrite_schema(existing_schema, new_diff_schema)
712+
else:
713+
updated_schema = existing_schema
714+
else:
715+
updated_schema = new_diff_schema
716+
690717
return CollectionConfiguration(
691718
hnsw=updated_hnsw_config,
692719
spann=updated_spann_config,
693720
embedding_function=updated_embedding_function,
721+
schema=updated_schema,
694722
)
695723

696724

725+
def overwrite_schema(
726+
existing_schema: Dict[str, Dict[ValueType, CollectionSchema]],
727+
new_diff_schema: Dict[str, Dict[ValueType, CollectionSchema]],
728+
) -> Dict[str, Dict[ValueType, CollectionSchema]]:
729+
"""Overwrite a schema with a new configuration"""
730+
for new_key, new_value in new_diff_schema.items():
731+
if new_key in existing_schema:
732+
for value_type, new_schema in new_value.items():
733+
existing_schema[new_key][value_type] = new_schema
734+
else:
735+
existing_schema[new_key] = new_value
736+
return existing_schema
737+
738+
697739
def validate_embedding_function_conflict_on_create(
698740
embedding_function: Optional[EmbeddingFunction], # type: ignore
699741
configuration_ef: Optional[EmbeddingFunction], # type: ignore

chromadb/api/models/AsyncCollection.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,22 @@ async def add(
6060
ValueError: If you provide an id that already exists
6161
6262
"""
63-
add_request = self._validate_and_prepare_add_request(
63+
64+
curr_schema = self._model.get_configuration().get("schema")
65+
66+
add_request, new_attributes = self._validate_and_prepare_add_request(
6467
ids=ids,
6568
embeddings=embeddings,
6669
metadatas=metadatas,
6770
documents=documents,
6871
images=images,
6972
uris=uris,
73+
schema=curr_schema,
7074
)
7175

76+
if len(new_attributes.keys()) > 0:
77+
await self.modify(configuration={"schema": new_attributes})
78+
7279
await self._client._add(
7380
collection_id=self.id,
7481
ids=add_request["ids"],
@@ -313,15 +320,20 @@ async def update(
313320
Returns:
314321
None
315322
"""
316-
update_request = self._validate_and_prepare_update_request(
323+
curr_schema = self._model.get_configuration().get("schema")
324+
update_request, new_attributes = self._validate_and_prepare_update_request(
317325
ids=ids,
318326
embeddings=embeddings,
319327
metadatas=metadatas,
320328
documents=documents,
321329
images=images,
322330
uris=uris,
331+
schema=curr_schema,
323332
)
324333

334+
if len(new_attributes.keys()) > 0:
335+
await self.modify(configuration={"schema": new_attributes})
336+
325337
await self._client._update(
326338
collection_id=self.id,
327339
ids=update_request["ids"],
@@ -358,14 +370,18 @@ async def upsert(
358370
Returns:
359371
None
360372
"""
361-
upsert_request = self._validate_and_prepare_upsert_request(
373+
curr_schema = self._model.get_configuration().get("schema")
374+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
362375
ids=ids,
363376
embeddings=embeddings,
364377
metadatas=metadatas,
365378
documents=documents,
366379
images=images,
367380
uris=uris,
381+
schema=curr_schema,
368382
)
383+
if len(new_attributes.keys()) > 0:
384+
await self.modify(configuration={"schema": new_attributes})
369385

370386
await self._client._upsert(
371387
collection_id=self.id,

chromadb/api/models/Collection.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,20 @@ def add(
7777
7878
"""
7979

80-
add_request = self._validate_and_prepare_add_request(
80+
curr_schema = self._model.get_configuration().get("schema")
81+
add_request, new_attributes = self._validate_and_prepare_add_request(
8182
ids=ids,
8283
embeddings=embeddings,
8384
metadatas=metadatas,
8485
documents=documents,
8586
images=images,
8687
uris=uris,
88+
schema=curr_schema,
8789
)
8890

91+
if len(new_attributes.keys()) > 0:
92+
self.modify(configuration={"schema": new_attributes})
93+
8994
self._client._add(
9095
collection_id=self.id,
9196
ids=add_request["ids"],
@@ -255,6 +260,7 @@ def modify(
255260
# Note there is a race condition here where the metadata can be updated
256261
# but another thread sees the cached local metadata.
257262
# TODO: fixme
263+
258264
self._client._modify(
259265
id=self.id,
260266
new_name=name,
@@ -317,15 +323,20 @@ def update(
317323
Returns:
318324
None
319325
"""
320-
update_request = self._validate_and_prepare_update_request(
326+
curr_schema = self._model.get_configuration().get("schema")
327+
update_request, new_attributes = self._validate_and_prepare_update_request(
321328
ids=ids,
322329
embeddings=embeddings,
323330
metadatas=metadatas,
324331
documents=documents,
325332
images=images,
326333
uris=uris,
334+
schema=curr_schema,
327335
)
328336

337+
if len(new_attributes.keys()) > 0:
338+
self.modify(configuration={"schema": new_attributes})
339+
329340
self._client._update(
330341
collection_id=self.id,
331342
ids=update_request["ids"],
@@ -362,15 +373,20 @@ def upsert(
362373
Returns:
363374
None
364375
"""
365-
upsert_request = self._validate_and_prepare_upsert_request(
376+
curr_schema = self._model.get_configuration().get("schema")
377+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
366378
ids=ids,
367379
embeddings=embeddings,
368380
metadatas=metadatas,
369381
documents=documents,
370382
images=images,
371383
uris=uris,
384+
schema=curr_schema,
372385
)
373386

387+
if len(new_attributes.keys()) > 0:
388+
self.modify(configuration={"schema": new_attributes})
389+
374390
self._client._upsert(
375391
collection_id=self.id,
376392
ids=upsert_request["ids"],

0 commit comments

Comments
 (0)