Skip to content

Commit

Permalink
keyword retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
jlonge4 committed Apr 5, 2024
1 parent 0fd7f91 commit 9eefa57
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore


@component
class PgvectorKeywordRetriever:
"""
Retrieves documents from the `PgvectorDocumentStore`, based on their sparse vectors.
Example usage:
```python
from haystack.document_stores import DuplicatePolicy
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever
# Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database.
# e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"
document_store = PgvectorDocumentStore(
embedding_dimension=768,
vector_function="cosine_similarity",
recreate_table=True,
)
documents = [Document(content="There are over 7,000 languages spoken around the world today."),
Document(content="Elephants have been observed to behave in a way that indicates..."),
Document(content="In certain places, you can witness the phenomenon of bioluminescent waves.")]
document_embedder = SentenceTransformersDocumentEmbedder()
document_embedder.warm_up()
documents_with_embeddings = document_embedder.run(documents)
document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE)
query_pipeline = Pipeline()
query_pipeline.add_component("retriever", PgvectorKeywordRetriever(document_store=document_store))
query_pipeline.connect("query", "retriever.query")
query = "How many languages are there?"
res = query_pipeline.run({"retriever": {"text": query}})
assert res['retriever']['documents'][0].content == "There are over 7,000 languages spoken around the world today."
"""

def __init__(
self,
*,
document_store: PgvectorDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
language: str = "english",
):
"""
:param document_store: An instance of `PgvectorDocumentStore}.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:raises ValueError: If `document_store` is not an instance of `PgvectorDocumentStore` or if `vector_function`
is not one of the valid options.
"""
if not isinstance(document_store, PgvectorDocumentStore):
msg = "document_store must be an instance of PgvectorDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k
self.language = language

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
language=self.language,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
doc_store_params = data["init_parameters"]["document_store"]
data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
user_query: str,
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
language: Optional[str] = "english",
):
"""
Retrieve documents from the `PgvectorDocumentStore`, based on their embeddings.
:param user_input: The user's query.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:returns: List of Documents similar to `user_query`.
"""
filters = filters or self.filters
top_k = top_k or self.top_k
language = language or self.language

docs = self.document_store._keyword_retrieval(
user_query=user_query,
filters=filters,
top_k=top_k,
language=language,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def __init__(
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity",
recreate_table: bool = False,
search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor",
hybrid_search: bool = False,
hnsw_recreate_index_if_exists: bool = False,
hnsw_index_creation_kwargs: Optional[Dict[str, int]] = None,
hnsw_ef_search: Optional[int] = None,
language: Optional[str] = "english",
):
"""
Creates a new PgvectorDocumentStore instance.
Expand Down Expand Up @@ -117,6 +119,7 @@ def __init__(
:param hnsw_ef_search: The `ef_search` parameter to use at query time. Only used if search_strategy is set to
`"hnsw"`. You can find more information about this parameter in the
[pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw)
:param language: The language to use for the full-text/hybrid search.
"""

self.connection_string = connection_string
Expand All @@ -128,9 +131,11 @@ def __init__(
self.vector_function = vector_function
self.recreate_table = recreate_table
self.search_strategy = search_strategy
self.hybrid_search = hybrid_search
self.hnsw_recreate_index_if_exists = hnsw_recreate_index_if_exists
self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {}
self.hnsw_ef_search = hnsw_ef_search
self.language = language

connection = connect(self.connection_string.resolve_value())
connection.autocommit = True
Expand Down Expand Up @@ -168,6 +173,7 @@ def to_dict(self) -> Dict[str, Any]:
hnsw_recreate_index_if_exists=self.hnsw_recreate_index_if_exists,
hnsw_index_creation_kwargs=self.hnsw_index_creation_kwargs,
hnsw_ef_search=self.hnsw_ef_search,
language=self.language,
)

@classmethod
Expand Down Expand Up @@ -220,6 +226,7 @@ def _create_table_if_not_exists(self):
)

self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore")
self._create_keyword_index()

def delete_table(self):
"""
Expand All @@ -231,6 +238,17 @@ def delete_table(self):

self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore")

def _create_keyword_index(self):
"""
Internal method to create the keyword index.
"""

sql_create_index = SQL("CREATE INDEX ON {table_name} USING GIN (to_tsvector({language}, content))").format(
table_name=Identifier(self.table_name), language=SQLLiteral(self.language)
)

self._execute_sql(sql_create_index, error_msg="Could not create keyword index on table")

def _handle_hnsw(self):
"""
Internal method to handle the HNSW index creation.
Expand Down Expand Up @@ -415,16 +433,6 @@ def _from_haystack_to_pg_documents(documents: List[Document]) -> List[Dict[str,
db_document["dataframe"] = Jsonb(db_document["dataframe"]) if db_document["dataframe"] else None
db_document["meta"] = Jsonb(db_document["meta"])

if "sparse_embedding" in db_document:
sparse_embedding = db_document.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Pgvector is not currently supported."
"The `sparse_embedding` field will be ignored.",
db_document["id"],
)

db_documents.append(db_document)

return db_documents
Expand Down Expand Up @@ -475,6 +483,56 @@ def delete_documents(self, document_ids: List[str]) -> None:

self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore")

def _keyword_retrieval(
self,
user_query: str,
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
language: Optional[str] = "english",
) -> List[Document]:
"""
Retrieves documents that are most similar to the query using a full-text search.
This method is not meant to be part of the public interface of
`PgvectorDocumentStore` and it should not be called directly.
`PgvectorKeywordRetriever` uses this method directly and is the public interface for it.
:returns: List of Documents that are most similar to `user_query`
"""

if not user_query:
msg = "user_query must be a non-empty string"
raise ValueError(msg)

sql_select = SQL(
"""SELECT *, RANK() OVER (ORDER BY
ts_rank_cd(to_tsvector({language}, content), query) DESC) AS rank
FROM {table_name}, plainto_tsquery({language}, {query}) query
WHERE to_tsvector({language}, content) @@ query"""
).format(table_name=Identifier(self.table_name), language=language, query=user_query)

sql_where_clause = SQL("")
params = ()
if filters:
sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters)

sql_sort = SQL(" ORDER BY rank {sort_order} LIMIT {top_k}").format(
top_k=SQLLiteral(top_k),
sort_order=SQL("DESC"),
)

sql_query = sql_select + sql_where_clause + sql_sort

result = self._execute_sql(
sql_query,
params,
error_msg="Could not retrieve documents from PgvectorDocumentStore.",
cursor=self._dict_cursor,
)

records = result.fetchall()
docs = self._from_pg_to_haystack_documents(records)
return docs

def _embedding_retrieval(
self,
query_embedding: List[float],
Expand Down

0 comments on commit 9eefa57

Please sign in to comment.