Skip to content

Commit

Permalink
read postgres settings from PGSTAC_SECRET_ARN
Browse files Browse the repository at this point in the history
  • Loading branch information
hrodmn committed Feb 4, 2025
1 parent 8c16d16 commit 4607867
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 70 deletions.
10 changes: 5 additions & 5 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ services:
- POSTGRES_PORT=5432
- DB_MIN_CONN_SIZE=1
- DB_MAX_CONN_SIZE=1
# - EOAPI_STAC_TITILER_ENDPOINT=raster
- EOAPI_STAC_TITILER_ENDPOINT=http://127.0.0.1:8082
# - TITILER_ENDPOINT=raster
- TITILER_ENDPOINT=http://127.0.0.1:8082
# PgSTAC extensions
# - EOAPI_STAC_EXTENSIONS=["filter", "query", "sort", "fields", "pagination", "titiler", "transaction"] # defaults
# - EOAPI_STAC_CORS_METHODS='GET,POST,PUT,OPTIONS'
# - EXTENSIONS=["filter", "query", "sort", "fields", "pagination", "titiler", "transaction"] # defaults
# - CORS_METHODS='GET,POST,PUT,OPTIONS'
env_file:
- path: .env
required: false
Expand Down Expand Up @@ -123,7 +123,7 @@ services:
- POSTGRES_PORT=5432
- DB_MIN_CONN_SIZE=1
- DB_MAX_CONN_SIZE=10
- EOAPI_VECTOR_DEBUG=TRUE
- DEBUG=TRUE
env_file:
- path: .env
required: false
Expand Down
58 changes: 5 additions & 53 deletions infrastructure/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,8 @@ def __init__(
self,
"raster-api",
api_env={
"EOAPI_RASTER_NAME": app_config.build_service_name("raster"),
"NAME": app_config.build_service_name("raster"),
"description": f"{app_config.stage} Raster API",
"POSTGRES_HOST": pgstac_db.pgstac_secret.secret_value_from_json(
"host"
).to_string(),
"POSTGRES_DBNAME": pgstac_db.pgstac_secret.secret_value_from_json(
"dbname"
).to_string(),
"POSTGRES_USER": pgstac_db.pgstac_secret.secret_value_from_json(
"username"
).to_string(),
"POSTGRES_PASS": pgstac_db.pgstac_secret.secret_value_from_json(
"password"
).to_string(),
"POSTGRES_PORT": pgstac_db.pgstac_secret.secret_value_from_json(
"port"
).to_string(),
},
db=pgstac_db.connection_target,
db_secret=pgstac_db.pgstac_secret,
Expand Down Expand Up @@ -199,28 +184,10 @@ def __init__(
self,
"stac-api",
api_env={
"EOAPI_STAC_NAME": app_config.build_service_name("stac"),
"NAME": app_config.build_service_name("stac"),
"description": f"{app_config.stage} STAC API",
"POSTGRES_HOST_READER": pgstac_db.pgstac_secret.secret_value_from_json(
"host"
).to_string(),
"POSTGRES_HOST_WRITER": pgstac_db.pgstac_secret.secret_value_from_json(
"host"
).to_string(),
"POSTGRES_DBNAME": pgstac_db.pgstac_secret.secret_value_from_json(
"dbname"
).to_string(),
"POSTGRES_USER": pgstac_db.pgstac_secret.secret_value_from_json(
"username"
).to_string(),
"POSTGRES_PASS": pgstac_db.pgstac_secret.secret_value_from_json(
"password"
).to_string(),
"POSTGRES_PORT": pgstac_db.pgstac_secret.secret_value_from_json(
"port"
).to_string(),
"EOAPI_STAC_TITILER_ENDPOINT": raster.url.strip("/"),
"EOAPI_STAC_EXTENSIONS": '["filter", "query", "sort", "fields", "pagination", "titiler"]',
"TITILER_ENDPOINT": raster.url.strip("/"),
"EXTENSIONS": '["filter", "query", "sort", "fields", "pagination", "titiler"]',
},
db=pgstac_db.connection_target,
db_secret=pgstac_db.pgstac_secret,
Expand Down Expand Up @@ -268,23 +235,8 @@ def __init__(
db=pgstac_db.connection_target,
db_secret=pgstac_db.pgstac_secret,
api_env={
"EOAPI_VECTOR_NAME": app_config.build_service_name("vector"),
"NAME": app_config.build_service_name("vector"),
"description": f"{app_config.stage} tipg API",
"POSTGRES_HOST": pgstac_db.pgstac_secret.secret_value_from_json(
"host"
).to_string(),
"POSTGRES_DBNAME": pgstac_db.pgstac_secret.secret_value_from_json(
"dbname"
).to_string(),
"POSTGRES_USER": pgstac_db.pgstac_secret.secret_value_from_json(
"username"
).to_string(),
"POSTGRES_PASS": pgstac_db.pgstac_secret.secret_value_from_json(
"password"
).to_string(),
"POSTGRES_PORT": pgstac_db.pgstac_secret.secret_value_from_json(
"port"
).to_string(),
},
# If the db is not in the public subnet then we need to put
# the lambda within the VPC
Expand Down
8 changes: 7 additions & 1 deletion infrastructure/handlers/raster_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@
import os

from eoapi.raster.app import app
from eoapi.raster.config import ApiSettings
from mangum import Mangum
from titiler.pgstac.db import connect_to_db

logging.getLogger("mangum.lifespan").setLevel(logging.ERROR)
logging.getLogger("mangum.http").setLevel(logging.ERROR)

settings = ApiSettings()


@app.on_event("startup")
async def startup_event() -> None:
"""Connect to database on startup."""
await connect_to_db(app)
await connect_to_db(
app,
settings=settings.load_postgres_settings(),
)


handler = Mangum(app, lifespan="off")
Expand Down
6 changes: 3 additions & 3 deletions infrastructure/handlers/vector_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import os

from eoapi.vector.app import app
from eoapi.vector.config import ApiSettings
from mangum import Mangum
from tipg.collections import register_collection_catalog
from tipg.database import connect_to_db
from tipg.settings import PostgresSettings

logging.getLogger("mangum.lifespan").setLevel(logging.ERROR)
logging.getLogger("mangum.http").setLevel(logging.ERROR)

postgres_settings = PostgresSettings()
settings = ApiSettings()

try:
from importlib.resources import files as resources_files # type: ignore
Expand All @@ -31,7 +31,7 @@ async def startup_event() -> None:
"""Connect to database on startup."""
await connect_to_db(
app,
settings=postgres_settings,
settings=settings.load_postgres_settings(),
# We enable both pgstac and public schemas (pgstac will be used by custom functions)
schemas=["pgstac", "public"],
user_sql_files=sql_files,
Expand Down
2 changes: 1 addition & 1 deletion runtimes/eoapi/raster/eoapi/raster/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
async def lifespan(app: FastAPI):
"""FastAPI Lifespan."""
logger.debug("Connecting to db...")
await connect_to_db(app)
await connect_to_db(app, settings=settings.load_postgres_settings())
logger.debug("Connected to db.")

yield
Expand Down
48 changes: 47 additions & 1 deletion runtimes/eoapi/raster/eoapi/raster/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
"""API settings."""

import base64
import json
from typing import Optional

import boto3
from pydantic import field_validator
from pydantic_settings import BaseSettings
from titiler.pgstac.settings import PostgresSettings


def get_secret_dict(secret_name: str):
"""Retrieve secrets from AWS Secrets Manager
Args:
secret_name (str): name of aws secrets manager secret containing database connection secrets
profile_name (str, optional): optional name of aws profile for use in debugger only
Returns:
secrets (dict): decrypted secrets in dict
"""

# Create a Secrets Manager client
session = boto3.session.Session()
client = session.client(service_name="secretsmanager")

get_secret_value_response = client.get_secret_value(SecretId=secret_name)

if "SecretString" in get_secret_value_response:
return json.loads(get_secret_value_response["SecretString"])
else:
return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"]))


class ApiSettings(BaseSettings):
Expand All @@ -14,8 +43,9 @@ class ApiSettings(BaseSettings):
debug: bool = False
root_path: str = ""

pgstac_secret_arn: Optional[str] = None

model_config = {
"env_prefix": "EOAPI_RASTER_",
"env_file": ".env",
"extra": "allow",
}
Expand All @@ -29,3 +59,19 @@ def parse_cors_origin(cls, v):
def parse_cors_methods(cls, v):
"""Parse CORS methods."""
return [method.strip() for method in v.split(",")]

def load_postgres_settings(self) -> "PostgresSettings":
"""Load postgres connection params from AWS secret"""

if self.pgstac_secret_arn:
secret = get_secret_dict(self.pgstac_secret_arn)

return PostgresSettings(
postgres_host=secret["host"],
postgres_dbname=secret["dbname"],
postgres_user=secret["username"],
postgres_pass=secret["password"],
postgres_port=secret["port"],
)
else:
return PostgresSettings()
1 change: 1 addition & 0 deletions runtimes/eoapi/raster/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"titiler.extensions",
"starlette-cramjam>=0.4,<0.5",
"importlib_resources>=1.1.0;python_version<'3.9'",
"boto3",
"eoapi.auth-utils>=0.2.0",
]

Expand Down
3 changes: 1 addition & 2 deletions runtimes/eoapi/stac/eoapi/stac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
TransactionExtension,
)
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.core import CoreCrudClient
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.extensions import QueryExtension
Expand Down Expand Up @@ -51,7 +50,7 @@

api_settings = ApiSettings()
auth_settings = OpenIdConnectSettings()
settings = Settings(enable_response_models=True)
settings = api_settings.load_postgres_settings()

# Logs
init_logging(debug=api_settings.debug)
Expand Down
46 changes: 45 additions & 1 deletion runtimes/eoapi/stac/eoapi/stac/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
"""API settings."""

import base64
import json
from typing import List, Optional

import boto3
from pydantic import field_validator
from pydantic_settings import BaseSettings
from stac_fastapi.pgstac.config import Settings


def get_secret_dict(secret_name: str):
"""Retrieve secrets from AWS Secrets Manager
Args:
secret_name (str): name of aws secrets manager secret containing database connection secrets
profile_name (str, optional): optional name of aws profile for use in debugger only
Returns:
secrets (dict): decrypted secrets in dict
"""

# Create a Secrets Manager client
session = boto3.session.Session()
client = session.client(service_name="secretsmanager")

get_secret_value_response = client.get_secret_value(SecretId=secret_name)

if "SecretString" in get_secret_value_response:
return json.loads(get_secret_value_response["SecretString"])
else:
return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"]))


class ApiSettings(BaseSettings):
Expand All @@ -15,6 +42,7 @@ class ApiSettings(BaseSettings):
cachecontrol: str = "public, max-age=3600"
debug: bool = False

pgstac_secret_arn: Optional[str] = None
titiler_endpoint: Optional[str] = None

extensions: List[str] = [
Expand All @@ -37,8 +65,24 @@ def parse_cors_methods(cls, v):
"""Parse CORS methods."""
return [method.strip() for method in v.split(",")]

def load_postgres_settings(self) -> "Settings":
"""Load postgres connection params from AWS secret"""

if self.pgstac_secret_arn:
secret = get_secret_dict(self.pgstac_secret_arn)

return Settings(
postgres_host_reader=secret["host"],
postgres_host_writer=secret["host"],
postgres_dbname=secret["dbname"],
postgres_user=secret["username"],
postgres_pass=secret["password"],
postgres_port=secret["port"],
)
else:
return Settings()

model_config = {
"env_prefix": "EOAPI_STAC_",
"env_file": ".env",
"extra": "allow",
}
1 change: 1 addition & 0 deletions runtimes/eoapi/stac/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"starlette-cramjam>=0.4,<0.5",
"psycopg_pool",
"eoapi.auth-utils>=0.2.0",
"boto3"
]

[project.optional-dependencies]
Expand Down
3 changes: 1 addition & 2 deletions runtimes/eoapi/vector/eoapi/vector/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from tipg.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from tipg.factory import Endpoints as TiPgEndpoints
from tipg.middleware import CacheControlMiddleware, CatalogUpdateMiddleware
from tipg.settings import PostgresSettings

from . import __version__ as eoapi_vector_version
from .config import ApiSettings
Expand All @@ -24,7 +23,7 @@
CUSTOM_SQL_DIRECTORY = resources_files(__package__) / "sql"

settings = ApiSettings()
postgres_settings = PostgresSettings()
postgres_settings = settings.load_postgres_settings()
auth_settings = OpenIdConnectSettings()

# Logs
Expand Down
Loading

0 comments on commit 4607867

Please sign in to comment.