Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def load_collection_configuration_from_json(
raise ValueError(
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
)

else:
ef = None

Expand Down Expand Up @@ -148,11 +147,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
if ef is None:
ef = None
ef_config = {"type": "legacy"}
return {
"hnsw": hnsw_config,
"spann": spann_config,
"embedding_function": ef_config,
}

if ef is not None:
try:
Expand Down Expand Up @@ -260,16 +254,6 @@ class CreateCollectionConfiguration(TypedDict, total=False):
embedding_function: Optional[EmbeddingFunction] # type: ignore


def load_collection_configuration_from_create_collection_configuration(
config: CreateCollectionConfiguration,
) -> CollectionConfiguration:
return CollectionConfiguration(
hnsw=config.get("hnsw"),
spann=config.get("spann"),
embedding_function=config.get("embedding_function"),
)


def create_collection_configuration_from_legacy_collection_metadata(
metadata: CollectionMetadata,
) -> CreateCollectionConfiguration:
Expand Down Expand Up @@ -301,13 +285,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
return CreateCollectionConfiguration(hnsw=hnsw_config)


def load_create_collection_configuration_from_json_str(
json_str: str,
) -> CreateCollectionConfiguration:
json_map = json.loads(json_str)
return load_create_collection_configuration_from_json(json_map)


# TODO: make warnings prettier and add link to migration docs
def load_create_collection_configuration_from_json(
json_map: Dict[str, Any]
Expand Down
35 changes: 27 additions & 8 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
# Prepare
if query_records["embeddings"] is None:
validate_record_set_for_embedding(record_set=query_records)
request_embeddings = self._embed_record_set(record_set=query_records)
request_embeddings = self._embed_record_set(
record_set=query_records, is_query=True
)
else:
request_embeddings = query_records["embeddings"]

Expand Down Expand Up @@ -531,7 +533,10 @@ def _update_model_after_modify_success(
)

def _embed_record_set(
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
self,
record_set: BaseRecordSet,
embeddable_fields: Optional[Set[str]] = None,
is_query: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works, but just for discussion an alternative approach is to have separate methods for read and write paths. any idea here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah +1

) -> Embeddings:
if embeddable_fields is None:
embeddable_fields = get_default_embeddable_record_set_fields()
Expand All @@ -545,27 +550,41 @@ def _embed_record_set(
"You must set a data loader on the collection if loading from URIs."
)
return self._embed(
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
is_query=is_query,
)
else:
return self._embed(input=record_set[field]) # type: ignore[literal-required]
return self._embed(
input=record_set[field], # type: ignore[literal-required]
is_query=is_query,
)
raise ValueError(
"Record does not contain any non-None fields that can be embedded."
f"Embeddable Fields: {embeddable_fields}"
f"Record Fields: {record_set}"
)

def _embed(self, input: Any) -> Embeddings:
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
if self._embedding_function is not None and not isinstance(
self._embedding_function, ef.DefaultEmbeddingFunction
):
return self._embedding_function(input=input)
if is_query:
return self._embedding_function.embed_query(input=input)
else:
return self._embedding_function(input=input)
Comment on lines +571 to +574
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for all embedding functions we support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since theres a default that assumes query config doesnt exist, none of the existing efs will break.


config_ef = self.configuration.get("embedding_function")
if config_ef is not None:
return config_ef(input=input)
if is_query:
return config_ef.embed_query(input=input)
else:
return config_ef(input=input)
if self._embedding_function is None:
raise ValueError(
"You must provide an embedding function to compute embeddings."
"https://docs.trychroma.com/guides/embeddings"
)
return self._embedding_function(input=input)
if is_query:
return self._embedding_function.embed_query(input=input)
else:
return self._embedding_function(input=input)
116 changes: 116 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def maybe_cast_one_to_many(target: Optional[OneOrMany[T]]) -> Optional[List[T]]:
PyEmbeddings = List[PyEmbedding]
Embedding = Vector
Embeddings = List[Embedding]
SparseEmbedding = SparseVector
SparseEmbeddings = List[SparseEmbedding]

Space = Literal["cosine", "l2", "ip"]

Expand Down Expand Up @@ -569,6 +571,13 @@ class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...

def embed_query(self, input: D) -> Embeddings:
"""
Get the embeddings for a query input.
This method is optional, and if not implemented, the default behavior is to call __call__.
"""
return self.__call__(input)

def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
Expand Down Expand Up @@ -1096,6 +1105,21 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
return embeddings


def validate_sparse_embeddings(embeddings: SparseEmbeddings) -> SparseEmbeddings:
"""Validates sparse embeddings to ensure it is a list of sparse vectors"""
if not isinstance(embeddings, list):
raise ValueError(
f"Expected sparse embeddings to be a list, got {type(embeddings).__name__}"
)
if len(embeddings) == 0:
raise ValueError(
f"Expected sparse embeddings to be a non-empty list, got {len(embeddings)} sparse embeddings"
)
for embedding in embeddings:
validate_sparse_vector(embedding)
return embeddings


def validate_documents(documents: Documents, nullable: bool = False) -> None:
"""Validates documents to ensure it is a list of strings"""
if not isinstance(documents, list):
Expand Down Expand Up @@ -1150,3 +1174,95 @@ def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings:

def convert_list_embeddings_to_np(embeddings: PyEmbeddings) -> Embeddings:
return [np.array(embedding) for embedding in embeddings]


@runtime_checkable
class SparseEmbeddingFunction(Protocol[D]):
"""
A protocol for sparse embedding functions. To implement a new sparse embedding function,
you need to implement the following methods at minimum:
- __call__

For future compatibility, it is strongly recommended to also implement:
- __init__
- name
- build_from_config
- get_config
"""

@abstractmethod
def __call__(self, input: D) -> SparseEmbeddings:
...

def embed_query(self, input: D) -> SparseEmbeddings:
"""
Get the embeddings for a query input.
This method is optional, and if not implemented, the default behavior is to call __call__.
"""
return self.__call__(input)

def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
call = getattr(cls, "__call__")

def __call__(self: SparseEmbeddingFunction[D], input: D) -> SparseEmbeddings:
result = call(self, input)
assert result is not None
return validate_sparse_embeddings(cast(SparseEmbeddings, result))

setattr(cls, "__call__", __call__)

def embed_with_retries(
self, input: D, **retry_kwargs: Dict[str, Any]
) -> SparseEmbeddings:
return cast(SparseEmbeddings, retry(**retry_kwargs)(self.__call__)(input))

@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Initialize the embedding function.
Pass any arguments that will be needed to build the embedding function
config.
"""
...

@staticmethod
@abstractmethod
def name() -> str:
"""
Return the name of the embedding function.
"""
...

@staticmethod
@abstractmethod
def build_from_config(config: Dict[str, Any]) -> "SparseEmbeddingFunction[D]":
"""
Build the embedding function from a config, which will be used to
deserialize the embedding function.
"""
...

@abstractmethod
def get_config(self) -> Dict[str, Any]:
"""
Return the config for the embedding function, which will be used to
serialize the embedding function.
"""
...

def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
"""
Validate the update to the config.
"""
return

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the config.
"""
return
123 changes: 123 additions & 0 deletions chromadb/test/configurations/test_collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from chromadb.test.conftest import ClientFactories
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
from chromadb.types import Collection as CollectionModel
from typing import Optional, TypedDict


class LegacyEmbeddingFunction(EmbeddingFunction[Embeddable]):
Expand Down Expand Up @@ -1616,3 +1617,125 @@ def test_default_space_custom_embedding_function_with_metadata_and_config(
spann_config = coll.configuration.get("spann")
assert spann_config is not None
assert spann_config.get("space") == "ip"


class CustomEmbeddingFunctionQueryConfig(TypedDict):
task: str


@register_embedding_function
class CustomEmbeddingFunctionWithQueryConfig(EmbeddingFunction[Embeddable]):
def __init__(
self,
task: str,
model_name: str,
dim: int = 3,
query_config: Optional[CustomEmbeddingFunctionQueryConfig] = None,
):
self._dim = dim
self._model_name = model_name
self._task = task
self._query_config = query_config

def __call__(self, input: Embeddable) -> Embeddings:
return cast(Embeddings, np.array([[1.0] * self._dim], dtype=np.float32))

def embed_query(self, input: Embeddable) -> Embeddings:
if self._query_config is not None and self._query_config.get("task") == "query":
return cast(Embeddings, np.array([[2.0] * self._dim], dtype=np.float32))
else:
return self.__call__(input)

@staticmethod
def name() -> str:
return "custom_ef_with_query_config"

def get_config(self) -> Dict[str, Any]:
return {
"model_name": self._model_name,
"dim": self._dim,
"task": self._task,
"query_config": self._query_config,
}

@staticmethod
def build_from_config(
config: Dict[str, Any]
) -> "CustomEmbeddingFunctionWithQueryConfig":
model_name = config.get("model_name")
dim = config.get("dim")
task = config.get("task")
query_config = config.get("query_config")

if model_name is None or dim is None:
assert False, "This code should not be reached"

return CustomEmbeddingFunctionWithQueryConfig(
model_name=model_name, dim=dim, task=task, query_config=query_config # type: ignore
)

def default_space(self) -> Space:
return "cosine"

def supported_spaces(self) -> List[Space]:
return ["cosine"]


def test_custom_embedding_function_with_query_config(client: ClientAPI) -> None:
client.reset()
coll = client.create_collection(
name="test_custom_embedding_function_with_query_config",
embedding_function=CustomEmbeddingFunctionWithQueryConfig(
task="document",
model_name="i_want_anything",
dim=3,
query_config={"task": "query"},
),
)
assert coll is not None
ef = coll.configuration.get("embedding_function")
assert ef is not None
assert ef.name() == "custom_ef_with_query_config"
assert ef.get_config() == {
"model_name": "i_want_anything",
"dim": 3,
"task": "document",
"query_config": {"task": "query"},
}
assert ef.default_space() == "cosine"
assert ef.supported_spaces() == ["cosine"]
assert np.array_equal(
ef.embed_query(input="How many people in Berlin?"),
np.array([[2.0, 2.0, 2.0]], dtype=np.float32),
)


def test_deserializing_custom_embedding_function_with_query_config_no_query_config(
client: ClientAPI,
) -> None:
json_string = """
{
"embedding_function": {
"type": "known",
"name": "custom_ef_with_query_config",
"config": {"model_name": "i_want_anything", "dim": 3, "task": "document"}
}
}
"""
config = load_collection_configuration_from_json(json.loads(json_string))
assert config is not None
assert config.get("embedding_function") is not None
ef = config.get("embedding_function")
assert ef is not None
assert ef.get_config() == {
"model_name": "i_want_anything",
"dim": 3,
"task": "document",
"query_config": None,
}
assert ef.default_space() == "cosine"
assert ef.supported_spaces() == ["cosine"]
assert np.array_equal(
ef.embed_query(input="How many people in Berlin?"),
np.array([[1.0, 1.0, 1.0]], dtype=np.float32),
)
Loading
Loading