diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 2bffabfd8..27b204432 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -131,6 +131,8 @@ ignore = [ "PLR0915", # Ignore unused params "ARG002", + # Allow assertions + "S101", ] unfixable = [ # Don't touch unused imports diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 359ace58d..990aa4c34 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -40,9 +40,8 @@ def __init__( **embedding_function_params, ): """ - Initializes the store. The __init__ constructor is not part of the Store Protocol - and the signature can be customized to your needs. For example, parameters needed - to set up a database client would be passed to this method. + Creates a new ChromaDocumentStore instance. + It is meant to be connected to a Chroma collection. Note: for the component to be part of a serializable pipeline, the __init__ parameters must be serializable, reason why we use a registry to configure the @@ -65,7 +64,6 @@ def __init__( :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the `distance_function` parameter above. - :param embedding_function_params: additional parameters to pass to the embedding function. """ @@ -79,53 +77,61 @@ def __init__( # Store the params for marshalling self._collection_name = collection_name self._embedding_function = embedding_function + self._embedding_func = get_embedding_function(embedding_function, **embedding_function_params) self._embedding_function_params = embedding_function_params self._distance_function = distance_function + self._metadata = metadata + self._collection = None self._persist_path = persist_path self._host = host self._port = port - # Create the client instance - if persist_path and (host or port is not None): - error_message = ( - "You must specify `persist_path` for local persistent storage or, " - "alternatively, `host` and `port` for remote HTTP client connection. " - "You cannot specify both options." - ) - raise ValueError(error_message) - if host and port is not None: - # Remote connection via HTTP client - self._chroma_client = chromadb.HttpClient( - host=host, - port=port, - ) - elif persist_path is None: - # In-memory storage - self._chroma_client = chromadb.Client() - else: - # Local persistent storage - self._chroma_client = chromadb.PersistentClient(path=persist_path) + self._initialized = False - embedding_func = get_embedding_function(embedding_function, **embedding_function_params) + def _ensure_initialized(self): + if not self._initialized: + # Create the client instance + if self._persist_path and (self._host or self._port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if self._host and self._port is not None: + # Remote connection via HTTP client + client = chromadb.HttpClient( + host=self._host, + port=self._port, + ) + elif self._persist_path is None: + # In-memory storage + client = chromadb.Client() + else: + # Local persistent storage + client = chromadb.PersistentClient(path=self._persist_path) - metadata = metadata or {} - if "hnsw:space" not in metadata: - metadata["hnsw:space"] = distance_function + self._metadata = self._metadata or {} + if "hnsw:space" not in self._metadata: + self._metadata["hnsw:space"] = self._distance_function - if collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) + if self._collection_name in [c.name for c in client.list_collections()]: + self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func) - if metadata != self._collection.metadata: - logger.warning( - "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + if self._metadata != self._collection.metadata: + logger.warning( + "Collection already exists. " + "The `distance_function` and `metadata` parameters will be ignored." + ) + else: + self._collection = client.create_collection( + name=self._collection_name, + metadata=self._metadata, + embedding_function=self._embedding_func, ) - else: - self._collection = self._chroma_client.create_collection( - name=collection_name, - metadata=metadata, - embedding_function=embedding_func, - ) + + self._initialized = True def count_documents(self) -> int: """ @@ -133,6 +139,8 @@ def count_documents(self) -> int: :returns: how many documents are present in the document store. """ + self._ensure_initialized() + assert self._collection is not None return self._collection.count() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: @@ -197,6 +205,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: the filters to apply to the document list. :returns: a list of Documents that match the given filters. """ + self._ensure_initialized() + assert self._collection is not None + if filters: chroma_filter = _convert_filters(filters) kwargs: Dict[str, Any] = {"where": chroma_filter.where} @@ -227,6 +238,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :returns: The number of documents written """ + self._ensure_initialized() + assert self._collection is not None + for doc in documents: if not isinstance(doc, Document): msg = "param 'documents' must contain a list of objects of type Document" @@ -280,8 +294,11 @@ def delete_documents(self, document_ids: List[str]) -> None: """ Deletes all documents with a matching document_ids from the document store. - :param document_ids: the object_ids to delete + :param document_ids: the document ids to delete """ + self._ensure_initialized() + assert self._collection is not None + self._collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: @@ -292,6 +309,9 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: matching documents for each query. """ + self._ensure_initialized() + assert self._collection is not None + if filters is None: results = self._collection.query( query_texts=queries, @@ -323,6 +343,9 @@ def search_embeddings( :returns: a list of lists of documents that match the given filters. """ + self._ensure_initialized() + assert self._collection is not None + if filters is None: results = self._collection.query( query_embeddings=query_embeddings, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index d33086945..41491dc4d 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -98,7 +98,8 @@ def test_invalid_initialization_both_host_and_persist_path(self): Test that providing both host and persist_path raises an error. """ with pytest.raises(ValueError): - ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store._ensure_initialized() def test_delete_empty(self, document_store: ChromaDocumentStore): """ @@ -207,6 +208,7 @@ def test_same_collection_name_reinitialization(self): @pytest.mark.integration def test_distance_metric_initialization(self): store = ChromaDocumentStore("test_2", distance_function="cosine") + store._ensure_initialized() assert store._collection.metadata["hnsw:space"] == "cosine" with pytest.raises(ValueError): @@ -215,9 +217,11 @@ def test_distance_metric_initialization(self): @pytest.mark.integration def test_distance_metric_reinitialization(self, caplog): store = ChromaDocumentStore("test_4", distance_function="cosine") + store._ensure_initialized() with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore("test_4", distance_function="ip") + new_store._ensure_initialized() assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." @@ -238,6 +242,8 @@ def test_metadata_initialization(self, caplog): "hnsw:M": 103, }, ) + store._ensure_initialized() + assert store._collection.metadata["hnsw:space"] == "ip" assert store._collection.metadata["hnsw:search_ef"] == 101 assert store._collection.metadata["hnsw:construction_ef"] == 102 @@ -254,6 +260,8 @@ def test_metadata_initialization(self, caplog): }, ) + new_store._ensure_initialized() + assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text