diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index 4e61cf8..8ea16f6 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -48,4 +48,10 @@ class Settings(BaseSettings): public_endpoints: EndpointMethods = {"/api.html": ["GET"], "/api": ["GET"]} openapi_spec_endpoint: Optional[str] = None + collections_filter: Optional[ClassInput] = { + "cls": "stac_auth_proxy.filters.Template", + "args": ["""A_CONTAINEDBY(id, ( '{{ token.collections | join("', '") }}' ))"""], + } + items_filter: Optional[ClassInput] = None + model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_") diff --git a/src/stac_auth_proxy/filters/__init__.py b/src/stac_auth_proxy/filters/__init__.py new file mode 100644 index 0000000..35f216f --- /dev/null +++ b/src/stac_auth_proxy/filters/__init__.py @@ -0,0 +1,3 @@ +from .template import Template + +__all__ = ["Template"] diff --git a/src/stac_auth_proxy/filters/template.py b/src/stac_auth_proxy/filters/template.py new file mode 100644 index 0000000..ffaea42 --- /dev/null +++ b/src/stac_auth_proxy/filters/template.py @@ -0,0 +1,44 @@ +from typing import Any, Callable + +from cql2 import Expr +from jinja2 import Environment, BaseLoader +from fastapi import Request, Security + +from ..utils import extract_variables + +from dataclasses import dataclass, field + + +@dataclass +class Template: + template_str: str + token_dependency: Callable[..., Any] + + # Generated attributes + env: Environment = field(init=False) + + def __post_init__(self): + self.env = Environment(loader=BaseLoader).from_string(self.template_str) + self.render.__annotations__["auth_token"] = Security(self.token_dependency) + + async def cql2(self, request: Request, auth_token=Security(...)) -> Expr: + # TODO: How to handle the case where auth_token is null? + context = { + "req": { + "path": request.url.path, + "method": request.method, + "query_params": dict(request.query_params), + "path_params": extract_variables(request.url.path), + "headers": dict(request.headers), + "body": ( + await request.json() + if request.headers.get("content-type") == "application/json" + else (await request.body()).decode() + ), + }, + "token": auth_token, + } + cql2_str = self.env.render(**context) + cql2_expr = Expr(cql2_str) + cql2_expr.validate() + return cql2_expr