1
1
"""Tests for Jinja2 CQL2 filter."""
2
2
3
- from dataclasses import dataclass
4
- from typing import Generator
5
- from unittest .mock import AsyncMock , MagicMock , patch
6
3
from urllib .parse import parse_qs
7
4
8
5
import httpx
9
6
import pytest
10
7
from fastapi .testclient import TestClient
11
8
from utils import AppFactory
12
9
10
+ from tests .utils import single_chunk_async_stream_response
11
+
13
12
app_factory = AppFactory (
14
13
oidc_discovery_url = "https://example-stac-api.com/.well-known/openid-configuration" ,
15
14
default_public = False ,
16
15
)
17
16
18
17
19
- @pytest .fixture
20
- def mock_send () -> Generator [MagicMock , None , None ]:
21
- """Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
22
- with patch (
23
- "stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send" ,
24
- new_callable = AsyncMock ,
25
- ) as mock_send_method :
26
- yield mock_send_method
27
-
28
-
29
- @dataclass
30
- class SingleChunkAsyncStream (httpx .AsyncByteStream ):
31
- """Mock async stream that returns a single chunk of data."""
32
-
33
- body : bytes
34
-
35
- async def __aiter__ (self ):
36
- """Return a single chunk of data."""
37
- yield self .body
38
-
39
-
40
18
def test_collections_filter_contained_by_token (
41
- mock_send , source_api_server , token_builder
19
+ mock_upstream , source_api_server , token_builder
42
20
):
43
21
"""Test that the collections filter is applied correctly."""
44
22
# Mock response from upstream API
45
- mock_send .return_value = httpx .Response (
46
- 200 ,
47
- stream = SingleChunkAsyncStream (b"{}" ),
48
- headers = {"content-type" : "application/json" },
49
- )
23
+ mock_upstream .return_value = single_chunk_async_stream_response (b"{}" )
50
24
51
25
app = app_factory (
52
26
upstream_url = source_api_server ,
@@ -59,15 +33,49 @@ def test_collections_filter_contained_by_token(
59
33
)
60
34
61
35
auth_token = token_builder ({"collections" : ["foo" , "bar" ]})
62
- client = TestClient (
63
- app ,
64
- headers = {"Authorization" : f"Bearer { auth_token } " },
65
- )
66
-
36
+ client = TestClient (app , headers = {"Authorization" : f"Bearer { auth_token } " })
67
37
response = client .get ("/collections" )
38
+
68
39
assert response .status_code == 200
69
- assert mock_send .call_count == 1
70
- [r ] = mock_send .call_args [0 ]
40
+ assert mock_upstream .call_count == 1
41
+ [r ] = mock_upstream .call_args [0 ]
71
42
assert parse_qs (r .url .query .decode ()) == {
72
43
"filter" : ["a_containedby(id, ('foo', 'bar'))" ]
73
44
}
45
+
46
+
47
+ @pytest .mark .parametrize (
48
+ "authenticated, expected_filter" ,
49
+ [
50
+ (True , "true" ),
51
+ (False , "(private = false)" ),
52
+ ],
53
+ )
54
+ def test_collections_filter_private_and_public (
55
+ mock_upstream , source_api_server , token_builder , authenticated , expected_filter
56
+ ):
57
+ """Test that filter can be used for private/public collections."""
58
+ # Mock response from upstream API
59
+ mock_upstream .return_value = single_chunk_async_stream_response (b"{}" )
60
+
61
+ app = app_factory (
62
+ upstream_url = source_api_server ,
63
+ collections_filter = {
64
+ "cls" : "stac_auth_proxy.filters.Template" ,
65
+ "args" : ["{{ '(private = false)' if token is none else true }}" ],
66
+ },
67
+ default_public = True ,
68
+ )
69
+
70
+ client = TestClient (
71
+ app ,
72
+ headers = (
73
+ {"Authorization" : f"Bearer { token_builder ({})} " } if authenticated else {}
74
+ ),
75
+ )
76
+ response = client .get ("/collections" )
77
+
78
+ assert response .status_code == 200
79
+ assert mock_upstream .call_count == 1
80
+ [r ] = mock_upstream .call_args [0 ]
81
+ assert parse_qs (r .url .query .decode ()) == {"filter" : [expected_filter ]}
0 commit comments