From 1bdb3943c4ca5870a9840cd4a53a08246e7e295c Mon Sep 17 00:00:00 2001 From: alperkaya Date: Wed, 25 Sep 2024 21:19:17 +0200 Subject: [PATCH 1/5] defer DB from chroma --- .../document_stores/chroma/document_store.py | 111 ++++++++++-------- .../chroma/tests/test_document_store.py | 31 +++-- 2 files changed, 83 insertions(+), 59 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 359ace58d..a94c453a6 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -79,53 +79,68 @@ def __init__( # Store the params for marshalling self._collection_name = collection_name self._embedding_function = embedding_function + self._embedding_func = get_embedding_function(embedding_function, **embedding_function_params) self._embedding_function_params = embedding_function_params self._distance_function = distance_function + self._metadata = metadata + self._collection = None self._persist_path = persist_path self._host = host self._port = port + self._chroma_client = None + + @property + def chroma_client(self): + if self._chroma_client is None: + # Create the client instance + if self._persist_path and (self._host or self._port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if self._host and self._port is not None: + # Remote connection via HTTP client + self._chroma_client = chromadb.HttpClient( + host=self._host, + port=self._port, + ) + elif self._persist_path is None: + # In-memory storage + self._chroma_client = chromadb.Client() + else: + # Local persistent storage + self._chroma_client = chromadb.PersistentClient(path=self._persist_path) + + return self._chroma_client + + @property + def collection(self): + if self._collection is None: + + self._metadata = self._metadata or {} + if "hnsw:space" not in self._metadata: + self._metadata["hnsw:space"] = self._distance_function + + if self._collection_name in [c.name for c in self.chroma_client.list_collections()]: + self._collection = self.chroma_client.get_collection( + self._collection_name, embedding_function=self._embedding_func + ) - # Create the client instance - if persist_path and (host or port is not None): - error_message = ( - "You must specify `persist_path` for local persistent storage or, " - "alternatively, `host` and `port` for remote HTTP client connection. " - "You cannot specify both options." - ) - raise ValueError(error_message) - if host and port is not None: - # Remote connection via HTTP client - self._chroma_client = chromadb.HttpClient( - host=host, - port=port, - ) - elif persist_path is None: - # In-memory storage - self._chroma_client = chromadb.Client() - else: - # Local persistent storage - self._chroma_client = chromadb.PersistentClient(path=persist_path) - - embedding_func = get_embedding_function(embedding_function, **embedding_function_params) - - metadata = metadata or {} - if "hnsw:space" not in metadata: - metadata["hnsw:space"] = distance_function - - if collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) - - if metadata != self._collection.metadata: - logger.warning( - "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + if self._metadata != self._collection.metadata: + logger.warning( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + ) + else: + self._collection = self.chroma_client.create_collection( + name=self._collection_name, + metadata=self._metadata, + embedding_function=self._embedding_func, ) - else: - self._collection = self._chroma_client.create_collection( - name=collection_name, - metadata=metadata, - embedding_function=embedding_func, - ) + + return self._collection def count_documents(self) -> int: """ @@ -133,7 +148,7 @@ def count_documents(self) -> int: :returns: how many documents are present in the document store. """ - return self._collection.count() + return self.collection.count() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ @@ -206,9 +221,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc if chroma_filter.where_document: kwargs["where_document"] = chroma_filter.where_document - result = self._collection.get(**kwargs) + result = self.collection.get(**kwargs) else: - result = self._collection.get() + result = self.collection.get() return self._get_result_to_documents(result) @@ -272,7 +287,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D doc.id, ) - self._collection.add(**data) + self.collection.add(**data) return len(documents) @@ -282,7 +297,7 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: the object_ids to delete """ - self._collection.delete(ids=document_ids) + self.collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: """Search the documents in the store using the provided text queries. @@ -293,14 +308,14 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any :returns: matching documents for each query. """ if filters is None: - results = self._collection.query( + results = self.collection.query( query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"], ) else: chroma_filters = _convert_filters(filters=filters) - results = self._collection.query( + results = self.collection.query( query_texts=queries, n_results=top_k, where=chroma_filters.where, @@ -324,14 +339,14 @@ def search_embeddings( """ if filters is None: - results = self._collection.query( + results = self.collection.query( query_embeddings=query_embeddings, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"], ) else: chroma_filters = _convert_filters(filters=filters) - results = self._collection.query( + results = self.collection.query( query_embeddings=query_embeddings, n_results=top_k, where=chroma_filters.where, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 3a6952ff8..403f35641 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 import logging import operator +import sys import uuid from typing import List from unittest import mock -import sys import numpy as np import pytest @@ -98,7 +98,7 @@ def test_invalid_initialization_both_host_and_persist_path(self): Test that providing both host and persist_path raises an error. """ with pytest.raises(ValueError): - ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost").chroma_client def test_delete_empty(self, document_store: ChromaDocumentStore): """ @@ -207,7 +207,7 @@ def test_same_collection_name_reinitialization(self): @pytest.mark.integration def test_distance_metric_initialization(self): store = ChromaDocumentStore("test_2", distance_function="cosine") - assert store._collection.metadata["hnsw:space"] == "cosine" + assert store.collection.metadata["hnsw:space"] == "cosine" with pytest.raises(ValueError): ChromaDocumentStore("test_3", distance_function="jaccard") @@ -216,15 +216,20 @@ def test_distance_metric_initialization(self): def test_distance_metric_reinitialization(self, caplog): store = ChromaDocumentStore("test_4", distance_function="cosine") + # Access the collection to trigger the creation and set the distance function + _ = store.collection + with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore("test_4", distance_function="ip") + # Access the collection of the new_store to trigger the log + _ = new_store.collection assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text ) - assert store._collection.metadata["hnsw:space"] == "cosine" - assert new_store._collection.metadata["hnsw:space"] == "cosine" + assert store.collection.metadata["hnsw:space"] == "cosine" + assert new_store.collection.metadata["hnsw:space"] == "cosine" @pytest.mark.integration def test_metadata_initialization(self, caplog): @@ -238,10 +243,12 @@ def test_metadata_initialization(self, caplog): "hnsw:M": 103, }, ) - assert store._collection.metadata["hnsw:space"] == "ip" - assert store._collection.metadata["hnsw:search_ef"] == 101 - assert store._collection.metadata["hnsw:construction_ef"] == 102 - assert store._collection.metadata["hnsw:M"] == 103 + _ = store.collection + + assert store.collection.metadata["hnsw:space"] == "ip" + assert store.collection.metadata["hnsw:search_ef"] == 101 + assert store.collection.metadata["hnsw:construction_ef"] == 102 + assert store.collection.metadata["hnsw:M"] == 103 with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore( @@ -254,12 +261,14 @@ def test_metadata_initialization(self, caplog): }, ) + _ = new_store.collection + assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text ) - assert store._collection.metadata["hnsw:space"] == "ip" - assert new_store._collection.metadata["hnsw:space"] == "ip" + assert store.collection.metadata["hnsw:space"] == "ip" + assert new_store.collection.metadata["hnsw:space"] == "ip" def test_contains(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) From e884f2b47b89b204ea3e64256cf1050aa3cfe456 Mon Sep 17 00:00:00 2001 From: alperkaya Date: Thu, 26 Sep 2024 21:20:48 +0200 Subject: [PATCH 2/5] added ensure_initialized --- .../document_stores/chroma/document_store.py | 63 ++++++++++++++----- .../chroma/tests/test_document_store.py | 33 +++++----- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index a94c453a6..bbf133adf 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -90,7 +90,14 @@ def __init__( self._port = port self._chroma_client = None - @property + self._initialized = False + + def _ensure_initialized(self): + if not self._initialized: + self._chroma_client = self.chroma_client() + self._collection = self.collection() + self._initialized = True + def chroma_client(self): if self._chroma_client is None: # Create the client instance @@ -116,7 +123,6 @@ def chroma_client(self): return self._chroma_client - @property def collection(self): if self._collection is None: @@ -124,8 +130,8 @@ def collection(self): if "hnsw:space" not in self._metadata: self._metadata["hnsw:space"] = self._distance_function - if self._collection_name in [c.name for c in self.chroma_client.list_collections()]: - self._collection = self.chroma_client.get_collection( + if self._collection_name in [c.name for c in self._chroma_client.list_collections()]: + self._collection = self._chroma_client.get_collection( self._collection_name, embedding_function=self._embedding_func ) @@ -134,7 +140,7 @@ def collection(self): "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." ) else: - self._collection = self.chroma_client.create_collection( + self._collection = self._chroma_client.create_collection( name=self._collection_name, metadata=self._metadata, embedding_function=self._embedding_func, @@ -148,7 +154,11 @@ def count_documents(self) -> int: :returns: how many documents are present in the document store. """ - return self.collection.count() + self._ensure_initialized() + if self._collection is None: + msg = "Collection is not initialized" + raise ValueError(msg) + return self._collection.count() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ @@ -212,6 +222,11 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: the filters to apply to the document list. :returns: a list of Documents that match the given filters. """ + self._ensure_initialized() + if self._collection is None: + msg = "Collection is not initialized" + raise ValueError(msg) + if filters: chroma_filter = _convert_filters(filters) kwargs: Dict[str, Any] = {"where": chroma_filter.where} @@ -221,9 +236,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc if chroma_filter.where_document: kwargs["where_document"] = chroma_filter.where_document - result = self.collection.get(**kwargs) + result = self._collection.get(**kwargs) else: - result = self.collection.get() + result = self._collection.get() return self._get_result_to_documents(result) @@ -242,6 +257,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :returns: The number of documents written """ + self._ensure_initialized() + if self._collection is None: + msg = "Collection is not initialized" + raise ValueError(msg) + for doc in documents: if not isinstance(doc, Document): msg = "param 'documents' must contain a list of objects of type Document" @@ -287,7 +307,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D doc.id, ) - self.collection.add(**data) + self._collection.add(**data) return len(documents) @@ -297,7 +317,12 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: the object_ids to delete """ - self.collection.delete(ids=document_ids) + self._ensure_initialized() + if self._collection is None: + msg = "Collection is not initialized" + raise ValueError(msg) + + self._collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: """Search the documents in the store using the provided text queries. @@ -307,15 +332,20 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: matching documents for each query. """ + self._ensure_initialized() + if self._collection is None: + msg = "Collection is not initialized" + raise ValueError(msg) + if filters is None: - results = self.collection.query( + results = self._collection.query( query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"], ) else: chroma_filters = _convert_filters(filters=filters) - results = self.collection.query( + results = self._collection.query( query_texts=queries, n_results=top_k, where=chroma_filters.where, @@ -338,15 +368,20 @@ def search_embeddings( :returns: a list of lists of documents that match the given filters. """ + self._ensure_initialized() + if self._collection is None: + msg = "Collection is not initialized" + raise ValueError(msg) + if filters is None: - results = self.collection.query( + results = self._collection.query( query_embeddings=query_embeddings, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"], ) else: chroma_filters = _convert_filters(filters=filters) - results = self.collection.query( + results = self._collection.query( query_embeddings=query_embeddings, n_results=top_k, where=chroma_filters.where, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 403f35641..41491dc4d 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -98,7 +98,8 @@ def test_invalid_initialization_both_host_and_persist_path(self): Test that providing both host and persist_path raises an error. """ with pytest.raises(ValueError): - ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost").chroma_client + store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store._ensure_initialized() def test_delete_empty(self, document_store: ChromaDocumentStore): """ @@ -207,7 +208,8 @@ def test_same_collection_name_reinitialization(self): @pytest.mark.integration def test_distance_metric_initialization(self): store = ChromaDocumentStore("test_2", distance_function="cosine") - assert store.collection.metadata["hnsw:space"] == "cosine" + store._ensure_initialized() + assert store._collection.metadata["hnsw:space"] == "cosine" with pytest.raises(ValueError): ChromaDocumentStore("test_3", distance_function="jaccard") @@ -215,21 +217,18 @@ def test_distance_metric_initialization(self): @pytest.mark.integration def test_distance_metric_reinitialization(self, caplog): store = ChromaDocumentStore("test_4", distance_function="cosine") - - # Access the collection to trigger the creation and set the distance function - _ = store.collection + store._ensure_initialized() with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore("test_4", distance_function="ip") - # Access the collection of the new_store to trigger the log - _ = new_store.collection + new_store._ensure_initialized() assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text ) - assert store.collection.metadata["hnsw:space"] == "cosine" - assert new_store.collection.metadata["hnsw:space"] == "cosine" + assert store._collection.metadata["hnsw:space"] == "cosine" + assert new_store._collection.metadata["hnsw:space"] == "cosine" @pytest.mark.integration def test_metadata_initialization(self, caplog): @@ -243,12 +242,12 @@ def test_metadata_initialization(self, caplog): "hnsw:M": 103, }, ) - _ = store.collection + store._ensure_initialized() - assert store.collection.metadata["hnsw:space"] == "ip" - assert store.collection.metadata["hnsw:search_ef"] == 101 - assert store.collection.metadata["hnsw:construction_ef"] == 102 - assert store.collection.metadata["hnsw:M"] == 103 + assert store._collection.metadata["hnsw:space"] == "ip" + assert store._collection.metadata["hnsw:search_ef"] == 101 + assert store._collection.metadata["hnsw:construction_ef"] == 102 + assert store._collection.metadata["hnsw:M"] == 103 with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore( @@ -261,14 +260,14 @@ def test_metadata_initialization(self, caplog): }, ) - _ = new_store.collection + new_store._ensure_initialized() assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text ) - assert store.collection.metadata["hnsw:space"] == "ip" - assert new_store.collection.metadata["hnsw:space"] == "ip" + assert store._collection.metadata["hnsw:space"] == "ip" + assert new_store._collection.metadata["hnsw:space"] == "ip" def test_contains(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) From 3076bfc62d83e5ade25ffbd43e20f658de38202b Mon Sep 17 00:00:00 2001 From: alperkaya Date: Fri, 27 Sep 2024 15:44:35 +0200 Subject: [PATCH 3/5] addressed comments --- .../document_stores/chroma/document_store.py | 116 ++++++++---------- 1 file changed, 48 insertions(+), 68 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index bbf133adf..9fa3b6513 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -94,59 +94,51 @@ def __init__( def _ensure_initialized(self): if not self._initialized: - self._chroma_client = self.chroma_client() - self._collection = self.collection() - self._initialized = True - - def chroma_client(self): - if self._chroma_client is None: - # Create the client instance - if self._persist_path and (self._host or self._port is not None): - error_message = ( - "You must specify `persist_path` for local persistent storage or, " - "alternatively, `host` and `port` for remote HTTP client connection. " - "You cannot specify both options." - ) - raise ValueError(error_message) - if self._host and self._port is not None: - # Remote connection via HTTP client - self._chroma_client = chromadb.HttpClient( - host=self._host, - port=self._port, - ) - elif self._persist_path is None: - # In-memory storage - self._chroma_client = chromadb.Client() - else: - # Local persistent storage - self._chroma_client = chromadb.PersistentClient(path=self._persist_path) - - return self._chroma_client - - def collection(self): - if self._collection is None: - - self._metadata = self._metadata or {} - if "hnsw:space" not in self._metadata: - self._metadata["hnsw:space"] = self._distance_function - - if self._collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection( - self._collection_name, embedding_function=self._embedding_func - ) + if self._chroma_client is None: + # Create the client instance + if self._persist_path and (self._host or self._port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if self._host and self._port is not None: + # Remote connection via HTTP client + self._chroma_client = chromadb.HttpClient( + host=self._host, + port=self._port, + ) + elif self._persist_path is None: + # In-memory storage + self._chroma_client = chromadb.Client() + else: + # Local persistent storage + self._chroma_client = chromadb.PersistentClient(path=self._persist_path) + + if self._collection is None: + self._metadata = self._metadata or {} + if "hnsw:space" not in self._metadata: + self._metadata["hnsw:space"] = self._distance_function + + if self._collection_name in [c.name for c in self._chroma_client.list_collections()]: + self._collection = self._chroma_client.get_collection( + self._collection_name, embedding_function=self._embedding_func + ) - if self._metadata != self._collection.metadata: - logger.warning( - "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + if self._metadata != self._collection.metadata: + logger.warning( + "Collection already exists. " + "The `distance_function` and `metadata` parameters will be ignored." + ) + else: + self._collection = self._chroma_client.create_collection( + name=self._collection_name, + metadata=self._metadata, + embedding_function=self._embedding_func, ) - else: - self._collection = self._chroma_client.create_collection( - name=self._collection_name, - metadata=self._metadata, - embedding_function=self._embedding_func, - ) - return self._collection + self._initialized = True def count_documents(self) -> int: """ @@ -155,9 +147,7 @@ def count_documents(self) -> int: :returns: how many documents are present in the document store. """ self._ensure_initialized() - if self._collection is None: - msg = "Collection is not initialized" - raise ValueError(msg) + assert self._collection is not None return self._collection.count() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: @@ -223,9 +213,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: a list of Documents that match the given filters. """ self._ensure_initialized() - if self._collection is None: - msg = "Collection is not initialized" - raise ValueError(msg) + assert self._collection is not None if filters: chroma_filter = _convert_filters(filters) @@ -258,9 +246,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D The number of documents written """ self._ensure_initialized() - if self._collection is None: - msg = "Collection is not initialized" - raise ValueError(msg) + assert self._collection is not None for doc in documents: if not isinstance(doc, Document): @@ -318,9 +304,7 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: the object_ids to delete """ self._ensure_initialized() - if self._collection is None: - msg = "Collection is not initialized" - raise ValueError(msg) + assert self._collection is not None self._collection.delete(ids=document_ids) @@ -333,9 +317,7 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any :returns: matching documents for each query. """ self._ensure_initialized() - if self._collection is None: - msg = "Collection is not initialized" - raise ValueError(msg) + assert self._collection is not None if filters is None: results = self._collection.query( @@ -369,9 +351,7 @@ def search_embeddings( """ self._ensure_initialized() - if self._collection is None: - msg = "Collection is not initialized" - raise ValueError(msg) + assert self._collection is not None if filters is None: results = self._collection.query( From 4ea60d50e14616fb44ecdbc647d609d3f676241a Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 30 Sep 2024 09:51:02 +0200 Subject: [PATCH 4/5] simplification and linting --- integrations/chroma/pyproject.toml | 2 + .../document_stores/chroma/document_store.py | 79 +++++++++---------- 2 files changed, 39 insertions(+), 42 deletions(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 2bffabfd8..27b204432 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -131,6 +131,8 @@ ignore = [ "PLR0915", # Ignore unused params "ARG002", + # Allow assertions + "S101", ] unfixable = [ # Don't touch unused imports diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 9fa3b6513..65fda21d1 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -88,55 +88,50 @@ def __init__( self._persist_path = persist_path self._host = host self._port = port - self._chroma_client = None self._initialized = False def _ensure_initialized(self): if not self._initialized: - if self._chroma_client is None: - # Create the client instance - if self._persist_path and (self._host or self._port is not None): - error_message = ( - "You must specify `persist_path` for local persistent storage or, " - "alternatively, `host` and `port` for remote HTTP client connection. " - "You cannot specify both options." - ) - raise ValueError(error_message) - if self._host and self._port is not None: - # Remote connection via HTTP client - self._chroma_client = chromadb.HttpClient( - host=self._host, - port=self._port, - ) - elif self._persist_path is None: - # In-memory storage - self._chroma_client = chromadb.Client() - else: - # Local persistent storage - self._chroma_client = chromadb.PersistentClient(path=self._persist_path) - - if self._collection is None: - self._metadata = self._metadata or {} - if "hnsw:space" not in self._metadata: - self._metadata["hnsw:space"] = self._distance_function - - if self._collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection( - self._collection_name, embedding_function=self._embedding_func - ) + # Create the client instance + if self._persist_path and (self._host or self._port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if self._host and self._port is not None: + # Remote connection via HTTP client + client = chromadb.HttpClient( + host=self._host, + port=self._port, + ) + elif self._persist_path is None: + # In-memory storage + client = chromadb.Client() + else: + # Local persistent storage + client = chromadb.PersistentClient(path=self._persist_path) - if self._metadata != self._collection.metadata: - logger.warning( - "Collection already exists. " - "The `distance_function` and `metadata` parameters will be ignored." - ) - else: - self._collection = self._chroma_client.create_collection( - name=self._collection_name, - metadata=self._metadata, - embedding_function=self._embedding_func, + self._metadata = self._metadata or {} + if "hnsw:space" not in self._metadata: + self._metadata["hnsw:space"] = self._distance_function + + if self._collection_name in [c.name for c in client.list_collections()]: + self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func) + + if self._metadata != self._collection.metadata: + logger.warning( + "Collection already exists. " + "The `distance_function` and `metadata` parameters will be ignored." ) + else: + self._collection = client.create_collection( + name=self._collection_name, + metadata=self._metadata, + embedding_function=self._embedding_func, + ) self._initialized = True From a9645351d5a0e4235853395e8064b7956d6e233c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 30 Sep 2024 11:04:28 +0200 Subject: [PATCH 5/5] refinements to docstrings --- .../document_stores/chroma/document_store.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 65fda21d1..990aa4c34 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -40,9 +40,8 @@ def __init__( **embedding_function_params, ): """ - Initializes the store. The __init__ constructor is not part of the Store Protocol - and the signature can be customized to your needs. For example, parameters needed - to set up a database client would be passed to this method. + Creates a new ChromaDocumentStore instance. + It is meant to be connected to a Chroma collection. Note: for the component to be part of a serializable pipeline, the __init__ parameters must be serializable, reason why we use a registry to configure the @@ -65,7 +64,6 @@ def __init__( :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the `distance_function` parameter above. - :param embedding_function_params: additional parameters to pass to the embedding function. """ @@ -296,7 +294,7 @@ def delete_documents(self, document_ids: List[str]) -> None: """ Deletes all documents with a matching document_ids from the document store. - :param document_ids: the object_ids to delete + :param document_ids: the document ids to delete """ self._ensure_initialized() assert self._collection is not None