diff --git a/.example.env b/.example.env index f639ead5..c3fb4b62 100644 --- a/.example.env +++ b/.example.env @@ -4,12 +4,13 @@ CDK_DEFAULT_REGION=[REQUIRED IF DEPLOYING TO EXISTING VPC] STAGE=[FILL ME IN] -VEDA_PROJECT_NAME= -VEDA_PROJECT_DESCRIPTION= - -VEDA_DB_PGSTAC_VERSION=0.6.6 +VEDA_DB_PGSTAC_VERSION=0.7.10 VEDA_DB_SCHEMA_VERSION=0.1.0 VEDA_DB_SNAPSHOT_ID=[OPTIONAL BUT **REQUIRED** FOR ALL DEPLOYMENTS AFTER BASING DEPLOYMENT ON SNAPSHOT] +VEDA_DB_PUBLICLY_ACCESSIBLE=TRUE +VEDA_DB_USE_RDS_PROXY=[OPTIONAL] +VEDA_DB_RDS_INSTANCE_CLASS=[OPTIONAL] +VEDA_DB_RDS_INSTANCE_SIZE=[OPTIONAL] VEDA_DOMAIN_HOSTED_ZONE_ID=[OPTIONAL] VEDA_DOMAIN_HOSTED_ZONE_NAME=[OPTIONAL] @@ -21,11 +22,16 @@ VEDA_DOMAIN_ALT_HOSTED_ZONE_NAME=[OPTIONAL SECOND DOMAIN] VEDA_RASTER_ENABLE_MOSAIC_SEARCH=TRUE VEDA_RASTER_DATA_ACCESS_ROLE_ARN=[OPTIONAL ARN OF IAM ROLE TO BE ASSUMED BY RASTER API] VEDA_RASTER_EXPORT_ASSUME_ROLE_CREDS_AS_ENVS=False - -VEDA_DB_PUBLICLY_ACCESSIBLE=TRUE - VEDA_RASTER_ROOT_PATH= + VEDA_STAC_ROOT_PATH= +VEDA_STAC_ENABLE_TRANSACTIONS=FALSE + +VEDA_USERPOOL_ID= +VEDA_CLIENT_ID= +VEDA_CLIENT_SECRET=secret +VEDA_DATA_ACCESS_ROLE_ARN= +VEDA_COGNITO_DOMAIN= STAC_BROWSER_BUCKET= STAC_URL= diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 64e4e495..f3c12243 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -65,6 +65,9 @@ jobs: - name: Install reqs for ingest api run: python -m pip install -r ingest_api/runtime/requirements_dev.txt + - name: Install veda auth for ingest api + run: python -m pip install common/auth + - name: Ingest unit tests run: NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest ingest_api/runtime/tests/ -vv -s diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9afb6849..7541f92b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -65,6 +65,9 @@ jobs: - name: Install reqs for ingest api run: python -m pip install -r ingest_api/runtime/requirements_dev.txt + - name: Install veda auth for ingest api + run: python -m pip install common/auth + - name: Ingest unit tests run: NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest ingest_api/runtime/tests/ -vv -s diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index ccad6ebc..03b7fbad 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -77,9 +77,15 @@ jobs: - name: Install reqs for ingest api run: python -m pip install -r ingest_api/runtime/requirements_dev.txt + - name: Install veda auth for ingest api + run: python -m pip install common/auth + - name: Ingest unit tests run: NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest ingest_api/runtime/tests/ -vv -s + # - name: Stac-api transactions unit tests + # run: python -m pytest stac_api/runtime/tests/ -vv -s + - name: Stop services run: docker compose stop diff --git a/README.md b/README.md index 15fb6fe6..3407ff31 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ This project uses an AWS CDK [CloudFormation](https://docs.aws.amazon.com/AWSClo ### Enviroment variables -An [.example.env](.example.env) template is supplied for for local deployments. If updating an existing deployment, it is essential to check the most current values for these variables by fetching these values from AWS Secrets Manager. The environment secrets are named `--env`, for example `veda-backend-dev-env`. +An [.example.env](.example.env) template is supplied for local deployments. If updating an existing deployment, it is essential to check the most current values for these variables by fetching these values from AWS Secrets Manager. The environment secrets are named `--env`, for example `veda-backend-dev-env`. > **Warning** The environment variables stored as AWS secrets are manually maintained and should be reviewed before deploying updates to existing stacks. ### Fetch environment variables using AWS CLI @@ -92,6 +92,8 @@ python3 -m pip install -e ".[dev,deploy,test]" #### Run the deployment ``` +# Login to ECR so that you can pull public docker images +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws # Review what infrastructure changes your deployment will cause cdk diff # Execute deployment and standby--security changes will require approval for deployment diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 00000000..bde69941 --- /dev/null +++ b/common/__init__.py @@ -0,0 +1 @@ +"""common utils shared by veda stacks""" diff --git a/common/auth/setup.py b/common/auth/setup.py new file mode 100644 index 00000000..a076c143 --- /dev/null +++ b/common/auth/setup.py @@ -0,0 +1,17 @@ +"""Setup veda_auth +""" + +from setuptools import find_packages, setup + +inst_reqs = ["cryptography>=42.0.5", "pyjwt>=2.8.0", "fastapi", "pydantic<2"] + +setup( + name="veda_auth", + version="0.0.1", + description="", + python_requires=">=3.7", + packages=find_packages(), + zip_safe=False, + install_requires=inst_reqs, + include_package_data=True, +) diff --git a/common/auth/veda_auth/__init__.py b/common/auth/veda_auth/__init__.py new file mode 100644 index 00000000..1948ba93 --- /dev/null +++ b/common/auth/veda_auth/__init__.py @@ -0,0 +1,5 @@ +""" + VEDA cognito auth +""" + +from veda_auth.main import VedaAuth # noqa: F401 diff --git a/common/auth/veda_auth/main.py b/common/auth/veda_auth/main.py new file mode 100644 index 00000000..ce19e65f --- /dev/null +++ b/common/auth/veda_auth/main.py @@ -0,0 +1,128 @@ +"""Authentication handler for veda.stac and veda.ingest""" + +import base64 +import hashlib +import hmac +import logging +from typing import Annotated, Any, Dict + +import boto3 +import jwt + +from fastapi import Depends, HTTPException, Security, security, status + +logger = logging.getLogger(__name__) + + +class VedaAuth: + """Class for handling authentication""" + + def __init__(self, settings) -> None: + """ + Args: + settings: pydantic settings object containing cognito details + Returns: + None + + """ + self.oauth2_scheme = security.OAuth2AuthorizationCodeBearer( + authorizationUrl=settings.cognito_authorization_url, + tokenUrl=settings.cognito_token_url, + refreshUrl=settings.cognito_token_url, + ) + + self.jwks_client = jwt.PyJWKClient(settings.jwks_url) # Caches JWKS + + def validated_token( + token_str: Annotated[str, Security(self.oauth2_scheme)], + required_scopes: security.SecurityScopes, + ) -> Dict: + # Parse & validate token + logger.info(f"\nToken String {token_str}") + try: + token = jwt.decode( + token_str, + self.jwks_client.get_signing_key_from_jwt(token_str).key, + algorithms=["RS256"], + ) + except jwt.exceptions.InvalidTokenError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) from e + + # Validate scopes (if required) + for scope in required_scopes.scopes: + if scope not in token["scope"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={ + "WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"' + }, + ) + + return token + + self.validated_token = validated_token + + def get_username( + token: Annotated[Dict[Any, Any], Depends(self.validated_token)] + ) -> str: + result = token["username"] if "username" in token else str(token.get("sub")) + return result + + self.get_username = get_username + + def _get_secret_hash( + self, username: str, client_id: str, client_secret: str + ) -> str: + # A keyed-hash message authentication code (HMAC) calculated using + # the secret key of a user pool client and username plus the client + # ID in the message. + message = username + client_id + dig = hmac.new( + bytearray(client_secret, "utf-8"), + msg=message.encode("UTF-8"), + digestmod=hashlib.sha256, + ).digest() + return base64.b64encode(dig).decode() + + def authenticate_and_get_token( + self, + username: str, + password: str, + user_pool_id: str, + app_client_id: str, + app_client_secret: str, + ) -> Dict: + """Authenticates the credentials and returns token""" + client = boto3.client("cognito-idp") + if app_client_secret: + auth_params = { + "USERNAME": username, + "PASSWORD": password, + "SECRET_HASH": self._get_secret_hash( + username, app_client_id, app_client_secret + ), + } + else: + auth_params = { + "USERNAME": username, + "PASSWORD": password, + } + try: + resp = client.admin_initiate_auth( + UserPoolId=user_pool_id, + ClientId=app_client_id, + AuthFlow="ADMIN_USER_PASSWORD_AUTH", + AuthParameters=auth_params, + ) + except client.exceptions.NotAuthorizedException: + return { + "message": "Login failed, please make sure the credentials are correct." + } + except Exception as e: + return {"message": f"Login failed with exception {e}"} + return resp["AuthenticationResult"] diff --git a/docker-compose.yml b/docker-compose.yml index a126d76c..5c513a03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -128,6 +128,7 @@ services: - PGPASSWORD=password - PGDATABASE=postgis - DYNAMODB_ENDPOINT=http://localhost:8085 + - VEDA_DB_PGSTAC_VERSION=0.7.10 ports: - "8083:8083" command: bash -c "bash /tmp/scripts/wait-for-it.sh -t 120 -h database -p 5432 && python /asset/local.py" diff --git a/ingest_api/infrastructure/config.py b/ingest_api/infrastructure/config.py index 13651b6b..771d346d 100644 --- a/ingest_api/infrastructure/config.py +++ b/ingest_api/infrastructure/config.py @@ -67,6 +67,10 @@ class IngestorConfig(BaseSettings): ingest_root_path: str = Field("", description="Root path for ingest API") custom_host: Optional[str] = Field(description="Custom host name") + db_pgstac_version: str = Field( + ..., + description="Version of PgStac database, i.e. 0.5", + ) class Config: case_sensitive = False diff --git a/ingest_api/infrastructure/construct.py b/ingest_api/infrastructure/construct.py index 1e763d09..c4322647 100644 --- a/ingest_api/infrastructure/construct.py +++ b/ingest_api/infrastructure/construct.py @@ -66,6 +66,7 @@ def __init__( "db_secret": db_secret, "db_vpc": db_vpc, "db_security_group": db_security_group, + "pgstac_version": config.db_pgstac_version, } if config.raster_data_access_role_arn: @@ -98,21 +99,15 @@ def __init__( custom_host=config.custom_host, ) - # CfnOutput(self, "ingest-api", value=self.api.url) stack_name = Stack.of(self).stack_name CfnOutput( self, "stac-ingestor-api-url", export_name=f"{stack_name}-stac-ingestor-api-url", value=self.api.url, + key="ingestapiurl", ) - register_ssm_parameter( - self, - name="jwks_url", - value=self.jwks_url, - description="JWKS URL for Cognito user pool", - ) register_ssm_parameter( self, name="dynamodb_table", @@ -130,6 +125,7 @@ def build_api_lambda( db_vpc: ec2.IVpc, db_security_group: ec2.ISecurityGroup, data_access_role: Union[iam.IRole, None] = None, + pgstac_version: str, code_dir: str = "./", ) -> apigateway.LambdaRestApi: stack_name = Stack.of(self).stack_name @@ -156,6 +152,7 @@ def build_api_lambda( path=os.path.abspath(code_dir), file="ingest_api/runtime/Dockerfile", platform="linux/amd64", + build_args={"PGSTAC_VERSION": pgstac_version}, ), runtime=aws_lambda.Runtime.PYTHON_3_9, timeout=Duration.seconds(30), @@ -297,6 +294,7 @@ def __init__( db_vpc=db_vpc, db_security_group=db_security_group, db_vpc_subnets=db_vpc_subnets, + pgstac_version=config.db_pgstac_version, ) def build_ingestor( @@ -308,6 +306,7 @@ def build_ingestor( db_vpc: ec2.IVpc, db_security_group: ec2.ISecurityGroup, db_vpc_subnets: ec2.SubnetSelection, + pgstac_version: str, code_dir: str = "./", ) -> aws_lambda.Function: handler = aws_lambda.Function( @@ -317,6 +316,7 @@ def build_ingestor( path=os.path.abspath(code_dir), file="ingest_api/runtime/Dockerfile", platform="linux/amd64", + build_args={"PGSTAC_VERSION": pgstac_version}, ), handler="ingestor.handler", runtime=aws_lambda.Runtime.PYTHON_3_9, diff --git a/ingest_api/runtime/Dockerfile b/ingest_api/runtime/Dockerfile index c2955171..01bc2f6a 100644 --- a/ingest_api/runtime/Dockerfile +++ b/ingest_api/runtime/Dockerfile @@ -1,9 +1,16 @@ FROM public.ecr.aws/sam/build-python3.9:latest +ARG PGSTAC_VERSION +RUN echo "Using PGSTAC Version ${PGSTAC_VERSION}" + WORKDIR /tmp +COPY common/auth /tmp/common/auth +RUN pip install /tmp/common/auth -t /asset +RUN rm -rf /tmp/common + COPY ingest_api/runtime/requirements.txt /tmp/ingestor/requirements.txt -RUN pip install -r /tmp/ingestor/requirements.txt -t /asset --no-binary pydantic uvicorn +RUN pip install -r /tmp/ingestor/requirements.txt pypgstac==${PGSTAC_VERSION} -t /asset --no-binary pydantic uvicorn RUN rm -rf /tmp/ingestor # TODO this is temporary until we use a real packaging system like setup.py or poetry COPY ingest_api/runtime/src /asset/src diff --git a/ingest_api/runtime/requirements.txt b/ingest_api/runtime/requirements.txt index de984f61..c773522e 100644 --- a/ingest_api/runtime/requirements.txt +++ b/ingest_api/runtime/requirements.txt @@ -1,15 +1,13 @@ # Waiting for https://github.com/stac-utils/stac-pydantic/pull/116 and 117 cryptography>=42.0.5 ddbcereal==2.1.1 -fastapi<=0.108.0 +fastapi>=0.109.1 fsspec==2023.3.0 mangum>=0.15.0 orjson>=3.6.8 psycopg[binary,pool]>=3.0.15 pydantic_ssm_settings>=0.2.0 pydantic>=1.10.12 -pyjwt>=2.8.0 -pypgstac==0.7.4 python-multipart==0.0.7 requests>=2.27.1 s3fs==2023.3.0 diff --git a/ingest_api/runtime/src/auth.py b/ingest_api/runtime/src/auth.py deleted file mode 100644 index 294bec64..00000000 --- a/ingest_api/runtime/src/auth.py +++ /dev/null @@ -1,106 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -from typing import Annotated, Any, Dict - -import boto3 -import jwt -from src.config import settings - -from fastapi import Depends, HTTPException, Security, security, status - -logger = logging.getLogger(__name__) - -oauth2_scheme = security.OAuth2AuthorizationCodeBearer( - authorizationUrl=settings.cognito_authorization_url, - tokenUrl=settings.cognito_token_url, - refreshUrl=settings.cognito_token_url, -) - -jwks_client = jwt.PyJWKClient(settings.jwks_url) # Caches JWKS - - -def validated_token( - token_str: Annotated[str, Security(oauth2_scheme)], - required_scopes: security.SecurityScopes, -) -> Dict: - # Parse & validate token - try: - token = jwt.decode( - token_str, - jwks_client.get_signing_key_from_jwt(token_str).key, - algorithms=["RS256"], - ) - except jwt.exceptions.InvalidTokenError as e: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) from e - - # Validate scopes (if required) - for scope in required_scopes.scopes: - if scope not in token["scope"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not enough permissions", - headers={ - "WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"' - }, - ) - - return token - - -def get_username(token: Annotated[Dict[Any, Any], Depends(validated_token)]) -> str: - result = token["username"] if "username" in token else str(token.get("sub")) - return result - - -def _get_secret_hash(username: str, client_id: str, client_secret: str) -> str: - # A keyed-hash message authentication code (HMAC) calculated using - # the secret key of a user pool client and username plus the client - # ID in the message. - message = username + client_id - dig = hmac.new( - bytearray(client_secret, "utf-8"), - msg=message.encode("UTF-8"), - digestmod=hashlib.sha256, - ).digest() - return base64.b64encode(dig).decode() - - -def authenticate_and_get_token( - username: str, - password: str, - user_pool_id: str, - app_client_id: str, - app_client_secret: str, -) -> Dict: - client = boto3.client("cognito-idp") - if app_client_secret: - auth_params = { - "USERNAME": username, - "PASSWORD": password, - "SECRET_HASH": _get_secret_hash(username, app_client_id, app_client_secret), - } - else: - auth_params = { - "USERNAME": username, - "PASSWORD": password, - } - try: - resp = client.admin_initiate_auth( - UserPoolId=user_pool_id, - ClientId=app_client_id, - AuthFlow="ADMIN_USER_PASSWORD_AUTH", - AuthParameters=auth_params, - ) - except client.exceptions.NotAuthorizedException: - return { - "message": "Login failed, please make sure the credentials are correct." - } - except Exception as e: - return {"message": f"Login failed with exception {e}"} - return resp["AuthenticationResult"] diff --git a/ingest_api/runtime/src/config.py b/ingest_api/runtime/src/config.py index ff51f680..0b9e1c0b 100644 --- a/ingest_api/runtime/src/config.py +++ b/ingest_api/runtime/src/config.py @@ -4,6 +4,7 @@ from pydantic import AnyHttpUrl, BaseSettings, Field, constr from pydantic_ssm_settings import AwsSsmSourceConfig +from veda_auth import VedaAuth AwsArn = constr(regex=r"^arn:aws:iam::\d{12}:role/.+") @@ -63,3 +64,5 @@ def from_ssm(cls, stack: str): ), ) ) + +auth = VedaAuth(settings) diff --git a/ingest_api/runtime/src/dependencies.py b/ingest_api/runtime/src/dependencies.py index 901f338c..5c772c1f 100644 --- a/ingest_api/runtime/src/dependencies.py +++ b/ingest_api/runtime/src/dependencies.py @@ -1,9 +1,8 @@ import logging import boto3 -import src.auth as auth -import src.config as config import src.services as services +from src.config import auth, settings from fastapi import Depends, HTTPException, security @@ -14,7 +13,7 @@ def get_table(): client = boto3.resource("dynamodb") - return client.Table(config.settings.dynamodb_table) + return client.Table(settings.dynamodb_table) def get_db(table=Depends(get_table)) -> services.Database: diff --git a/ingest_api/runtime/src/main.py b/ingest_api/runtime/src/main.py index 9b5fa6d2..8074c860 100644 --- a/ingest_api/runtime/src/main.py +++ b/ingest_api/runtime/src/main.py @@ -1,12 +1,11 @@ from typing import Dict -import src.auth as auth import src.dependencies as dependencies import src.schemas as schemas import src.services as services from aws_lambda_powertools.metrics import MetricUnit from src.collection_publisher import CollectionPublisher, ItemPublisher -from src.config import settings +from src.config import auth, settings from src.doc import DESCRIPTION from src.monitoring import LoggerRouteHandler, logger, metrics, tracer diff --git a/ingest_api/runtime/tests/conftest.py b/ingest_api/runtime/tests/conftest.py index e24208d4..3c81c23b 100644 --- a/ingest_api/runtime/tests/conftest.py +++ b/ingest_api/runtime/tests/conftest.py @@ -27,7 +27,7 @@ def test_environ(): os.environ["RASTER_URL"] = "https://test-raster.url" os.environ["USERPOOL_ID"] = "fake_id" os.environ["STAGE"] = "testing" - os.environ["ROOT_PATH"] = "/" + os.environ["ROOT_PATH"] = "" os.environ["COGNITO_DOMAIN"] = "https://test-cognito.url" diff --git a/local/Dockerfile.ingest b/local/Dockerfile.ingest index bcce1482..ccc5741e 100644 --- a/local/Dockerfile.ingest +++ b/local/Dockerfile.ingest @@ -7,6 +7,8 @@ RUN pip install -r /tmp/ingestor/requirements.txt --no-binary pydantic uvicorn RUN rm -rf /tmp/ingestor # TODO this is temporary until we use a real packaging system like setup.py or poetry COPY ingest_api/runtime/src /asset/src +COPY common/auth /tmp/common/auth +RUN pip install /tmp/common/auth # # Reduce package size and remove useless files RUN cd /asset && find . -type f -name '*.pyc' | while read f; do n=$(echo $f | sed 's/__pycache__\///' | sed 's/.cpython-[2-3][0-9]//'); cp $f $n; done; diff --git a/local/Dockerfile.stac b/local/Dockerfile.stac index 4a3b1af1..71463001 100644 --- a/local/Dockerfile.stac +++ b/local/Dockerfile.stac @@ -4,12 +4,14 @@ FROM ghcr.io/vincentsarago/uvicorn-gunicorn:${PYTHON_VERSION} ENV CURL_CA_BUNDLE /etc/ssl/certs/ca-certificates.crt -RUN pip install boto3 - -COPY stac_api/runtime /tmp/stac # Installing boto3, which isn't needed in the lambda container instance # since lambda execution environment includes boto3 by default RUN pip install boto3 + +COPY stac_api/runtime /tmp/stac + +COPY common/auth /tmp/stac/common/auth +RUN pip install /tmp/stac/common/auth RUN pip install /tmp/stac RUN rm -rf /tmp/stac diff --git a/raster_api/infrastructure/construct.py b/raster_api/infrastructure/construct.py index 95c71265..c7362148 100644 --- a/raster_api/infrastructure/construct.py +++ b/raster_api/infrastructure/construct.py @@ -122,6 +122,7 @@ def __init__( "raster-api", value=self.raster_api.url, export_name=f"{stack_name}-raster-url", + key="rasterapiurl", ) CfnOutput(self, "raster-api-arn", value=veda_raster_function.function_arn) diff --git a/s3_website/infrastructure/construct.py b/s3_website/infrastructure/construct.py index 40d7c07f..3fffca3b 100644 --- a/s3_website/infrastructure/construct.py +++ b/s3_website/infrastructure/construct.py @@ -48,6 +48,7 @@ def __init__( CfnOutput( self, - "bucket-website", - value=f"https://{self.bucket.bucket_website_domain_name}", + "stac-browser-bucket-name", + value=self.bucket.bucket_name, + key="stacbrowserbucketname", ) diff --git a/stac_api/infrastructure/config.py b/stac_api/infrastructure/config.py index 65c04eda..2196a430 100644 --- a/stac_api/infrastructure/config.py +++ b/stac_api/infrastructure/config.py @@ -2,7 +2,7 @@ from typing import Dict, Optional -from pydantic import BaseSettings, Field +from pydantic import AnyHttpUrl, BaseSettings, Field, root_validator class vedaSTACSettings(BaseSettings): @@ -44,6 +44,37 @@ class vedaSTACSettings(BaseSettings): description="Description of the STAC Catalog", ) + userpool_id: Optional[str] = Field( + description="The Cognito Userpool used for authentication" + ) + cognito_domain: Optional[AnyHttpUrl] = Field( + description="The base url of the Cognito domain for authorization and token urls" + ) + client_id: Optional[str] = Field(description="The Cognito APP client ID") + client_secret: Optional[str] = Field( + "", description="The Cognito APP client secret" + ) + stac_enable_transactions: bool = Field( + False, description="Whether to enable transactions endpoints" + ) + + @root_validator + def check_transaction_fields(cls, values): + """ + Validates the existence of auth env vars in case stac_enable_transactions is True + """ + if values.get("stac_enable_transactions"): + missing_fields = [ + field + for field in ["userpool_id", "cognito_domain", "client_id"] + if not values.get(field) + ] + if missing_fields: + raise ValueError( + f"When 'stac_enable_transactions' is True, the following fields must be provided: {', '.join(missing_fields)}" + ) + return values + class Config: """model config""" diff --git a/stac_api/infrastructure/construct.py b/stac_api/infrastructure/construct.py index 272cddca..f4848a41 100644 --- a/stac_api/infrastructure/construct.py +++ b/stac_api/infrastructure/construct.py @@ -43,6 +43,22 @@ def __init__( # TODO config stack_name = Stack.of(self).stack_name + lambda_env = { + "VEDA_STAC_PROJECT_NAME": veda_stac_settings.project_name, + "VEDA_STAC_PROJECT_DESCRIPTION": veda_stac_settings.project_description, + "VEDA_STAC_ROOT_PATH": veda_stac_settings.stac_root_path, + "VEDA_STAC_STAGE": stage, + "VEDA_STAC_USERPOOL_ID": veda_stac_settings.userpool_id, + "VEDA_STAC_CLIENT_ID": veda_stac_settings.client_id, + "VEDA_STAC_COGNITO_DOMAIN": veda_stac_settings.cognito_domain, + "VEDA_STAC_ENABLE_TRANSACTIONS": str( + veda_stac_settings.stac_enable_transactions + ), + "DB_MIN_CONN_SIZE": "0", + "DB_MAX_CONN_SIZE": "1", + **{k.upper(): v for k, v in veda_stac_settings.env.items()}, + } + lambda_function = aws_lambda.Function( self, "lambda", @@ -56,15 +72,7 @@ def __init__( allow_public_subnet=True, memory_size=veda_stac_settings.memory, timeout=Duration.seconds(veda_stac_settings.timeout), - environment={ - **{k.upper(): v for k, v in veda_stac_settings.env.items()}, - "DB_MIN_CONN_SIZE": "0", - "DB_MAX_CONN_SIZE": "1", - "VEDA_STAC_ROOT_PATH": veda_stac_settings.stac_root_path, - "VEDA_STAC_STAGE": stage, - "VEDA_STAC_PROJECT_NAME": veda_stac_settings.project_name, - "VEDA_STAC_PROJECT_DESCRIPTION": veda_stac_settings.project_description, - }, + environment=lambda_env, log_retention=aws_logs.RetentionDays.ONE_WEEK, tracing=aws_lambda.Tracing.ACTIVE, ) @@ -120,4 +128,5 @@ def __init__( "stac-api", value=self.stac_api.url, export_name=f"{stack_name}-stac-url", + key="stacapiurl", ) diff --git a/stac_api/runtime/Dockerfile b/stac_api/runtime/Dockerfile index 5dcae19c..c8a8c852 100644 --- a/stac_api/runtime/Dockerfile +++ b/stac_api/runtime/Dockerfile @@ -3,7 +3,10 @@ FROM --platform=linux/amd64 public.ecr.aws/sam/build-python3.9:latest WORKDIR /tmp COPY stac_api/runtime /tmp/stac + RUN pip install "mangum>=0.14,<0.15" "plpygis>=0.2.1" /tmp/stac -t /asset --no-binary pydantic +COPY common/auth /tmp/stac/common/auth +RUN pip install /tmp/stac/common/auth -t /asset RUN rm -rf /tmp/stac # Reduce package size and remove useless files diff --git a/stac_api/runtime/handler.py b/stac_api/runtime/handler.py index b04844cf..176be13b 100644 --- a/stac_api/runtime/handler.py +++ b/stac_api/runtime/handler.py @@ -4,11 +4,8 @@ from mangum import Mangum from src.app import app -from src.config import ApiSettings from src.monitoring import logger, metrics, tracer -settings = ApiSettings() - logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) diff --git a/stac_api/runtime/setup.py b/stac_api/runtime/setup.py index 6dfa0ff0..604e25dd 100644 --- a/stac_api/runtime/setup.py +++ b/stac_api/runtime/setup.py @@ -18,6 +18,7 @@ "pygeoif<=0.8", # newest release (1.0+ / 09-22-2022) breaks a number of other geo libs "aws-lambda-powertools>=1.18.0", "aws_xray_sdk>=2.6.0,<3", + "pystac[validation]==1.10.1", ] extra_reqs = { diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index 0e30fa82..61dac75c 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -1,14 +1,16 @@ """FastAPI application using PGStac. Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac """ + from aws_lambda_powertools.metrics import MetricUnit -from src.config import ApiSettings, TilesApiSettings +from src.config import TilesApiSettings, api_settings from src.config import extensions as PgStacExtensions from src.config import get_request_model as GETModel from src.config import post_request_model as POSTModel from src.extension import TiTilerExtension from fastapi import APIRouter, FastAPI +from fastapi.params import Depends from fastapi.responses import ORJSONResponse from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from starlette.middleware.cors import CORSMiddleware @@ -20,6 +22,8 @@ from .api import VedaStacApi from .core import VedaCrudClient from .monitoring import LoggerRouteHandler, logger, metrics, tracer +from .routes import add_route_dependencies +from .validation import ValidationMiddleware try: from importlib.resources import files as resources_files # type: ignore @@ -30,7 +34,6 @@ templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore -api_settings = ApiSettings() tiles_settings = TilesApiSettings() api = VedaStacApi( @@ -39,6 +42,15 @@ openapi_url="/openapi.json", docs_url="/docs", root_path=api_settings.root_path, + swagger_ui_init_oauth=( + { + "appName": "Cognito", + "clientId": api_settings.client_id, + "usePkceWithAuthorizationCodeGrant": True, + } + if api_settings.client_id + else {} + ), ), title=f"{api_settings.project_name} STAC API", description=api_settings.project_description, @@ -48,7 +60,7 @@ search_get_request_model=GETModel, search_post_request_model=POSTModel, response_class=ORJSONResponse, - middlewares=[CompressionMiddleware], + middlewares=[CompressionMiddleware, ValidationMiddleware], router=APIRouter(route_class=LoggerRouteHandler), ) app = api.app @@ -62,10 +74,45 @@ CORSMiddleware, allow_origins=api_settings.cors_origins, allow_credentials=True, - allow_methods=["GET", "POST", "OPTIONS"], + allow_methods=["GET", "POST", "PUT", "OPTIONS"], allow_headers=["*"], ) +if api_settings.enable_transactions: + from veda_auth import VedaAuth + + auth = VedaAuth(api_settings) + # Require auth for all endpoints that create, modify or delete data. + add_route_dependencies( + app.router.routes, + [ + {"path": "/collections", "method": "POST", "type": "http"}, + {"path": "/collections/{collectionId}", "method": "PUT", "type": "http"}, + {"path": "/collections/{collectionId}", "method": "DELETE", "type": "http"}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "type": "http", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "type": "http", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "type": "http", + }, + { + "path": "/collections/{collectionId}/bulk_items", + "method": "POST", + "type": "http", + }, + ], + [Depends(auth.validated_token)], + ) + if tiles_settings.titiler_endpoint: # Register to the TiTiler extension to the api extension = TiTilerExtension() diff --git a/stac_api/runtime/src/config.py b/stac_api/runtime/src/config.py index 0a70e13f..78eccf94 100644 --- a/stac_api/runtime/src/config.py +++ b/stac_api/runtime/src/config.py @@ -1,13 +1,15 @@ """API settings. Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac""" + import base64 import json from functools import lru_cache from typing import Optional import boto3 -import pydantic +from pydantic import AnyHttpUrl, BaseSettings, Field, root_validator, validator +from fastapi.responses import ORJSONResponse from stac_fastapi.api.models import create_get_request_model, create_post_request_model # from stac_fastapi.pgstac.extensions import QueryExtension @@ -18,8 +20,11 @@ QueryExtension, SortExtension, TokenPaginationExtension, + TransactionExtension, ) +from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.pgstac.config import Settings +from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch @@ -47,7 +52,7 @@ def get_secret_dict(secret_name: str): return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"])) -class _ApiSettings(pydantic.BaseSettings): +class _ApiSettings(BaseSettings): """API settings""" project_name: Optional[str] = "veda" @@ -59,7 +64,54 @@ class _ApiSettings(pydantic.BaseSettings): pgstac_secret_arn: Optional[str] stage: Optional[str] = None - @pydantic.validator("cors_origins") + userpool_id: Optional[str] = Field( + "", description="The Cognito Userpool used for authentication" + ) + cognito_domain: Optional[AnyHttpUrl] = Field( + description="The base url of the Cognito domain for authorization and token urls" + ) + client_id: Optional[str] = Field(description="The Cognito APP client ID") + client_secret: Optional[str] = Field( + "", description="The Cognito APP client secret" + ) + enable_transactions: bool = Field( + False, description="Whether to enable transactions" + ) + + @root_validator + def check_transaction_fields(cls, values): + enable_transactions = values.get("enable_transactions") + + if enable_transactions: + missing_fields = [ + field + for field in ["userpool_id", "cognito_domain", "client_id"] + if not values.get(field) + ] + if missing_fields: + raise ValueError( + f"When 'enable_transactions' is True, the following fields must be provided: {', '.join(missing_fields)}" + ) + return values + + @property + def jwks_url(self) -> AnyHttpUrl: + """JWKS url""" + if self.userpool_id: + region = self.userpool_id.split("_")[0] + return f"https://cognito-idp.{region}.amazonaws.com/{self.userpool_id}/.well-known/jwks.json" + + @property + def cognito_authorization_url(self) -> AnyHttpUrl: + """Cognito user pool authorization url""" + return f"{self.cognito_domain}/oauth2/authorize" + + @property + def cognito_token_url(self) -> AnyHttpUrl: + """Cognito user pool token and refresh url""" + return f"{self.cognito_domain}/oauth2/token" + + @validator("cors_origins") def parse_cors_origin(cls, v): """Parse CORS origins.""" return [origin.strip() for origin in v.split(",")] @@ -101,7 +153,10 @@ def ApiSettings() -> _ApiSettings: return _ApiSettings() -class _TilesApiSettings(pydantic.BaseSettings): +api_settings = ApiSettings() + + +class _TilesApiSettings(BaseSettings): """Tile API settings""" titiler_endpoint: Optional[str] @@ -123,12 +178,24 @@ def TilesApiSettings() -> _TilesApiSettings: extensions = [ + ContextExtension(), + FieldsExtension(), FilterExtension(), QueryExtension(), SortExtension(), - FieldsExtension(), TokenPaginationExtension(), - ContextExtension(), ] + +if api_settings.enable_transactions: + extensions.extend( + [ + BulkTransactionExtension(client=BulkTransactionsClient()), + TransactionExtension( + client=TransactionsClient(), + settings=ApiSettings().load_postgres_settings(), + response_class=ORJSONResponse, + ), + ] + ) post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) diff --git a/stac_api/runtime/src/routes.py b/stac_api/runtime/src/routes.py new file mode 100644 index 00000000..ae0bae08 --- /dev/null +++ b/stac_api/runtime/src/routes.py @@ -0,0 +1,25 @@ +"""Dependency injection in to fastapi routes""" + +from typing import List + +from fastapi.dependencies.utils import get_parameterless_sub_dependant +from fastapi.params import Depends +from fastapi.routing import APIRoute +from starlette.routing import Match +from starlette.types import Scope + + +def add_route_dependencies( + routes: List[APIRoute], scopes: List[Scope], dependencies: List[Depends] +): + """Inject dependencies to routes""" + for route in routes: + if not any(route.matches(scope)[0] == Match.FULL for scope in scopes): + continue + + route.dependant.dependencies = [ + # Mimicking how APIRoute handles dependencies: + # https://github.com/tiangolo/fastapi/blob/1760da0efa55585c19835d81afa8ca386036c325/fastapi/routing.py#L408-L412 + get_parameterless_sub_dependant(depends=depends, path=route.path_format) + for depends in dependencies + ] + route.dependant.dependencies diff --git a/stac_api/runtime/src/validation.py b/stac_api/runtime/src/validation.py new file mode 100644 index 00000000..9f429e3c --- /dev/null +++ b/stac_api/runtime/src/validation.py @@ -0,0 +1,60 @@ +"""Middleware for validating transaction endpoints""" + +import json +import re +from typing import Dict + +from pydantic import BaseModel, Field +from pystac import STACObjectType +from pystac.errors import STACValidationError +from pystac.validation import validate_dict +from src.config import api_settings + +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +path_prefix = api_settings.root_path or "" + + +class BulkItems(BaseModel): + """Validation model for bulk-items endpoint request""" + + items: Dict[str, dict] + method: str = Field(default="insert") + + +class ValidationMiddleware(BaseHTTPMiddleware): + """Middleware that handles STAC collection and item validation in transaction endpoints""" + + async def dispatch(self, request: Request, call_next): + """Middleware dispatch""" + if request.method in ("POST", "PUT"): + try: + body = await request.body() + request_data = json.loads(body) + if re.match( + f"^{path_prefix}/collections(?:/[^/]+)?$", + request.url.path, + ): + validate_dict(request_data, STACObjectType.COLLECTION) + elif re.match( + f"^{path_prefix}/collections/[^/]+/items(?:/[^/]+)?$", + request.url.path, + ): + validate_dict(request_data, STACObjectType.ITEM) + elif re.match( + f"^{path_prefix}/collections/[^/]+/bulk-items$", + request.url.path, + ): + bulk_items = BulkItems(**request_data) + for item_data in bulk_items.items.values(): + validate_dict(item_data, STACObjectType.ITEM) + except STACValidationError as e: + return JSONResponse( + status_code=422, + content={"detail": "Validation Error", "errors": str(e)}, + ) + + response = await call_next(request) + return response diff --git a/stac_api/runtime/tests/__init__.py b/stac_api/runtime/tests/__init__.py new file mode 100644 index 00000000..2d9078d5 --- /dev/null +++ b/stac_api/runtime/tests/__init__.py @@ -0,0 +1 @@ +"""STAC API tests""" diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py new file mode 100644 index 00000000..e584bde6 --- /dev/null +++ b/stac_api/runtime/tests/conftest.py @@ -0,0 +1,338 @@ +""" +Test fixtures and data for STAC Transactions API testing. + +This module contains fixtures and mock data used for testing the STAC API. +It includes valid and invalid STAC collections and items, as well as environment +setup for testing with mock AWS and PostgreSQL configurations. +""" + +import os + +import pytest + +from fastapi.testclient import TestClient + +VALID_COLLECTION = { + "id": "CMIP245-winter-median-pr", + "type": "Collection", + "title": "Projected changes to winter (January, February, and March) cumulative daily precipitation", + "links": [], + "description": "Differences in winter (January, February, and March) cumulative daily precipitation between a historical period (1995 - 2014) and multiple 20-year periods from an ensemble of CMIP6 climate projections (SSP2-4.5) downscaled by NASA Earth Exchange (NEX-GDDP-CMIP6)", + "extent": { + "spatial": {"bbox": [[-126, 30, -104, 51]]}, + "temporal": {"interval": [["2025-01-01T00:00:00Z", "2085-03-31T12:00:00Z"]]}, + }, + "license": "MIT", + "stac_extensions": [ + "https://stac-extensions.github.io/render/v1.0.0/schema.json", + "https://stac-extensions.github.io/item-assets/v1.0.0/schema.json", + ], + "item_assets": { + "cog_default": { + "type": "image/tiff; application=geotiff; profile=cloud-optimized", + "roles": ["data", "layer"], + "title": "Default COG Layer", + "description": "Cloud optimized default layer to display on map", + } + }, + "dashboard:is_periodic": False, + "dashboard:time_density": "year", + "stac_version": "1.0.0", + "renders": { + "dashboard": { + "resampling": "bilinear", + "bidx": [1], + "nodata": "nan", + "colormap_name": "rdbu", + "rescale": [[-60, 60]], + "assets": ["cog_default"], + "title": "VEDA Dashboard Render Parameters", + } + }, + "providers": [ + { + "name": "NASA Center for Climate Simulation (NCCS)", + "url": "https://www.nccs.nasa.gov/services/data-collections/land-based-products/nex-gddp-cmip6", + "roles": ["producer", "processor", "licensor"], + }, + { + "name": "NASA VEDA", + "url": "https://www.earthdata.nasa.gov/dashboard/", + "roles": ["host"], + }, + ], + "assets": { + "thumbnail": { + "title": "Thumbnail", + "description": "Photo by Justin Pflug (Photo of Nisqually glacier)", + "href": "https://thumbnails.openveda.cloud/CMIP-winter-median.jpeg", + "type": "image/jpeg", + "roles": ["thumbnail"], + } + }, +} + +VALID_ITEM = { + "id": "OMI_trno2_0.10x0.10_2023_Col3_V4", + "bbox": [-180.0, -90.0, 180.0, 90.0], + "type": "Feature", + "links": [ + { + "rel": "collection", + "type": "application/json", + "href": "https://dev.openveda.cloud/api/stac/collections/CMIP245-winter-median-pr", + }, + { + "rel": "parent", + "type": "application/json", + "href": "https://dev.openveda.cloud/api/stac/collections/CMIP245-winter-median-pr", + }, + { + "rel": "root", + "type": "application/json", + "href": "https://dev.openveda.cloud/api/stac/", + }, + { + "rel": "self", + "type": "application/geo+json", + "href": "https://dev.openveda.cloud/api/stac/collections/CMIP245-winter-median-pr/items/OMI_trno2_0.10x0.10_2023_Col3_V4", + }, + { + "title": "Map of Item", + "href": "https://dev.openveda.cloud/api/raster/stac/map?collection=CMIP245-winter-median-pr&item=OMI_trno2_0.10x0.10_2023_Col3_V4&assets=cog_default&rescale=0%2C3000000000000000&colormap_name=reds", + "rel": "preview", + "type": "text/html", + }, + ], + "assets": { + "no2": { + "href": "s3://veda-data-store-staging/OMI_trno2-COG/OMI_trno2_0.10x0.10_2023_Col3_V4.tif", + "type": "image/tiff; application=geotiff", + "roles": ["data", "layer"], + "title": "NO2 values", + "proj:bbox": [-180.0, -90.0, 180.0, 90.0], + "proj:epsg": 4326, + "proj:wkt2": 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]', + "proj:shape": [1800, 3600], + "description": "description", + "raster:bands": [ + { + "scale": 1.0, + "nodata": -1.2676506002282294e30, + "offset": 0.0, + "sampling": "area", + "data_type": "float32", + "histogram": { + "max": 14863169193246720, + "min": -2293753591103488.0, + "count": 11, + "buckets": [57, 484234, 23295, 2552, 694, 318, 230, 79, 42, 12], + }, + "statistics": { + "mean": 365095923477877.9, + "stddev": 569167954388057.0, + "maximum": 14863169193246720, + "minimum": -2293753591103488.0, + "valid_percent": 97.56336212158203, + }, + } + ], + "proj:geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-180.0, -90.0], + [180.0, -90.0], + [180.0, 90.0], + [-180.0, 90.0], + [-180.0, -90.0], + ] + ], + }, + "proj:projjson": { + "id": {"code": 4326, "authority": "EPSG"}, + "name": "WGS 84", + "type": "GeographicCRS", + "datum": { + "name": "World Geodetic System 1984", + "type": "GeodeticReferenceFrame", + "ellipsoid": { + "name": "WGS 84", + "semi_major_axis": 6378137, + "inverse_flattening": 298.257223563, + }, + }, + "$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", + "coordinate_system": { + "axis": [ + { + "name": "Geodetic latitude", + "unit": "degree", + "direction": "north", + "abbreviation": "Lat", + }, + { + "name": "Geodetic longitude", + "unit": "degree", + "direction": "east", + "abbreviation": "Lon", + }, + ], + "subtype": "ellipsoidal", + }, + }, + "proj:transform": [0.1, 0.0, -180.0, 0.0, -0.1, 90.0, 0.0, 0.0, 1.0], + }, + "rendered_preview": { + "title": "Rendered preview", + "href": "https://dev.openveda.cloud/api/raster/stac/preview.png?collection=CMIP245-winter-median-pr&item=OMI_trno2_0.10x0.10_2023_Col3_V4&assets=cog_default&rescale=0%2C3000000000000000&colormap_name=reds", + "rel": "preview", + "roles": ["overview"], + "type": "image/png", + }, + }, + "geometry": { + "type": "Polygon", + "coordinates": [[[-180, -90], [180, -90], [180, 90], [-180, 90], [-180, -90]]], + }, + "collection": "CMIP245-winter-median-pr", + "properties": { + "end_datetime": "2023-12-31T00:00:00+00:00", + "start_datetime": "2023-01-01T00:00:00+00:00", + "datetime": None, + }, + "stac_version": "1.0.0", + "stac_extensions": [ + "https://stac-extensions.github.io/raster/v1.1.0/schema.json", + "https://stac-extensions.github.io/projection/v1.1.0/schema.json", + ], +} + + +@pytest.fixture +def test_environ(): + """ + Set up the test environment with mocked AWS and PostgreSQL credentials. + + This fixture sets environment variables to mock AWS credentials and + PostgreSQL database configuration for testing purposes. + """ + # Mocked AWS Credentials for moto (best practice recommendation from moto) + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_REGION"] = "us-west-2" + os.environ["VEDA_STAC_USERPOOL_ID"] = "us-west-2_FAKEUSERPOOL" + os.environ["VEDA_STAC_CLIENT_ID"] = "Xdjkfghadsfkdsadfjas" + os.environ["VEDA_STAC_CLIENT_SECRET"] = "dsakfjdsalfkjadslfjalksfj" + os.environ[ + "VEDA_STAC_COGNITO_DOMAIN" + ] = "https://fake.auth.us-west-2.amazoncognito.com" + os.environ["VEDA_STAC_ENABLE_TRANSACTIONS"] = "TRUE" + + # Config mocks + os.environ["POSTGRES_USER"] = "username" + os.environ["POSTGRES_PASS"] = "password" + os.environ["POSTGRES_DBNAME"] = "postgis" + os.environ["POSTGRES_HOST_READER"] = "database" + os.environ["POSTGRES_HOST_WRITER"] = "database" + os.environ["POSTGRES_PORT"] = "5432" + + +def override_validated_token(): + """ + Mock function to override validated token dependency. + + Returns: + str: A fake token to bypass authorization in tests. + """ + return "fake_token" + + +@pytest.fixture +def app(test_environ): + """ + Fixture to initialize the FastAPI application. + + This fixture imports and returns the FastAPI application instance + for testing purposes. + + Args: + test_environ: A fixture setting up the test environment. + + Returns: + FastAPI: The FastAPI application instance. + """ + from src.app import app + + return app + + +@pytest.fixture +def api_client(app): + """ + Fixture to initialize the API client for making requests. + + This fixture creates a TestClient instance for interacting with the + FastAPI application, and sets up dependency overrides for testing. + + Args: + app: A fixture providing the FastAPI application instance. + + Yields: + TestClient: The TestClient instance for API testing. + """ + from src.app import auth + + app.dependency_overrides[auth.validated_token] = override_validated_token + yield TestClient(app) + app.dependency_overrides.clear() + + +@pytest.fixture +def valid_stac_collection(): + """ + Fixture providing a valid STAC collection for testing. + + Returns: + dict: A valid STAC collection. + """ + return VALID_COLLECTION + + +@pytest.fixture +def invalid_stac_collection(): + """ + Fixture providing an invalid STAC collection for testing. + + Returns: + dict: An invalid STAC collection with the 'extent' field removed. + """ + invalid = VALID_COLLECTION.copy() + invalid.pop("extent") + return invalid + + +@pytest.fixture +def valid_stac_item(): + """ + Fixture providing a valid STAC item for testing. + + Returns: + dict: A valid STAC item. + """ + return VALID_ITEM + + +@pytest.fixture +def invalid_stac_item(): + """ + Fixture providing an invalid STAC item for testing. + + Returns: + dict: An invalid STAC item with the 'properties' field removed. + """ + invalid_item = VALID_ITEM.copy() + invalid_item.pop("properties") + return invalid_item diff --git a/stac_api/runtime/tests/test_transactions.py b/stac_api/runtime/tests/test_transactions.py new file mode 100644 index 00000000..6a5cd8e3 --- /dev/null +++ b/stac_api/runtime/tests/test_transactions.py @@ -0,0 +1,137 @@ +""" +Test suite for STAC (SpatioTemporal Asset Catalog) Transactions API endpoints. + +This module contains tests for the collection and item endpoints of the STAC API. +It verifies the behavior of the API when posting valid and invalid STAC collections and items, +as well as bulk items. + +Endpoints tested: +- /collections +- /collections/{}/items +- /collections/{}/bulk_items +""" + +import pytest + +collections_endpoint = "/collections" +items_endpoint = "/collections/{}/items" +bulk_endpoint = "/collections/{}/bulk_items" + + +class TestList: + """ + Test cases for STAC API's collection and item endpoints. + + This class contains tests to ensure that the STAC API correctly handles + posting valid and invalid STAC collections and items, both individually + and in bulk. It uses pytest fixtures to set up the test environment with + necessary data. + """ + + @pytest.fixture(autouse=True) + def setup( + self, + api_client, + valid_stac_collection, + valid_stac_item, + invalid_stac_collection, + invalid_stac_item, + ): + """ + Set up the test environment with the required fixtures. + + Args: + api_client: The API client for making requests. + valid_stac_collection: A valid STAC collection for testing. + valid_stac_item: A valid STAC item for testing. + invalid_stac_collection: An invalid STAC collection for testing. + invalid_stac_item: An invalid STAC item for testing. + """ + self.api_client = api_client + self.valid_stac_collection = valid_stac_collection + self.valid_stac_item = valid_stac_item + self.invalid_stac_collection = invalid_stac_collection + self.invalid_stac_item = invalid_stac_item + + def test_post_invalid_collection(self): + """ + Test the API's response to posting an invalid STAC collection. + + Asserts that the response status code is 422 and the detail + is "Validation Error". + """ + response = self.api_client.post( + collections_endpoint, json=self.invalid_stac_collection + ) + assert response.json()["detail"] == "Validation Error" + assert response.status_code == 422 + + def test_post_valid_collection(self): + """ + Test the API's response to posting a valid STAC collection. + + Asserts that the response status code is 200. + """ + response = self.api_client.post( + collections_endpoint, json=self.valid_stac_collection + ) + # assert response.json() == {} + assert response.status_code == 200 + + def test_post_invalid_item(self): + """ + Test the API's response to posting an invalid STAC item. + + Asserts that the response status code is 422 and the detail + is "Validation Error". + """ + response = self.api_client.post( + items_endpoint.format(self.invalid_stac_item["collection"]), + json=self.invalid_stac_item, + ) + assert response.json()["detail"] == "Validation Error" + assert response.status_code == 422 + + def test_post_valid_item(self): + """ + Test the API's response to posting a valid STAC item. + + Asserts that the response status code is 200. + """ + response = self.api_client.post( + items_endpoint.format(self.valid_stac_item["collection"]), + json=self.valid_stac_item, + ) + # assert response.json() == {} + assert response.status_code == 200 + + def test_post_invalid_bulk_items(self): + """ + Test the API's response to posting invalid bulk STAC items. + + Asserts that the response status code is 422. + """ + item_id = self.invalid_stac_item["id"] + collection_id = self.invalid_stac_item["collection"] + invalid_request = { + "items": {item_id: self.invalid_stac_item}, + "method": "upsert", + } + response = self.api_client.post( + bulk_endpoint.format(collection_id), json=invalid_request + ) + assert response.status_code == 422 + + def test_post_valid_bulk_items(self): + """ + Test the API's response to posting valid bulk STAC items. + + Asserts that the response status code is 200. + """ + item_id = self.valid_stac_item["id"] + collection_id = self.valid_stac_item["collection"] + valid_request = {"items": {item_id: self.valid_stac_item}, "method": "upsert"} + response = self.api_client.post( + bulk_endpoint.format(collection_id), json=valid_request + ) + assert response.status_code == 200