Skip to content

Commit

Permalink
Issue #16 Add initial test coverage for STAC-API endpoints module
Browse files Browse the repository at this point in the history
  • Loading branch information
JohanKJSchreurs committed Feb 29, 2024
1 parent 3f0fc3d commit 28dadf9
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 34 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
-r requirements.txt
pytest==7.2.*
requests-mock
2 changes: 1 addition & 1 deletion stacbuilder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def create_empty_collection(self, group: Optional[str | int] = None) -> None:
self._collection = collection
self._log_progress_message("DONE: create_empty_collection")

def get_default_extent(self):
def get_default_extent(self) -> Extent:
end_dt = dt.datetime.utcnow()

return Extent(
Expand Down
123 changes: 95 additions & 28 deletions stacbuilder/stacapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,59 +47,103 @@ def _check_response_status(response: requests.Response, expected_status_codes: l


class RestApi:
def __init__(self, base_url: URL | str, auth: AuthBase) -> None:
def __init__(self, base_url: URL | str, auth: AuthBase | None = None) -> None:
self.base_url = URL(base_url)
self.auth = auth or None

def get(self, url_path: str, *args, **kwargs) -> requests.Response:
return requests.get(str(self.base_url / url_path), auth=self.auth, *args, **kwargs)
def join_path(self, *url_path: list[str]) -> str:
return "/".join(url_path)
# if isinstance(url_path, list):
# return "/".join(url_path)
# return url_path

def post(self, url_path: str, *args, **kwargs) -> requests.Response:
return requests.post(str(self.base_url / url_path), auth=self.auth, *args, **kwargs)
def join_url(self, url_path: str | list[str]) -> str:
return str(self.base_url / self.join_path(url_path))

def put(self, url_path: str, *args, **kwargs) -> requests.Response:
return requests.put(str(self.base_url / url_path), auth=self.auth, *args, **kwargs)
def get(self, url_path: str | list[str], *args, **kwargs) -> requests.Response:
return requests.get(self.join_url(url_path), auth=self.auth, *args, **kwargs)

def delete(self, url_path: str, *args, **kwargs) -> requests.Response:
return requests.delete(str(self.base_url / url_path), auth=self.auth, *args, **kwargs)
def post(self, url_path: str | list[str], *args, **kwargs) -> requests.Response:
return requests.post(self.join_url(url_path), auth=self.auth, *args, **kwargs)

def put(self, url_path: str | list[str], *args, **kwargs) -> requests.Response:
return requests.put(self.join_url(url_path), auth=self.auth, *args, **kwargs)

def delete(self, url_path: str | list[str], *args, **kwargs) -> requests.Response:
return requests.delete(self.join_url(url_path), auth=self.auth, *args, **kwargs)


class CollectionsEndpoint:
def __init__(self, stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None) -> None:
self._stac_api_url = URL(stac_api_url)
self._collections_url = self._stac_api_url / "collections"
self._auth = auth or None
def __init__(self, rest_api: RestApi, collection_auth_info: dict | None = None) -> None:
self._rest_api = rest_api
# self._collections_path = "collections"
self._collection_auth_info: dict | None = collection_auth_info or None

@staticmethod
def create_endpoint(
stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None
) -> "CollectionsEndpoint":
rest_api = RestApi(base_url=stac_api_url, auth=auth)
return CollectionsEndpoint(
rest_api=rest_api,
collection_auth_info=collection_auth_info,
)

@property
def stac_api_url(self) -> URL:
return self._stac_api_url
return self._rest_api.base

@property
def collections_url(self) -> URL:
return self._collections_url
# def collections_path(self) -> str:
# return self._collections_path

# def _join_path(self, *url_path: list[str]) -> str:
# return self._rest_api.join_path(*url_path)

# def get_collections_path(self, collection_id: str | None = None) -> str:
# if not collection_id:
# return self._rest_api.join_path("collections")
# return self._rest_api.join_path("collections", str(collection_id))

# def get_collections_url(self, collection_id: str | None) -> URL:
# if not collection_id:
# return self._rest_api.join_url("collections")
# return self._rest_api.join_url("collections", str(collection_id))

def get_all(self) -> List[Collection]:
response = requests.get(str(self.collections_url), auth=self._auth)
response = self._rest_api.get("collections")

_check_response_status(response, _EXPECTED_STATUS_GET)
data = response.json()
if not isinstance(data, dict):
raise Exception(f"Expected a dict in the JSON body but received type {type(data)}, value={data!r}")

return [Collection.from_dict(j) for j in data.get("collections", [])]

def get(self, collection_id: str) -> Collection:
if not collection_id:
raise ValueError(f'Argument "collection_id" must have a value of type str. {collection_id=!r}')
if not isinstance(collection_id, str):
raise TypeError(f'Argument "collection_id" must be of type str, but its type is {type(collection_id)=}')

url_str = str(self.collections_url / str(collection_id))
response = requests.get(url_str, auth=self._auth)
if collection_id == "":
raise ValueError(
f'Argument "collection_id" must have a value; it can not be the empty string. {collection_id=!r}'
)

response = self._rest_api.get(f"collections/{collection_id}")
_check_response_status(response, _EXPECTED_STATUS_GET)

return Collection.from_dict(response.json())

def exists(self, collection_id: str) -> bool:
url_str = str(self.collections_url / str(collection_id))
response = requests.get(url_str, auth=self._auth)
if not isinstance(collection_id, str):
raise TypeError(f'Argument "collection_id" must be of type str, but its type is {type(collection_id)=}')

if collection_id == "":
raise ValueError(
f'Argument "collection_id" must have a value; it can not be the empty string. {collection_id=!r}'
)

response = self._rest_api.get(f"collections/{collection_id}")

# We do expect HTTP 404 when it doesn't exist.
# Any other error status means there is an actual problem.
Expand All @@ -109,22 +153,44 @@ def exists(self, collection_id: str) -> bool:
return True

def create(self, collection: Collection) -> dict:
if not isinstance(collection, Collection):
raise TypeError(
f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}'
)

collection.validate()
data = self._add_authentication_section(collection)
response = requests.post(str(self.collections_url), json=data, auth=self._auth)
response = self._rest_api.post("collections", json=data)
_check_response_status(response, _EXPECTED_STATUS_POST)

return response.json()

def update(self, collection: Collection) -> dict:
if not isinstance(collection, Collection):
raise TypeError(
f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}'
)

