Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion csp_gateway/server/config/gateway/omnibus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,28 @@ modules:
_target_: csp_gateway.MountWebSocketRoutes
mount_api_key_middleware:
_target_: csp_gateway.MountAPIKeyMiddleware
api_key: 12345
enforce_ui: false
enforce_controls: false
mount_api_key_middleware_ui:
_target_: csp_gateway.MountAPIKeyMiddleware
api_key: token
enforce: []
enforce_ui: true
enforce_controls: false
mount_api_key_middleware_controls:
_target_: csp_gateway.MountAPIKeyMiddleware
api_key: 54321
enforce: []
enforce_ui: false
enforce_controls: true

gateway:
_target_: csp_gateway.Gateway
settings:
PORT: ${port}
AUTHENTICATE: ${authenticate}
UI: True
API_KEY: "12345"
modules:
- /modules/example_module
- /modules/example_module_feedback
Expand All @@ -49,6 +63,8 @@ gateway:
- /modules/mount_rest_routes
- /modules/mount_websocket_routes
- /modules/mount_api_key_middleware
- /modules/mount_api_key_middleware_ui
- /modules/mount_api_key_middleware_controls
channels:
_target_: csp_gateway.server.demo.ExampleGatewayChannels

Expand Down
4 changes: 2 additions & 2 deletions csp_gateway/server/demo/config/omnibus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- /gateway: omnibus
- _self_

# csp-gateway-start --config-dir=csp_gateway/server/omnibus +config=omnibus
# csp-gateway-start --config-dir=csp_gateway/server/demo +config=omnibus

