Skip to content

Commit 67f6b3a

Browse files
committed
Continue test buildout
1 parent 166ca41 commit 67f6b3a

File tree

3 files changed

+82
-39
lines changed

3 files changed

+82
-39
lines changed

tests/conftest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import json
44
import os
55
import threading
6-
from typing import Any
7-
from unittest.mock import MagicMock, patch
6+
from typing import Any, Generator
7+
from unittest.mock import AsyncMock, MagicMock, patch
88

99
import pytest
1010
import uvicorn
@@ -140,3 +140,13 @@ def mock_env():
140140
"""Clear environment variables to avoid poluting configs from runtime env."""
141141
with patch.dict(os.environ, clear=True):
142142
yield
143+
144+
145+
@pytest.fixture
146+
def mock_upstream() -> Generator[MagicMock, None, None]:
147+
"""Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
148+
with patch(
149+
"stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send",
150+
new_callable=AsyncMock,
151+
) as mock_send_method:
152+
yield mock_send_method

tests/test_filters_jinja2.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,26 @@
11
"""Tests for Jinja2 CQL2 filter."""
22

3-
from dataclasses import dataclass
4-
from typing import Generator
5-
from unittest.mock import AsyncMock, MagicMock, patch
63
from urllib.parse import parse_qs
74

85
import httpx
96
import pytest
107
from fastapi.testclient import TestClient
118
from utils import AppFactory
129

10+
from tests.utils import single_chunk_async_stream_response
11+
1312
app_factory = AppFactory(
1413
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration",
1514
default_public=False,
1615
)
1716

1817

19-
@pytest.fixture
20-
def mock_send() -> Generator[MagicMock, None, None]:
21-
"""Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
22-
with patch(
23-
"stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send",
24-
new_callable=AsyncMock,
25-
) as mock_send_method:
26-
yield mock_send_method
27-
28-
29-
@dataclass
30-
class SingleChunkAsyncStream(httpx.AsyncByteStream):
31-
"""Mock async stream that returns a single chunk of data."""
32-
33-
body: bytes
34-
35-
async def __aiter__(self):
36-
"""Return a single chunk of data."""
37-
yield self.body
38-
39-
4018
def test_collections_filter_contained_by_token(
41-
mock_send, source_api_server, token_builder
19+
mock_upstream, source_api_server, token_builder
4220
):
4321
"""Test that the collections filter is applied correctly."""
4422
# Mock response from upstream API
45-
mock_send.return_value = httpx.Response(
46-
200,
47-
stream=SingleChunkAsyncStream(b"{}"),
48-
headers={"content-type": "application/json"},
49-
)
23+
mock_upstream.return_value = single_chunk_async_stream_response(b"{}")
5024

5125
app = app_factory(
5226
upstream_url=source_api_server,
@@ -59,15 +33,49 @@ def test_collections_filter_contained_by_token(
5933
)
6034

6135
auth_token = token_builder({"collections": ["foo", "bar"]})
62-
client = TestClient(
63-
app,
64-
headers={"Authorization": f"Bearer {auth_token}"},
65-
)
66-
36+
client = TestClient(app, headers={"Authorization": f"Bearer {auth_token}"})
6737
response = client.get("/collections")
38+
6839
assert response.status_code == 200
69-
assert mock_send.call_count == 1
70-
[r] = mock_send.call_args[0]
40+
assert mock_upstream.call_count == 1
41+
[r] = mock_upstream.call_args[0]
7142
assert parse_qs(r.url.query.decode()) == {
7243
"filter": ["a_containedby(id, ('foo', 'bar'))"]
7344
}
45+
46+
47+
@pytest.mark.parametrize(
48+
"authenticated, expected_filter",
49+
[
50+
(True, "true"),
51+
(False, "(private = false)"),
52+
],
53+
)
54+
def test_collections_filter_private_and_public(
55+
mock_upstream, source_api_server, token_builder, authenticated, expected_filter
56+
):
57+
"""Test that filter can be used for private/public collections."""
58+
# Mock response from upstream API
59+
mock_upstream.return_value = single_chunk_async_stream_response(b"{}")
60+
61+
app = app_factory(
62+
upstream_url=source_api_server,
63+
collections_filter={
64+
"cls": "stac_auth_proxy.filters.Template",
65+
"args": ["{{ '(private = false)' if token is none else true }}"],
66+
},
67+
default_public=True,
68+
)
69+
70+
client = TestClient(
71+
app,
72+
headers=(
73+
{"Authorization": f"Bearer {token_builder({})}"} if authenticated else {}
74+
),
75+
)
76+
response = client.get("/collections")
77+
78+
assert response.status_code == 200
79+
assert mock_upstream.call_count == 1
80+
[r] = mock_upstream.call_args[0]
81+
assert parse_qs(r.url.query.decode()) == {"filter": [expected_filter]}

tests/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Utilities for testing."""
22

3+
from dataclasses import dataclass
34
from typing import Callable
45

6+
import httpx
7+
58
from stac_auth_proxy import Settings, create_app
69

710

@@ -23,3 +26,25 @@ def __call__(self, *, upstream_url, **overrides) -> Callable:
2326
},
2427
)
2528
)
29+
30+
31+
@dataclass
32+
class SingleChunkAsyncStream(httpx.AsyncByteStream):
33+
"""Mock async stream that returns a single chunk of data."""
34+
35+
body: bytes
36+
37+
async def __aiter__(self):
38+
"""Return a single chunk of data."""
39+
yield self.body
40+
41+
42+
def single_chunk_async_stream_response(
43+
body: bytes, status_code=200, headers={"content-type": "application/json"}
44+
):
45+
"""Create a response with a single chunk of data."""
46+
return httpx.Response(
47+
stream=SingleChunkAsyncStream(body),
48+
status_code=status_code,
49+
headers=headers,
50+
)

0 commit comments

Comments
 (0)