diff --git a/README.md b/README.md index 65a08baa..8eebc4cc 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ Typically settings are set by setting an environment variable with the same name | `static_registration_pin` | `int` | If set - all new EndDevice registrations will have their Registration PIN set to this value (use 5 digit form). Uses a random number generator otherwise. | | `nmi_validation_enabled` | `bool` | If `true` - all updates of `ConnectionPoint` resource will trigger validation on `ConnectionPoint.id` against on AEMO's NMI Allocation List (Version 13 – November 2022). Defaults to `false`. | | `nmi_validation_participant_id` | `str` | Specifies the Participant ID (DNSP-only) as defined in AEMO’s NMI Allocation List (Version 13 – November 2022). For entities without an official Participant ID, a custom identifier is used - refer to DNSPParticipantId for details. This setting is required if `nmi_validation_enabled` is `true`. | +| `exclude_endpoints` | `string` | JSON-encoded set of tuples of the form (HTTP Method, URI), each defining an endpoint which should be excluded from the App at runtime e.g. `[["GET", "/tm"], ["HEAD", "/tm"]]`. Optional. | **Additional Admin Server Settings (admin)** diff --git a/pyproject.toml b/pyproject.toml index 1ff12caa..0e537196 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ markers = [ "admin_ro_user: marks tests that install the admin server 'Read Only' user/passwords", "disable_device_registration: marks tests that disable NULL Aggregator and disable unrecognised devices from registering (equivalent to allow_device_registration = False)", "nmi_validation_enabled: marks tests that enables NMI validation logic for the PUT ConnectionPoint endpoint.", + "exclude_endpoints: marks test that excludes endpoints from the application" ] # (for pytests only) Using pytest-env to set placeholder values for required settings. diff --git a/src/envoy/server/endpoint_exclusion.py b/src/envoy/server/endpoint_exclusion.py new file mode 100644 index 00000000..d293fd6f --- /dev/null +++ b/src/envoy/server/endpoint_exclusion.py @@ -0,0 +1,74 @@ +from copy import deepcopy +from enum import Enum +import logging + +from starlette.routing import BaseRoute, Route +from fastapi import APIRouter + + +logging.basicConfig(style="{", level=logging.INFO) +logger = logging.getLogger(__name__) + + +# This should be replaced with http.HTTPMethod when this project is ported to Python 3.11 +class HTTPMethod(str, Enum): + DELETE = "DELETE" + GET = "GET" + HEAD = "HEAD" + POST = "POST" + PATCH = "PATCH" + PUT = "PUT" + + +class ExcludeEndpointException(Exception): ... # noqa: E701 + + +EndpointExclusionSet = set[tuple[HTTPMethod, str]] + + +def generate_routers_with_excluded_endpoints( + api_routers: list[APIRouter], exclude_endpoints: EndpointExclusionSet +) -> list[APIRouter]: + """Generates a new list of api routers with endpoint filters applied. Endpoint filters are defined as tuple + of HTTPMethod and URI string). A route is removed entirely if all it's available methods are removed. Validates all + endpoints before modifying routers. + + NOTE: This function should be called before routers are included in the FastAPI app. The assumption is that FastAPI + defers route registration and schema generation until routers are included. If this changes and internal state is + managed during APIRouter setup, this approach may need to be revisited. + + Raises: + ExcludeEndpointException: if any endpoints cannot be found across the given routers + """ + + logger.info(f"Disabling the following endpoints from routers: {exclude_endpoints}") + + # We deepcopy and mutate to avoid reconstruction (where we may miss metadata), should be safe. + routers = deepcopy(api_routers) + endpoint_filters = deepcopy(exclude_endpoints) + + for router in routers: + remaining_routes: list[BaseRoute] = [] + for route in router.routes: + if isinstance(route, Route) and route.methods: + remaining_methods: list[str] = [] + + # filtering route methods + for method in route.methods: + endpoint = (HTTPMethod(method), route.path) + if endpoint in endpoint_filters: + endpoint_filters.discard(endpoint) # tracking which filters have been applied. + else: + remaining_methods.append(method) + + # mutating route methods + route.methods = set(remaining_methods) + if route.methods: + remaining_routes.append(route) + router.routes = remaining_routes + + if endpoint_filters: + raise ExcludeEndpointException( + f"The following endpoints cannot be found in provided routers: {endpoint_filters}" + ) + return routers diff --git a/src/envoy/server/main.py b/src/envoy/server/main.py index 0619a7db..d6710669 100644 --- a/src/envoy/server/main.py +++ b/src/envoy/server/main.py @@ -22,6 +22,7 @@ from envoy.server.api.router import routers, unsecured_routers from envoy.server.database import enable_dynamic_azure_ad_database_credentials from envoy.server.lifespan import generate_combined_lifespan_manager +from envoy.server.endpoint_exclusion import generate_routers_with_excluded_endpoints from envoy.server.settings import AppSettings, settings # Setup logs @@ -78,7 +79,13 @@ def generate_app(new_settings: AppSettings) -> FastAPI: new_app = FastAPI(**new_settings.fastapi_kwargs, lifespan=generate_combined_lifespan_manager(lifespan_managers)) new_app.add_middleware(SQLAlchemyMiddleware, **new_settings.db_middleware_kwargs) - for router in routers: + # install routers + if new_settings.exclude_endpoints: + routers_to_include = generate_routers_with_excluded_endpoints(routers, new_settings.exclude_endpoints) + else: + routers_to_include = routers + + for router in routers_to_include: new_app.include_router(router, dependencies=global_dependencies) for router in unsecured_routers: new_app.include_router(router) diff --git a/src/envoy/server/settings.py b/src/envoy/server/settings.py index 34ed80e2..b5d7d14a 100644 --- a/src/envoy/server/settings.py +++ b/src/envoy/server/settings.py @@ -6,6 +6,7 @@ from pydantic import Field, model_validator from pydantic_settings import BaseSettings +from envoy.server.endpoint_exclusion import EndpointExclusionSet from envoy.server.manager.nmi_validator import NmiValidator, DNSPParticipantId from envoy.settings import CommonSettings @@ -61,6 +62,8 @@ class AppSettings(CommonSettings): nmi_validation: NmiValidationSettings = Field(default_factory=NmiValidationSettings) + exclude_endpoints: Optional[EndpointExclusionSet] = None + @property def fastapi_kwargs(self) -> Dict[str, Any]: return { diff --git a/tests/conftest.py b/tests/conftest.py index dbbbe5a2..bbef02bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import json import os from decimal import Decimal from typing import Generator @@ -92,6 +93,10 @@ def pg_empty_config( else: os.environ["nmi_validation_enabled"] = "false" + exclude_endpoints_marker = request.node.get_closest_marker("exclude_endpoints") + if exclude_endpoints_marker is not None: + os.environ["exclude_endpoints"] = json.dumps(exclude_endpoints_marker.args[0]) + # This will install all of the alembic migrations - DB is accessed from the DATABASE_URL env variable upgrade() diff --git a/tests/integration/general/test_endpoint_exclusion.py b/tests/integration/general/test_endpoint_exclusion.py new file mode 100644 index 00000000..fcf8e1f3 --- /dev/null +++ b/tests/integration/general/test_endpoint_exclusion.py @@ -0,0 +1,63 @@ +from http import HTTPStatus +import urllib +import pytest + +from httpx import AsyncClient +from envoy_schema.server.schema.uri import TimeUri, EndDeviceUri + +from envoy.server.endpoint_exclusion import HTTPMethod + +from tests.data.certificates.certificate1 import TEST_CERTIFICATE_FINGERPRINT +from tests.integration.integration_server import cert_header + + +@pytest.mark.exclude_endpoints([(HTTPMethod.HEAD, TimeUri)]) +@pytest.mark.anyio +async def test_exclude_endpoints_no_method(client: AsyncClient): + """Test basic filter usage where route still exist but a specific method has been removed.""" + # Act + resp = await client.head(TimeUri, headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)}) + + # Assert + assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED + + +@pytest.mark.exclude_endpoints([(HTTPMethod.GET, EndDeviceUri)]) +@pytest.mark.anyio +async def test_exclude_endpoints_no_method_formattable_uri(client: AsyncClient): + """Test basic filter usage where route still exist but a specific method has been removed, + using a formattable uri.""" + # Act + resp = await client.get( + EndDeviceUri.format(site_id=1), # site included in base_config.sql + headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)}, + ) + + # Assert + assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED + + +@pytest.mark.exclude_endpoints([(HTTPMethod.HEAD, TimeUri), (HTTPMethod.GET, TimeUri)]) +@pytest.mark.anyio +async def test_exclude_endpoints_no_route(client: AsyncClient): + """Test where all methods of a route have been removed, expecting the entire route to be removed i.e. NOT FOUND""" + # Act + resp = await client.head(TimeUri, headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)}) + + # Assert + assert resp.status_code == HTTPStatus.NOT_FOUND + + +@pytest.mark.exclude_endpoints( + [(HTTPMethod.HEAD, EndDeviceUri), (HTTPMethod.GET, EndDeviceUri), (HTTPMethod.DELETE, EndDeviceUri)] +) +@pytest.mark.anyio +async def test_exclude_endpoints_no_route_formattable_uri(client: AsyncClient): + """Test where all methods of a route have been removed, expecting the entire route to be removed i.e. NOT FOUND""" + # Act + resp = await client.head( + EndDeviceUri.format(site_id=1), headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)} + ) + + # Assert + assert resp.status_code == HTTPStatus.NOT_FOUND diff --git a/tests/unit/server/test_endpoint_exclusion.py b/tests/unit/server/test_endpoint_exclusion.py new file mode 100644 index 00000000..0f288eca --- /dev/null +++ b/tests/unit/server/test_endpoint_exclusion.py @@ -0,0 +1,86 @@ +import pytest +from fastapi import APIRouter, FastAPI +from envoy.server.endpoint_exclusion import ( + ExcludeEndpointException, + generate_routers_with_excluded_endpoints, + HTTPMethod, +) + + +def test_generate_routers_with_excluded_endpoints(): + """Basic success test""" + # Arrange + router = APIRouter() + router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD]) + router.add_api_route("/someotherpath", lambda x: x, methods=[HTTPMethod.DELETE]) + + # Act + filtered_routers = generate_routers_with_excluded_endpoints([router], {(HTTPMethod.DELETE, "/someotherpath")}) + + # Assert + assert len(filtered_routers[0].routes) == 1 + assert filtered_routers[0].routes[0].path == "/somepath" + assert filtered_routers[0].routes[0].methods == {"GET", "HEAD"} + + +def test_generate_routers_with_excluded_endpoints_single_method(): + """Tests Disabling one method from a route with multiple methods""" + # Arrange + router = APIRouter() + router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD, HTTPMethod.DELETE]) + + # Act + filtered_routers = generate_routers_with_excluded_endpoints([router], {("DELETE", "/somepath")}) + + # Assert + assert len(filtered_routers[0].routes) == 1 + assert filtered_routers[0].routes[0].path == "/somepath" + assert filtered_routers[0].routes[0].methods == {"GET", "HEAD"} + + +def test_generate_routers_with_excluded_endpoints_raises_error_on_unmatched_endpoint(): + """Should raise error on unmatched endpoint""" + # Arrange + router = APIRouter() + router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD]) + + # Act / Assert + with pytest.raises(ExcludeEndpointException): + generate_routers_with_excluded_endpoints([router], {(HTTPMethod.DELETE, "/sometherepath")}) + + # Assert + assert len(router.routes) == 1 + assert router.routes[0].path == "/somepath" + assert router.routes[0].methods == {"GET", "HEAD"} + + +def test_generate_routers_with_excluded_endpoints_raises_error_no_side_effects(): + """Should raise error on unmatched endpoint, without side-effects""" + # Arrange + router = APIRouter() + router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD]) + + # Act / Assert + with pytest.raises(ExcludeEndpointException): + generate_routers_with_excluded_endpoints([router], {("HEAD", "/somepath"), ("GET", "/someotherepath")}) + + # Assert + assert len(router.routes) == 1 + assert router.routes[0].path == "/somepath" + assert router.routes[0].methods == {"GET", "HEAD"} + + +def test_generate_routers_with_excluded_endpoints_includes_successfully(): + """Tests no errors raised when modified router is added to app""" + # Arrange + app = FastAPI() + router = APIRouter() + router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD]) + + # Act + filtered_routers = generate_routers_with_excluded_endpoints([router], {("HEAD", "/somepath")}) + app.include_router(filtered_routers[0]) + + # Assert + route = [r for r in app.routes if r.path == "/somepath"].pop() + assert route.methods == {"GET"}