Skip to content

Commit

Permalink
Support specifying a similarity threshold
Browse files Browse the repository at this point in the history
Related to wagtail#21

Add support for specifying a similarity threshold in vector search methods.

* **VectorIndex Class**: Add `similarity_threshold` parameter to `query`, `find_similar`, and `search` methods in `src/wagtail_vector_index/storage/base.py`.
* **PgvectorIndexMixin Class**: Update `get_similar_documents` and `aget_similar_documents` methods to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/pgvector/provider.py`.
* **QdrantIndexMixin Class**: Update `get_similar_documents` method to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/qdrant/provider.py`.
* **WeaviateIndexMixin Class**: Update `get_similar_documents` method to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/weaviate/provider.py`.
* **NumpyIndexMixin Class**: Update `get_similar_documents` method to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/numpy/provider.py`.
* **Documentation**: Update `docs/vector-indexes.md` to include information on the new `similarity_threshold` parameter and provide examples of its usage.
* **Tests**: Add tests for the new `similarity_threshold` parameter in `query`, `find_similar`, and `search` methods in `tests/test_index.py`.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/wagtail/wagtail-vector-index/issues/21?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
brylie committed Jul 30, 2024
1 parent ea41b7d commit aea8437
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 23 deletions.
68 changes: 68 additions & 0 deletions docs/vector-indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,71 @@ class MyWeaviateVectorIndex(
):
storage_provider_alias = "weaviate"
```

# Using Similarity Threshold in Vector Search

The vector search implementation supports a similarity threshold for filtering results. This feature ensures that only the most relevant documents are returned, improving both the quality of results and query performance.

## Understanding Similarity Threshold

The similarity threshold is a float value between 0 and 1, where:
- 0 means no threshold (all results are returned, up to the specified limit)
- 1 means maximum similarity (only exact matches would be returned)
- Values in between filter results based on their similarity to the query

Higher threshold values result in fewer, but potentially more relevant results.

## Similarity Threshold in Vector Index Methods

The `VectorIndex` class includes a `similarity_threshold` parameter in methods like `query`, `find_similar`, and `search`:

```python
class MyVectorIndex(EmbeddableFieldsVectorIndexMixin, PgvectorIndexMixin, VectorIndex):
def query(self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default", similarity_threshold: float = 0.0) -> QueryResponse:
# Implementation here

def find_similar(self, object, *, include_self: bool = False, limit: int = 5, similarity_threshold: float = 0.0) -> list:
# Implementation here

def search(self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0) -> list:
# Implementation here
```

## Example Usage

Here are examples of how to use the `similarity_threshold` parameter:

```python
index = MyVectorIndex()

# Query with similarity threshold
response = index.query("What is the capital of France?", similarity_threshold=0.8)

# Find similar objects with similarity threshold
similar_objects = index.find_similar(my_object, similarity_threshold=0.7)

# Search with similarity threshold
search_results = index.search("AI in healthcare", similarity_threshold=0.75)
```

## Best Practices and Considerations

1. **Choosing a Threshold**: Start with a lower threshold (e.g., 0.5) and adjust based on your specific use case and the quality of results.

2. **Performance**: Higher thresholds can improve query performance by reducing the number of results processed.

3. **Result Set Size**: Be aware that high thresholds might significantly reduce the number of results. Always check if your result set is empty and consider lowering the threshold if necessary.

4. **Backend Differences**: While we strive for consistency, different vector search backends (e.g., pgvector, Qdrant, Weaviate) may have slight variations in how similarity is calculated. Test thoroughly with your specific backend.

5. **Combining with Limit**: The `similarity_threshold` works in conjunction with the `limit` parameter. Results are first filtered by the similarity threshold, then limited to the specified number.

## Debugging and Tuning

If you're not getting the expected results:

1. Try lowering the threshold to see if more relevant results appear.
2. Check the similarity scores of your results (if available) to understand the distribution.
3. Consider the nature of your data and queries. Some domains might require lower thresholds to capture relevant semantic relationships.

Remember, the optimal threshold can vary depending on your specific use case, data, and the embedding model used. Experimentation and iterative tuning are often necessary to find the best balance between precision and recall for your application.
20 changes: 10 additions & 10 deletions src/wagtail_vector_index/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ def get_converter(self) -> DocumentConverter:
# Public API

def query(
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default"
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default", similarity_threshold: float = 0.0
) -> QueryResponse:
"""Perform a natural language query against the index, returning a QueryResponse containing the natural language response, and a list of sources"""
try:
query_embedding = next(self.get_embedding_backend().embed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e

similar_documents = list(self.get_similar_documents(query_embedding))
similar_documents = list(self.get_similar_documents(query_embedding, similarity_threshold=similarity_threshold))

sources = list(self.get_converter().bulk_from_documents(similar_documents))

Expand All @@ -165,7 +165,7 @@ def query(
return QueryResponse(response=response.choices[0], sources=sources)

async def aquery(
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default"
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default", similarity_threshold: float = 0.0
) -> AsyncQueryResponse:
"""
Replicates the features of `VectorIndex.query()`, but in an async way.
Expand All @@ -176,7 +176,7 @@ async def aquery(
raise ValueError("No embeddings were generated for the given query.") from e

similar_documents = [
doc async for doc in self.aget_similar_documents(query_embedding)
doc async for doc in self.aget_similar_documents(query_embedding, similarity_threshold=similarity_threshold)
]

sources = [
Expand Down Expand Up @@ -209,7 +209,7 @@ async def async_stream_wrapper():
)

def find_similar(
self, object, *, include_self: bool = False, limit: int = 5
self, object, *, include_self: bool = False, limit: int = 5, similarity_threshold: float = 0.0
) -> list:
"""Find similar objects to the given object"""
converter = self.get_converter()
Expand All @@ -219,7 +219,7 @@ def find_similar(
similar_documents = []
for document in object_documents:
similar_documents += self.get_similar_documents(
document.vector, limit=limit
document.vector, limit=limit, similarity_threshold=similarity_threshold
)

return [
Expand All @@ -228,13 +228,13 @@ def find_similar(
if include_self or obj != object
]

def search(self, query: str, *, limit: int = 5) -> list:
def search(self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0) -> list:
"""Perform a search against the index, returning only a list of matching sources"""
try:
query_embedding = next(self.get_embedding_backend().embed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e
similar_documents = self.get_similar_documents(query_embedding, limit=limit)
similar_documents = self.get_similar_documents(query_embedding, limit=limit, similarity_threshold=similarity_threshold)
return list(self.get_converter().bulk_from_documents(similar_documents))

# Utilities
Expand Down Expand Up @@ -262,11 +262,11 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
raise NotImplementedError

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self, query_vector: Sequence[float], *, limit: int = 5, similarity_threshold: float = 0.0
) -> Generator[Document, None, None]:
raise NotImplementedError

def aget_similar_documents(
self, query_vector, *, limit: int = 5
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> AsyncGenerator[Document, None]:
raise NotImplementedError
5 changes: 3 additions & 2 deletions src/wagtail_vector_index/storage/numpy/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
pass

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self, query_vector: Sequence[float], *, limit: int = 5, similarity_threshold: float = 0.0
) -> Generator[Document, None, None]:
similarities = []
for document in self.get_documents():
Expand All @@ -44,7 +44,8 @@ def get_similar_documents(
/ np.linalg.norm(query_vector)
* np.linalg.norm(document.vector)
)
similarities.append((cosine_similarity, document))
if cosine_similarity >= similarity_threshold:
similarities.append((cosine_similarity, document))

sorted_similarities = sorted(
similarities, key=lambda pair: pair[0], reverse=True
Expand Down
21 changes: 13 additions & 8 deletions src/wagtail_vector_index/storage/pgvector/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ def clear(self):
self._get_queryset().delete()

def get_similar_documents(
self, query_vector, *, limit: int = 5
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> Generator[Document, None, None]:
for pgvector_embedding in self._get_similar_documents_queryset(
query_vector, limit=limit
query_vector, limit=limit, similarity_threshold=similarity_threshold
).iterator():
embedding = pgvector_embedding.embedding
yield embedding.to_document()

async def aget_similar_documents(
self, query_vector, *, limit: int = 5
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> AsyncGenerator[Document, None]:
async for pgvector_embedding in self._get_similar_documents_queryset(
query_vector, limit=limit
query_vector, limit=limit, similarity_threshold=similarity_threshold
):
embedding = pgvector_embedding.embedding
yield embedding.to_document()
Expand All @@ -105,18 +105,23 @@ def _get_queryset(self) -> "PgvectorEmbeddingQuerySet":
)

def _get_similar_documents_queryset(
self, query_vector: Sequence[float], *, limit: int
self, query_vector: Sequence[float], *, limit: int, similarity_threshold: float
) -> "PgvectorEmbeddingQuerySet":
return (
queryset = (
self._get_queryset()
.select_related("embedding")
.filter(embedding_output_dimensions=len(query_vector))
.order_by_distance(
query_vector,
distance_method=self.distance_method,
fetch_distance=False,
)[:limit]
fetch_distance=True,
)
)
if similarity_threshold > 0.0:
# Convert similarity threshold to distance threshold
distance_threshold = 1 - similarity_threshold
queryset = queryset.filter(distance__lte=distance_threshold)
return queryset[:limit]

def _bulk_create(self, embeddings: Sequence["PgvectorEmbedding"]) -> None:
_embedding_model().objects.bulk_create(
Expand Down
32 changes: 30 additions & 2 deletions src/wagtail_vector_index/storage/qdrant/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,38 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
)

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self, query_vector: Sequence[float], *, limit: int = 5, similarity_threshold: float = 0.0
) -> Generator[Document, None, None]:
"""
Retrieve similar documents from Qdrant.
Args:
query_vector (Sequence[float]): The query vector to find similar documents for.
limit (int): The maximum number of similar documents to return.
similarity_threshold (float): The minimum similarity score for returned documents.
Range is [0, 1] where 1 is most similar.
0 means no threshold (default).
Returns:
Generator[Document, None, None]: A generator of similar documents.
Note:
Qdrant uses cosine similarity by default, where higher scores indicate
more similar vectors. The similarity_threshold is used directly as
Qdrant's score_threshold.
"""
if not 0 <= similarity_threshold <= 1:
raise ValueError("similarity_threshold must be between 0 and 1")

# Convert similarity threshold to score threshold
# For Qdrant with cosine similarity, we can use the similarity_threshold directly
score_threshold = similarity_threshold if similarity_threshold > 0 else None

similar_documents = self.storage_provider.client.search(
collection_name=self.index_name, query_vector=query_vector, limit=limit
collection_name=self.index_name,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold
)
for doc in similar_documents:
yield Document(
Expand Down
30 changes: 29 additions & 1 deletion src/wagtail_vector_index/storage/weaviate/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,39 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
raise NotImplementedError

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self, query_vector: Sequence[float], *, limit: int = 5, similarity_threshold: float = 0.0
) -> Generator[Document, None, None]:
"""
Retrieve similar documents from Weaviate.
Args:
query_vector (Sequence[float]): The query vector to find similar documents for.
limit (int): The maximum number of similar documents to return. Defaults to 5.
similarity_threshold (float): The minimum similarity for returned documents.
Range is [0, 1] where 1 is most similar.
0 means no threshold (default).
Returns:
Generator[Document, None, None]: A generator of similar documents.
Note:
Weaviate uses cosine distance internally, where lower values indicate
more similar vectors. The similarity_threshold is converted to a
distance threshold where distance = 1 - similarity.
"""
if not 0 <= similarity_threshold <= 1:
raise ValueError("similarity_threshold must be between 0 and 1")

# Convert similarity threshold to distance threshold
# Weaviate uses cosine distance, which is 1 - cosine similarity
distance_threshold = 1 - similarity_threshold if similarity_threshold > 0 else None

near_vector = {
"vector": query_vector,
}
if distance_threshold is not None:
near_vector["distance"] = distance_threshold

similar_documents = (
self.storage_provider.client.query.get(
self.index_name,
Expand Down
66 changes: 66 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,69 @@ def get_similar_documents(query_embedding, limit=0):
index.query("")
first_call_messages = query_mock.call_args.kwargs["messages"]
assert first_call_messages[1] == {"content": expected_content, "role": "system"}


@pytest.mark.django_db
def test_query_with_similarity_threshold(mocker):
ExamplePageFactory.create_batch(2)
index = ExamplePage.vector_index
documents = index.get_documents()[:2]

def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5):
yield from documents

query_mock = mocker.patch("conftest.ChatMockBackend.chat")
expected_content = "\n".join([doc.metadata["content"] for doc in documents])
similar_documents_mock = mocker.patch.object(index, "get_similar_documents")
similar_documents_mock.side_effect = get_similar_documents
index.query("", similarity_threshold=0.5)
first_call_messages = query_mock.call_args.kwargs["messages"]
assert first_call_messages[1] == {"content": expected_content, "role": "system"}


@pytest.mark.django_db
def test_find_similar_with_similarity_threshold(mocker):
pages = ExamplePageFactory.create_batch(10)
vector_index = ExamplePage.vector_index

def gen_pages(cls, *args, **kwargs):
yield from pages

mocker.patch(
"wagtail_vector_index.storage.models.EmbeddableFieldsDocumentConverter.bulk_from_documents",
side_effect=gen_pages,
)

case = unittest.TestCase()

# We expect 9 results without the page itself.
actual = vector_index.find_similar(pages[0], limit=100, include_self=False, similarity_threshold=0.5)
case.assertCountEqual(actual, pages[1:])

# We expect 10 results with the page itself.
actual = vector_index.find_similar(pages[0], limit=100, include_self=True, similarity_threshold=0.5)
case.assertCountEqual(actual, pages)


@pytest.mark.django_db
def test_search_with_similarity_threshold(mocker):
pages = ExamplePageFactory.create_batch(10)
vector_index = ExamplePage.vector_index

def gen_pages(cls, *args, **kwargs):
yield from pages

mocker.patch(
"wagtail_vector_index.storage.models.EmbeddableFieldsDocumentConverter.bulk_from_documents",
side_effect=gen_pages,
)

case = unittest.TestCase()

# We expect 9 results without the page itself.
actual = vector_index.search("test", limit=100, similarity_threshold=0.5)
case.assertCountEqual(actual, pages[1:])

# We expect 10 results with the page itself.
actual = vector_index.search("test", limit=100, similarity_threshold=0.5)
case.assertCountEqual(actual, pages)

0 comments on commit aea8437

Please sign in to comment.