diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index fd0bfd7..7cbeefd 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -674,6 +674,80 @@ async def __query_collection( return combined_results return dense_results + async def __query_collection_with_filter( + self, + *, + k: Optional[int] = None, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> Sequence[RowMapping]: + """ + Asynchronously query the database collection using optional filters and return matching rows. + + Args: + k (Optional[int]): The maximum number of rows to retrieve. If not provided, a default is + computed based on the hybrid search configuration or a fallback value. + filter (Optional[dict]): A dictionary representing filtering conditions to apply in the SQL WHERE clause. + **kwargs (Any): Additional keyword arguments (currently unused but accepted for extensibility). + + Returns: + Sequence[RowMapping]: A sequence of row mappings, representing a Document + + Notes: + - If `k` is not specified, it defaults to the maximum of the configured top-k values. + - If `index_query_options` are set, they are applied using `SET LOCAL` before executing the query. + """ + + if not k: + k = ( + max( + self.k, + self.hybrid_search_config.primary_top_k, + self.hybrid_search_config.secondary_top_k, + ) + if self.hybrid_search_config + else self.k + ) + + columns = [ + self.id_column, + self.content_column, + self.embedding_column, + ] + self.metadata_columns + if self.metadata_json_column: + columns.append(self.metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + safe_filter = None + filter_dict = None + if filter and isinstance(filter, dict): + safe_filter, filter_dict = self._create_filter_clause(filter) + + where_filters = f"WHERE {safe_filter}" if safe_filter else "" + dense_query_stmt = f"""SELECT {column_names} + FROM "{self.schema_name}"."{self.table_name}" {where_filters} LIMIT :k; + """ + param_dict = {"k": k} + if filter_dict: + param_dict.update(filter_dict) + if self.index_query_options: + async with self.engine.connect() as conn: + # Set each query option individually + for query_option in self.index_query_options.to_parameter(): + query_options_stmt = f"SET LOCAL {query_option};" + await conn.execute(text(query_options_stmt)) + result = await conn.execute(text(dense_query_stmt), param_dict) + result_map = result.mappings() + results = result_map.fetchall() + else: + async with self.engine.connect() as conn: + result = await conn.execute(text(dense_query_stmt), param_dict) + result_map = result.mappings() + results = result_map.fetchall() + + return results + async def asimilarity_search( self, query: str, @@ -997,6 +1071,52 @@ async def is_valid_index( results = result_map.fetchall() return bool(len(results) == 1) + async def aget( + self, + filter: Optional[dict] = None, + k: Optional[int] = None, + **kwargs: Any, + ) -> list[Document]: + """ + Asynchronously retrieves documents from a collection based on an optional filter and other parameters. + + This method queries the underlying collection using the provided filter and additional keyword arguments. + It constructs a list of `Document` objects from the query results, combining content and metadata from + specified columns. + + Args: + filter (Optional[dict]): A dictionary specifying filtering criteria for the query. Defaults to None. + k (Optional[int]): The maximum number of documents to retrieve. If None, retrieves all matching documents. + **kwargs (Any): Additional keyword arguments passed to the internal query method. + + Returns: + list[Document]: A list of `Document` instances, each containing content, metadata, and an identifier. + + """ + + results = await self.__query_collection_with_filter( + k=k, filter=filter, **kwargs + ) + + documents = [] + for row in results: + metadata = ( + row[self.metadata_json_column] + if self.metadata_json_column and row[self.metadata_json_column] + else {} + ) + for col in self.metadata_columns: + metadata[col] = row[col] + documents.append( + Document( + page_content=row[self.content_column], + metadata=metadata, + id=str(row[self.id_column]), + ), + ) + + return documents + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" diff --git a/langchain_postgres/v2/vectorstores.py b/langchain_postgres/v2/vectorstores.py index 52224db..69f9fd8 100644 --- a/langchain_postgres/v2/vectorstores.py +++ b/langchain_postgres/v2/vectorstores.py @@ -857,5 +857,39 @@ def get_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids)) + async def aget( + self, + filter: Optional[dict] = None, + k: Optional[int] = None, + **kwargs: Any, + ) -> list[Document]: + + return await self._engine._run_as_async( + self.__vs.aget(filter=filter, k=k, **kwargs) + ) + + def get( + self, + filter: Optional[dict] = None, + k: Optional[int] = None, + **kwargs: Any, + ) -> list[Document]: + """ + Retrieve documents from the collection using optional filters and parameters. + + Args: + filter (Optional[dict]): A dictionary specifying filtering criteria for the query. Defaults to None. + k (Optional[int]): The maximum number of documents to retrieve. If None, retrieves all matching documents. + **kwargs (Any): Additional keyword arguments passed to the asynchronous `aget` method. + + Returns: + list[Document]: A list of `Document` instances matching the filter criteria. + + Raises: + Any exceptions raised by the underlying asynchronous method or the sync execution engine. + """ + + return self._engine._run_as_sync(self.__vs.aget(filter=filter, k=k, **kwargs)) + def get_table_name(self) -> str: return self.__vs.table_name diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 16c70fd..cdc4cfb 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -370,6 +370,19 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + async def test_vectorstore_get( + self, + vs_custom_filter: AsyncPGVectorStore, + test_filter: dict, + expected_ids: list[str], + ) -> None: + """Test end to end construction and filter.""" + docs = await vs_custom_filter.aget(test_filter, k=5) + assert set([doc.metadata["code"] for doc in docs]) == set( + expected_ids + ), test_filter + async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None: results = await vs.asimilarity_search( "foo", k=1, hybrid_search_config=HybridSearchConfig() diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 7815a25..9697883 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -429,6 +429,20 @@ def test_sync_vectorstore_with_metadata_filters( docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + def test_sync_vectorstore_get( + self, + vs_custom_filter_sync: PGVectorStore, + test_filter: dict, + expected_ids: list[str], + ) -> None: + """Test end to end construction and filter.""" + + docs = vs_custom_filter_sync.get(k=5, filter=test_filter) + assert set([doc.metadata["code"] for doc in docs]) == set( + expected_ids + ), test_filter + @pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES) def test_metadata_filter_negative_tests( self, vs_custom_filter_sync: PGVectorStore, test_filter: dict