From 28dadf9021cc6d96d16789a94e3c6fdbe1019fc9 Mon Sep 17 00:00:00 2001 From: Johan Schreurs Date: Thu, 29 Feb 2024 10:16:59 +0100 Subject: [PATCH] Issue #16 Add initial test coverage for STAC-API endpoints module --- requirements/requirements-test.txt | 1 + stacbuilder/builder.py | 2 +- stacbuilder/stacapi/endpoints.py | 123 ++++++++++++++++----- stacbuilder/stacapi/upload.py | 12 +- tests/stacapi/test_endpoints.py | 170 +++++++++++++++++++++++++++++ 5 files changed, 274 insertions(+), 34 deletions(-) create mode 100644 tests/stacapi/test_endpoints.py diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index bec9e75..ee16d8b 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,2 +1,3 @@ -r requirements.txt pytest==7.2.* +requests-mock diff --git a/stacbuilder/builder.py b/stacbuilder/builder.py index 849becd..544098b 100644 --- a/stacbuilder/builder.py +++ b/stacbuilder/builder.py @@ -609,7 +609,7 @@ def create_empty_collection(self, group: Optional[str | int] = None) -> None: self._collection = collection self._log_progress_message("DONE: create_empty_collection") - def get_default_extent(self): + def get_default_extent(self) -> Extent: end_dt = dt.datetime.utcnow() return Extent( diff --git a/stacbuilder/stacapi/endpoints.py b/stacbuilder/stacapi/endpoints.py index dca9642..c095aa0 100644 --- a/stacbuilder/stacapi/endpoints.py +++ b/stacbuilder/stacapi/endpoints.py @@ -47,59 +47,103 @@ def _check_response_status(response: requests.Response, expected_status_codes: l class RestApi: - def __init__(self, base_url: URL | str, auth: AuthBase) -> None: + def __init__(self, base_url: URL | str, auth: AuthBase | None = None) -> None: self.base_url = URL(base_url) self.auth = auth or None - def get(self, url_path: str, *args, **kwargs) -> requests.Response: - return requests.get(str(self.base_url / url_path), auth=self.auth, *args, **kwargs) + def join_path(self, *url_path: list[str]) -> str: + return "/".join(url_path) + # if isinstance(url_path, list): + # return "/".join(url_path) + # return url_path - def post(self, url_path: str, *args, **kwargs) -> requests.Response: - return requests.post(str(self.base_url / url_path), auth=self.auth, *args, **kwargs) + def join_url(self, url_path: str | list[str]) -> str: + return str(self.base_url / self.join_path(url_path)) - def put(self, url_path: str, *args, **kwargs) -> requests.Response: - return requests.put(str(self.base_url / url_path), auth=self.auth, *args, **kwargs) + def get(self, url_path: str | list[str], *args, **kwargs) -> requests.Response: + return requests.get(self.join_url(url_path), auth=self.auth, *args, **kwargs) - def delete(self, url_path: str, *args, **kwargs) -> requests.Response: - return requests.delete(str(self.base_url / url_path), auth=self.auth, *args, **kwargs) + def post(self, url_path: str | list[str], *args, **kwargs) -> requests.Response: + return requests.post(self.join_url(url_path), auth=self.auth, *args, **kwargs) + + def put(self, url_path: str | list[str], *args, **kwargs) -> requests.Response: + return requests.put(self.join_url(url_path), auth=self.auth, *args, **kwargs) + + def delete(self, url_path: str | list[str], *args, **kwargs) -> requests.Response: + return requests.delete(self.join_url(url_path), auth=self.auth, *args, **kwargs) class CollectionsEndpoint: - def __init__(self, stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None) -> None: - self._stac_api_url = URL(stac_api_url) - self._collections_url = self._stac_api_url / "collections" - self._auth = auth or None + def __init__(self, rest_api: RestApi, collection_auth_info: dict | None = None) -> None: + self._rest_api = rest_api + # self._collections_path = "collections" self._collection_auth_info: dict | None = collection_auth_info or None + @staticmethod + def create_endpoint( + stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None + ) -> "CollectionsEndpoint": + rest_api = RestApi(base_url=stac_api_url, auth=auth) + return CollectionsEndpoint( + rest_api=rest_api, + collection_auth_info=collection_auth_info, + ) + @property def stac_api_url(self) -> URL: - return self._stac_api_url + return self._rest_api.base @property - def collections_url(self) -> URL: - return self._collections_url + # def collections_path(self) -> str: + # return self._collections_path + + # def _join_path(self, *url_path: list[str]) -> str: + # return self._rest_api.join_path(*url_path) + + # def get_collections_path(self, collection_id: str | None = None) -> str: + # if not collection_id: + # return self._rest_api.join_path("collections") + # return self._rest_api.join_path("collections", str(collection_id)) + + # def get_collections_url(self, collection_id: str | None) -> URL: + # if not collection_id: + # return self._rest_api.join_url("collections") + # return self._rest_api.join_url("collections", str(collection_id)) def get_all(self) -> List[Collection]: - response = requests.get(str(self.collections_url), auth=self._auth) + response = self._rest_api.get("collections") _check_response_status(response, _EXPECTED_STATUS_GET) data = response.json() if not isinstance(data, dict): raise Exception(f"Expected a dict in the JSON body but received type {type(data)}, value={data!r}") + return [Collection.from_dict(j) for j in data.get("collections", [])] def get(self, collection_id: str) -> Collection: - if not collection_id: - raise ValueError(f'Argument "collection_id" must have a value of type str. {collection_id=!r}') + if not isinstance(collection_id, str): + raise TypeError(f'Argument "collection_id" must be of type str, but its type is {type(collection_id)=}') - url_str = str(self.collections_url / str(collection_id)) - response = requests.get(url_str, auth=self._auth) + if collection_id == "": + raise ValueError( + f'Argument "collection_id" must have a value; it can not be the empty string. {collection_id=!r}' + ) + + response = self._rest_api.get(f"collections/{collection_id}") _check_response_status(response, _EXPECTED_STATUS_GET) + return Collection.from_dict(response.json()) def exists(self, collection_id: str) -> bool: - url_str = str(self.collections_url / str(collection_id)) - response = requests.get(url_str, auth=self._auth) + if not isinstance(collection_id, str): + raise TypeError(f'Argument "collection_id" must be of type str, but its type is {type(collection_id)=}') + + if collection_id == "": + raise ValueError( + f'Argument "collection_id" must have a value; it can not be the empty string. {collection_id=!r}' + ) + + response = self._rest_api.get(f"collections/{collection_id}") # We do expect HTTP 404 when it doesn't exist. # Any other error status means there is an actual problem. @@ -109,22 +153,44 @@ def exists(self, collection_id: str) -> bool: return True def create(self, collection: Collection) -> dict: + if not isinstance(collection, Collection): + raise TypeError( + f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}' + ) + collection.validate() data = self._add_authentication_section(collection) - response = requests.post(str(self.collections_url), json=data, auth=self._auth) + response = self._rest_api.post("collections", json=data) _check_response_status(response, _EXPECTED_STATUS_POST) + return response.json() def update(self, collection: Collection) -> dict: + if not isinstance(collection, Collection): + raise TypeError( + f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}' + ) + collection.validate() data = self._add_authentication_section(collection) - response = requests.put(str(self.collections_url), json=data, auth=self._auth) + response = self._rest_api.put(f"collections/{collection.id}", json=data) _check_response_status(response, _EXPECTED_STATUS_PUT) + return response.json() def delete(self, collection: Collection) -> dict: - collection.validate() - response = requests.delete(str(self.collections_url), json=collection.to_dict(), auth=self._auth) + return self.delete_by_id(collection.id) + + def delete_by_id(self, collection_id: str) -> dict: + if not isinstance(collection_id, str): + raise TypeError(f'Argument "collection_id" must be of type str, but its type is {type(collection_id)=}') + + if collection_id == "": + raise ValueError( + f'Argument "collection_id" must have a value; it can not be the empty string. {collection_id=!r}' + ) + + response = self._rest_api.delete(f"collections/{collection_id}") _check_response_status(response, _EXPECTED_STATUS_DELETE) return response.json() @@ -135,10 +201,11 @@ def create_or_update(self, collection: Collection) -> dict: else: return self.create(collection) - def _add_authentication_section(self, collection) -> dict: + def _add_authentication_section(self, collection: Collection) -> dict: coll_dict = collection.to_dict() if self._collection_auth_info: coll_dict.update(self._collection_auth_info) + return coll_dict diff --git a/stacbuilder/stacapi/upload.py b/stacbuilder/stacapi/upload.py index 439da3d..1db76c1 100644 --- a/stacbuilder/stacapi/upload.py +++ b/stacbuilder/stacapi/upload.py @@ -11,7 +11,7 @@ from stacbuilder.stacapi.auth import get_auth from stacbuilder.stacapi.config import Settings -from stacbuilder.stacapi.endpoints import CollectionsEndpoint, ItemsEndpoint +from stacbuilder.stacapi.endpoints import CollectionsEndpoint, ItemsEndpoint, RestApi _logger = logging.Logger(__name__) @@ -25,15 +25,17 @@ def __init__(self, collections_ep: CollectionsEndpoint, items_ep: ItemsEndpoint) @classmethod def from_settings(cls, settings: Settings) -> "Uploader": auth = get_auth(settings.auth) - return cls.setup( + return cls.create_uploader( stac_api_url=settings.stac_api_url, auth=auth, collection_auth_info=settings.collection_auth_info ) @staticmethod - def setup(stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None) -> "Uploader": + def create_uploader( + stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None + ) -> "Uploader": + rest_api = RestApi(base_url=stac_api_url, auth=auth) collections_endpoint = CollectionsEndpoint( - stac_api_url=stac_api_url, - auth=auth, + rest_api=rest_api, collection_auth_info=collection_auth_info, ) items_endpoint = ItemsEndpoint(stac_api_url=stac_api_url, auth=auth) diff --git a/tests/stacapi/test_endpoints.py b/tests/stacapi/test_endpoints.py new file mode 100644 index 0000000..bdda5e8 --- /dev/null +++ b/tests/stacapi/test_endpoints.py @@ -0,0 +1,170 @@ +import datetime as dt + +import pytest +import pystac +import requests +from requests.auth import AuthBase +from pystac import Collection, Item, Extent, SpatialExtent, TemporalExtent +from yarl import URL + + +from stacbuilder.stacapi.endpoints import CollectionsEndpoint, RestApi + + +@pytest.fixture +def default_extent() -> Extent: + return Extent( + # Default spatial extent is the entire world. + SpatialExtent([-180.0, -90.0, 180.0, 90.0]), + # Default temporal extent is from 1 year ago up until now. + TemporalExtent([[dt.datetime(2020, 1, 1), dt.datetime(2021, 1, 1)]]), + ) + + +@pytest.fixture +def test_provider() -> pystac.Provider: + return pystac.Provider( + name="ACME Faux GeoData Org", description="ACME providers of faux geodata", roles=[pystac.ProviderRole.PRODUCER] + ) + + +@pytest.fixture +def test_collection(test_provider, default_extent) -> Collection: + return Collection( + id="ACME-test-collection", + title="Collection of faux ACME data", + description="Collection of faux data from ACME org", + keywords=["foo", "bar"], + providers=[test_provider], + extent=default_extent, + ) + + +@pytest.fixture +def test_items() -> list[Item]: + return [] + + +@pytest.fixture +def test_collection_with_items(test_collection, test_items) -> Collection: + test_collection.add_items(test_items) + test_collection.update_extent_from_items() + return test_collection + + +class FauxAuth(AuthBase): + def __call__(self, r): + r.headers["Authorization"] = "magic-token" + return r + + +class MockRestApi(RestApi): + def __init__(self, base_url: URL | str, auth: AuthBase) -> None: + super().__init__(base_url, auth) + self.collections = {} + self.items = {} + + def add_collection(self, collection: Collection) -> None: + self.collections[collection.id] = collection + + def add_item(self, collection_id: str, item: Item) -> None: + if collection_id not in self.collections: + raise Exception("You have to add the collection with id={collection_id} before the item can be added") + self.items[(collection_id, item.id)] = item + + def create(collection: Collection, items: list[Item]): + api = MockRestApi() + api.add_collection(collection) + for item in items: + api.add_item(collection.id, item) + + return api + + def get(self, url_path: str, *args, **kwargs) -> requests.Response: + # return super().get(url_path, *args, **kwargs) + path_parts = url_path.split("/") + coll_id = path_parts[-1] + if coll_id not in self.collections: + # How to return HTTP 404 here? + return requests.Response("???", status_code=404) + + +class TestRestApi: + BASE_URL = URL("http://test.local/api") + + @pytest.fixture + def api(self): + return RestApi(self.BASE_URL, auth=None) + + def test_get(self, requests_mock, api): + m = requests_mock.get(str(self.BASE_URL / "foo/bar"), status_code=200) + api.get("foo/bar") + assert m.called + + def test_post(self, requests_mock, api): + m = requests_mock.post(str(self.BASE_URL / "foo/bar"), json=[1, 2, 3], status_code=200) + api.post("foo/bar") + assert m.called + + def test_put(self, requests_mock, api): + m = requests_mock.put(str(self.BASE_URL / "foo/bar"), json=[1, 2, 3], status_code=200) + api.put("foo/bar") + assert m.called + + def test_delete(self, requests_mock, api): + m = requests_mock.delete(str(self.BASE_URL / "foo/bar"), json=[1, 2, 3], status_code=200) + api.delete("foo/bar") + assert m.called + + +class TestCollectionsEndPoint: + BASE_URL_STR = "https://test.stacapi.local" + BASE_URL = URL(BASE_URL_STR) + + @pytest.fixture + def collection_endpt(self) -> CollectionsEndpoint: + return CollectionsEndpoint.create_endpoint(self.BASE_URL_STR, auth=None) + + def test_get(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + m = requests_mock.get( + str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + ) + actual_collection = collection_endpt.get(test_collection.id) + assert test_collection.to_dict() == actual_collection.to_dict() + assert m.called + + @pytest.mark.xfail(reason="Test not implemented yet") + def test_get_all(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + assert False, "Test not implemented yet" + + def test_create(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + m = requests_mock.post(str(self.BASE_URL / "collections"), json=test_collection.to_dict(), status_code=201) + response_json = collection_endpt.create(test_collection) + assert test_collection.to_dict() == response_json + assert m.called + + def test_update(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + m = requests_mock.put( + str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + ) + response_json = collection_endpt.update(test_collection) + assert test_collection.to_dict() == response_json + assert m.called + + def test_delete_by_id(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + m = requests_mock.delete( + str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + ) + collection_endpt.delete_by_id(test_collection.id) + assert m.called + + def test_delete(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + m = requests_mock.delete( + str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + ) + collection_endpt.delete(test_collection) + assert m.called + + +class TestItemsEndPoint: + pass