From d1ef75332605b64b658314d833805b1f83d013e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sat, 24 Aug 2024 16:46:17 +0200 Subject: [PATCH] fix: OpenTelemetry doesn't capture exceptions in the outermost application layer (#3689) * fix(opentelemetry): wrap all middleware under Otel middleware to ensure spans are created and exceptions are logged correctly under that span --------- Co-authored-by: bella --- docs/usage/metrics/open-telemetry.rst | 4 +- litestar/app.py | 29 ++- litestar/contrib/opentelemetry/__init__.py | 7 +- litestar/contrib/opentelemetry/_utils.py | 10 +- litestar/contrib/opentelemetry/plugin.py | 50 +++++ tests/unit/test_contrib/test_opentelemetry.py | 206 +++++++++++++++--- 6 files changed, 267 insertions(+), 39 deletions(-) create mode 100644 litestar/contrib/opentelemetry/plugin.py diff --git a/docs/usage/metrics/open-telemetry.rst b/docs/usage/metrics/open-telemetry.rst index 27b69677ae..495eabf823 100644 --- a/docs/usage/metrics/open-telemetry.rst +++ b/docs/usage/metrics/open-telemetry.rst @@ -22,11 +22,11 @@ the Litestar constructor: .. code-block:: python from litestar import Litestar - from litestar.contrib.opentelemetry import OpenTelemetryConfig + from litestar.contrib.opentelemetry import OpenTelemetryConfig, OpenTelemetryPlugin open_telemetry_config = OpenTelemetryConfig() - app = Litestar(middleware=[open_telemetry_config.middleware]) + app = Litestar(plugins=[OpenTelemetryPlugin(open_telemetry_config)]) The above example will work out of the box if you configure a global ``tracer_provider`` and/or ``metric_provider`` and an exporter to use these (see the diff --git a/litestar/app.py b/litestar/app.py index e17b8d00c4..b17bff272a 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -60,6 +60,7 @@ from litestar.config.compression import CompressionConfig from litestar.config.cors import CORSConfig from litestar.config.csrf import CSRFConfig + from litestar.contrib.opentelemetry import OpenTelemetryPlugin from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.events.listener import EventListener @@ -385,8 +386,10 @@ def __init__( for handler in chain( on_app_init or [], (p.on_app_init for p in config.plugins if isinstance(p, InitPluginProtocol)), + [self._patch_opentelemetry_middleware], ): config = handler(config) # pyright: ignore + self.plugins = PluginRegistry(config.plugins) self._openapi_schema: OpenAPI | None = None @@ -446,7 +449,6 @@ def __init__( config.exception_handlers.setdefault(StarletteHTTPException, _starlette_exception_handler) except ImportError: pass - super().__init__( after_request=config.after_request, after_response=config.after_response, @@ -492,6 +494,23 @@ def __init__( self.asgi_handler = self._create_asgi_handler() + @staticmethod + def _patch_opentelemetry_middleware(config: AppConfig) -> AppConfig: + # workaround to support otel middleware priority. Should be replaced by regular + # middleware priorities once available + try: + from litestar.contrib.opentelemetry import OpenTelemetryPlugin + + if not any(isinstance(p, OpenTelemetryPlugin) for p in config.plugins): + config.middleware, otel_middleware = OpenTelemetryPlugin._pop_otel_middleware(config.middleware) + if otel_middleware: + otel_plugin = OpenTelemetryPlugin() + otel_plugin._middleware = otel_middleware + config.plugins = [*config.plugins, otel_plugin] + except ImportError: + pass + return config + @property @deprecated(version="2.6.0", kind="property", info="Use create_static_files router instead") def static_files_config(self) -> list[StaticFilesConfig]: @@ -843,7 +862,13 @@ def _create_asgi_handler(self) -> ASGIApp: asgi_handler = wrap_in_exception_handler(app=self.asgi_router) if self.cors_config: - return CORSMiddleware(app=asgi_handler, config=self.cors_config) + asgi_handler = CORSMiddleware(app=asgi_handler, config=self.cors_config) + + try: + otel_plugin: OpenTelemetryPlugin = self.plugins.get("OpenTelemetryPlugin") + asgi_handler = otel_plugin.middleware(app=asgi_handler) + except KeyError: + pass return asgi_handler diff --git a/litestar/contrib/opentelemetry/__init__.py b/litestar/contrib/opentelemetry/__init__.py index 3f936110d2..983b777754 100644 --- a/litestar/contrib/opentelemetry/__init__.py +++ b/litestar/contrib/opentelemetry/__init__.py @@ -1,4 +1,9 @@ from .config import OpenTelemetryConfig from .middleware import OpenTelemetryInstrumentationMiddleware +from .plugin import OpenTelemetryPlugin -__all__ = ("OpenTelemetryConfig", "OpenTelemetryInstrumentationMiddleware") +__all__ = ( + "OpenTelemetryConfig", + "OpenTelemetryInstrumentationMiddleware", + "OpenTelemetryPlugin", +) diff --git a/litestar/contrib/opentelemetry/_utils.py b/litestar/contrib/opentelemetry/_utils.py index 0ba7cb9960..66ba442c9a 100644 --- a/litestar/contrib/opentelemetry/_utils.py +++ b/litestar/contrib/opentelemetry/_utils.py @@ -27,5 +27,11 @@ def get_route_details_from_scope(scope: Scope) -> tuple[str, dict[Any, str]]: Returns: A tuple of the span name and a dict of attrs. """ - route_handler_fn_name = scope["route_handler"].handler_name - return route_handler_fn_name, {SpanAttributes.HTTP_ROUTE: route_handler_fn_name} + + path = scope.get("path", "").strip() + method = str(scope.get("method", "")).strip() + + if method and path: # http + return f"{method} {path}", {SpanAttributes.HTTP_ROUTE: f"{method} {path}"} + + return path, {SpanAttributes.HTTP_ROUTE: path} # websocket diff --git a/litestar/contrib/opentelemetry/plugin.py b/litestar/contrib/opentelemetry/plugin.py new file mode 100644 index 0000000000..b8f60d6d5b --- /dev/null +++ b/litestar/contrib/opentelemetry/plugin.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.contrib.opentelemetry.config import OpenTelemetryConfig +from litestar.contrib.opentelemetry.middleware import OpenTelemetryInstrumentationMiddleware +from litestar.middleware.base import DefineMiddleware +from litestar.plugins import InitPluginProtocol + +if TYPE_CHECKING: + from litestar.config.app import AppConfig + from litestar.types.composite_types import Middleware + + +class OpenTelemetryPlugin(InitPluginProtocol): + """OpenTelemetry Plugin.""" + + __slots__ = ("config", "_middleware") + + def __init__(self, config: OpenTelemetryConfig | None = None) -> None: + self.config = config or OpenTelemetryConfig() + self._middleware: DefineMiddleware | None = None + super().__init__() + + @property + def middleware(self) -> DefineMiddleware: + if self._middleware: + return self._middleware + return DefineMiddleware(OpenTelemetryInstrumentationMiddleware, config=self.config) + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + app_config.middleware, _middleware = self._pop_otel_middleware(app_config.middleware) + return app_config + + @staticmethod + def _pop_otel_middleware(middlewares: list[Middleware]) -> tuple[list[Middleware], DefineMiddleware | None]: + """Get the OpenTelemetry middleware if it is enabled in the application. + Remove the middleware from the list of middlewares if it is found. + """ + otel_middleware: DefineMiddleware | None = None + other_middlewares = [] + for middleware in middlewares: + if ( + isinstance(middleware, DefineMiddleware) + and middleware.middleware is OpenTelemetryInstrumentationMiddleware + ): + otel_middleware = middleware + else: + other_middlewares.append(middleware) + return other_middlewares, otel_middleware diff --git a/tests/unit/test_contrib/test_opentelemetry.py b/tests/unit/test_contrib/test_opentelemetry.py index 2afc46f6ab..f34907ef7d 100644 --- a/tests/unit/test_contrib/test_opentelemetry.py +++ b/tests/unit/test_contrib/test_opentelemetry.py @@ -1,5 +1,7 @@ -from typing import Any, Tuple, cast +from typing import Tuple, cast +import pytest +from _pytest.fixtures import FixtureRequest from opentelemetry.metrics import get_meter_provider, set_meter_provider from opentelemetry.sdk.metrics._internal import MeterProvider from opentelemetry.sdk.metrics._internal.aggregation import ( @@ -13,48 +15,64 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from litestar import WebSocket, get, websocket -from litestar.contrib.opentelemetry import OpenTelemetryConfig +from litestar.config.app import AppConfig +from litestar.contrib.opentelemetry import OpenTelemetryConfig, OpenTelemetryPlugin +from litestar.exceptions import http_exceptions from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client +from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send -def create_config(**kwargs: Any) -> Tuple[OpenTelemetryConfig, InMemoryMetricReader, InMemorySpanExporter]: - """Create OpenTelemetryConfig, an InMemoryMetricReader and InMemorySpanExporter. +@pytest.fixture(scope="session") +def resource() -> Resource: + return Resource(attributes={SERVICE_NAME: "litestar-test"}) - Args: - **kwargs: Any config kwargs to pass to the OpenTelemetryConfig constructor. - - Returns: - A tuple containing an OpenTelemetryConfig, an InMemoryMetricReader and InMemorySpanExporter. - """ - resource = Resource(attributes={SERVICE_NAME: "litestar-test"}) - tracer_provider = TracerProvider(resource=resource) - exporter = InMemorySpanExporter() - tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) +@pytest.fixture(scope="session") +def reader() -> InMemoryMetricReader: aggregation_last_value = {Counter: ExplicitBucketHistogramAggregation()} - reader = InMemoryMetricReader(preferred_aggregation=aggregation_last_value) # type: ignore[arg-type] - meter_provider = MeterProvider(resource=resource, metric_readers=[reader]) + return InMemoryMetricReader(preferred_aggregation=aggregation_last_value) # type: ignore[arg-type] + + +@pytest.fixture(scope="session") +def meter_provider(resource: Resource, reader: InMemoryMetricReader) -> MeterProvider: + provider = MeterProvider(resource=resource, metric_readers=[reader]) + set_meter_provider(provider) + return provider - set_meter_provider(meter_provider) - meter = get_meter_provider().get_meter("litestar-test") +@pytest.fixture() +def exporter() -> InMemorySpanExporter: + return InMemorySpanExporter() - return ( - OpenTelemetryConfig(tracer_provider=tracer_provider, meter=meter, **kwargs), - reader, - exporter, - ) + +@pytest.fixture() +def config( + resource: Resource, exporter: InMemorySpanExporter, meter_provider: MeterProvider, request: FixtureRequest +) -> OpenTelemetryConfig: + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + meter = get_meter_provider().get_meter(f"litestar-test-{request.node.nodeid}") + return OpenTelemetryConfig(tracer_provider=tracer_provider, meter=meter) -def test_open_telemetry_middleware_with_http_route() -> None: - config, reader, exporter = create_config() +@pytest.fixture(params=["middleware", "plugin"]) +def app_config(request: FixtureRequest, config: OpenTelemetryConfig) -> AppConfig: + if request.param == "middleware": + return AppConfig(middleware=[config.middleware]) + return AppConfig(plugins=[OpenTelemetryPlugin(config)]) + +def test_open_telemetry_middleware_with_http_route( + app_config: AppConfig, + reader: InMemoryMetricReader, + exporter: InMemorySpanExporter, +) -> None: @get("/") def handler() -> dict: return {"hello": "world"} - with create_test_client(handler, middleware=[config.middleware]) as client: + with create_test_client(handler, middleware=app_config.middleware, plugins=app_config.plugins) as client: response = client.get("/") assert response.status_code == HTTP_200_OK assert reader.get_metrics_data() @@ -74,7 +92,7 @@ def handler() -> dict: "http.user_agent": "testclient", "net.peer.ip": "testclient", "net.peer.port": 50000, - "http.route": "handler", + "http.route": "GET /", "http.status_code": 200, } @@ -91,16 +109,20 @@ def handler() -> dict: assert len(list(request_metric.data.data_points)) == 1 -def test_open_telemetry_middleware_with_websocket_route() -> None: - config, reader, exporter = create_config() - +def test_open_telemetry_middleware_with_websocket_route( + app_config: AppConfig, + reader: InMemoryMetricReader, + exporter: InMemorySpanExporter, +) -> None: @websocket("/") async def handler(socket: "WebSocket") -> None: await socket.accept() await socket.send_json({"hello": "world"}) await socket.close() - with create_test_client(handler, middleware=[config.middleware]).websocket_connect("/") as client: + with create_test_client(handler, middleware=app_config.middleware, plugins=app_config.plugins).websocket_connect( + "/" + ) as client: data = client.receive_json() assert data == {"hello": "world"} @@ -121,6 +143,126 @@ async def handler(socket: "WebSocket") -> None: "http.user_agent": "testclient", "net.peer.ip": "testclient", "net.peer.port": 50000, - "http.route": "handler", + "http.route": "/", "http.status_code": 200, } + + +def test_open_telemetry_middleware_handles_route_not_found_under_span_http( + app_config: AppConfig, + reader: InMemoryMetricReader, + exporter: InMemorySpanExporter, +) -> None: + @get("/") + def handler() -> dict: + raise Exception("random Exception") + + with create_test_client(handler, middleware=app_config.middleware, plugins=app_config.plugins) as client: + response = client.get("/route_that_does_not_exist") + assert response.status_code + + first_span, second_span, third_span = cast("Tuple[Span, Span, Span]", exporter.get_finished_spans()) + assert dict(first_span.attributes) == { # type: ignore[arg-type] + "http.status_code": 404, + "asgi.event.type": "http.response.start", + } + assert dict(second_span.attributes) == {"asgi.event.type": "http.response.body"} # type: ignore[arg-type] + assert dict(third_span.attributes) == { # type: ignore[arg-type] + "http.scheme": "http", + "http.host": "testserver.local", + "net.host.port": 80, + "http.flavor": "1.1", + "http.target": "/route_that_does_not_exist", + "http.url": "http://testserver.local/route_that_does_not_exist", + "http.method": "GET", + "http.server_name": "testserver.local", + "http.user_agent": "testclient", + "net.peer.ip": "testclient", + "net.peer.port": 50000, + "http.route": "GET /route_that_does_not_exist", + "http.status_code": 404, + } + + +def test_open_telemetry_middleware_handles_method_not_allowed_under_span_http( + app_config: AppConfig, + reader: InMemoryMetricReader, + exporter: InMemorySpanExporter, +) -> None: + @get("/") + def handler() -> dict: + raise Exception("random Exception") + + with create_test_client(handler, middleware=app_config.middleware, plugins=app_config.plugins) as client: + response = client.post("/") + assert response.status_code + + first_span, second_span, third_span = cast("Tuple[Span, Span, Span]", exporter.get_finished_spans()) + assert dict(first_span.attributes) == { # type: ignore[arg-type] + "http.status_code": 405, + "asgi.event.type": "http.response.start", + } + assert dict(second_span.attributes) == {"asgi.event.type": "http.response.body"} # type: ignore[arg-type] + assert dict(third_span.attributes) == { # type: ignore[arg-type] + "http.scheme": "http", + "http.host": "testserver.local", + "net.host.port": 80, + "http.flavor": "1.1", + "http.target": "/", + "http.url": "http://testserver.local/", + "http.method": "POST", + "http.server_name": "testserver.local", + "http.user_agent": "testclient", + "net.peer.ip": "testclient", + "net.peer.port": 50000, + "http.route": "POST /", + "http.status_code": 405, + } + + +def test_open_telemetry_middleware_handles_errors_caused_on_middleware( + app_config: AppConfig, + reader: InMemoryMetricReader, + exporter: InMemorySpanExporter, +) -> None: + raise_exception = True + + def middleware_factory(app: ASGIApp) -> ASGIApp: + async def error_middleware(scope: Scope, receive: Receive, send: Send) -> None: + if raise_exception: + raise http_exceptions.NotAuthorizedException() + await app(scope, receive, send) + + return error_middleware + + @get("/") + def handler() -> dict: + raise Exception("random Exception") + + with create_test_client( + handler, middleware=[middleware_factory, *app_config.middleware], plugins=app_config.plugins + ) as client: + response = client.get("/") + assert response.status_code + + first_span, second_span, third_span = cast("Tuple[Span, Span, Span]", exporter.get_finished_spans()) + assert dict(first_span.attributes) == { # type: ignore[arg-type] + "http.status_code": 401, + "asgi.event.type": "http.response.start", + } + assert dict(second_span.attributes) == {"asgi.event.type": "http.response.body"} # type: ignore[arg-type] + assert dict(third_span.attributes) == { # type: ignore[arg-type] + "http.scheme": "http", + "http.host": "testserver.local", + "net.host.port": 80, + "http.flavor": "1.1", + "http.target": "/", + "http.url": "http://testserver.local/", + "http.method": "GET", + "http.server_name": "testserver.local", + "http.user_agent": "testclient", + "net.peer.ip": "testclient", + "net.peer.port": 50000, + "http.route": "GET /", + "http.status_code": 401, + }