Skip to content

Commit

Permalink
Support specifying a similarity threshold (#81)
Browse files Browse the repository at this point in the history
* Support specifying a similarity threshold

Related to #21

Add support for specifying a similarity threshold in vector search methods.

* **VectorIndex Class**: Add `similarity_threshold` parameter to `query`, `find_similar`, and `search` methods in `src/wagtail_vector_index/storage/base.py`.
* **PgvectorIndexMixin Class**: Update `get_similar_documents` and `aget_similar_documents` methods to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/pgvector/provider.py`.
* **QdrantIndexMixin Class**: Update `get_similar_documents` method to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/qdrant/provider.py`.
* **WeaviateIndexMixin Class**: Update `get_similar_documents` method to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/weaviate/provider.py`.
* **NumpyIndexMixin Class**: Update `get_similar_documents` method to accept `similarity_threshold` parameter and filter results based on it in `src/wagtail_vector_index/storage/numpy/provider.py`.
* **Documentation**: Update `docs/vector-indexes.md` to include information on the new `similarity_threshold` parameter and provide examples of its usage.
* **Tests**: Add tests for the new `similarity_threshold` parameter in `query`, `find_similar`, and `search` methods in `tests/test_index.py`.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/wagtail/wagtail-vector-index/issues/21?shareId=XXXX-XXXX-XXXX-XXXX).

* Add similarity_score to mock get_similar_documents function

* Ignore .DS_Store

* Fix test_search_with_similarity_threshold

* Update similarity documentation

* Run pre-commit

* Simplify/clarify docs with Grammarly

* Fix Pyright error

* Use only one set of vectors

* Use Pytest style assertions
  • Loading branch information
brylie authored Aug 12, 2024
1 parent ba35e82 commit cd015ec
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ devenv.lock
devenv.nix
.env
.ruff_cache
.DS_Store
54 changes: 54 additions & 0 deletions docs/vector-indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,57 @@ class MyWeaviateVectorIndex(
):
storage_provider_alias = "weaviate"
```

## Using Similarity Threshold in Vector Search

### Understanding Similarity in Vector Search

Vector search operates on the principle of finding documents or items that are "similar" to a given query or reference item. This similarity is typically measured by the distance between vectors in a high-dimensional space. Before we dive into similarity thresholds, it's crucial to understand this concept.

### What is a Similarity Threshold?

A similarity threshold in vector search helps control the relevance of results. It filters out less relevant documents, improving result quality and query performance.

The similarity threshold represents a float value between 0 and 1, indicating the degree of similarity.

- 0 means no threshold (the system returns all results within the specified limit.)
- 1 means maximum similarity (the system returns only exact matches)
- Values in between filter results based on their similarity to the query

The similarity threshold has a significant impact on the number of returned results. Higher threshold values lead to fewer results, but these are potentially more relevant, highlighting the trade-off between result quantity and relevance.

### Implementing Similarity Threshold in Wagtail Vector Index

In Wagtail Vector Index, the `VectorIndex` class includes a `similarity_threshold` parameter in key methods:

* `query`: This method is used for querying the vector index.
* `find_similar`: This method finds similar objects in the vector index.
* `search`: This method performs a search in the vector index.

Each of these methods includes a `similarity_threshold` parameter, allowing you to control the similarity threshold for that specific operation.

### Best Practices and Considerations

1. **Choosing a Threshold**: Start with a lower threshold (e.g., 0.5) and adjust based on your specific use case and the quality of results.
2. Performance Impact: Optimistically, higher thresholds can significantly improve query performance by reducing the number of results processed. This potential for optimization is a key advantage of vector search.
3. **Result Set Size**: Be aware that high thresholds might significantly reduce the number of results. Always check if your result set is empty, and consider lowering the threshold if necessary.
4. **Backend Differences**: While we strive for consistency, different vector search backends (e.g., pgvector, Qdrant, Weaviate) may calculate similarity slightly differently. Test thoroughly with your specific backend.
5. **Combining with Limit**: The `similarity_threshold` parameter works in conjunction with the `limit` parameter. Results are first filtered by the similarity threshold and then limited to the specified number.

## Practical Applications

Consider these scenarios where adjusting the similarity threshold can be beneficial:

1. **Content Recommendations**: In a content recommendation system, starting with a lower threshold to cast a wide net and then gradually increasing it as you gather more user data underscores the value of your work in refining recommendations.
2. **Semantic Search**: A higher threshold in a semantic search engine can ensure that it returns more relevant results, thereby improving the user experience.
3. **Duplicate Detection: A very high threshold is appropriate to catch only the closest matches when looking for near-duplicate content.

### Debugging and Tuning

If you're not getting the expected results:

1. Try lowering the threshold to see if more relevant results appear.
2. Check the similarity scores of your results (if available) to understand the distribution.
3. Consider the nature of your data and queries. Some domains require lower thresholds to capture relevant semantic relationships.

Remember, the optimal threshold can vary depending on your specific use case, data, and embedding model. Experimentation and iterative tuning are often necessary to find the best balance between precision and recall for your application.
50 changes: 40 additions & 10 deletions src/wagtail_vector_index/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,24 @@ def get_converter(self) -> DocumentConverter:
# Public API

def query(
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default"
self,
query: str,
*,
sources_limit: int = 5,
chat_backend_alias: str = "default",
similarity_threshold: float = 0.0,
) -> QueryResponse:
"""Perform a natural language query against the index, returning a QueryResponse containing the natural language response, and a list of sources"""
try:
query_embedding = next(self.get_embedding_backend().embed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e

similar_documents = list(self.get_similar_documents(query_embedding))
similar_documents = list(
self.get_similar_documents(
query_embedding, similarity_threshold=similarity_threshold
)
)

sources = list(self.get_converter().bulk_from_documents(similar_documents))

Expand All @@ -165,7 +174,12 @@ def query(
return QueryResponse(response=response.choices[0], sources=sources)

async def aquery(
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default"
self,
query: str,
*,
sources_limit: int = 5,
chat_backend_alias: str = "default",
similarity_threshold: float = 0.0,
) -> AsyncQueryResponse:
"""
Replicates the features of `VectorIndex.query()`, but in an async way.
Expand All @@ -176,7 +190,10 @@ async def aquery(
raise ValueError("No embeddings were generated for the given query.") from e

similar_documents = [
doc async for doc in self.aget_similar_documents(query_embedding)
doc
async for doc in self.aget_similar_documents(
query_embedding, similarity_threshold=similarity_threshold
)
]

sources = [
Expand Down Expand Up @@ -209,7 +226,12 @@ async def async_stream_wrapper():
)

def find_similar(
self, object, *, include_self: bool = False, limit: int = 5
self,
object,
*,
include_self: bool = False,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> list:
"""Find similar objects to the given object"""
converter = self.get_converter()
Expand All @@ -219,7 +241,7 @@ def find_similar(
similar_documents = []
for document in object_documents:
similar_documents += self.get_similar_documents(
document.vector, limit=limit
document.vector, limit=limit, similarity_threshold=similarity_threshold
)

return [
Expand All @@ -228,13 +250,17 @@ def find_similar(
if include_self or obj != object
]

def search(self, query: str, *, limit: int = 5) -> list:
def search(
self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0
) -> list:
"""Perform a search against the index, returning only a list of matching sources"""
try:
query_embedding = next(self.get_embedding_backend().embed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e
similar_documents = self.get_similar_documents(query_embedding, limit=limit)
similar_documents = self.get_similar_documents(
query_embedding, limit=limit, similarity_threshold=similarity_threshold
)
return list(self.get_converter().bulk_from_documents(similar_documents))

# Utilities
Expand Down Expand Up @@ -262,11 +288,15 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
raise NotImplementedError

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self,
query_vector: Sequence[float],
*,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> Generator[Document, None, None]:
raise NotImplementedError

def aget_similar_documents(
self, query_vector, *, limit: int = 5
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> AsyncGenerator[Document, None]:
raise NotImplementedError
9 changes: 7 additions & 2 deletions src/wagtail_vector_index/storage/numpy/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
pass

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self,
query_vector: Sequence[float],
*,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> Generator[Document, None, None]:
similarities = []
for document in self.get_documents():
Expand All @@ -44,7 +48,8 @@ def get_similar_documents(
/ np.linalg.norm(query_vector)
* np.linalg.norm(document.vector)
)
similarities.append((cosine_similarity, document))
if cosine_similarity >= similarity_threshold:
similarities.append((cosine_similarity, document))

sorted_similarities = sorted(
similarities, key=lambda pair: pair[0], reverse=True
Expand Down
21 changes: 13 additions & 8 deletions src/wagtail_vector_index/storage/pgvector/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ def clear(self):
self._get_queryset().delete()

def get_similar_documents(
self, query_vector, *, limit: int = 5
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> Generator[Document, None, None]:
for pgvector_embedding in self._get_similar_documents_queryset(
query_vector, limit=limit
query_vector, limit=limit, similarity_threshold=similarity_threshold
).iterator():
embedding = pgvector_embedding.embedding
yield embedding.to_document()

async def aget_similar_documents(
self, query_vector, *, limit: int = 5
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> AsyncGenerator[Document, None]:
async for pgvector_embedding in self._get_similar_documents_queryset(
query_vector, limit=limit
query_vector, limit=limit, similarity_threshold=similarity_threshold
):
embedding = pgvector_embedding.embedding
yield embedding.to_document()
Expand All @@ -105,18 +105,23 @@ def _get_queryset(self) -> "PgvectorEmbeddingQuerySet":
)

def _get_similar_documents_queryset(
self, query_vector: Sequence[float], *, limit: int
self, query_vector: Sequence[float], *, limit: int, similarity_threshold: float
) -> "PgvectorEmbeddingQuerySet":
return (
queryset = (
self._get_queryset()
.select_related("embedding")
.filter(embedding_output_dimensions=len(query_vector))
.order_by_distance(
query_vector,
distance_method=self.distance_method,
fetch_distance=False,
)[:limit]
fetch_distance=True,
)
)
if similarity_threshold > 0.0:
# Convert similarity threshold to distance threshold
distance_threshold = 1 - similarity_threshold
queryset = queryset.filter(distance__lte=distance_threshold)
return queryset[:limit]

def _bulk_create(self, embeddings: Sequence["PgvectorEmbedding"]) -> None:
_embedding_model().objects.bulk_create(
Expand Down
36 changes: 34 additions & 2 deletions src/wagtail_vector_index/storage/qdrant/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,42 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
)

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self,
query_vector: Sequence[float],
*,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> Generator[Document, None, None]:
"""
Retrieve similar documents from Qdrant.
Args:
query_vector (Sequence[float]): The query vector to find similar documents for.
limit (int): The maximum number of similar documents to return.
similarity_threshold (float): The minimum similarity score for returned documents.
Range is [0, 1] where 1 is most similar.
0 means no threshold (default).
Returns:
Generator[Document, None, None]: A generator of similar documents.
Note:
Qdrant uses cosine similarity by default, where higher scores indicate
more similar vectors. The similarity_threshold is used directly as
Qdrant's score_threshold.
"""
if not 0 <= similarity_threshold <= 1:
raise ValueError("similarity_threshold must be between 0 and 1")

# Convert similarity threshold to score threshold
# For Qdrant with cosine similarity, we can use the similarity_threshold directly
score_threshold = similarity_threshold if similarity_threshold > 0 else None

similar_documents = self.storage_provider.client.search(
collection_name=self.index_name, query_vector=query_vector, limit=limit
collection_name=self.index_name,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold,
)
for doc in similar_documents:
yield Document(
Expand Down
36 changes: 35 additions & 1 deletion src/wagtail_vector_index/storage/weaviate/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,45 @@ def delete(self, *, document_ids: Sequence[str]) -> None:
raise NotImplementedError

def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
self,
query_vector: Sequence[float],
*,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> Generator[Document, None, None]:
"""
Retrieve similar documents from Weaviate.
Args:
query_vector (Sequence[float]): The query vector to find similar documents for.
limit (int): The maximum number of similar documents to return. Defaults to 5.
similarity_threshold (float): The minimum similarity for returned documents.
Range is [0, 1] where 1 is most similar.
0 means no threshold (default).
Returns:
Generator[Document, None, None]: A generator of similar documents.
Note:
Weaviate uses cosine distance internally, where lower values indicate
more similar vectors. The similarity_threshold is converted to a
distance threshold where distance = 1 - similarity.
"""
if not 0 <= similarity_threshold <= 1:
raise ValueError("similarity_threshold must be between 0 and 1")

# Convert similarity threshold to distance threshold
# Weaviate uses cosine distance, which is 1 - cosine similarity
distance_threshold = (
1 - similarity_threshold if similarity_threshold > 0 else None
)

near_vector = {
"vector": query_vector,
}
if distance_threshold is not None:
near_vector["distance"] = [distance_threshold]

similar_documents = (
self.storage_provider.client.query.get(
self.index_name,
Expand Down
Loading

0 comments on commit cd015ec

Please sign in to comment.