Skip to content

Commit

Permalink
Mv from dataclasses to higher order functions
Browse files Browse the repository at this point in the history
  • Loading branch information
alukach committed Dec 13, 2024
1 parent d4757da commit a34c370
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 47 deletions.
77 changes: 31 additions & 46 deletions src/stac_auth_proxy/filters/template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Generate CQL2 filter expressions via Jinja2 templating."""

from dataclasses import dataclass, field
from typing import Any, Callable
from typing import Any, Annotated, Callable

from cql2 import Expr
from fastapi import Request, Security
Expand All @@ -10,48 +9,34 @@
from ..utils import extract_variables


@dataclass
class Template:
def Template(template_str: str, token_dependency: Callable[..., Any]):
"""Generate CQL2 filter expressions via Jinja2 templating."""

template_str: str
token_dependency: Callable[..., Any]

# Generated attributes
env: Environment = field(init=False)
dependency: Callable[[Request, Security], Expr] = field(init=False)

def __post_init__(self):
"""Initialize the Jinja2 environment."""
self.env = Environment(loader=BaseLoader).from_string(self.template_str)
self.dependency = self.build()

def build(self):
"""Generate a dependency for rendering a CQL2 filter expression."""

async def dependency(
request: Request, auth_token=Security(self.token_dependency)
) -> Expr:
"""Render a CQL2 filter expression with the request and auth token."""
# TODO: How to handle the case where auth_token is null?
context = {
"req": {
"path": request.url.path,
"method": request.method,
"query_params": dict(request.query_params),
"path_params": extract_variables(request.url.path),
"headers": dict(request.headers),
"body": (
await request.json()
if request.headers.get("content-type") == "application/json"
else (await request.body()).decode()
),
},
"token": auth_token,
}
cql2_str = self.env.render(**context)
cql2_expr = Expr(cql2_str)
cql2_expr.validate()
return cql2_expr

return dependency
env = Environment(loader=BaseLoader).from_string(template_str)

async def dependency(
request: Request,
auth_token=Annotated[dict[str, Any], Security(token_dependency)],
) -> Expr:
"""Render a CQL2 filter expression with the request and auth token."""
# TODO: How to handle the case where auth_token is null?
context = {
"req": {
"path": request.url.path,
"method": request.method,
"query_params": dict(request.query_params),
"path_params": extract_variables(request.url.path),
"headers": dict(request.headers),
"body": (
await request.json()
if request.headers.get("content-type") == "application/json"
else (await request.body()).decode()
),
},
"token": auth_token,
}
cql2_str = env.render(**context)
cql2_expr = Expr(cql2_str)
cql2_expr.validate()
return cql2_expr

return dependency
3 changes: 2 additions & 1 deletion src/stac_auth_proxy/handlers/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __post_init__(self):
for endpoint in [self.proxy_request, self.stream]:
endpoint.__annotations__["collections_filter"] = Annotated[
Optional[Expr],
Depends(getattr(self.collections_filter, "dependency", lambda: None)),
Depends(self.collections_filter or (lambda: None)),
]

async def proxy_request(
Expand All @@ -55,6 +55,7 @@ async def proxy_request(
path = request.url.path
query = request.url.query

# Appliy filters
if utils.is_collection_endpoint(path) and collections_filter:
if request.method == "GET" and path == "/collections":
query = utils.insert_filter(qs=query, filter=collections_filter)
Expand Down

0 comments on commit a34c370

Please sign in to comment.