|
1 | 1 | """Tests for Jinja2 CQL2 filter."""
|
2 | 2 |
|
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Generator |
| 5 | +from unittest.mock import AsyncMock, MagicMock, patch |
| 6 | +from urllib.parse import parse_qs |
| 7 | + |
| 8 | +import httpx |
3 | 9 | import pytest
|
4 | 10 | from fastapi.testclient import TestClient
|
5 | 11 | from utils import AppFactory
|
|
10 | 16 | )
|
11 | 17 |
|
12 | 18 |
|
13 |
| -def test_collections_filter_contained_by_token(source_api_server, token_builder): |
14 |
| - """""" |
| 19 | +@pytest.fixture |
| 20 | +def mock_send() -> Generator[MagicMock, None, None]: |
| 21 | + """Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API.""" |
| 22 | + with patch( |
| 23 | + "stac_auth_proxy.handlers.reverse_proxy.httpx.AsyncClient.send", |
| 24 | + new_callable=AsyncMock, |
| 25 | + ) as mock_send_method: |
| 26 | + yield mock_send_method |
| 27 | + |
| 28 | + |
| 29 | +@dataclass |
| 30 | +class SingleChunkAsyncStream(httpx.AsyncByteStream): |
| 31 | + """Mock async stream that returns a single chunk of data.""" |
| 32 | + |
| 33 | + body: bytes |
| 34 | + |
| 35 | + async def __aiter__(self): |
| 36 | + """Return a single chunk of data.""" |
| 37 | + yield self.body |
| 38 | + |
| 39 | + |
| 40 | +def test_collections_filter_contained_by_token( |
| 41 | + mock_send, source_api_server, token_builder |
| 42 | +): |
| 43 | + """Test that the collections filter is applied correctly.""" |
| 44 | + # Mock response from upstream API |
| 45 | + mock_send.return_value = httpx.Response( |
| 46 | + 200, |
| 47 | + stream=SingleChunkAsyncStream(b"{}"), |
| 48 | + headers={"content-type": "application/json"}, |
| 49 | + ) |
| 50 | + |
15 | 51 | app = app_factory(
|
16 | 52 | upstream_url=source_api_server,
|
17 | 53 | collections_filter={
|
18 | 54 | "cls": "stac_auth_proxy.filters.Template",
|
19 | 55 | "args": [
|
20 |
| - "A_CONTAINEDBY(id, ( '{{ token.collections | join(\"', '\") }}' ))" |
| 56 | + "A_CONTAINEDBY(id, ('{{ token.collections | join(\"', '\") }}' ))" |
21 | 57 | ],
|
22 | 58 | },
|
23 | 59 | )
|
| 60 | + |
| 61 | + auth_token = token_builder({"collections": ["foo", "bar"]}) |
24 | 62 | client = TestClient(
|
25 | 63 | app,
|
26 |
| - headers={ |
27 |
| - "Authorization": f"Bearer {token_builder({"collections": ["foo", "bar"]})}" |
28 |
| - }, |
| 64 | + headers={"Authorization": f"Bearer {auth_token}"}, |
29 | 65 | )
|
| 66 | + |
30 | 67 | response = client.get("/collections")
|
31 | 68 | assert response.status_code == 200
|
32 |
| - |
33 |
| - # TODO: We need to verify that the upstream API was called with an applied filter |
| 69 | + assert mock_send.call_count == 1 |
| 70 | + [r] = mock_send.call_args[0] |
| 71 | + assert parse_qs(r.url.query.decode()) == { |
| 72 | + "filter": ["a_containedby(id, ('foo', 'bar'))"] |
| 73 | + } |
0 commit comments