Skip to content

Commit 4a22684

Browse files
committed
[BUG] Replace empty config objects with None during deserializing
1 parent 730ed06 commit 4a22684

14 files changed

+252
-61
lines changed

chromadb/api/collection_configuration.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
known_embedding_functions,
1111
register_embedding_function,
1212
)
13+
from chromadb.utils.embedding_functions.malformed_embedding_function import (
14+
MalformedEmbeddingFunction,
15+
)
1316
from multiprocessing import cpu_count
1417
import warnings
1518

@@ -83,11 +86,23 @@ def load_collection_configuration_from_json(
8386
else:
8487
try:
8588
ef = known_embedding_functions[ef_config["name"]]
86-
ef = ef.build_from_config(ef_config["config"]) # type: ignore
8789
except KeyError:
8890
raise ValueError(
8991
f"Embedding function {ef_config['name']} not found. Add @register_embedding_function decorator to the class definition."
9092
)
93+
try:
94+
ef = ef.build_from_config(ef_config["config"]) # type: ignore
95+
except Exception as e:
96+
warnings.warn(
97+
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e} \
98+
Returning a MalformedEmbeddingFunction",
99+
UserWarning,
100+
stacklevel=2,
101+
)
102+
ef = MalformedEmbeddingFunction( # type: ignore
103+
malformed_ef_name=ef_config["name"],
104+
config=ef_config["config"],
105+
)
91106
else:
92107
ef = None
93108

@@ -649,7 +664,10 @@ def overwrite_embedding_function(
649664
return existing_embedding_function
650665

651666
# Validate function compatibility
652-
if existing_embedding_function.name() != update_embedding_function.name():
667+
if (
668+
existing_embedding_function.name() != update_embedding_function.name()
669+
and not isinstance(existing_embedding_function, MalformedEmbeddingFunction)
670+
):
653671
raise ValueError(
654672
f"Cannot update embedding function: incompatible types "
655673
f"({existing_embedding_function.name()} vs {update_embedding_function.name()})"

chromadb/test/configurations/test_collection_configuration.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
UpdateHNSWConfiguration,
1717
CreateSpannConfiguration,
1818
UpdateSpannConfiguration,
19-
load_update_collection_configuration_from_json,
2019
SpannConfiguration,
2120
overwrite_spann_configuration,
21+
overwrite_collection_configuration,
2222
)
2323
import json
2424
import os
2525
from chromadb.utils.embedding_functions import register_embedding_function
2626
from chromadb.test.conftest import ClientFactories
27+
from chromadb.utils.embedding_functions.malformed_embedding_function import (
28+
MalformedEmbeddingFunction,
29+
)
2730

2831

2932
# Check if we are running in a mode where SPANN is disabled
@@ -709,19 +712,11 @@ def test_spann_update_from_json(client: ClientAPI) -> None:
709712
configuration={"spann": initial_spann},
710713
)
711714

