From 32382da2a06489d0c1c175d585e283bffbc52fbe Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Fri, 6 Dec 2024 23:00:00 -0800 Subject: [PATCH] More fancy CEL test --- src/stac_auth_proxy/guards/cel.py | 13 +++++++ tests/test_guard.py | 59 ++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/stac_auth_proxy/guards/cel.py b/src/stac_auth_proxy/guards/cel.py index 39d94d7..ebce1e5 100644 --- a/src/stac_auth_proxy/guards/cel.py +++ b/src/stac_auth_proxy/guards/cel.py @@ -1,5 +1,7 @@ from dataclasses import dataclass from typing import Any +import re +from urllib.parse import urlparse from fastapi import Request, Depends, HTTPException import celpy @@ -25,6 +27,7 @@ async def check( "path": request.url.path, "method": request.method, "query_params": dict(request.query_params), + "path_params": extract_variables(request.url.path), "headers": dict(request.headers), # Body may need to be read (await request.json()) or (await request.body()) if needed "body": ( @@ -34,6 +37,8 @@ async def check( ), } + print(f"{request_data['path_params']=}") + result = self.program.evaluate( celpy.json_to_cel( { @@ -48,3 +53,11 @@ async def check( ) self.check = check + + +def extract_variables(url: str) -> dict: + path = urlparse(url).path + # This allows either /items or /bulk_items, with an optional item_id following. + pattern = r"^/collections/(?P[^/]+)(?:/(?:items|bulk_items)(?:/(?P[^/]+))?)?/?$" + match = re.match(pattern, path) + return {k: v for k, v in match.groupdict().items() if v} if match else {} diff --git a/tests/test_guard.py b/tests/test_guard.py index e590061..c1f678c 100644 --- a/tests/test_guard.py +++ b/tests/test_guard.py @@ -10,31 +10,15 @@ ) -import pytest -from unittest.mock import patch, MagicMock - - -# Fixture to patch OpenIdConnectAuth and mock valid_token_dependency -@pytest.fixture -def skip_auth(): - with patch("eoapi.auth_utils.OpenIdConnectAuth") as MockClass: - # Create a mock instance - mock_instance = MagicMock() - # Set the return value of `valid_token_dependency` - mock_instance.valid_token_dependency.return_value = "constant" - # Assign the mock instance to the patched class's return value - MockClass.return_value = mock_instance - - # Yield the mock instance for use in tests - yield mock_instance - - @pytest.mark.parametrize( "endpoint, expected_status_code", [ ("/", 403), ("/?foo=xyz", 403), + ("/?bar=foo", 403), ("/?foo=bar", 200), + ("/?foo=xyz&foo=bar", 200), # Only the last value is checked + ("/?foo=bar&foo=xyz", 403), # Only the last value is checked ], ) def test_guard_query_params( @@ -43,7 +27,6 @@ def test_guard_query_params( endpoint, expected_status_code, ): - """When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered.""" app = app_factory( upstream_url=source_api_server, guard={ @@ -56,3 +39,39 @@ def test_guard_query_params( client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"}) response = client.get(endpoint) assert response.status_code == expected_status_code + + +@pytest.mark.parametrize( + "token, expected_status_code", + [ + ({"foo": "bar"}, 403), + ({"collections": []}, 403), + ({"collections": ["foo", "bar"]}, 403), + ({"collections": ["xyz"]}, 200), + ({"collections": ["foo", "xyz"]}, 200), + ], +) +def test_guard_auth_token( + source_api_server, + token_builder, + token, + expected_status_code, +): + app = app_factory( + upstream_url=source_api_server, + guard={ + "cls": "stac_auth_proxy.guards.cel.Cel", + "kwargs": { + "expression": """ + ("collections" in token) + && ("collection_id" in req.path_params) + && (req.path_params.collection_id in token.collections) + """ + }, + }, + ) + client = TestClient( + app, headers={"Authorization": f"Bearer {token_builder(token)}"} + ) + response = client.get("/collections/xyz") + assert response.status_code == expected_status_code