collection.validate()
data = self._add_authentication_section(collection)
response = requests.put(str(self.collections_url), json=data, auth=self._auth)
response = self._rest_api.put(f"collections/{collection.id}", json=data)
_check_response_status(response, _EXPECTED_STATUS_PUT)

return response.json()

def delete(self, collection: Collection) -> dict:
collection.validate()
response = requests.delete(str(self.collections_url), json=collection.to_dict(), auth=self._auth)
return self.delete_by_id(collection.id)

def delete_by_id(self, collection_id: str) -> dict:
if not isinstance(collection_id, str):
raise TypeError(f'Argument "collection_id" must be of type str, but its type is {type(collection_id)=}')

if collection_id == "":
raise ValueError(
f'Argument "collection_id" must have a value; it can not be the empty string. {collection_id=!r}'
)

response = self._rest_api.delete(f"collections/{collection_id}")
_check_response_status(response, _EXPECTED_STATUS_DELETE)
return response.json()

Expand All @@ -135,10 +201,11 @@ def create_or_update(self, collection: Collection) -> dict:
else:
return self.create(collection)

def _add_authentication_section(self, collection) -> dict:
def _add_authentication_section(self, collection: Collection) -> dict:
coll_dict = collection.to_dict()
if self._collection_auth_info:
coll_dict.update(self._collection_auth_info)

return coll_dict


Expand Down
12 changes: 7 additions & 5 deletions stacbuilder/stacapi/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from stacbuilder.stacapi.auth import get_auth
from stacbuilder.stacapi.config import Settings
from stacbuilder.stacapi.endpoints import CollectionsEndpoint, ItemsEndpoint
from stacbuilder.stacapi.endpoints import CollectionsEndpoint, ItemsEndpoint, RestApi


_logger = logging.Logger(__name__)
Expand All @@ -25,15 +25,17 @@ def __init__(self, collections_ep: CollectionsEndpoint, items_ep: ItemsEndpoint)
@classmethod
def from_settings(cls, settings: Settings) -> "Uploader":
auth = get_auth(settings.auth)
return cls.setup(
return cls.create_uploader(
stac_api_url=settings.stac_api_url, auth=auth, collection_auth_info=settings.collection_auth_info
)

@staticmethod
def setup(stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None) -> "Uploader":
def create_uploader(
stac_api_url: URL, auth: AuthBase | None, collection_auth_info: dict | None = None
) -> "Uploader":
rest_api = RestApi(base_url=stac_api_url, auth=auth)
collections_endpoint = CollectionsEndpoint(
stac_api_url=stac_api_url,
auth=auth,
rest_api=rest_api,
collection_auth_info=collection_auth_info,
)
items_endpoint = ItemsEndpoint(stac_api_url=stac_api_url, auth=auth)
Expand Down
Loading

0 comments on commit 28dadf9

Please sign in to comment.