diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py index c562c2f5f636e..409e5294147a2 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py @@ -291,13 +291,26 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul else: where = kwargs.pop("where", {}) - results = self._collection.query( + if not query.query_embedding: + return self._get(limit=query.similarity_top_k, where=where, **kwargs) + + return self._query( query_embeddings=query.query_embedding, n_results=query.similarity_top_k, where=where, **kwargs, ) + def _query( + self, query_embeddings: List["float"], n_results: int, where: dict, **kwargs + ) -> VectorStoreQueryResult: + results = self._collection.query( + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + **kwargs, + ) + logger.debug(f"> Top {len(results['documents'])} nodes:") nodes = [] similarities = [] @@ -338,3 +351,48 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul ids.append(node_id) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) + + def _get(self, limit: int, where: dict, **kwargs) -> VectorStoreQueryResult: + results = self._collection.get( + limit=limit, + where=where, + **kwargs, + ) + + logger.debug(f"> Top {len(results['documents'])} nodes:") + nodes = [] + ids = [] + + if not results["ids"]: + results["ids"] = [[]] + + for node_id, text, metadata in zip( + results["ids"][0], results["documents"], results["metadatas"] + ): + try: + node = metadata_dict_to_node(metadata) + node.set_content(text) + except Exception: + # NOTE: deprecated legacy logic for backward compatibility + metadata, node_info, relationships = legacy_metadata_dict_to_node( + metadata + ) + + node = TextNode( + text=text, + id_=node_id, + metadata=metadata, + start_char_idx=node_info.get("start", None), + end_char_idx=node_info.get("end", None), + relationships=relationships, + ) + + nodes.append(node) + + logger.debug( + f"> [Node {node_id}] [Similarity score: N/A - using get()] " + f"{truncate_text(str(text), 100)}" + ) + ids.append(node_id) + + return VectorStoreQueryResult(nodes=nodes, ids=ids) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py index ccf10b02eacc3..6404452c3c795 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/tests/test_chromadb.py @@ -4,7 +4,13 @@ import pytest from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode from llama_index.vector_stores.chroma import ChromaVectorStore -from llama_index.core.vector_stores.types import VectorStoreQuery +from llama_index.core.vector_stores.types import ( + VectorStoreQuery, + MetadataFilter, + MetadataFilters, + FilterOperator, + FilterCondition, +) ## # Run tests @@ -159,3 +165,34 @@ async def test_add_to_chromadb_and_query( ) assert res.nodes assert res.nodes[0].get_content() == "lorem ipsum" + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_async", [True, False]) +async def test_add_to_chromadb_and_query_by_metafilters_only( + vector_store: ChromaVectorStore, + node_embeddings: List[TextNode], + use_async: bool, +) -> None: + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="author", value="Marie Curie", operator=FilterOperator.EQ + ) + ], + condition=FilterCondition.AND, + ) + + if use_async: + await vector_store.async_add(node_embeddings) + res = await vector_store.aquery( + VectorStoreQuery(filters=filters, similarity_top_k=1) + ) + else: + vector_store.add(node_embeddings) + res = vector_store.query(VectorStoreQuery(filters=filters, similarity_top_k=1)) + + assert ( + res.nodes[0].get_content() + == "I was taught that the way of progress was neither swift nor easy." + )