Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Typically settings are set by setting an environment variable with the same name
| `static_registration_pin` | `int` | If set - all new EndDevice registrations will have their Registration PIN set to this value (use 5 digit form). Uses a random number generator otherwise. |
| `nmi_validation_enabled` | `bool` | If `true` - all updates of `ConnectionPoint` resource will trigger validation on `ConnectionPoint.id` against on AEMO's NMI Allocation List (Version 13 – November 2022). Defaults to `false`. |
| `nmi_validation_participant_id` | `str` | Specifies the Participant ID (DNSP-only) as defined in AEMO’s NMI Allocation List (Version 13 – November 2022). For entities without an official Participant ID, a custom identifier is used - refer to DNSPParticipantId for details. This setting is required if `nmi_validation_enabled` is `true`. |
| `exclude_endpoints` | `string` | JSON-encoded set of tuples of the form (HTTP Method, URI), each defining an endpoint which should be excluded from the App at runtime e.g. `[["GET", "/tm"], ["HEAD", "/tm"]]`. Optional. |

**Additional Admin Server Settings (admin)**

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ markers = [
"admin_ro_user: marks tests that install the admin server 'Read Only' user/passwords",
"disable_device_registration: marks tests that disable NULL Aggregator and disable unrecognised devices from registering (equivalent to allow_device_registration = False)",
"nmi_validation_enabled: marks tests that enables NMI validation logic for the PUT ConnectionPoint endpoint.",
"exclude_endpoints: marks test that excludes endpoints from the application"
]

# (for pytests only) Using pytest-env to set placeholder values for required settings.
Expand Down
74 changes: 74 additions & 0 deletions src/envoy/server/endpoint_exclusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from copy import deepcopy
from enum import Enum
import logging

from starlette.routing import BaseRoute, Route
from fastapi import APIRouter


logging.basicConfig(style="{", level=logging.INFO)
logger = logging.getLogger(__name__)


# This should be replaced with http.HTTPMethod when this project is ported to Python 3.11
class HTTPMethod(str, Enum):
DELETE = "DELETE"
GET = "GET"
HEAD = "HEAD"
POST = "POST"
PATCH = "PATCH"
PUT = "PUT"


class ExcludeEndpointException(Exception): ... # noqa: E701


EndpointExclusionSet = set[tuple[HTTPMethod, str]]


def generate_routers_with_excluded_endpoints(
api_routers: list[APIRouter], exclude_endpoints: EndpointExclusionSet
) -> list[APIRouter]:
"""Generates a new list of api routers with endpoint filters applied. Endpoint filters are defined as tuple
of HTTPMethod and URI string). A route is removed entirely if all it's available methods are removed. Validates all
endpoints before modifying routers.

NOTE: This function should be called before routers are included in the FastAPI app. The assumption is that FastAPI
defers route registration and schema generation until routers are included. If this changes and internal state is
managed during APIRouter setup, this approach may need to be revisited.

Raises:
ExcludeEndpointException: if any endpoints cannot be found across the given routers
"""

logger.info(f"Disabling the following endpoints from routers: {exclude_endpoints}")

# We deepcopy and mutate to avoid reconstruction (where we may miss metadata), should be safe.
routers = deepcopy(api_routers)
endpoint_filters = deepcopy(exclude_endpoints)

for router in routers:
remaining_routes: list[BaseRoute] = []
for route in router.routes:
if isinstance(route, Route) and route.methods:
remaining_methods: list[str] = []

# filtering route methods
for method in route.methods:
endpoint = (HTTPMethod(method), route.path)
if endpoint in endpoint_filters:
endpoint_filters.discard(endpoint) # tracking which filters have been applied.
else:
remaining_methods.append(method)

# mutating route methods
route.methods = set(remaining_methods)
if route.methods:
remaining_routes.append(route)
router.routes = remaining_routes

