Skip to content

Feature: Add Document Retrieval with Metadata Filtering #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,80 @@ async def __query_collection(
return combined_results
return dense_results

async def __query_collection_with_filter(
self,
*,
k: Optional[int] = None,
filter: Optional[dict] = None,
**kwargs: Any,
) -> Sequence[RowMapping]:
"""
Asynchronously query the database collection using optional filters and return matching rows.

Args:
k (Optional[int]): The maximum number of rows to retrieve. If not provided, a default is
computed based on the hybrid search configuration or a fallback value.
filter (Optional[dict]): A dictionary representing filtering conditions to apply in the SQL WHERE clause.
**kwargs (Any): Additional keyword arguments (currently unused but accepted for extensibility).

Returns:
Sequence[RowMapping]: A sequence of row mappings, representing a Document

Notes:
- If `k` is not specified, it defaults to the maximum of the configured top-k values.
- If `index_query_options` are set, they are applied using `SET LOCAL` before executing the query.
"""

if not k:
k = (
max(
self.k,
self.hybrid_search_config.primary_top_k,
self.hybrid_search_config.secondary_top_k,
)
if self.hybrid_search_config
else self.k
)

columns = [
self.id_column,
self.content_column,
self.embedding_column,
] + self.metadata_columns
if self.metadata_json_column:
columns.append(self.metadata_json_column)

column_names = ", ".join(f'"{col}"' for col in columns)

safe_filter = None
filter_dict = None
if filter and isinstance(filter, dict):
safe_filter, filter_dict = self._create_filter_clause(filter)

where_filters = f"WHERE {safe_filter}" if safe_filter else ""
dense_query_stmt = f"""SELECT {column_names}
FROM "{self.schema_name}"."{self.table_name}" {where_filters} LIMIT :k;
"""
param_dict = {"k": k}
if filter_dict:
param_dict.update(filter_dict)
if self.index_query_options:
async with self.engine.connect() as conn:
# Set each query option individually
for query_option in self.index_query_options.to_parameter():
query_options_stmt = f"SET LOCAL {query_option};"
await conn.execute(text(query_options_stmt))
result = await conn.execute(text(dense_query_stmt), param_dict)
result_map = result.mappings()
results = result_map.fetchall()
else:
async with self.engine.connect() as conn:
result = await conn.execute(text(dense_query_stmt), param_dict)
result_map = result.mappings()
results = result_map.fetchall()

return results

async def asimilarity_search(
self,
query: str,
Expand Down Expand Up @@ -997,6 +1071,52 @@ async def is_valid_index(
results = result_map.fetchall()
return bool(len(results) == 1)

async def aget(
self,
filter: Optional[dict] = None,
k: Optional[int] = None,
**kwargs: Any,
) -> list[Document]:
"""
Asynchronously retrieves documents from a collection based on an optional filter and other parameters.

This method queries the underlying collection using the provided filter and additional keyword arguments.
It constructs a list of `Document` objects from the query results, combining content and metadata from
specified columns.

Args:
filter (Optional[dict]): A dictionary specifying filtering criteria for the query. Defaults to None.
k (Optional[int]): The maximum number of documents to retrieve. If None, retrieves all matching documents.
**kwargs (Any): Additional keyword arguments passed to the internal query method.

Returns:
list[Document]: A list of `Document` instances, each containing content, metadata, and an identifier.

"""

results = await self.__query_collection_with_filter(
k=k, filter=filter, **kwargs
)

documents = []
for row in results:
metadata = (
row[self.metadata_json_column]
if self.metadata_json_column and row[self.metadata_json_column]
else {}
)
for col in self.metadata_columns:
metadata[col] = row[col]
documents.append(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=str(row[self.id_column]),
),
)

return documents

async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""

Expand Down
23 changes: 23 additions & 0 deletions langchain_postgres/v2/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,29 @@ async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids))

def get(
self,
filter: Optional[dict] = None,
k: Optional[int] = None,
**kwargs: Any,
) -> list[Document]:
"""
Retrieve documents from the collection using optional filters and parameters.

Args:
filter (Optional[dict]): A dictionary specifying filtering criteria for the query. Defaults to None.
k (Optional[int]): The maximum number of documents to retrieve. If None, retrieves all matching documents.
**kwargs (Any): Additional keyword arguments passed to the asynchronous `aget` method.

Returns:
list[Document]: A list of `Document` instances matching the filter criteria.

Raises:
Any exceptions raised by the underlying asynchronous method or the sync execution engine.
"""

return self._engine._run_as_sync(self.__vs.aget(filter=filter, k=k, **kwargs))

def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids))
Expand Down
13 changes: 13 additions & 0 deletions tests/unit_tests/v2/test_async_pg_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,19 @@ async def test_vectorstore_with_metadata_filters(
)
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter

@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
async def test_vectorstore_get(
self,
vs_custom_filter: AsyncPGVectorStore,
test_filter: dict,
expected_ids: list[str],
) -> None:
"""Test end to end construction and filter."""
docs = await vs_custom_filter.aget(test_filter, k=5)
assert set([doc.metadata["code"] for doc in docs]) == set(
expected_ids
), test_filter

async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None:
results = await vs.asimilarity_search(
"foo", k=1, hybrid_search_config=HybridSearchConfig()
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/v2/test_pg_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,20 @@ def test_sync_vectorstore_with_metadata_filters(
docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter

@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
def test_sync_vectorstore_get(
self,
vs_custom_filter_sync: PGVectorStore,
test_filter: dict,
expected_ids: list[str],
) -> None:
"""Test end to end construction and filter."""

docs = vs_custom_filter_sync.get(k=5, filter=test_filter)
assert set([doc.metadata["code"] for doc in docs]) == set(
expected_ids
), test_filter

@pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES)
def test_metadata_filter_negative_tests(
self, vs_custom_filter_sync: PGVectorStore, test_filter: dict
Expand Down