authenticate: False
authenticate: True
port: 8000
2 changes: 1 addition & 1 deletion csp_gateway/server/demo/omnibus.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def push_to_perspective( # type: ignore[no-untyped-def]
# be instantiated directly as we do so here:

# Setting authentication
settings = GatewaySettings(API_KEY="12345", AUTHENTICATE=False)
settings = GatewaySettings(AUTHENTICATE=False)

# instantiate gateway
gateway = Gateway(
Expand Down
20 changes: 17 additions & 3 deletions csp_gateway/server/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,26 @@ def start(
log.info("Launching web server on:")
url = f"http://{gethostname()}:{self.settings.PORT}"

if ui:
if self.settings.AUTHENTICATE:
log.info(f"\tUI: {url}?token={self.settings.API_KEY}")
if ui and self.settings.AUTHENTICATE:
from ..middleware import MountAPIKeyMiddleware

# TODO: Will need to handle others
auth = ""

# Find any middleware enforcing auth
for module in self.modules:
if isinstance(module, MountAPIKeyMiddleware) and module.enforce_ui is True:
auth = module.api_key
break

if auth:
log.info(f"\tUI: {url}?{module.api_key_name}={auth}")
else:
log.info(f"\tUI: {url}")

else:
log.info(f"\tUI: {url}")

log.info(f"\tDocs: {url}/docs")
log.info(f"\tDocs: {url}/redoc")

Expand Down
2 changes: 1 addition & 1 deletion csp_gateway/server/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .api_key import MountAPIKeyMiddleware
from .api_key import *
213 changes: 157 additions & 56 deletions csp_gateway/server/middleware/api_key.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from datetime import timedelta
from logging import getLogger
from secrets import token_urlsafe
from typing import List

from fastapi import APIRouter, Depends, HTTPException, Request, Security
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from pydantic import Field, PrivateAttr
from pydantic import Field, PrivateAttr, field_validator
from starlette.status import HTTP_403_FORBIDDEN

from csp_gateway.server import GatewayChannels, GatewayModule

from ..shared import ChannelSelection

# separate to avoid circular
from ..web import GatewayWebApp
from .hacks.api_key_middleware_websocket_fix.api_key import (
Expand All @@ -15,54 +20,154 @@
APIKeyQuery,
)

_log = getLogger(__name__)

class MountAPIKeyMiddleware(GatewayModule):
api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12))
__all__ = (
"MountAuthMiddleware",
"MountAPIKeyMiddleware",
)

# TODO: More eventually


class MountAuthMiddleware(GatewayModule):
enforce: list = Field(default=(), description="Routes to enforce, default empty means 'all'")
channels: ChannelSelection = Field(
default_factory=ChannelSelection,
description="Channels or subroutes to enforce. If route is not present in `enforce`, implies 'allow all'",
)

# NOTE: don't make this publically configureable
# as it is needed in gateway.py
_api_key_name: str = PrivateAttr("token")
_api_key_secret: str = PrivateAttr("")
enforce_controls: bool = Field(default=False, description="Whether to allow access to controls routes. Defaults to True")
enforce_ui: bool = Field(default=True, description="Whether to allow web access to the API Key authentication routes. Defaults to True")

unauthorized_status_message: str = "unauthorized"

_enforced_channels: List[str] = PrivateAttr(default_factory=list)

def connect(self, channels: GatewayChannels) -> None:
# NO-OP
...


class MountAPIKeyMiddleware(MountAuthMiddleware):
api_key: str = Field(default=token_urlsafe(32), description="API Key to use")
api_key_name: str = Field(default="token", description="API Key to use")
api_key_timeout: timedelta = Field(description="Cookie timeout for API Key authentication", default=timedelta(hours=12))

_instance_count = 0

@field_validator("api_key_name", mode="before")
@classmethod
def _validate_api_key_name(cls, value: str) -> str:
if not value:
raise ValueError("API Key name must be a non-empty string")
value = f"{value.strip().lower()}-{cls._instance_count}"
cls._instance_count += 1
return value

def rest(self, app: GatewayWebApp) -> None:
if app.settings.AUTHENTICATE:
# first, pull out the api key secret from the settings
self._api_key_secret = app.settings.API_KEY

# reinitialize header
api_key_query = APIKeyQuery(name=self._api_key_name, auto_error=False)
api_key_header = APIKeyHeader(name=self._api_key_name, auto_error=False)
api_key_cookie = APIKeyCookie(name=self._api_key_name, auto_error=False)

# routers
auth_router: APIRouter = app.get_router("auth")
public_router: APIRouter = app.get_router("public")

# now mount middleware
async def get_api_key(
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
api_key_cookie: str = Security(api_key_cookie),
):
if api_key_query == self._api_key_secret or api_key_header == self._api_key_secret or api_key_cookie == self._api_key_secret:
return self._api_key_secret
else:
# Use configuration to determine allowed routes
# for this API key
self._calculate_auth(app)

# Setup the routes for authentication
self._setup_routes(app)

def _calculate_auth(self, app: GatewayWebApp) -> None:
self._enforced_channels = self.channels.select_from(app.gateway.channels_model)

# Fully form the url
self._api_str = app.settings.API_STR

def _setup_routes(self, app: GatewayWebApp) -> None:
# reinitialize header
api_key_query = APIKeyQuery(name=self.api_key_name, auto_error=False)
api_key_header = APIKeyHeader(name=self.api_key_name, auto_error=False)
api_key_cookie = APIKeyCookie(name=self.api_key_name, auto_error=False)

# routers
auth_router: APIRouter = app.get_router("auth")
public_router: APIRouter = app.get_router("public")

# now mount middleware
async def get_api_key(
request: Request = None,
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
api_key_cookie: str = Security(api_key_cookie),
):
if request is None:
# If request is None, we are not in a request context, return None
_log.warning("API Key check: request is None, returning None")
return None

if hasattr(request.state, "auth"):
# Already authenticated, return the API key
_log.info(f"API Key check: already authenticated, returning {self.api_key_name}")
return request.state.auth

resolved_path = request.url.path.rstrip("/").replace(self._api_str, "").lstrip("/").rsplit("/", 1)

if len(resolved_path) == 1:
root = resolved_path[0]
channel = ""

elif len(resolved_path) > 1:
root = resolved_path[0]
channel = resolved_path[1]

if self.enforce and root not in self.enforce:
# Route not in enforce, allow
_log.info(f"API Key check: {root}/{channel} not in enforced list {self.enforce}, allowing")
return ""

if root == "controls" and not self.enforce_controls:
# Controls route not enforced, allow
_log.info(f"API Key check: root {root} not enforced, allowing")
return ""

# TODO
if root in ("", "auth", "perspective") and not self.enforce_ui:
# UI route not enforced, allow
_log.info(f"API Key check: root {root} not enforced, allowing")
return ""

if root not in ("controls", "auth", "perspective") and channel and channel not in self._enforced_channels:
# Channel not in enforce, allow
_log.info(f"API Key check: channel {root}/{channel} not in enforced channels {self._enforced_channels}, allowing")
return ""

# Else, enforce
if api_key_query == self.api_key or api_key_header == self.api_key or api_key_cookie == self.api_key:
# Return the API key secret to allow access
_log.info(f"API Key check: {self.api_key_name} matched for {root}/{channel}, allowing access")

# NOTE: only set this if we are the one validating, not if we are ignoring
request.state.auth = self.api_key
return self.api_key

_log.warning(f"API Key check: {self.api_key_name} did not match, denying access")
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail=self.unauthorized_status_message,
)