712-
# Create JSON for update
713-
update_json = """
714-
{
715-
"spann": {
716-
"search_nprobe": 15,
717-
"ef_search": 200
718-
}
719-
}
720-
"""
721-
722-
# Parse JSON and create update configuration
723-
update_config = load_update_collection_configuration_from_json(
724-
json.loads(update_json)
715+
update_config = UpdateCollectionConfiguration(
716+
spann=UpdateSpannConfiguration(
717+
search_nprobe=15,
718+
ef_search=200,
719+
)
725720
)
726721

727722
# Apply the update
@@ -830,3 +825,68 @@ def test_default_collection_creation(client: ClientAPI) -> None:
830825
ef = config.get("embedding_function")
831826
assert ef is not None
832827
assert ef.name() == "default"
828+
829+
830+
def test_malformed_embedding_function() -> None:
831+
"""Test that on an invalid configuration, the embedding function is a MalformedEmbeddingFunction"""
832+
invalid_config: Dict[str, Any] = {
833+
"hnsw": {
834+
"space": "l2",
835+
"ef_construction": 100,
836+
"ef_search": 100,
837+
"max_neighbors": 16,
838+
"resize_factor": 1.2,
839+
"sync_threshold": 1000,
840+
},
841+
"spann": None,
842+
"embedding_function": {
843+
"name": "custom_ef",
844+
"type": "known",
845+
"config": {},
846+
},
847+
}
848+
loaded_config = load_collection_configuration_from_json(invalid_config)
849+
assert loaded_config is not None
850+
loaded_ef = loaded_config.get("embedding_function")
851+
assert loaded_ef is not None
852+
assert isinstance(loaded_ef, MalformedEmbeddingFunction)
853+
assert loaded_ef.malformed_ef_name == "custom_ef"
854+
855+
assert isinstance(invalid_config["embedding_function"], dict)
856+
assert isinstance(invalid_config["embedding_function"]["config"], dict)
857+
assert loaded_ef.config == invalid_config["embedding_function"]["config"]
858+
859+
860+
def test_update_ef_when_malformed() -> None:
861+
invalid_config: Dict[str, Any] = {
862+
"hnsw": {
863+
"space": "l2",
864+
"ef_construction": 100,
865+
"ef_search": 100,
866+
"max_neighbors": 16,
867+
"resize_factor": 1.2,
868+
"sync_threshold": 1000,
869+
},
870+
"spann": None,
871+
"embedding_function": {
872+
"name": "custom_ef",
873+
"type": "known",
874+
"config": {},
875+
},
876+
}
877+
loaded_config = load_collection_configuration_from_json(invalid_config)
878+
assert loaded_config is not None
879+
assert isinstance(
880+
loaded_config.get("embedding_function"), MalformedEmbeddingFunction
881+
)
882+
883+
overwrite_config: UpdateCollectionConfiguration = {
884+
"embedding_function": CustomEmbeddingFunction(dim=10),
885+
}
886+
887+
new_config = overwrite_collection_configuration(loaded_config, overwrite_config)
888+
assert new_config is not None
889+
ef = new_config.get("embedding_function")
890+
assert ef is not None
891+
assert isinstance(ef, CustomEmbeddingFunction)
892+
assert ef.get_config() == {"dim": 10}

chromadb/types.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,13 @@
88
from uuid import UUID
99
from enum import Enum
1010
from pydantic import BaseModel
11-
import warnings
1211

1312
from chromadb.api.configuration import (
1413
ConfigurationInternal,
1514
)
1615
from chromadb.serde import BaseModelJSONSerializable
1716
from chromadb.api.collection_configuration import (
1817
CollectionConfiguration,
19-
HNSWConfiguration,
20-
SpannConfiguration,
2118
collection_configuration_to_json,
2219
load_collection_configuration_from_json,
2320
)
@@ -149,15 +146,8 @@ def get_configuration(self) -> CollectionConfiguration:
149146
try:
150147
return load_collection_configuration_from_json(self.configuration_json)
151148
except Exception as e:
152-
warnings.warn(
153-
f"Server does not respond with configuration_json. Please update server: {e}",
154-
DeprecationWarning,
155-
stacklevel=2,
156-
)
157-
return CollectionConfiguration(
158-
hnsw=HNSWConfiguration(),
159-
spann=SpannConfiguration(),
160-
embedding_function=None,
149+
raise ValueError(
150+
f"Could not deserialize configuration_json: {e}",
161151
)
162152

163153
def set_configuration(self, configuration: CollectionConfiguration) -> None:
@@ -175,19 +165,12 @@ def get_model_fields(self) -> Dict[Any, Any]:
175165
@override
176166
def from_json(cls, json_map: Dict[str, Any]) -> Self:
177167
"""Deserializes a Collection object from JSON"""
178-
configuration: CollectionConfiguration = {
179-
"hnsw": {},
180-
"spann": {},
181-
"embedding_function": None,
182-
}
183168
try:
184169
configuration_json = json_map.get("configuration_json", None)
185170
configuration = load_collection_configuration_from_json(configuration_json)
186171
except Exception as e:
187-
warnings.warn(
188-
f"Server does not respond with configuration_json. Please update server: {e}",
189-
DeprecationWarning,
190-
stacklevel=2,
172+
raise ValueError(
173+
f"Could not deserialize configuration_json: {e}",
191174
)
192175
return cls(
193176
id=json_map["id"],
Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import os
2-
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
3-
from chromadb.api.types import Documents, EmbeddingFunction
2+
from chromadb.utils.embedding_functions.openai_embedding_function import (
3+
OpenAIEmbeddingFunction,
4+
)
45
from typing import Dict, Any, Optional
6+
import warnings
57

6-
class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
78

9+
class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
810
def __init__(
9-
self,
10-
api_key: Optional[str],
11-
api_base: str,
12-
api_key_env_var: str = "CHROMA_BASETEN_API_KEY",
13-
):
11+
self,
12+
api_key: Optional[str],
13+
api_base: str,
14+
api_key_env_var: str = "CHROMA_BASETEN_API_KEY",
15+
):
1416
"""
1517
Initialize the BasetenEmbeddingFunction.
1618
Args:
@@ -25,33 +27,35 @@ def __init__(
2527
"The openai python package is not installed. Please install it with `pip install openai`"
2628
)
2729

30+
if api_key is not None:
31+
warnings.warn(
32+
"Direct api_key configuration will not be persisted. "
33+
"Please use environment variables via api_key_env_var for persistent storage.",
34+
DeprecationWarning,
35+
)
36+
2837
self.api_key_env_var = api_key_env_var
2938
# Prioritize api_key argument, then environment variable
3039
resolved_api_key = api_key or os.getenv(api_key_env_var)
3140
if not resolved_api_key:
32-
raise ValueError(f"API key not provided and {api_key_env_var} environment variable is not set.")
41+
raise ValueError(
42+
f"API key not provided and {api_key_env_var} environment variable is not set."
43+
)
3344
self.api_key = resolved_api_key
3445
if not api_base:
3546
raise ValueError("The api_base argument must be provided.")
3647
self.api_base = api_base
3748
self.model_name = "baseten-embedding-model"
3849
self.dimensions = None
3950

40-
self.client = openai.OpenAI(
41-
api_key=self.api_key,
42-
base_url=self.api_base
43-
)
44-
51+
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_base)
52+
4553
@staticmethod
4654
def name() -> str:
4755
return "baseten"
48-
56+
4957
def get_config(self) -> Dict[str, Any]:
50-
return {
51-
"api_base": self.api_base,
52-
"api_key_env_var": self.api_key_env_var
53-
}
54-
58+
return {"api_base": self.api_base, "api_key_env_var": self.api_key_env_var}
5559

5660
@staticmethod
5761
def build_from_config(config: Dict[str, Any]) -> "BasetenEmbeddingFunction":
@@ -68,16 +72,20 @@ def build_from_config(config: Dict[str, Any]) -> "BasetenEmbeddingFunction":
6872
api_key_env_var = config.get("api_key_env_var")
6973
api_base = config.get("api_base")
7074
if api_key_env_var is None or api_base is None:
71-
raise ValueError("Missing 'api_key_env_var' or 'api_base' in configuration for BasetenEmbeddingFunction.")
75+
raise ValueError(
76+
"Missing 'api_key_env_var' or 'api_base' in configuration for BasetenEmbeddingFunction."
77+
)
7278

7379
# Note: We rely on the __init__ method to handle potential missing api_key
7480
# by checking the environment variable if the config value is None.
7581
# However, api_base must be present either in config or have a default.
7682
if api_base is None:
77-
raise ValueError("Missing 'api_base' in configuration for BasetenEmbeddingFunction.")
83+
raise ValueError(
84+
"Missing 'api_base' in configuration for BasetenEmbeddingFunction."
85+
)
7886

7987
return BasetenEmbeddingFunction(
80-
api_key=None, # Pass None if not in config, __init__ will check env var
88+
api_key=None, # Pass None if not in config, __init__ will check env var
8189
api_base=api_base,
8290
api_key_env_var=api_key_env_var,
83-
)
91+
)

chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
from chromadb.utils.embedding_functions.schemas import validate_config_schema
1010
from typing import cast
11+
import warnings
1112

1213
BASE_URL = "https://api.cloudflare.com/client/v4/accounts"
1314
GATEWAY_BASE_URL = "https://gateway.ai.cloudflare.com/v1"
@@ -43,6 +44,13 @@ def __init__(
4344
raise ValueError(
4445
"The httpx python package is not installed. Please install it with `pip install httpx`"
4546
)
47+
48+
if api_key is not None:
49+
warnings.warn(
50+
"Direct api_key configuration will not be persisted. "
51+
"Please use environment variables via api_key_env_var for persistent storage.",
52+
DeprecationWarning,
53+
)
4654
self.model_name = model_name
4755
self.account_id = account_id
4856
self.api_key_env_var = api_key_env_var

chromadb/utils/embedding_functions/cohere_embedding_function.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import base64
1414
import io
1515
import importlib
16+
import warnings
1617

1718

1819
class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
@@ -36,6 +37,12 @@ def __init__(
3637
"The PIL python package is not installed. Please install it with `pip install pillow`"
3738
)
3839

40+
if api_key is not None:
41+
warnings.warn(
42+
"Direct api_key configuration will not be persisted. "
43+
"Please use environment variables via api_key_env_var for persistent storage.",
44+
DeprecationWarning,
45+
)
3946
self.api_key_env_var = api_key_env_var
4047
self.api_key = api_key or os.getenv(api_key_env_var)
4148
if not self.api_key:

0 commit comments

Comments
 (0)