diff --git a/CHANGES.md b/CHANGES.md index 32c2f1e..06a2d9e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,7 @@ ### Fixed - fix root-path handling when setting via env var or on app instance +- Allow `q` parameter to be a `str` not a `list[str]` for Advanced Free-Text extension ### Changed diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7854ad0..948dca2 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -54,8 +54,7 @@ async def all_collections( # noqa: C901 sortby: Optional[str] = None, filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> Collections: """Cross catalog search (GET). @@ -86,7 +85,7 @@ async def all_collections( # noqa: C901 sortby=sortby, filter_query=filter_expr, filter_lang=filter_lang, - q=q, + **kwargs, ) async with request.app.state.get_connection(request, "r") as conn: @@ -157,7 +156,10 @@ async def all_collections( # noqa: C901 ) async def get_collection( - self, collection_id: str, request: Request, **kwargs + self, + collection_id: str, + request: Request, + **kwargs: Any, ) -> Collection: """Get collection by id. @@ -202,7 +204,9 @@ async def get_collection( return Collection(**collection) async def _get_base_item( - self, collection_id: str, request: Request + self, + collection_id: str, + request: Request, ) -> Dict[str, Any]: """Get the base item of a collection for use in rehydrating full item collection properties. @@ -359,7 +363,7 @@ async def item_collection( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, token: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> ItemCollection: """Get all items from a specific collection. @@ -391,6 +395,7 @@ async def item_collection( filter_lang=filter_lang, fields=fields, sortby=sortby, + **kwargs, ) try: @@ -417,7 +422,11 @@ async def item_collection( return ItemCollection(**item_collection) async def get_item( - self, item_id: str, collection_id: str, request: Request, **kwargs + self, + item_id: str, + collection_id: str, + request: Request, + **kwargs: Any, ) -> Item: """Get item by id. @@ -445,7 +454,10 @@ async def get_item( return Item(**item_collection["features"][0]) async def post_search( - self, search_request: PgstacSearch, request: Request, **kwargs + self, + search_request: PgstacSearch, + request: Request, + **kwargs: Any, ) -> ItemCollection: """Cross catalog search (POST). @@ -489,7 +501,7 @@ async def get_search( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, token: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> ItemCollection: """Cross catalog search (GET). @@ -516,6 +528,7 @@ async def get_search( sortby=sortby, filter_query=filter_expr, filter_lang=filter_lang, + **kwargs, ) try: @@ -550,7 +563,8 @@ def _clean_search_args( # noqa: C901 sortby: Optional[str] = None, filter_query: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[List[str]] = None, + q: Optional[Union[str, List[str]]] = None, + **kwargs: Any, ) -> Dict[str, Any]: """Clean up search arguments to match format expected by pgstac""" if filter_query: @@ -596,7 +610,7 @@ def _clean_search_args( # noqa: C901 base_args["fields"] = {"include": includes, "exclude": excludes} if q: - base_args["q"] = " OR ".join(q) + base_args["q"] = " OR ".join(q) if isinstance(q, list) else q # Remove None values from dict clean = {} diff --git a/tests/conftest.py b/tests/conftest.py index 05846be..f9afd2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import json import logging import os -import time from typing import Callable, Dict from urllib.parse import quote_plus as quote from urllib.parse import urljoin @@ -26,6 +25,7 @@ CollectionSearchExtension, CollectionSearchFilterExtension, FieldsExtension, + FreeTextAdvancedExtension, FreeTextExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, @@ -139,6 +139,7 @@ def api_client(request): FieldsExtension(), SearchFilterExtension(client=FiltersClient()), TokenPaginationExtension(), + FreeTextExtension(), # not recommended by PgSTAC ] application_extensions.extend(search_extensions) @@ -167,6 +168,7 @@ def api_client(request): FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]), ItemCollectionFilterExtension(client=FiltersClient()), TokenPaginationExtension(), + FreeTextExtension(), # not recommended by PgSTAC ] application_extensions.extend(item_collection_extensions) @@ -207,7 +209,6 @@ async def app(api_client, database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() app = api_client.app await connect_to_db( app, @@ -314,7 +315,6 @@ async def app_no_ext(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( api_client_no_ext.app, postgres_settings=postgres_settings, @@ -354,7 +354,6 @@ async def app_no_transaction(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( api.app, postgres_settings=postgres_settings, @@ -402,3 +401,57 @@ async def default_client(default_app): transport=ASGITransport(app=default_app), base_url="http://test" ) as c: yield c + + +@pytest.fixture(scope="function") +async def app_advanced_freetext(database): + """Default stac-fastapi-pgstac application without only the transaction extensions.""" + api_settings = Settings(testing=True) + + application_extensions = [ + TransactionExtension(client=TransactionsClient(), settings=api_settings) + ] + + collection_extensions = [ + FreeTextAdvancedExtension(), + OffsetPaginationExtension(), + ] + collection_search_extension = CollectionSearchExtension.from_extensions( + collection_extensions + ) + application_extensions.append(collection_search_extension) + + app = StacApi( + settings=api_settings, + extensions=application_extensions, + client=CoreCrudClient(), + health_check=health_check, + collections_get_request_model=collection_search_extension.GET, + ) + + postgres_settings = PostgresSettings( + pguser=database.user, + pgpassword=database.password, + pghost=database.host, + pgport=database.port, + pgdatabase=database.dbname, + ) + logger.info("Creating app Fixture") + await connect_to_db( + app.app, + postgres_settings=postgres_settings, + add_write_connection_pool=True, + ) + yield app.app + await close_db_connection(app.app) + + logger.info("Closed Pools.") + + +@pytest.fixture(scope="function") +async def app_client_advanced_freetext(app_advanced_freetext): + logger.info("creating app_client") + async with AsyncClient( + transport=ASGITransport(app=app_advanced_freetext), base_url="http://test" + ) as c: + yield c diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 745d423..013f9ba 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -365,6 +365,71 @@ async def test_collection_search_freetext( assert resp.json()["collections"][0]["id"] == load_test2_collection.id resp = await app_client.get( + "/collections", + params={"q": "temperature,calibrated"}, + ) + assert resp.json()["numberReturned"] == 2 + assert resp.json()["numberMatched"] == 2 + assert len(resp.json()["collections"]) == 2 + + resp = await app_client.get( + "/collections", + params={"q": "temperature,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client.get( + "/collections", + params={"q": "nosuchthing"}, + ) + assert len(resp.json()["collections"]) == 0 + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_collection_search_freetext_advanced( + app_client_advanced_freetext, load_test_collection, load_test2_collection +): + # free-text + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature,calibrated"}, + ) + assert resp.json()["numberReturned"] == 2 + assert resp.json()["numberMatched"] == 2 + assert len(resp.json()["collections"]) == 2 + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature OR yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( "/collections", params={"q": "nosuchthing"}, ) diff --git a/tests/resources/test_item.py b/tests/resources/test_item.py index 4ea7019..65112ed 100644 --- a/tests/resources/test_item.py +++ b/tests/resources/test_item.py @@ -18,6 +18,8 @@ from stac_fastapi.pgstac.models.links import CollectionLinks +from ..conftest import requires_pgstac_0_9_2 + async def test_create_collection(app_client, load_test_data: Callable): in_json = load_test_data("test_collection.json") @@ -1689,3 +1691,56 @@ async def test_get_search_link_media(app_client): assert len(links) == 2 get_self_link = next((link for link in links if link["rel"] == "self"), None) assert get_self_link["type"] == "application/geo+json" + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_item_search_freetext(app_client, load_test_data, load_test_collection): + test_item = load_test_data("test_item.json") + resp = await app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 201 + + # free-text + resp = await app_client.get( + "/search", + params={"q": "temperature"}, + ) + print(resp.json()) + # assert resp.json()["numberReturned"] == 1 + # assert resp.json()["numberMatched"] == 1 + # assert len(resp.json()["collections"]) == 1 + # assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "temperature,calibrated"}, + # ) + # assert resp.json()["numberReturned"] == 2 + # assert resp.json()["numberMatched"] == 2 + # assert len(resp.json()["collections"]) == 2 + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "temperature,yo"}, + # ) + # assert resp.json()["numberReturned"] == 1 + # assert resp.json()["numberMatched"] == 1 + # assert len(resp.json()["collections"]) == 1 + # assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "temperature OR yo"}, + # ) + # assert resp.json()["numberReturned"] == 1 + # assert resp.json()["numberMatched"] == 1 + # assert len(resp.json()["collections"]) == 1 + # assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # resp = await app_client_advanced_freetext.get( + # "/collections", + # params={"q": "nosuchthing"}, + # ) + # assert len(resp.json()["collections"]) == 0