diff --git a/stacbuilder/stacapi/endpoints.py b/stacbuilder/stacapi/endpoints.py index 8389ede..d807f4c 100644 --- a/stacbuilder/stacapi/endpoints.py +++ b/stacbuilder/stacapi/endpoints.py @@ -53,9 +53,6 @@ def __init__(self, base_url: URL | str, auth: AuthBase | None = None) -> None: def join_path(self, *url_path: list[str]) -> str: return "/".join(url_path) - # if isinstance(url_path, list): - # return "/".join(url_path) - # return url_path def join_url(self, url_path: str | list[str]) -> str: return str(self.base_url / self.join_path(url_path)) @@ -76,7 +73,6 @@ def delete(self, url_path: str | list[str], *args, **kwargs) -> requests.Respons class CollectionsEndpoint: def __init__(self, rest_api: RestApi, collection_auth_info: dict | None = None) -> None: self._rest_api = rest_api - # self._collections_path = "collections" self._collection_auth_info: dict | None = collection_auth_info or None @staticmethod @@ -193,25 +189,38 @@ def _add_authentication_section(self, collection: Collection) -> dict: class ItemsEndpoint: - def __init__(self, stac_api_url: URL, auth: AuthBase | None) -> None: - self._stac_api_url = URL(stac_api_url) - self._auth = auth or None + def __init__(self, rest_api: RestApi) -> None: + self._rest_api: RestApi = rest_api + # self._stac_api_url = URL(stac_api_url) + # self._auth = auth or None + + @staticmethod + def create_endpoint(stac_api_url: URL, auth: AuthBase | None) -> "ItemsEndpoint": + return ItemsEndpoint(rest_api=RestApi(base_url=stac_api_url, auth=auth)) + + # @property + # def stac_api_url(self) -> URL: + # return self._stac_api_url + + # @property + # def rest_api(self) -> RestApi: + # return self._rest_api @property def stac_api_url(self) -> URL: - return self._stac_api_url + return self._rest_api.base_url def get_items_url(self, collection_id) -> URL: if not collection_id: raise ValueError(f'Argument "collection_id" must have a value of type str. {collection_id=!r}') - return self._stac_api_url / "collections" / str(collection_id) / "items" + return f"collections/{collection_id}/items" - def get_items_url_for_id(self, collection_id, item_id) -> URL: + def get_items_url_for_id(self, collection_id: str, item_id: str) -> URL: if not collection_id: - raise ValueError(f'Argument "collection_id" must have a value of type str. {collection_id=!r}') + raise ValueError(f'Argument "collection_id" miust have a value of type str. {collection_id=!r}') if not item_id: raise ValueError(f'Argument "item_id" must have a value of type str. {item_id=!r}') - return self._stac_api_url / "collections" / str(collection_id) / "items" / str(item_id) + return f"collections/{collection_id}/items/{item_id}" def get_items_url_for_item(self, item: Item) -> URL: if not item: @@ -219,9 +228,8 @@ def get_items_url_for_item(self, item: Item) -> URL: return self.get_items_url_for_id(item.collection_id, item.id) def get_all(self, collection_id) -> ItemCollection: - url_str = str(self.get_items_url(collection_id)) - response = requests.get(url_str, auth=self._auth) - response.raise_for_status() + response = self._rest_api.get(self.get_items_url(collection_id)) + _check_response_status(response, _EXPECTED_STATUS_GET) data = response.json() if not isinstance(data, dict): @@ -230,10 +238,14 @@ def get_all(self, collection_id) -> ItemCollection: return ItemCollection.from_dict(data) def get(self, collection_id: str, item_id: str) -> Item: - url_str = str(self.get_items_url_for_id(collection_id, item_id)) - response = requests.get(url_str, auth=self._auth) + response = self._rest_api.get(self.get_items_url_for_id(collection_id, item_id)) + _check_response_status(response, _EXPECTED_STATUS_GET) - return Item.from_dict(response.json()) + data = response.json() + if not isinstance(data, dict): + raise Exception(f"Expected a dict in the JSON body but received type {type(data)}, value={data!r}") + + return Item.from_dict(data) def exists_by_id(self, collection_id: str, item_id: str) -> bool: if not collection_id: @@ -245,8 +257,7 @@ def exists_by_id(self, collection_id: str, item_id: str) -> bool: raise InvalidOperation( f"item_id must have a non-empty str value. Actual type and value: {type(item_id)=}, {item_id=!r}" ) - url_str = str(self.get_items_url_for_id(collection_id, item_id)) - response = requests.get(url_str, auth=self._auth) + response = self._rest_api.get(self.get_items_url_for_id(collection_id, item_id)) # We do expect HTTP 404 when it doesn't exist. # Any other error status means there is an actual problem. @@ -270,8 +281,7 @@ def exists(self, item: Item) -> bool: def create(self, item: Item) -> dict: item.validate() - url_str = str(self.get_items_url(item.collection_id)) - response = requests.post(url_str, json=item.to_dict(), auth=self._auth) + response = self._rest_api.post(self.get_items_url(item.collection_id), json=item.to_dict()) _check_response_status(response, _EXPECTED_STATUS_POST) return response.json() @@ -280,16 +290,15 @@ def ingest_bulk(self, items: Iterable[Item]) -> dict: if not all(i.collection_id == collection_id for i in items): raise Exception("All collection IDs should be identical for bulk ingests") - url_str = str(self._stac_api_url / "collections" / str(collection_id) / "bulk_items") + url_path = str(self._stac_api_url / "collections" / str(collection_id) / "bulk_items") data = {"items": {item.id: item.to_dict() for item in items}} - response = requests.post(url_str, json=data, auth=self._auth) + response = self._rest_api.post(url_path, json=data) _check_response_status(response, _EXPECTED_STATUS_POST) return response.json() def update(self, item: Item) -> dict: item.validate() - url_str = str(self.get_items_url_for_id(item.collection_id, item.id)) - response = requests.put(url_str, json=item.to_dict(), auth=self._auth) + response = self._rest_api.put(self.get_items_url_for_id(item.collection_id, item.id), json=item.to_dict()) _check_response_status(response, _EXPECTED_STATUS_PUT) return response.json() @@ -309,8 +318,7 @@ def delete_by_id(self, collection_id: str, item_id: str) -> dict: raise InvalidOperation( f"item_id must have a non-empty str value. Actual type and value: {type(item_id)=}, {item_id=!r}" ) - url_str = str(self.get_items_url_for_id(collection_id, item_id)) - response = requests.delete(url_str, auth=self._auth) + response = self._rest_api.delete(self.get_items_url_for_id(collection_id, item_id)) _check_response_status(response, _EXPECTED_STATUS_DELETE) return response.json() diff --git a/tests/stacapi/test_endpoints.py b/tests/stacapi/test_endpoints.py index bdda5e8..b445634 100644 --- a/tests/stacapi/test_endpoints.py +++ b/tests/stacapi/test_endpoints.py @@ -1,14 +1,22 @@ import datetime as dt +import json +from pathlib import Path import pytest import pystac import requests +import shapely +from pystac import Asset, Collection, Item, ItemCollection, Extent, SpatialExtent, TemporalExtent +from pystac.layout import TemplateLayoutStrategy from requests.auth import AuthBase -from pystac import Collection, Item, Extent, SpatialExtent, TemporalExtent from yarl import URL -from stacbuilder.stacapi.endpoints import CollectionsEndpoint, RestApi +from stacbuilder.stacapi.endpoints import CollectionsEndpoint, ItemsEndpoint, RestApi +from stacbuilder.boundingbox import BoundingBox + + +API_BASE_URL = URL("http://test.stacapi.local") @pytest.fixture @@ -22,34 +30,97 @@ def default_extent() -> Extent: @pytest.fixture -def test_provider() -> pystac.Provider: +def provider() -> pystac.Provider: return pystac.Provider( name="ACME Faux GeoData Org", description="ACME providers of faux geodata", roles=[pystac.ProviderRole.PRODUCER] ) @pytest.fixture -def test_collection(test_provider, default_extent) -> Collection: - return Collection( - id="ACME-test-collection", +def empty_collection(provider, default_extent) -> Collection: + coll_id = "ACME-test-collection" + collection = Collection( + id=coll_id, title="Collection of faux ACME data", description="Collection of faux data from ACME org", keywords=["foo", "bar"], - providers=[test_provider], + providers=[provider], extent=default_extent, ) + return collection + + +def create_asset(asset_path: Path) -> Asset: + return Asset( + href=str(asset_path), + title=asset_path.stem, + description=f"GeoTIFF File {asset_path.stem}", + media_type=pystac.MediaType.COG, + roles=["data"], + ) @pytest.fixture -def test_items() -> list[Item]: - return [] +def asset_paths() -> dict[str, Path]: + return { + "t2m": Path("2000/observations_2m-temp-monthly_2000-01-01.tif"), + "pr_tot": Path("2000/observations_tot-precip-monthly_2000-01-01.tif"), + } + + +@pytest.fixture +def fake_assets(asset_paths: dict[str, Path]) -> dict[str, Asset]: + return {asset_type: create_asset(path) for asset_type, path in asset_paths.items()} + + +def create_item(item_id: str, fake_assets) -> Item: + bbox_list = [-180, -90, 180, 90] + geometry = BoundingBox.from_list(bbox_list, epsg=4326).as_geometry_dict() + + polygon: shapely.Polygon = shapely.from_geojson(json.dumps(geometry)) + geo_json = shapely.to_geojson(polygon) + geo_dict = json.loads(geo_json) + + item = pystac.Item( + id=item_id, + assets=fake_assets, + bbox=bbox_list, + geometry=geo_dict, + datetime=dt.datetime(2024, 1, 1), + properties={}, + href=f"./{item_id}", + ) + + item.validate() + + return item @pytest.fixture -def test_collection_with_items(test_collection, test_items) -> Collection: - test_collection.add_items(test_items) - test_collection.update_extent_from_items() - return test_collection +def single_item(fake_assets) -> Item: + return create_item("items01", fake_assets) + + +@pytest.fixture +def multiple_items(fake_assets) -> Item: + return [create_item("items01", fake_assets), create_item("items02", fake_assets)] + + +def feature_collection(multiple_items) -> ItemCollection: + return ItemCollection(items=multiple_items) + + +@pytest.fixture +def collection_with_items(empty_collection, multiple_items) -> Collection: + item: Item + for item in multiple_items: + item.collection = empty_collection + + empty_collection.add_items(multiple_items) + empty_collection.update_extent_from_items() + empty_collection.make_all_asset_hrefs_relative() + + return empty_collection class FauxAuth(AuthBase): @@ -81,7 +152,6 @@ def create(collection: Collection, items: list[Item]): return api def get(self, url_path: str, *args, **kwargs) -> requests.Response: - # return super().get(url_path, *args, **kwargs) path_parts = url_path.split("/") coll_id = path_parts[-1] if coll_id not in self.collections: @@ -90,7 +160,7 @@ def get(self, url_path: str, *args, **kwargs) -> requests.Response: class TestRestApi: - BASE_URL = URL("http://test.local/api") + BASE_URL = API_BASE_URL @pytest.fixture def api(self): @@ -118,53 +188,127 @@ def test_delete(self, requests_mock, api): class TestCollectionsEndPoint: - BASE_URL_STR = "https://test.stacapi.local" - BASE_URL = URL(BASE_URL_STR) + BASE_URL = API_BASE_URL + BASE_URL_STR = str(API_BASE_URL) @pytest.fixture def collection_endpt(self) -> CollectionsEndpoint: return CollectionsEndpoint.create_endpoint(self.BASE_URL_STR, auth=None) - def test_get(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + def test_get(self, requests_mock, empty_collection: Collection, collection_endpt: CollectionsEndpoint): m = requests_mock.get( - str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + str(self.BASE_URL / "collections" / empty_collection.id), json=empty_collection.to_dict(), status_code=200 ) - actual_collection = collection_endpt.get(test_collection.id) - assert test_collection.to_dict() == actual_collection.to_dict() + actual_collection = collection_endpt.get(empty_collection.id) + assert empty_collection.to_dict() == actual_collection.to_dict() assert m.called @pytest.mark.xfail(reason="Test not implemented yet") - def test_get_all(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + def test_get_all(self, requests_mock, empty_collection: Collection, collection_endpt: CollectionsEndpoint): assert False, "Test not implemented yet" - def test_create(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): - m = requests_mock.post(str(self.BASE_URL / "collections"), json=test_collection.to_dict(), status_code=201) - response_json = collection_endpt.create(test_collection) - assert test_collection.to_dict() == response_json + def test_create(self, requests_mock, empty_collection: Collection, collection_endpt: CollectionsEndpoint): + m = requests_mock.post(str(self.BASE_URL / "collections"), json=empty_collection.to_dict(), status_code=201) + response_json = collection_endpt.create(empty_collection) + assert empty_collection.to_dict() == response_json assert m.called - def test_update(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + def test_update(self, requests_mock, empty_collection: Collection, collection_endpt: CollectionsEndpoint): m = requests_mock.put( - str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + str(self.BASE_URL / "collections" / empty_collection.id), json=empty_collection.to_dict(), status_code=200 ) - response_json = collection_endpt.update(test_collection) - assert test_collection.to_dict() == response_json + response_json = collection_endpt.update(empty_collection) + assert empty_collection.to_dict() == response_json assert m.called - def test_delete_by_id(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + def test_delete_by_id(self, requests_mock, empty_collection: Collection, collection_endpt: CollectionsEndpoint): m = requests_mock.delete( - str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + str(self.BASE_URL / "collections" / empty_collection.id), json=empty_collection.to_dict(), status_code=200 ) - collection_endpt.delete_by_id(test_collection.id) + collection_endpt.delete_by_id(empty_collection.id) assert m.called - def test_delete(self, requests_mock, test_collection: Collection, collection_endpt: CollectionsEndpoint): + def test_delete(self, requests_mock, empty_collection: Collection, collection_endpt: CollectionsEndpoint): m = requests_mock.delete( - str(self.BASE_URL / "collections" / test_collection.id), json=test_collection.to_dict(), status_code=200 + str(self.BASE_URL / "collections" / empty_collection.id), json=empty_collection.to_dict(), status_code=200 ) - collection_endpt.delete(test_collection) + collection_endpt.delete(empty_collection) assert m.called class TestItemsEndPoint: - pass + BASE_URL = API_BASE_URL + BASE_URL_STR = str(API_BASE_URL) + + @pytest.fixture(autouse=True) + def items_endpt(self) -> ItemsEndpoint: + return ItemsEndpoint.create_endpoint(self.BASE_URL_STR, auth=None) + + def test_get(self, requests_mock, collection_with_items: Collection, items_endpt: ItemsEndpoint): + items = list(collection_with_items.get_all_items()) + expected_item: Item = items[0] + m = requests_mock.get( + str(self.BASE_URL / "collections" / collection_with_items.id / "items" / expected_item.id), + json=expected_item.to_dict(), + status_code=200, + ) + actual_item: Item = items_endpt.get(collection_with_items.id, expected_item.id) + assert expected_item.id == actual_item.id + assert expected_item.collection_id == actual_item.collection_id + assert expected_item.bbox == actual_item.bbox + + assert len(expected_item.assets) == len(actual_item.assets) + for asset_type, expected_asset in expected_item.assets.items(): + assert asset_type in actual_item.assets + assert expected_asset.to_dict() == actual_item.assets[asset_type].to_dict() + + assert m.called + + @pytest.mark.skip(reason="Test not yet correct, ItemCollection does not work yet") + def test_get_all(self, requests_mock, collection_with_items: Collection, items_endpt: ItemsEndpoint): + collection_path = Path(collection_with_items.self_href) + if not collection_path.parent.exists(): + collection_path.mkdir(parents=True) + collection_with_items.save(catalog_type=pystac.CatalogType.SELF_CONTAINED) + + expected_items = list(collection_with_items.get_all_items()) + expected_item_collection = ItemCollection(expected_items) + data = expected_item_collection.to_dict() + m = requests_mock.get( + str(self.BASE_URL / "collections" / collection_with_items.id / "items"), json=data, status_code=200 + ) + actual_item_collection = items_endpt.get_all(collection_with_items.id) + + assert expected_item_collection == actual_item_collection + assert m.called + + def test_create(self, requests_mock, collection_with_items: Collection, items_endpt: ItemsEndpoint, tmp_path): + collection_dir = tmp_path / "STAC" / collection_with_items.id + strategy = TemplateLayoutStrategy(item_template="${collection}") + collection_with_items.normalize_hrefs(root_href=str(collection_dir), strategy=strategy, skip_unresolved=True) + + items = list(collection_with_items.get_all_items()) + item = items[0] + m = requests_mock.post( + str(self.BASE_URL / "collections" / collection_with_items.id / "items"), + json=item.to_dict(), + status_code=201, + ) + actual_dict: dict = items_endpt.create(item) + assert item.to_dict() == actual_dict + assert m.called + + def test_update(self, requests_mock, collection_with_items: Collection, items_endpt: ItemsEndpoint, tmp_path): + collection_dir = tmp_path / "STAC" / collection_with_items.id + strategy = TemplateLayoutStrategy(item_template="${collection}") + collection_with_items.normalize_hrefs(root_href=str(collection_dir), strategy=strategy, skip_unresolved=True) + + item = list(collection_with_items.get_all_items())[0] + m = requests_mock.put( + str(self.BASE_URL / "collections" / collection_with_items.id / "items" / item.id), + json=item.to_dict(), + status_code=200, + ) + actual_dict: dict = items_endpt.update(item) + assert item.to_dict() == actual_dict + assert m.called