Skip to content

Commit

Permalink
fix(chromadb): queries without embeddings result in program crash (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
poxrud authored Feb 28, 2024
1 parent de0f8c2 commit 381afe6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,26 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
else:
where = kwargs.pop("where", {})

results = self._collection.query(
if not query.query_embedding:
return self._get(limit=query.similarity_top_k, where=where, **kwargs)

return self._query(
query_embeddings=query.query_embedding,
n_results=query.similarity_top_k,
where=where,
**kwargs,
)

def _query(
self, query_embeddings: List["float"], n_results: int, where: dict, **kwargs
) -> VectorStoreQueryResult:
results = self._collection.query(
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
**kwargs,
)

logger.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
similarities = []
Expand Down Expand Up @@ -338,3 +351,48 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
ids.append(node_id)

return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

def _get(self, limit: int, where: dict, **kwargs) -> VectorStoreQueryResult:
results = self._collection.get(
limit=limit,
where=where,
**kwargs,
)

logger.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
ids = []

if not results["ids"]:
results["ids"] = [[]]

for node_id, text, metadata in zip(
results["ids"][0], results["documents"], results["metadatas"]
):
try:
node = metadata_dict_to_node(metadata)
node.set_content(text)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(
metadata
)

node = TextNode(
text=text,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)

nodes.append(node)

logger.debug(
f"> [Node {node_id}] [Similarity score: N/A - using get()] "
f"{truncate_text(str(text), 100)}"
)
ids.append(node_id)

return VectorStoreQueryResult(nodes=nodes, ids=ids)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pytest
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core.vector_stores.types import VectorStoreQuery
from llama_index.core.vector_stores.types import (
VectorStoreQuery,
MetadataFilter,
MetadataFilters,
FilterOperator,
FilterCondition,
)

##
# Run tests
Expand Down Expand Up @@ -159,3 +165,34 @@ async def test_add_to_chromadb_and_query(
)
assert res.nodes
assert res.nodes[0].get_content() == "lorem ipsum"


@pytest.mark.asyncio()
@pytest.mark.parametrize("use_async", [True, False])
async def test_add_to_chromadb_and_query_by_metafilters_only(
vector_store: ChromaVectorStore,
node_embeddings: List[TextNode],
use_async: bool,
) -> None:
filters = MetadataFilters(
filters=[
MetadataFilter(
key="author", value="Marie Curie", operator=FilterOperator.EQ
)
],
condition=FilterCondition.AND,
)

if use_async:
await vector_store.async_add(node_embeddings)
res = await vector_store.aquery(
VectorStoreQuery(filters=filters, similarity_top_k=1)
)
else:
vector_store.add(node_embeddings)
res = vector_store.query(VectorStoreQuery(filters=filters, similarity_top_k=1))

assert (
res.nodes[0].get_content()
== "I was taught that the way of progress was neither swift nor easy."
)

0 comments on commit 381afe6

Please sign in to comment.