# add auth to all other routes
app.add_middleware(Depends(get_api_key))

if self.enforce_ui:

@auth_router.get("/login")
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail=self.unauthorized_status_message,
)

@auth_router.get("/login")
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
response = RedirectResponse(url="/")
response.set_cookie(
self._api_key_name,
self.api_key_name,
value=api_key,
domain=app.settings.AUTHENTICATION_DOMAIN,
httponly=True,
Expand All @@ -74,44 +179,40 @@ async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
@auth_router.get("/logout")
async def route_logout_and_remove_cookie():
response = RedirectResponse(url="/login")
response.delete_cookie(self._api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN)
response.delete_cookie(self.api_key_name, domain=app.settings.AUTHENTICATION_DOMAIN)
return response

# I'm hand rolling these for now...
@public_router.get("/login", response_class=HTMLResponse, include_in_schema=False)
async def get_login_page(token: str = "", request: Request = None):
if token:
if token != "":
return RedirectResponse(url=f"{app.settings.API_V1_STR}/auth/login?token={token}")
if token and token != "":
return RedirectResponse(url=f"{self._api_str}/auth/login?token={token}")
return app.templates.TemplateResponse(
"login.html.j2",
{"request": request, "api_key_name": self._api_key_name},
{"request": request, "api_key_name": self.api_key_name},
)

@public_router.get("/logout", response_class=HTMLResponse, include_in_schema=False)
async def get_logout_page(request: Request = None):
return app.templates.TemplateResponse("logout.html.j2", {"request": request})

# add auth to all other routes
app.add_middleware(Depends(get_api_key))

@app.app.exception_handler(403)
async def custom_403_handler(request: Request = None, *args):
if "/api" in request.url.path:
# programmatic api access, return json
return JSONResponse(
{
"detail": self.unauthorized_status_message,
"status_code": 403,
},
status_code=403,
)
return app.templates.TemplateResponse(
"login.html.j2",
@app.app.exception_handler(403)
async def custom_403_handler(request: Request = None, *args):
if "/api" in request.url.path:
# programmatic api access, return json
return JSONResponse(
{
"request": request,
"api_key_name": self._api_key_name,
"status_code": 403,
"detail": self.unauthorized_status_message,
"status_code": 403,
},
status_code=403,
)
return app.templates.TemplateResponse(
"login.html.j2",
{
"request": request,
"api_key_name": self.api_key_name,
"status_code": 403,
"detail": self.unauthorized_status_message,
},
)
5 changes: 0 additions & 5 deletions csp_gateway/server/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from secrets import token_urlsafe
from socket import gethostname
from typing import List

Expand Down Expand Up @@ -31,8 +30,4 @@ class Settings(BaseSettings):

UI: bool = Field(False, description="Enables ui in the web application")
AUTHENTICATE: bool = Field(False, description="Whether to authenticate users for access to the web application")
API_KEY: str = Field(
token_urlsafe(32),
description="The API key for access if `AUTHENTICATE=True`. The default is auto-generated, but a user-provided value can be used.",
)
AUTHENTICATION_DOMAIN: str = gethostname()
Loading