if endpoint_filters:
raise ExcludeEndpointException(
f"The following endpoints cannot be found in provided routers: {endpoint_filters}"
)
return routers
9 changes: 8 additions & 1 deletion src/envoy/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from envoy.server.api.router import routers, unsecured_routers
from envoy.server.database import enable_dynamic_azure_ad_database_credentials
from envoy.server.lifespan import generate_combined_lifespan_manager
from envoy.server.endpoint_exclusion import generate_routers_with_excluded_endpoints
from envoy.server.settings import AppSettings, settings

# Setup logs
Expand Down Expand Up @@ -78,7 +79,13 @@ def generate_app(new_settings: AppSettings) -> FastAPI:
new_app = FastAPI(**new_settings.fastapi_kwargs, lifespan=generate_combined_lifespan_manager(lifespan_managers))
new_app.add_middleware(SQLAlchemyMiddleware, **new_settings.db_middleware_kwargs)

for router in routers:
# install routers
if new_settings.exclude_endpoints:
routers_to_include = generate_routers_with_excluded_endpoints(routers, new_settings.exclude_endpoints)
else:
routers_to_include = routers

for router in routers_to_include:
new_app.include_router(router, dependencies=global_dependencies)
for router in unsecured_routers:
new_app.include_router(router)
Expand Down
3 changes: 3 additions & 0 deletions src/envoy/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings

from envoy.server.endpoint_exclusion import EndpointExclusionSet
from envoy.server.manager.nmi_validator import NmiValidator, DNSPParticipantId
from envoy.settings import CommonSettings

Expand Down Expand Up @@ -61,6 +62,8 @@ class AppSettings(CommonSettings):

nmi_validation: NmiValidationSettings = Field(default_factory=NmiValidationSettings)

exclude_endpoints: Optional[EndpointExclusionSet] = None

@property
def fastapi_kwargs(self) -> Dict[str, Any]:
return {
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from decimal import Decimal
from typing import Generator
Expand Down Expand Up @@ -92,6 +93,10 @@ def pg_empty_config(
else:
os.environ["nmi_validation_enabled"] = "false"

exclude_endpoints_marker = request.node.get_closest_marker("exclude_endpoints")
if exclude_endpoints_marker is not None:
os.environ["exclude_endpoints"] = json.dumps(exclude_endpoints_marker.args[0])

# This will install all of the alembic migrations - DB is accessed from the DATABASE_URL env variable
upgrade()

Expand Down
63 changes: 63 additions & 0 deletions tests/integration/general/test_endpoint_exclusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from http import HTTPStatus
import urllib
import pytest

from httpx import AsyncClient
from envoy_schema.server.schema.uri import TimeUri, EndDeviceUri

from envoy.server.endpoint_exclusion import HTTPMethod

from tests.data.certificates.certificate1 import TEST_CERTIFICATE_FINGERPRINT
from tests.integration.integration_server import cert_header


@pytest.mark.exclude_endpoints([(HTTPMethod.HEAD, TimeUri)])
@pytest.mark.anyio
async def test_exclude_endpoints_no_method(client: AsyncClient):
"""Test basic filter usage where route still exist but a specific method has been removed."""
# Act
resp = await client.head(TimeUri, headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)})

# Assert
assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED


@pytest.mark.exclude_endpoints([(HTTPMethod.GET, EndDeviceUri)])
@pytest.mark.anyio
async def test_exclude_endpoints_no_method_formattable_uri(client: AsyncClient):
"""Test basic filter usage where route still exist but a specific method has been removed,
using a formattable uri."""
# Act
resp = await client.get(
EndDeviceUri.format(site_id=1), # site included in base_config.sql
headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)},
)

# Assert
assert resp.status_code == HTTPStatus.METHOD_NOT_ALLOWED


@pytest.mark.exclude_endpoints([(HTTPMethod.HEAD, TimeUri), (HTTPMethod.GET, TimeUri)])
@pytest.mark.anyio
async def test_exclude_endpoints_no_route(client: AsyncClient):
"""Test where all methods of a route have been removed, expecting the entire route to be removed i.e. NOT FOUND"""
# Act
resp = await client.head(TimeUri, headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)})

