Skip to content

Commit

Permalink
Add first pass of using CEL for custom auth policies
Browse files Browse the repository at this point in the history
  • Loading branch information
alukach committed Dec 7, 2024
1 parent 50b99e6 commit 4a6765a
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ classifiers = [
dependencies = [
"authlib>=1.3.2",
"brotli>=1.1.0",
"cel-python>=0.1.5",
"eoapi-auth-utils>=0.4.0",
"fastapi>=0.115.5",
"httpx>=0.28.0",
Expand Down
8 changes: 8 additions & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
authentication, authorization, and proxying of requests to some internal STAC API.
"""

import logging
from typing import Optional

from eoapi.auth_utils import OpenIdConnectAuth
Expand All @@ -14,6 +15,8 @@
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
from .middleware import AddProcessTimeHeaderMiddleware

logger = logging.getLogger(__name__)


def create_app(settings: Optional[Settings] = None) -> FastAPI:
"""FastAPI Application Factory."""
Expand All @@ -26,6 +29,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
openid_configuration_url=str(settings.oidc_discovery_url)
).valid_token_dependency

if settings.guard:
logger.info("Wrapping auth scheme")
auth_scheme = settings.guard(auth_scheme).check
print(f"{auth_scheme=}")

proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
openapi_handler = OpenApiSpecHandler(
proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url)
Expand Down
16 changes: 16 additions & 0 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
"""Configuration for the STAC Auth Proxy."""

import importlib
from typing import Optional, TypeAlias

from pydantic import BaseModel
from pydantic.networks import HttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

EndpointMethods: TypeAlias = dict[str, list[str]]


class ClassInput(BaseModel):
cls: str
kwargs: Optional[dict[str, str]] = {}

def __call__(self, token_dependency):
"""Dynamically load a class and instantiate it with kwargs."""
module_path, class_name = self.cls.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return cls(**self.kwargs, token_dependency=token_dependency)


class Settings(BaseSettings):
"""Configuration settings for the STAC Auth Proxy."""

Expand All @@ -30,3 +44,5 @@ class Settings(BaseSettings):
openapi_spec_endpoint: Optional[str] = None

model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")

guard: Optional[ClassInput] = None
Empty file.
46 changes: 46 additions & 0 deletions src/stac_auth_proxy/guards/cel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Any

from fastapi import Request, Depends, HTTPException
import celpy


@dataclass
class Cel:
"""Custom middleware."""

expression: str
token_dependency: Any

def __post_init__(self):
env = celpy.Environment()
ast = env.compile(self.expression)
self.program = env.program(ast)

async def check(
request: Request,
auth_token=Depends(self.token_dependency),
):
request_data = {
"path": request.url.path,
"method": request.method,
"query_params": dict(request.query_params), # Convert to a dict
"headers": dict(request.headers), # Convert headers to a dict if needed
# Body may need to be read (await request.json()) or (await request.body()) if needed
"body": (
await request.json()
if request.headers.get("content-type") == "application/json"
else (await request.body()).decode()
),
}

activation = {"req": request_data, "token": auth_token}
print(f"{activation=}")
result = self.program.evaluate(celpy.json_to_cel(activation))
print(f"{result=}")
if not result:
raise HTTPException(
status_code=403, detail="Forbidden (failed CEL check)"
)

self.check = check
58 changes: 58 additions & 0 deletions tests/test_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Tests for OpenAPI spec handling."""

import pytest
from fastapi.testclient import TestClient
from utils import AppFactory

app_factory = AppFactory(
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
default_public=False,
)


import pytest
from unittest.mock import patch, MagicMock


# Fixture to patch OpenIdConnectAuth and mock valid_token_dependency
@pytest.fixture
def skip_auth():
with patch("eoapi.auth_utils.OpenIdConnectAuth") as MockClass:
# Create a mock instance
mock_instance = MagicMock()
# Set the return value of `valid_token_dependency`
mock_instance.valid_token_dependency.return_value = "constant"
# Assign the mock instance to the patched class's return value
MockClass.return_value = mock_instance

# Yield the mock instance for use in tests
yield mock_instance


@pytest.mark.parametrize(
"endpoint, expected_status_code",
[
("/", 403),
("/?foo=xyz", 403),
("/?foo=bar", 200),
],
)
def test_guard_query_params(
source_api_server,
token_builder,
endpoint,
expected_status_code,
):
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
app = app_factory(
upstream_url=source_api_server,
guard={
"cls": "stac_auth_proxy.guards.cel.Cel",
"kwargs": {
"expression": '("foo" in req.query_params) && req.query_params.foo == "bar"'
},
},
)
client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"})
response = client.get(endpoint)
assert response.status_code == expected_status_code
Loading

0 comments on commit 4a6765a

Please sign in to comment.