Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Chroma - defer the DB connection #1107

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integrations/chroma/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ ignore = [
"PLR0915",
# Ignore unused params
"ARG002",
# Allow assertions
"S101",
]
unfixable = [
# Don't touch unused imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def __init__(
**embedding_function_params,
):
"""
Initializes the store. The __init__ constructor is not part of the Store Protocol
and the signature can be customized to your needs. For example, parameters needed
to set up a database client would be passed to this method.
Creates a new ChromaDocumentStore instance.
It is meant to be connected to a Chroma collection.

Note: for the component to be part of a serializable pipeline, the __init__
parameters must be serializable, reason why we use a registry to configure the
Expand All @@ -65,7 +64,6 @@ def __init__(
:param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client
method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the
`distance_function` parameter above.

:param embedding_function_params: additional parameters to pass to the embedding function.
"""

Expand All @@ -79,60 +77,70 @@ def __init__(
# Store the params for marshalling
self._collection_name = collection_name
self._embedding_function = embedding_function
self._embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
self._embedding_function_params = embedding_function_params
self._distance_function = distance_function
self._metadata = metadata
self._collection = None

self._persist_path = persist_path
self._host = host
self._port = port

# Create the client instance
if persist_path and (host or port is not None):
error_message = (
"You must specify `persist_path` for local persistent storage or, "
"alternatively, `host` and `port` for remote HTTP client connection. "
"You cannot specify both options."
)
raise ValueError(error_message)
if host and port is not None:
# Remote connection via HTTP client
self._chroma_client = chromadb.HttpClient(
host=host,
port=port,
)
elif persist_path is None:
# In-memory storage
self._chroma_client = chromadb.Client()
else:
# Local persistent storage
self._chroma_client = chromadb.PersistentClient(path=persist_path)
self._initialized = False

embedding_func = get_embedding_function(embedding_function, **embedding_function_params)
def _ensure_initialized(self):
if not self._initialized:
# Create the client instance
if self._persist_path and (self._host or self._port is not None):
error_message = (
"You must specify `persist_path` for local persistent storage or, "
"alternatively, `host` and `port` for remote HTTP client connection. "
"You cannot specify both options."
)
raise ValueError(error_message)
if self._host and self._port is not None:
# Remote connection via HTTP client
client = chromadb.HttpClient(
host=self._host,
port=self._port,
)
elif self._persist_path is None:
# In-memory storage
client = chromadb.Client()
else:
# Local persistent storage
client = chromadb.PersistentClient(path=self._persist_path)

metadata = metadata or {}
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = distance_function
self._metadata = self._metadata or {}
if "hnsw:space" not in self._metadata:
self._metadata["hnsw:space"] = self._distance_function

if collection_name in [c.name for c in self._chroma_client.list_collections()]:
self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func)
if self._collection_name in [c.name for c in client.list_collections()]:
self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func)

if metadata != self._collection.metadata:
logger.warning(
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
if self._metadata != self._collection.metadata:
logger.warning(
"Collection already exists. "
"The `distance_function` and `metadata` parameters will be ignored."
)
else:
self._collection = client.create_collection(
name=self._collection_name,
metadata=self._metadata,
embedding_function=self._embedding_func,
)
else:
self._collection = self._chroma_client.create_collection(
name=collection_name,
metadata=metadata,
embedding_function=embedding_func,
)

self._initialized = True

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.

:returns: how many documents are present in the document store.
"""
self._ensure_initialized()
assert self._collection is not None
return self._collection.count()

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
Expand Down Expand Up @@ -197,6 +205,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:param filters: the filters to apply to the document list.
:returns: a list of Documents that match the given filters.
"""
self._ensure_initialized()
assert self._collection is not None

if filters:
chroma_filter = _convert_filters(filters)
kwargs: Dict[str, Any] = {"where": chroma_filter.where}
Expand Down Expand Up @@ -227,6 +238,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:returns:
The number of documents written
"""
self._ensure_initialized()
assert self._collection is not None

for doc in documents:
if not isinstance(doc, Document):
msg = "param 'documents' must contain a list of objects of type Document"
Expand Down Expand Up @@ -280,8 +294,11 @@ def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the document store.

:param document_ids: the object_ids to delete
:param document_ids: the document ids to delete
"""
self._ensure_initialized()
assert self._collection is not None

self._collection.delete(ids=document_ids)

def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]:
Expand All @@ -292,6 +309,9 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any
:param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format.
:returns: matching documents for each query.
"""
self._ensure_initialized()
assert self._collection is not None

if filters is None:
results = self._collection.query(
query_texts=queries,
Expand Down Expand Up @@ -323,6 +343,9 @@ def search_embeddings(
:returns: a list of lists of documents that match the given filters.

"""
self._ensure_initialized()
assert self._collection is not None

if filters is None:
results = self._collection.query(
query_embeddings=query_embeddings,
Expand Down
10 changes: 9 additions & 1 deletion integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def test_invalid_initialization_both_host_and_persist_path(self):
Test that providing both host and persist_path raises an error.
"""
with pytest.raises(ValueError):
ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
store._ensure_initialized()

def test_delete_empty(self, document_store: ChromaDocumentStore):
"""
Expand Down Expand Up @@ -207,6 +208,7 @@ def test_same_collection_name_reinitialization(self):
@pytest.mark.integration
def test_distance_metric_initialization(self):
store = ChromaDocumentStore("test_2", distance_function="cosine")
store._ensure_initialized()
assert store._collection.metadata["hnsw:space"] == "cosine"

with pytest.raises(ValueError):
Expand All @@ -215,9 +217,11 @@ def test_distance_metric_initialization(self):
@pytest.mark.integration
def test_distance_metric_reinitialization(self, caplog):
store = ChromaDocumentStore("test_4", distance_function="cosine")
store._ensure_initialized()

with caplog.at_level(logging.WARNING):
new_store = ChromaDocumentStore("test_4", distance_function="ip")
new_store._ensure_initialized()

assert (
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
Expand All @@ -238,6 +242,8 @@ def test_metadata_initialization(self, caplog):
"hnsw:M": 103,
},
)
store._ensure_initialized()

assert store._collection.metadata["hnsw:space"] == "ip"
assert store._collection.metadata["hnsw:search_ef"] == 101
assert store._collection.metadata["hnsw:construction_ef"] == 102
Expand All @@ -254,6 +260,8 @@ def test_metadata_initialization(self, caplog):
},
)

new_store._ensure_initialized()

assert (
"Collection already exists. The `distance_function` and `metadata` parameters will be ignored."
in caplog.text
Expand Down
Loading