Skip to content

Commit

Permalink
More fancy CEL test
Browse files Browse the repository at this point in the history
  • Loading branch information
alukach committed Dec 7, 2024
1 parent cff2a24 commit 32382da
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
13 changes: 13 additions & 0 deletions src/stac_auth_proxy/guards/cel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Any
import re
from urllib.parse import urlparse

from fastapi import Request, Depends, HTTPException
import celpy
Expand All @@ -25,6 +27,7 @@ async def check(
"path": request.url.path,
"method": request.method,
"query_params": dict(request.query_params),
"path_params": extract_variables(request.url.path),
"headers": dict(request.headers),
# Body may need to be read (await request.json()) or (await request.body()) if needed
"body": (
Expand All @@ -34,6 +37,8 @@ async def check(
),
}

print(f"{request_data['path_params']=}")

result = self.program.evaluate(
celpy.json_to_cel(
{
Expand All @@ -48,3 +53,11 @@ async def check(
)

self.check = check


def extract_variables(url: str) -> dict:
path = urlparse(url).path
# This allows either /items or /bulk_items, with an optional item_id following.
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
match = re.match(pattern, path)
return {k: v for k, v in match.groupdict().items() if v} if match else {}
59 changes: 39 additions & 20 deletions tests/test_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,15 @@
)


import pytest
from unittest.mock import patch, MagicMock


# Fixture to patch OpenIdConnectAuth and mock valid_token_dependency
@pytest.fixture
def skip_auth():
with patch("eoapi.auth_utils.OpenIdConnectAuth") as MockClass:
# Create a mock instance
mock_instance = MagicMock()
# Set the return value of `valid_token_dependency`
mock_instance.valid_token_dependency.return_value = "constant"
# Assign the mock instance to the patched class's return value
MockClass.return_value = mock_instance

# Yield the mock instance for use in tests
yield mock_instance


@pytest.mark.parametrize(
"endpoint, expected_status_code",
[
("/", 403),
("/?foo=xyz", 403),
("/?bar=foo", 403),
("/?foo=bar", 200),
("/?foo=xyz&foo=bar", 200), # Only the last value is checked
("/?foo=bar&foo=xyz", 403), # Only the last value is checked
],
)
def test_guard_query_params(
Expand All @@ -43,7 +27,6 @@ def test_guard_query_params(
endpoint,
expected_status_code,
):
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
app = app_factory(
upstream_url=source_api_server,
guard={
Expand All @@ -56,3 +39,39 @@ def test_guard_query_params(
client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"})
response = client.get(endpoint)
assert response.status_code == expected_status_code


@pytest.mark.parametrize(
"token, expected_status_code",
[
({"foo": "bar"}, 403),
({"collections": []}, 403),
({"collections": ["foo", "bar"]}, 403),
({"collections": ["xyz"]}, 200),
({"collections": ["foo", "xyz"]}, 200),
],
)
def test_guard_auth_token(
source_api_server,
token_builder,
token,
expected_status_code,
):
app = app_factory(
upstream_url=source_api_server,
guard={
"cls": "stac_auth_proxy.guards.cel.Cel",
"kwargs": {
"expression": """
("collections" in token)
&& ("collection_id" in req.path_params)
&& (req.path_params.collection_id in token.collections)
"""
},
},
)
client = TestClient(
app, headers={"Authorization": f"Bearer {token_builder(token)}"}
)
response = client.get("/collections/xyz")
assert response.status_code == expected_status_code

0 comments on commit 32382da

Please sign in to comment.