|
1 | | -from typing import TypedDict, Dict, Any, Optional, cast, get_args |
| 1 | +from typing import TypedDict, Dict, Any, Optional, cast, get_args, Literal |
2 | 2 | import json |
3 | 3 | import copy |
4 | 4 | from chromadb.api.types import ( |
|
15 | 15 | from multiprocessing import cpu_count |
16 | 16 | import warnings |
17 | 17 |
|
| 18 | +ValueType = Literal["int", "float", "string", "boolean"] |
| 19 | + |
| 20 | + |
| 21 | +class CollectionSchema(TypedDict): |
| 22 | + value_type: ValueType |
| 23 | + metadata_index: bool |
| 24 | + |
18 | 25 |
|
19 | 26 | class HNSWConfiguration(TypedDict, total=False): |
20 | 27 | space: Space |
@@ -44,6 +51,7 @@ class CollectionConfiguration(TypedDict, total=True): |
44 | 51 | spann: Optional[SpannConfiguration] |
45 | 52 | embedding_function: Optional[EmbeddingFunction] # type: ignore |
46 | 53 | query_embedding_function: Optional[EmbeddingFunction] # type: ignore |
| 54 | + schema: Optional[Dict[str, CollectionSchema]] |
47 | 55 |
|
48 | 56 |
|
49 | 57 | def load_collection_configuration_from_json_str( |
@@ -126,6 +134,7 @@ def load_collection_configuration_from_json( |
126 | 134 | spann=spann_config, |
127 | 135 | embedding_function=ef, # type: ignore |
128 | 136 | query_embedding_function=query_ef, # type: ignore |
| 137 | + schema=config_json_map.get("schema"), |
129 | 138 | ) |
130 | 139 |
|
131 | 140 |
|
@@ -278,6 +287,7 @@ class CreateCollectionConfiguration(TypedDict, total=False): |
278 | 287 | spann: Optional[CreateSpannConfiguration] |
279 | 288 | embedding_function: Optional[EmbeddingFunction] # type: ignore |
280 | 289 | query_config: Optional[QueryConfig] |
| 290 | + schema: Optional[Dict[str, CollectionSchema]] |
281 | 291 |
|
282 | 292 |
|
283 | 293 | def create_collection_configuration_from_legacy_collection_metadata( |
@@ -416,6 +426,7 @@ def create_collection_configuration_to_json( |
416 | 426 | "spann": spann_config, |
417 | 427 | "embedding_function": ef_config, |
418 | 428 | "query_config": query_config, |
| 429 | + "schema": config.get("schema"), |
419 | 430 | } |
420 | 431 |
|
421 | 432 |
|
@@ -488,6 +499,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False): |
488 | 499 | spann: Optional[UpdateSpannConfiguration] |
489 | 500 | embedding_function: Optional[EmbeddingFunction] # type: ignore |
490 | 501 | query_config: Optional[QueryConfig] |
| 502 | + schema: Optional[Dict[str, CollectionSchema]] |
491 | 503 |
|
492 | 504 |
|
493 | 505 | def update_collection_configuration_from_legacy_collection_metadata( |
@@ -587,6 +599,7 @@ def update_collection_configuration_to_json( |
587 | 599 | "spann": spann_config, |
588 | 600 | "embedding_function": ef_config, |
589 | 601 | "query_config": query_config, |
| 602 | + "schema": config.get("schema"), |
590 | 603 | } |
591 | 604 |
|
592 | 605 |
|
@@ -750,14 +763,34 @@ def overwrite_collection_configuration( |
750 | 763 | ef_config[k] = v |
751 | 764 | query_ef = updated_embedding_function.build_from_config(ef_config) |
752 | 765 |
|
| 766 | + existing_schema = existing_config.get("schema") |
| 767 | + new_diff_schema = update_config.get("schema") |
| 768 | + updated_schema: Optional[Dict[str, CollectionSchema]] = None |
| 769 | + if existing_schema is not None: |
| 770 | + if new_diff_schema is not None: |
| 771 | + updated_schema = overwrite_schema(existing_schema, new_diff_schema) |
| 772 | + else: |
| 773 | + updated_schema = existing_schema |
| 774 | + else: |
| 775 | + updated_schema = new_diff_schema |
| 776 | + |
753 | 777 | return CollectionConfiguration( |
754 | 778 | hnsw=updated_hnsw_config, |
755 | 779 | spann=updated_spann_config, |
756 | 780 | embedding_function=updated_embedding_function, |
757 | 781 | query_embedding_function=query_ef, |
| 782 | + schema=updated_schema, |
758 | 783 | ) |
759 | 784 |
|
760 | 785 |
|
| 786 | +def overwrite_schema( |
| 787 | + existing_schema: Dict[str, CollectionSchema], |
| 788 | + new_diff_schema: Dict[str, CollectionSchema], |
| 789 | +) -> Dict[str, CollectionSchema]: |
| 790 | + """Overwrite a schema with a new configuration""" |
| 791 | + return {**existing_schema, **new_diff_schema} |
| 792 | + |
| 793 | + |
761 | 794 | def validate_embedding_function_conflict_on_create( |
762 | 795 | embedding_function: Optional[EmbeddingFunction], # type: ignore |
763 | 796 | configuration_ef: Optional[EmbeddingFunction], # type: ignore |
|
0 commit comments