# Assert
assert resp.status_code == HTTPStatus.NOT_FOUND


@pytest.mark.exclude_endpoints(
[(HTTPMethod.HEAD, EndDeviceUri), (HTTPMethod.GET, EndDeviceUri), (HTTPMethod.DELETE, EndDeviceUri)]
)
@pytest.mark.anyio
async def test_exclude_endpoints_no_route_formattable_uri(client: AsyncClient):
"""Test where all methods of a route have been removed, expecting the entire route to be removed i.e. NOT FOUND"""
# Act
resp = await client.head(
EndDeviceUri.format(site_id=1), headers={cert_header: urllib.parse.quote(TEST_CERTIFICATE_FINGERPRINT)}
)

# Assert
assert resp.status_code == HTTPStatus.NOT_FOUND
86 changes: 86 additions & 0 deletions tests/unit/server/test_endpoint_exclusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from fastapi import APIRouter, FastAPI
from envoy.server.endpoint_exclusion import (
ExcludeEndpointException,
generate_routers_with_excluded_endpoints,
HTTPMethod,
)


def test_generate_routers_with_excluded_endpoints():
"""Basic success test"""
# Arrange
router = APIRouter()
router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD])
router.add_api_route("/someotherpath", lambda x: x, methods=[HTTPMethod.DELETE])

# Act
filtered_routers = generate_routers_with_excluded_endpoints([router], {(HTTPMethod.DELETE, "/someotherpath")})

# Assert
assert len(filtered_routers[0].routes) == 1
assert filtered_routers[0].routes[0].path == "/somepath"
assert filtered_routers[0].routes[0].methods == {"GET", "HEAD"}


def test_generate_routers_with_excluded_endpoints_single_method():
"""Tests Disabling one method from a route with multiple methods"""
# Arrange
router = APIRouter()
router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD, HTTPMethod.DELETE])

# Act
filtered_routers = generate_routers_with_excluded_endpoints([router], {("DELETE", "/somepath")})

# Assert
assert len(filtered_routers[0].routes) == 1
assert filtered_routers[0].routes[0].path == "/somepath"
assert filtered_routers[0].routes[0].methods == {"GET", "HEAD"}


def test_generate_routers_with_excluded_endpoints_raises_error_on_unmatched_endpoint():
"""Should raise error on unmatched endpoint"""
# Arrange
router = APIRouter()
router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD])

# Act / Assert
with pytest.raises(ExcludeEndpointException):
generate_routers_with_excluded_endpoints([router], {(HTTPMethod.DELETE, "/sometherepath")})

# Assert
assert len(router.routes) == 1
assert router.routes[0].path == "/somepath"
assert router.routes[0].methods == {"GET", "HEAD"}


def test_generate_routers_with_excluded_endpoints_raises_error_no_side_effects():
"""Should raise error on unmatched endpoint, without side-effects"""
# Arrange
router = APIRouter()
router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD])

# Act / Assert
with pytest.raises(ExcludeEndpointException):
generate_routers_with_excluded_endpoints([router], {("HEAD", "/somepath"), ("GET", "/someotherepath")})

# Assert
assert len(router.routes) == 1
assert router.routes[0].path == "/somepath"
assert router.routes[0].methods == {"GET", "HEAD"}


def test_generate_routers_with_excluded_endpoints_includes_successfully():
"""Tests no errors raised when modified router is added to app"""
# Arrange
app = FastAPI()
router = APIRouter()
router.add_api_route("/somepath", lambda x: x, methods=[HTTPMethod.GET, HTTPMethod.HEAD])

# Act
filtered_routers = generate_routers_with_excluded_endpoints([router], {("HEAD", "/somepath")})
app.include_router(filtered_routers[0])

# Assert
route = [r for r in app.routes if r.path == "/somepath"].pop()
assert route.methods == {"GET"}
Loading