diff --git a/CHANGES.md b/CHANGES.md index 7f5d6ba2..0fcc0713 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,12 +5,12 @@ ### Changed * Add more openapi metadata in input models [#734](https://github.com/stac-utils/stac-fastapi/pull/734) -* Use same `Limit` (capped to `10_000`) for `/items` and `GET - /search` input models ([#737](https://github.com/stac-utils/stac-fastapi/pull/737)) +* Use same `Limit` (capped to `10_000`) for `/items` and `GET - /search` input models ([#738](https://github.com/stac-utils/stac-fastapi/pull/738)) ### Added * Add Free-text Extension ([#655](https://github.com/stac-utils/stac-fastapi/pull/655)) -* Add Collection-Search Extension ([#736](https://github.com/stac-utils/stac-fastapi/pull/736)) +* Add Collection-Search Extension ([#736](https://github.com/stac-utils/stac-fastapi/pull/736), [#739](https://github.com/stac-utils/stac-fastapi/pull/739)) ## [3.0.0b2] - 2024-07-09 diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py index 385bd902..17eccde7 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py @@ -1,7 +1,7 @@ """stac_api.extensions.core module.""" from .aggregation import AggregationExtension -from .collection_search import CollectionSearchExtension +from .collection_search import CollectionSearchExtension, CollectionSearchPostExtension from .context import ContextExtension from .fields import FieldsExtension from .filter import FilterExtension @@ -24,4 +24,5 @@ "TokenPaginationExtension", "TransactionExtension", "CollectionSearchExtension", + "CollectionSearchPostExtension", ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/__init__.py index f919491d..eed6d502 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/__init__.py @@ -1,5 +1,13 @@ """Collection-Search extension module.""" -from .collection_search import CollectionSearchExtension, ConformanceClasses +from .collection_search import ( + CollectionSearchExtension, + CollectionSearchPostExtension, + ConformanceClasses, +) -__all__ = ["CollectionSearchExtension", "ConformanceClasses"] +__all__ = [ + "CollectionSearchExtension", + "CollectionSearchPostExtension", + "ConformanceClasses", +] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/client.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/client.py new file mode 100644 index 00000000..ac148dfb --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/client.py @@ -0,0 +1,49 @@ +"""collection-search extensions clients.""" + +import abc + +import attr + +from stac_fastapi.types import stac + +from .request import BaseCollectionSearchPostRequest + + +@attr.s +class AsyncBaseCollectionSearchClient(abc.ABC): + """Defines a pattern for implementing the STAC collection-search POST extension.""" + + @abc.abstractmethod + async def post_all_collections( + self, + search_request: BaseCollectionSearchPostRequest, + **kwargs, + ) -> stac.ItemCollection: + """Get all available collections. + + Called with `POST /collections`. + + Returns: + A list of collections. + + """ + ... + + +@attr.s +class BaseCollectionSearchClient(abc.ABC): + """Defines a pattern for implementing the STAC collection-search POST extension.""" + + @abc.abstractmethod + def post_all_collections( + self, search_request: BaseCollectionSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + """Get all available collections. + + Called with `POST /collections`. + + Returns: + A list of collections. + + """ + ... diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py index aac81205..2927cd82 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py @@ -1,14 +1,20 @@ """Collection-Search extension.""" from enum import Enum -from typing import List, Optional +from typing import List, Optional, Union import attr -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from stac_pydantic.api.collections import Collections +from stac_pydantic.shared import MimeTypes +from stac_fastapi.api.models import GeoJSONResponse +from stac_fastapi.api.routes import create_async_endpoint +from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.extension import ApiExtension -from .request import CollectionSearchExtensionGetRequest +from .client import AsyncBaseCollectionSearchClient, BaseCollectionSearchClient +from .request import BaseCollectionSearchGetRequest, BaseCollectionSearchPostRequest class ConformanceClasses(str, Enum): @@ -46,7 +52,7 @@ class CollectionSearchExtension(ApiExtension): the extension """ - GET = CollectionSearchExtensionGetRequest + GET: BaseCollectionSearchGetRequest = attr.ib(default=BaseCollectionSearchGetRequest) POST = None conformance_classes: List[str] = attr.ib( @@ -64,3 +70,65 @@ def register(self, app: FastAPI) -> None: None """ pass + + +@attr.s +class CollectionSearchPostExtension(CollectionSearchExtension): + """Collection-Search Extension. + + Extents the collection-search extension with an additional + POST - /collections endpoint + + NOTE: the POST - /collections endpoint can be conflicting with the + POST /collections endpoint registered for the Transaction extension. + + https://github.com/stac-api-extensions/collection-search + + Attributes: + conformance_classes (list): Defines the list of conformance classes for + the extension + """ + + client: Union[AsyncBaseCollectionSearchClient, BaseCollectionSearchClient] = attr.ib() + settings: ApiSettings = attr.ib() + conformance_classes: List[str] = attr.ib( + default=[ConformanceClasses.COLLECTIONSEARCH, ConformanceClasses.BASIS] + ) + schema_href: Optional[str] = attr.ib(default=None) + router: APIRouter = attr.ib(factory=APIRouter) + + GET: BaseCollectionSearchGetRequest = attr.ib(default=BaseCollectionSearchGetRequest) + POST: BaseCollectionSearchPostRequest = attr.ib( + default=BaseCollectionSearchPostRequest + ) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + self.router.prefix = app.state.router_prefix + + self.router.add_api_route( + name="Collections", + path="/collections", + methods=["POST"], + response_model=( + Collections if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collections, + }, + }, + response_class=GeoJSONResponse, + endpoint=create_async_endpoint(self.client.post_all_collections, self.POST), + ) + app.include_router(self.router) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/request.py index 663f488d..0bc6d22e 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/request.py @@ -1,18 +1,26 @@ """Request models for the Collection-Search extension.""" -from typing import Optional +from datetime import datetime as dt +from typing import List, Optional, Tuple, cast import attr from fastapi import Query +from pydantic import BaseModel, Field, field_validator +from stac_pydantic.api.search import SearchDatetime from stac_pydantic.shared import BBox from typing_extensions import Annotated from stac_fastapi.types.rfc3339 import DateTimeType -from stac_fastapi.types.search import APIRequest, _bbox_converter, _datetime_converter +from stac_fastapi.types.search import ( + APIRequest, + Limit, + _bbox_converter, + _datetime_converter, +) @attr.s -class CollectionSearchExtensionGetRequest(APIRequest): +class BaseCollectionSearchGetRequest(APIRequest): """Basics additional Collection-Search parameters for the GET request.""" bbox: Optional[BBox] = attr.ib(default=None, converter=_bbox_converter) @@ -20,8 +28,112 @@ class CollectionSearchExtensionGetRequest(APIRequest): default=None, converter=_datetime_converter ) limit: Annotated[ - Optional[int], + Optional[Limit], Query( description="Limits the number of results that are included in each page of the response." # noqa: E501 ), ] = attr.ib(default=10) + + +class BaseCollectionSearchPostRequest(BaseModel): + """Collection-Search POST model.""" + + bbox: Optional[BBox] = None + datetime: Optional[str] = None + limit: Optional[Limit] = Field( + 10, + description="Limits the number of results that are included in each page of the response (capped to 10_000).", # noqa: E501 + ) + + # Private properties to store the parsed datetime values. + # Not part of the model schema. + _start_date: Optional[dt] = None + _end_date: Optional[dt] = None + + # Properties to return the private values + @property + def start_date(self) -> Optional[dt]: + """start date.""" + return self._start_date + + @property + def end_date(self) -> Optional[dt]: + """end date.""" + return self._end_date + + @field_validator("bbox") + @classmethod + def validate_bbox(cls, v: BBox) -> BBox: + """validate bbox.""" + if v: + # Validate order + if len(v) == 4: + xmin, ymin, xmax, ymax = cast(Tuple[int, int, int, int], v) + else: + xmin, ymin, min_elev, xmax, ymax, max_elev = cast( + Tuple[int, int, int, int, int, int], v + ) + if max_elev < min_elev: + raise ValueError( + "Maximum elevation must greater than minimum elevation" + ) + + if xmax < xmin: + raise ValueError( + "Maximum longitude must be greater than minimum longitude" + ) + + if ymax < ymin: + raise ValueError( + "Maximum longitude must be greater than minimum longitude" + ) + + # Validate against WGS84 + if xmin < -180 or ymin < -90 or xmax > 180 or ymax > 90: + raise ValueError("Bounding box must be within (-180, -90, 180, 90)") + + return v + + @field_validator("datetime") + @classmethod + def validate_datetime(cls, value: str) -> str: + """validate datetime.""" + # Split on "/" and replace no value or ".." with None + values = [v if v and v != ".." else None for v in value.split("/")] + + # If there are more than 2 dates, it's invalid + if len(values) > 2: + raise ValueError( + """Invalid datetime range. Too many values. + Must match format: {begin_date}/{end_date}""" + ) + + # If there is only one date, duplicate to use for both start and end dates + if len(values) == 1: + values = [values[0], values[0]] + + # Cast because pylance gets confused by the type adapter and annotated type + dates = cast( + List[Optional[dt]], + [ + # Use the type adapter to validate the datetime strings, + # strict is necessary due to pydantic issues #8736 and #8762 + SearchDatetime.validate_strings(v, strict=True) if v else None + for v in values + ], + ) + + # If there is a start and end date, + # check that the start date is before the end date + if dates[0] and dates[1] and dates[0] > dates[1]: + raise ValueError( + "Invalid datetime range. Begin date after end date. " + "Must match format: {begin_date}/{end_date}" + ) + + # Store the parsed dates + cls._start_date = dates[0] + cls._end_date = dates[1] + + # Return the original string value + return value diff --git a/stac_fastapi/extensions/tests/test_collection_search.py b/stac_fastapi/extensions/tests/test_collection_search.py index 856c5b03..edc29221 100644 --- a/stac_fastapi/extensions/tests/test_collection_search.py +++ b/stac_fastapi/extensions/tests/test_collection_search.py @@ -1,20 +1,44 @@ import json from urllib.parse import quote_plus +import attr from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import create_request_model -from stac_fastapi.extensions.core import CollectionSearchExtension +from stac_fastapi.extensions.core import ( + CollectionSearchExtension, + CollectionSearchPostExtension, +) from stac_fastapi.extensions.core.collection_search import ConformanceClasses +from stac_fastapi.extensions.core.collection_search.client import ( + BaseCollectionSearchClient, +) from stac_fastapi.extensions.core.collection_search.request import ( - CollectionSearchExtensionGetRequest, + BaseCollectionSearchGetRequest, + BaseCollectionSearchPostRequest, +) +from stac_fastapi.extensions.core.fields.request import ( + FieldsExtensionGetRequest, + FieldsExtensionPostRequest, +) +from stac_fastapi.extensions.core.filter.request import ( + FilterExtensionGetRequest, + FilterExtensionPostRequest, +) +from stac_fastapi.extensions.core.free_text.request import ( + FreeTextExtensionGetRequest, + FreeTextExtensionPostRequest, +) +from stac_fastapi.extensions.core.query.request import ( + QueryExtensionGetRequest, + QueryExtensionPostRequest, +) +from stac_fastapi.extensions.core.sort.request import ( + SortExtensionGetRequest, + SortExtensionPostRequest, ) -from stac_fastapi.extensions.core.fields.request import FieldsExtensionGetRequest -from stac_fastapi.extensions.core.filter.request import FilterExtensionGetRequest -from stac_fastapi.extensions.core.free_text.request import FreeTextExtensionGetRequest -from stac_fastapi.extensions.core.query.request import QueryExtensionGetRequest -from stac_fastapi.extensions.core.sort.request import SortExtensionGetRequest +from stac_fastapi.types import stac from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import BaseCoreClient @@ -40,13 +64,22 @@ def item_collection(self, *args, **kwargs): raise NotImplementedError +@attr.s +class DummyPostClient(BaseCollectionSearchClient): + def post_all_collections( + self, search_request: BaseCollectionSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + """fake method.""" + return search_request.model_dump() + + def test_collection_search_extension_default(): - """Test /collections endpoint with collection-search ext.""" + """Test GET - /collections endpoint with collection-search ext.""" api = StacApi( settings=ApiSettings(), client=DummyCoreClient(), extensions=[CollectionSearchExtension()], - collections_get_request_model=CollectionSearchExtensionGetRequest, + collections_get_request_model=BaseCollectionSearchGetRequest, ) with TestClient(api.app) as client: response = client.get("/conformance") @@ -87,10 +120,12 @@ def test_collection_search_extension_default(): def test_collection_search_extension_models(): - """Test /collections endpoint with collection-search ext with additional models.""" + """Test GET - /collections endpoint with collection-search ext + with additional models. + """ collections_get_request_model = create_request_model( model_name="SearchGetRequest", - base_model=CollectionSearchExtensionGetRequest, + base_model=BaseCollectionSearchGetRequest, mixins=[ FreeTextExtensionGetRequest, FilterExtensionGetRequest, @@ -106,6 +141,7 @@ def test_collection_search_extension_models(): client=DummyCoreClient(), extensions=[ CollectionSearchExtension( + GET=collections_get_request_model, conformance_classes=[ ConformanceClasses.COLLECTIONSEARCH, ConformanceClasses.BASIS, @@ -114,7 +150,7 @@ def test_collection_search_extension_models(): ConformanceClasses.QUERY, ConformanceClasses.SORT, ConformanceClasses.FIELDS, - ] + ], ) ], collections_get_request_model=collections_get_request_model, @@ -179,3 +215,180 @@ def test_collection_search_extension_models(): assert "query" in response_dict assert ["-gsd", "-datetime"] == response_dict["sortby"] assert ["properties.datetime"] == response_dict["fields"] + + +def test_collection_search_extension_post_default(): + """Test POST - /collections endpoint with collection-search ext.""" + settings = ApiSettings() + collection_search_ext = CollectionSearchPostExtension( + client=DummyPostClient(), + settings=settings, + ) + + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=[collection_search_ext], + ) + with TestClient(api.app) as client: + response = client.get("/conformance") + assert response.is_success, response.json() + response_dict = response.json() + assert ( + "https://api.stacspec.org/v1.0.0-rc.1/collection-search" + in response_dict["conformsTo"] + ) + assert ( + "http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query" + in response_dict["conformsTo"] + ) + + response = client.post("/collections", json={}) + assert response.is_success, response.json() + response_dict = response.json() + assert "bbox" in response_dict + assert "datetime" in response_dict + assert "limit" in response_dict + assert response_dict["limit"] == 10 + + response = client.post( + "/collections", + json={ + "datetime": "2020-06-13T13:00:00Z/2020-06-13T14:00:00Z", + "bbox": [-175.05, -85.05, 175.05, 85.05], + "limit": 100_000, + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert [-175.05, -85.05, 175.05, 85.05] == response_dict["bbox"] + assert "2020-06-13T13:00:00Z/2020-06-13T14:00:00Z" == response_dict["datetime"] + assert 10_000 == response_dict["limit"] + + +def test_collection_search_extension_post_models(): + """Test POST - /collections endpoint with collection-search ext + with additional models. + """ + post_request_model = create_request_model( + model_name="SearchPostRequest", + base_model=BaseCollectionSearchPostRequest, + mixins=[ + FreeTextExtensionPostRequest, + FilterExtensionPostRequest, + QueryExtensionPostRequest, + SortExtensionPostRequest, + FieldsExtensionPostRequest, + ], + request_type="POST", + ) + + get_request_model = create_request_model( + model_name="SearchGetRequest", + base_model=BaseCollectionSearchGetRequest, + mixins=[ + FreeTextExtensionGetRequest, + FilterExtensionGetRequest, + QueryExtensionGetRequest, + SortExtensionGetRequest, + FieldsExtensionGetRequest, + ], + request_type="GET", + ) + + settings = ApiSettings() + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=[ + CollectionSearchPostExtension( + settings=settings, + client=DummyPostClient(), + GET=get_request_model, + POST=post_request_model, + conformance_classes=[ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ConformanceClasses.FREETEXT, + ConformanceClasses.FILTER, + ConformanceClasses.QUERY, + ConformanceClasses.SORT, + ConformanceClasses.FIELDS, + ], + ) + ], + collections_get_request_model=get_request_model, + ) + + with TestClient(api.app) as client: + response = client.get("/conformance") + assert response.is_success, response.json() + response_dict = response.json() + conforms = response_dict["conformsTo"] + assert "https://api.stacspec.org/v1.0.0-rc.1/collection-search" in conforms + assert ( + "http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query" + in conforms + ) + assert ( + "https://api.stacspec.org/v1.0.0-rc.1/collection-search#free-text" in conforms + ) + assert "https://api.stacspec.org/v1.0.0-rc.1/collection-search#filter" in conforms + assert "https://api.stacspec.org/v1.0.0-rc.1/collection-search#query" in conforms + assert "https://api.stacspec.org/v1.0.0-rc.1/collection-search#sort" in conforms + assert "https://api.stacspec.org/v1.0.0-rc.1/collection-search#fields" in conforms + + response = client.post("/collections", json={}) + assert response.is_success, response.json() + response_dict = response.json() + assert "bbox" in response_dict + assert "datetime" in response_dict + assert "limit" in response_dict + assert "q" in response_dict + assert "filter" in response_dict + assert "query" in response_dict + assert "sortby" in response_dict + assert "fields" in response_dict + + response = client.post( + "/collections", + json={ + "datetime": "2020-06-13T13:00:00Z/2020-06-13T14:00:00Z", + "bbox": [-175.05, -85.05, 175.05, 85.05], + "limit": 100_000, + "q": ["EO", "Earth Observation"], + "filter": { + "op": "and", + "args": [ + {"op": "=", "args": [{"property": "id"}, "item_id"]}, + { + "op": "=", + "args": [{"property": "collection"}, "collection_id"], + }, + ], + }, + "query": {"eo:cloud_cover": {"gte": 95}}, + "sortby": [ + { + "field": "properties.gsd", + "direction": "desc", + }, + { + "field": "properties.datetime", + "direction": "asc", + }, + ], + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert [-175.05, -85.05, 175.05, 85.05] == response_dict["bbox"] + assert "2020-06-13T13:00:00Z/2020-06-13T14:00:00Z" == response_dict["datetime"] + assert 10_000 == response_dict["limit"] + assert ["EO", "Earth Observation"] == response_dict["q"] + assert response_dict["filter"] + assert "filter_crs" in response_dict + assert "cql2-json" in response_dict["filter_lang"] + assert response_dict["query"] + assert response_dict["sortby"] + assert response_dict["fields"]