diff --git a/docs/examples/contrib/prometheus/__init__.py b/docs/examples/plugins/prometheus/__init__.py similarity index 100% rename from docs/examples/contrib/prometheus/__init__.py rename to docs/examples/plugins/prometheus/__init__.py diff --git a/docs/examples/contrib/prometheus/using_prometheus_exporter.py b/docs/examples/plugins/prometheus/using_prometheus_exporter.py similarity index 90% rename from docs/examples/contrib/prometheus/using_prometheus_exporter.py rename to docs/examples/plugins/prometheus/using_prometheus_exporter.py index 99ed912be4..28738e9d5a 100644 --- a/docs/examples/contrib/prometheus/using_prometheus_exporter.py +++ b/docs/examples/plugins/prometheus/using_prometheus_exporter.py @@ -1,5 +1,5 @@ from litestar import Litestar -from litestar.contrib.prometheus import PrometheusConfig, PrometheusController +from litestar.plugins.prometheus import PrometheusConfig, PrometheusController def create_app(group_path: bool = False): diff --git a/docs/examples/contrib/prometheus/using_prometheus_exporter_with_extra_configs.py b/docs/examples/plugins/prometheus/using_prometheus_exporter_with_extra_configs.py similarity index 92% rename from docs/examples/contrib/prometheus/using_prometheus_exporter_with_extra_configs.py rename to docs/examples/plugins/prometheus/using_prometheus_exporter_with_extra_configs.py index d14f4e92a1..913d93748e 100644 --- a/docs/examples/contrib/prometheus/using_prometheus_exporter_with_extra_configs.py +++ b/docs/examples/plugins/prometheus/using_prometheus_exporter_with_extra_configs.py @@ -1,7 +1,7 @@ from typing import Any, Dict from litestar import Litestar, Request -from litestar.contrib.prometheus import PrometheusConfig, PrometheusController +from litestar.plugins.prometheus import PrometheusConfig, PrometheusController # We can modify the path of our custom handler and override the metrics format by subclassing the PrometheusController. @@ -38,7 +38,7 @@ def custom_exemplar(request: Request[Any, Any, Any]) -> Dict[str, str]: app_name="litestar-example", prefix="litestar", labels=extra_labels, - buckets=buckets, + buckets=buckets, # pyright: ignore[reportArgumentType] exemplars=custom_exemplar, excluded_http_methods=["POST"], ) diff --git a/docs/reference/plugins/index.rst b/docs/reference/plugins/index.rst index 79365fa48e..50bf99e769 100644 --- a/docs/reference/plugins/index.rst +++ b/docs/reference/plugins/index.rst @@ -12,6 +12,7 @@ plugins flash_messages htmx problem_details + prometheus pydantic structlog sqlalchemy diff --git a/docs/reference/plugins/prometheus.rst b/docs/reference/plugins/prometheus.rst new file mode 100644 index 0000000000..c45052a9f8 --- /dev/null +++ b/docs/reference/plugins/prometheus.rst @@ -0,0 +1,5 @@ +prometheus +========== + +.. automodule:: litestar.plugins.prometheus + :members: diff --git a/docs/usage/metrics/prometheus.rst b/docs/usage/metrics/prometheus.rst index 766de0a2fb..49db6555a0 100644 --- a/docs/usage/metrics/prometheus.rst +++ b/docs/usage/metrics/prometheus.rst @@ -1,7 +1,7 @@ Prometheus ========== -Litestar includes optional Prometheus exporter that is exported from ``litestar.contrib.prometheus``. To use +Litestar includes optional Prometheus exporter that is exported from ``litestar.plugins.prometheus``. To use this package, you should first install the required dependencies: .. code-block:: bash @@ -17,12 +17,12 @@ this package, you should first install the required dependencies: Once these requirements are satisfied, you can instrument your Litestar application: -.. literalinclude:: /examples/contrib/prometheus/using_prometheus_exporter.py +.. literalinclude:: /examples/plugins/prometheus/using_prometheus_exporter.py :language: python :caption: Using the Prometheus Exporter You can also customize the configuration: -.. literalinclude:: /examples/contrib/prometheus/using_prometheus_exporter_with_extra_configs.py +.. literalinclude:: /examples/plugins/prometheus/using_prometheus_exporter_with_extra_configs.py :language: python :caption: Configuring the Prometheus Exporter diff --git a/litestar/contrib/prometheus/__init__.py b/litestar/contrib/prometheus/__init__.py index 1ccb494695..bedeec9976 100644 --- a/litestar/contrib/prometheus/__init__.py +++ b/litestar/contrib/prometheus/__init__.py @@ -1,5 +1,38 @@ -from .config import PrometheusConfig -from .controller import PrometheusController -from .middleware import PrometheusMiddleware +# ruff: noqa: TCH004, F401 +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.utils import warn_deprecation __all__ = ("PrometheusMiddleware", "PrometheusConfig", "PrometheusController") + + +def __getattr__(attr_name: str) -> object: + if attr_name in __all__: + from litestar.plugins.prometheus import ( + PrometheusConfig, + PrometheusController, + PrometheusMiddleware, + ) + + warn_deprecation( + deprecated_name=f"litestar.contrib.prometheus.{attr_name}", + version="2.13.0", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.prometheus' is deprecated, please " + f"import it from 'litestar.plugins.prometheus' instead", + ) + value = globals()[attr_name] = locals()[attr_name] + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") # pragma: no cover + + +if TYPE_CHECKING: + from litestar.plugins.prometheus import ( + PrometheusConfig, + PrometheusController, + PrometheusMiddleware, + ) diff --git a/litestar/contrib/prometheus/config.py b/litestar/contrib/prometheus/config.py index 6b0ceb6409..b24ec5340b 100644 --- a/litestar/contrib/prometheus/config.py +++ b/litestar/contrib/prometheus/config.py @@ -1,67 +1,30 @@ +# ruff: noqa: TCH004, F401 from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, Mapping, Sequence +from typing import TYPE_CHECKING -from litestar.contrib.prometheus.middleware import ( - PrometheusMiddleware, -) -from litestar.exceptions import MissingDependencyException -from litestar.middleware.base import DefineMiddleware +from litestar.utils import warn_deprecation __all__ = ("PrometheusConfig",) -try: - import prometheus_client # noqa: F401 -except ImportError as e: - raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e +def __getattr__(attr_name: str) -> object: + if attr_name in __all__: + from litestar.plugins.prometheus import PrometheusConfig + warn_deprecation( + deprecated_name=f"litestar.contrib.prometheus.config.{attr_name}", + version="2.13.0", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.prometheus.config' is deprecated, please " + f"import it from 'litestar.plugins.prometheus' instead", + ) + value = globals()[attr_name] = locals()[attr_name] + return value -if TYPE_CHECKING: - from litestar.connection.request import Request - from litestar.types import Method, Scopes - - -@dataclass -class PrometheusConfig: - """Configuration class for the PrometheusConfig middleware.""" + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") # pragma: no cover - app_name: str = field(default="litestar") - """The name of the application to use in the metrics.""" - prefix: str = "litestar" - """The prefix to use for the metrics.""" - labels: Mapping[str, str | Callable] | None = field(default=None) - """A mapping of labels to add to the metrics. The values can be either a string or a callable that returns a string.""" - exemplars: Callable[[Request], dict] | None = field(default=None) - """A callable that returns a list of exemplars to add to the metrics. Only supported in opementrics-text exposition format.""" - buckets: list[str | float] | None = field(default=None) - """A list of buckets to use for the histogram.""" - excluded_http_methods: Method | Sequence[Method] | None = field(default=None) - """A list of http methods to exclude from the metrics.""" - exclude_unhandled_paths: bool = field(default=False) - """Whether to ignore requests for unhandled paths from the metrics.""" - exclude: str | list[str] | None = field(default=None) - """A pattern or list of patterns for routes to exclude from the metrics.""" - exclude_opt_key: str | None = field(default=None) - """A key or list of keys in ``opt`` with which a route handler can "opt-out" of the middleware.""" - scopes: Scopes | None = field(default=None) - """ASGI scopes processed by the middleware, if None both ``http`` and ``websocket`` will be processed.""" - middleware_class: type[PrometheusMiddleware] = field(default=PrometheusMiddleware) - """The middleware class to use. - """ - group_path: bool = field(default=False) - """Whether to group paths in the metrics to avoid cardinality explosion. - """ - @property - def middleware(self) -> DefineMiddleware: - """Create an instance of :class:`DefineMiddleware ` that wraps with. - - [PrometheusMiddleware][litestar.contrib.prometheus.PrometheusMiddleware]. or a subclass - of this middleware. - - Returns: - An instance of ``DefineMiddleware``. - """ - return DefineMiddleware(self.middleware_class, config=self) +if TYPE_CHECKING: + from litestar.plugins.prometheus import PrometheusConfig diff --git a/litestar/contrib/prometheus/controller.py b/litestar/contrib/prometheus/controller.py index 15f5bf1d52..112238f445 100644 --- a/litestar/contrib/prometheus/controller.py +++ b/litestar/contrib/prometheus/controller.py @@ -1,53 +1,30 @@ +# ruff: noqa: TCH004, F401 from __future__ import annotations -import os - -from litestar import Controller, get -from litestar.exceptions import MissingDependencyException -from litestar.response import Response - -try: - import prometheus_client # noqa: F401 -except ImportError as e: - raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e - -from prometheus_client import ( - CONTENT_TYPE_LATEST, - REGISTRY, - CollectorRegistry, - generate_latest, - multiprocess, -) -from prometheus_client.openmetrics.exposition import ( - CONTENT_TYPE_LATEST as OPENMETRICS_CONTENT_TYPE_LATEST, -) -from prometheus_client.openmetrics.exposition import ( - generate_latest as openmetrics_generate_latest, -) - -__all__ = [ - "PrometheusController", -] - - -class PrometheusController(Controller): - """Controller for Prometheus endpoints.""" - - path: str = "/metrics" - """The path to expose the metrics on.""" - openmetrics_format: bool = False - """Whether to expose the metrics in OpenMetrics format.""" - - @get() - async def get(self) -> Response: - registry = REGISTRY - if "prometheus_multiproc_dir" in os.environ or "PROMETHEUS_MULTIPROC_DIR" in os.environ: - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call] - - if self.openmetrics_format: - headers = {"Content-Type": OPENMETRICS_CONTENT_TYPE_LATEST} - return Response(openmetrics_generate_latest(registry), status_code=200, headers=headers) # type: ignore[no-untyped-call] - - headers = {"Content-Type": CONTENT_TYPE_LATEST} - return Response(generate_latest(registry), status_code=200, headers=headers) +from typing import TYPE_CHECKING + +from litestar.utils import warn_deprecation + +__all__ = ("PrometheusController",) + + +def __getattr__(attr_name: str) -> object: + if attr_name in __all__: + from litestar.plugins.prometheus import PrometheusController + + warn_deprecation( + deprecated_name=f"litestar.contrib.prometheus.controller.{attr_name}", + version="2.13.0", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.prometheus.controller' is deprecated, please " + f"import it from 'litestar.plugins.prometheus' instead", + ) + value = globals()[attr_name] = locals()[attr_name] + return value + + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") # pragma: no cover + + +if TYPE_CHECKING: + from litestar.plugins.prometheus import PrometheusController diff --git a/litestar/contrib/prometheus/middleware.py b/litestar/contrib/prometheus/middleware.py index 150cf59311..80d02692b7 100644 --- a/litestar/contrib/prometheus/middleware.py +++ b/litestar/contrib/prometheus/middleware.py @@ -1,184 +1,30 @@ +# ruff: noqa: TCH004, F401 from __future__ import annotations -import time -from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast +from typing import TYPE_CHECKING -from litestar.connection.request import Request -from litestar.enums import ScopeType -from litestar.exceptions import MissingDependencyException -from litestar.middleware.base import AbstractMiddleware +from litestar.utils import warn_deprecation __all__ = ("PrometheusMiddleware",) -from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR -try: - import prometheus_client # noqa: F401 -except ImportError as e: - raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e +def __getattr__(attr_name: str) -> object: + if attr_name in __all__: + from litestar.plugins.prometheus import PrometheusMiddleware -from prometheus_client import Counter, Gauge, Histogram + warn_deprecation( + deprecated_name=f"litestar.contrib.prometheus.middleware.{attr_name}", + version="2.13.0", + kind="import", + removal_in="3.0", + info=f"importing {attr_name} from 'litestar.contrib.prometheus.middleware' is deprecated, please " + f"import it from 'litestar.plugins.prometheus' instead", + ) + value = globals()[attr_name] = locals()[attr_name] + return value -if TYPE_CHECKING: - from prometheus_client.metrics import MetricWrapperBase - - from litestar.contrib.prometheus import PrometheusConfig - from litestar.types import ASGIApp, Message, Receive, Scope, Send - - -class PrometheusMiddleware(AbstractMiddleware): - """Prometheus Middleware.""" - - _metrics: ClassVar[dict[str, MetricWrapperBase]] = {} - - def __init__(self, app: ASGIApp, config: PrometheusConfig) -> None: - """Middleware that adds Prometheus instrumentation to the application. - - Args: - app: The ``next`` ASGI app to call. - config: An instance of :class:`PrometheusConfig <.contrib.prometheus.PrometheusConfig>` - """ - super().__init__(app=app, scopes=config.scopes, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key) - self._config = config - self._kwargs: dict[str, Any] = {} - - if self._config.buckets is not None: - self._kwargs["buckets"] = self._config.buckets - - def request_count(self, labels: dict[str, str | int | float]) -> Counter: - metric_name = f"{self._config.prefix}_requests_total" - - if metric_name not in PrometheusMiddleware._metrics: - PrometheusMiddleware._metrics[metric_name] = Counter( - name=metric_name, - documentation="Total requests", - labelnames=[*labels.keys()], - ) - - return cast("Counter", PrometheusMiddleware._metrics[metric_name]) - - def request_time(self, labels: dict[str, str | int | float]) -> Histogram: - metric_name = f"{self._config.prefix}_request_duration_seconds" - - if metric_name not in PrometheusMiddleware._metrics: - PrometheusMiddleware._metrics[metric_name] = Histogram( - name=metric_name, - documentation="Request duration, in seconds", - labelnames=[*labels.keys()], - **self._kwargs, - ) - return cast("Histogram", PrometheusMiddleware._metrics[metric_name]) - - def requests_in_progress(self, labels: dict[str, str | int | float]) -> Gauge: - metric_name = f"{self._config.prefix}_requests_in_progress" - - if metric_name not in PrometheusMiddleware._metrics: - PrometheusMiddleware._metrics[metric_name] = Gauge( - name=metric_name, - documentation="Total requests currently in progress", - labelnames=[*labels.keys()], - multiprocess_mode="livesum", - ) - return cast("Gauge", PrometheusMiddleware._metrics[metric_name]) - - def requests_error_count(self, labels: dict[str, str | int | float]) -> Counter: - metric_name = f"{self._config.prefix}_requests_error_total" - - if metric_name not in PrometheusMiddleware._metrics: - PrometheusMiddleware._metrics[metric_name] = Counter( - name=metric_name, - documentation="Total errors in requests", - labelnames=[*labels.keys()], - ) - return cast("Counter", PrometheusMiddleware._metrics[metric_name]) - - def _get_extra_labels(self, request: Request[Any, Any, Any]) -> dict[str, str]: - """Get extra labels provided by the config and if they are callable, parse them. - - Args: - request: The request object. - - Returns: - A dictionary of extra labels. - """ - - return {k: str(v(request) if callable(v) else v) for k, v in (self._config.labels or {}).items()} + raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}") # pragma: no cover - def _get_default_labels(self, request: Request[Any, Any, Any]) -> dict[str, str | int | float]: - """Get default label values from the request. - Args: - request: The request object. - - Returns: - A dictionary of default labels. - """ - - path = request.url.path - if self._config.group_path: - path = request.scope["path_template"] - return { - "method": request.method if request.scope["type"] == ScopeType.HTTP else request.scope["type"], - "path": path, - "status_code": 200, - "app_name": self._config.app_name, - } - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """ASGI callable. - - Args: - scope: The ASGI connection scope. - receive: The ASGI receive function. - send: The ASGI send function. - - Returns: - None - """ - - request = Request[Any, Any, Any](scope, receive) - - if self._config.excluded_http_methods and request.method in self._config.excluded_http_methods: - await self.app(scope, receive, send) - return - - labels = {**self._get_default_labels(request), **self._get_extra_labels(request)} - - request_span = {"start_time": time.perf_counter(), "end_time": 0, "duration": 0, "status_code": 200} - - wrapped_send = self._get_wrapped_send(send, request_span) - - self.requests_in_progress(labels).labels(*labels.values()).inc() - - try: - await self.app(scope, receive, wrapped_send) - finally: - extra: dict[str, Any] = {} - if self._config.exemplars: - extra["exemplar"] = self._config.exemplars(request) - - self.requests_in_progress(labels).labels(*labels.values()).dec() - - labels["status_code"] = request_span["status_code"] - label_values = [*labels.values()] - - if request_span["status_code"] >= HTTP_500_INTERNAL_SERVER_ERROR: - self.requests_error_count(labels).labels(*label_values).inc(**extra) - - self.request_count(labels).labels(*label_values).inc(**extra) - self.request_time(labels).labels(*label_values).observe(request_span["duration"], **extra) - - def _get_wrapped_send(self, send: Send, request_span: dict[str, float]) -> Callable: - @wraps(send) - async def wrapped_send(message: Message) -> None: - if message["type"] == "http.response.start": - request_span["status_code"] = message["status"] - - if message["type"] == "http.response.body": - end = time.perf_counter() - request_span["duration"] = end - request_span["start_time"] - request_span["end_time"] = end - await send(message) - - return wrapped_send +if TYPE_CHECKING: + from litestar.plugins.prometheus import PrometheusMiddleware diff --git a/litestar/plugins/prometheus/__init__.py b/litestar/plugins/prometheus/__init__.py new file mode 100644 index 0000000000..1ccb494695 --- /dev/null +++ b/litestar/plugins/prometheus/__init__.py @@ -0,0 +1,5 @@ +from .config import PrometheusConfig +from .controller import PrometheusController +from .middleware import PrometheusMiddleware + +__all__ = ("PrometheusMiddleware", "PrometheusConfig", "PrometheusController") diff --git a/litestar/plugins/prometheus/config.py b/litestar/plugins/prometheus/config.py new file mode 100644 index 0000000000..49828898a3 --- /dev/null +++ b/litestar/plugins/prometheus/config.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, Mapping, Sequence + +from litestar.exceptions import MissingDependencyException +from litestar.middleware.base import DefineMiddleware +from litestar.plugins.prometheus.middleware import ( + PrometheusMiddleware, +) + +__all__ = ("PrometheusConfig",) + + +try: + import prometheus_client # noqa: F401 +except ImportError as e: + raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e + + +if TYPE_CHECKING: + from litestar.connection.request import Request + from litestar.types import Method, Scopes + + +@dataclass +class PrometheusConfig: + """Configuration class for the PrometheusConfig middleware.""" + + app_name: str = field(default="litestar") + """The name of the application to use in the metrics.""" + prefix: str = "litestar" + """The prefix to use for the metrics.""" + labels: Mapping[str, str | Callable] | None = field(default=None) + """A mapping of labels to add to the metrics. The values can be either a string or a callable that returns a string.""" + exemplars: Callable[[Request], dict] | None = field(default=None) + """A callable that returns a list of exemplars to add to the metrics. Only supported in opementrics-text exposition format.""" + buckets: list[str | float] | None = field(default=None) + """A list of buckets to use for the histogram.""" + excluded_http_methods: Method | Sequence[Method] | None = field(default=None) + """A list of http methods to exclude from the metrics.""" + exclude_unhandled_paths: bool = field(default=False) + """Whether to ignore requests for unhandled paths from the metrics.""" + exclude: str | list[str] | None = field(default=None) + """A pattern or list of patterns for routes to exclude from the metrics.""" + exclude_opt_key: str | None = field(default=None) + """A key or list of keys in ``opt`` with which a route handler can "opt-out" of the middleware.""" + scopes: Scopes | None = field(default=None) + """ASGI scopes processed by the middleware, if None both ``http`` and ``websocket`` will be processed.""" + middleware_class: type[PrometheusMiddleware] = field(default=PrometheusMiddleware) + """The middleware class to use. + """ + group_path: bool = field(default=False) + """Whether to group paths in the metrics to avoid cardinality explosion. + """ + + @property + def middleware(self) -> DefineMiddleware: + """Create an instance of :class:`DefineMiddleware ` that wraps with. + + [PrometheusMiddleware][litestar.plugins.prometheus.PrometheusMiddleware]. or a subclass + of this middleware. + + Returns: + An instance of ``DefineMiddleware``. + """ + return DefineMiddleware(self.middleware_class, config=self) diff --git a/litestar/plugins/prometheus/controller.py b/litestar/plugins/prometheus/controller.py new file mode 100644 index 0000000000..15f5bf1d52 --- /dev/null +++ b/litestar/plugins/prometheus/controller.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import os + +from litestar import Controller, get +from litestar.exceptions import MissingDependencyException +from litestar.response import Response + +try: + import prometheus_client # noqa: F401 +except ImportError as e: + raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e + +from prometheus_client import ( + CONTENT_TYPE_LATEST, + REGISTRY, + CollectorRegistry, + generate_latest, + multiprocess, +) +from prometheus_client.openmetrics.exposition import ( + CONTENT_TYPE_LATEST as OPENMETRICS_CONTENT_TYPE_LATEST, +) +from prometheus_client.openmetrics.exposition import ( + generate_latest as openmetrics_generate_latest, +) + +__all__ = [ + "PrometheusController", +] + + +class PrometheusController(Controller): + """Controller for Prometheus endpoints.""" + + path: str = "/metrics" + """The path to expose the metrics on.""" + openmetrics_format: bool = False + """Whether to expose the metrics in OpenMetrics format.""" + + @get() + async def get(self) -> Response: + registry = REGISTRY + if "prometheus_multiproc_dir" in os.environ or "PROMETHEUS_MULTIPROC_DIR" in os.environ: + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call] + + if self.openmetrics_format: + headers = {"Content-Type": OPENMETRICS_CONTENT_TYPE_LATEST} + return Response(openmetrics_generate_latest(registry), status_code=200, headers=headers) # type: ignore[no-untyped-call] + + headers = {"Content-Type": CONTENT_TYPE_LATEST} + return Response(generate_latest(registry), status_code=200, headers=headers) diff --git a/litestar/plugins/prometheus/middleware.py b/litestar/plugins/prometheus/middleware.py new file mode 100644 index 0000000000..cd987e8ac6 --- /dev/null +++ b/litestar/plugins/prometheus/middleware.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import time +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast + +from litestar.connection.request import Request +from litestar.enums import ScopeType +from litestar.exceptions import MissingDependencyException +from litestar.middleware.base import AbstractMiddleware + +__all__ = ("PrometheusMiddleware",) + +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR + +try: + import prometheus_client # noqa: F401 +except ImportError as e: + raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e + +from prometheus_client import Counter, Gauge, Histogram + +if TYPE_CHECKING: + from prometheus_client.metrics import MetricWrapperBase + + from litestar.plugins.prometheus import PrometheusConfig + from litestar.types import ASGIApp, Message, Receive, Scope, Send + + +class PrometheusMiddleware(AbstractMiddleware): + """Prometheus Middleware.""" + + _metrics: ClassVar[dict[str, MetricWrapperBase]] = {} + + def __init__(self, app: ASGIApp, config: PrometheusConfig) -> None: + """Middleware that adds Prometheus instrumentation to the application. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of :class:`PrometheusConfig <.plugins.prometheus.PrometheusConfig>` + """ + super().__init__(app=app, scopes=config.scopes, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key) + self._config = config + self._kwargs: dict[str, Any] = {} + + if self._config.buckets is not None: + self._kwargs["buckets"] = self._config.buckets + + def request_count(self, labels: dict[str, str | int | float]) -> Counter: + metric_name = f"{self._config.prefix}_requests_total" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Counter( + name=metric_name, + documentation="Total requests", + labelnames=[*labels.keys()], + ) + + return cast("Counter", PrometheusMiddleware._metrics[metric_name]) + + def request_time(self, labels: dict[str, str | int | float]) -> Histogram: + metric_name = f"{self._config.prefix}_request_duration_seconds" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Histogram( + name=metric_name, + documentation="Request duration, in seconds", + labelnames=[*labels.keys()], + **self._kwargs, + ) + return cast("Histogram", PrometheusMiddleware._metrics[metric_name]) + + def requests_in_progress(self, labels: dict[str, str | int | float]) -> Gauge: + metric_name = f"{self._config.prefix}_requests_in_progress" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Gauge( + name=metric_name, + documentation="Total requests currently in progress", + labelnames=[*labels.keys()], + multiprocess_mode="livesum", + ) + return cast("Gauge", PrometheusMiddleware._metrics[metric_name]) + + def requests_error_count(self, labels: dict[str, str | int | float]) -> Counter: + metric_name = f"{self._config.prefix}_requests_error_total" + + if metric_name not in PrometheusMiddleware._metrics: + PrometheusMiddleware._metrics[metric_name] = Counter( + name=metric_name, + documentation="Total errors in requests", + labelnames=[*labels.keys()], + ) + return cast("Counter", PrometheusMiddleware._metrics[metric_name]) + + def _get_extra_labels(self, request: Request[Any, Any, Any]) -> dict[str, str]: + """Get extra labels provided by the config and if they are callable, parse them. + + Args: + request: The request object. + + Returns: + A dictionary of extra labels. + """ + + return {k: str(v(request) if callable(v) else v) for k, v in (self._config.labels or {}).items()} + + def _get_default_labels(self, request: Request[Any, Any, Any]) -> dict[str, str | int | float]: + """Get default label values from the request. + + Args: + request: The request object. + + Returns: + A dictionary of default labels. + """ + + path = request.url.path + if self._config.group_path: + path = request.scope["path_template"] + return { + "method": request.method if request.scope["type"] == ScopeType.HTTP else request.scope["type"], + "path": path, + "status_code": 200, + "app_name": self._config.app_name, + } + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + + request = Request[Any, Any, Any](scope, receive) + + if self._config.excluded_http_methods and request.method in self._config.excluded_http_methods: + await self.app(scope, receive, send) + return + + labels = {**self._get_default_labels(request), **self._get_extra_labels(request)} + + request_span = {"start_time": time.perf_counter(), "end_time": 0, "duration": 0, "status_code": 200} + + wrapped_send = self._get_wrapped_send(send, request_span) + + self.requests_in_progress(labels).labels(*labels.values()).inc() + + try: + await self.app(scope, receive, wrapped_send) + finally: + extra: dict[str, Any] = {} + if self._config.exemplars: + extra["exemplar"] = self._config.exemplars(request) + + self.requests_in_progress(labels).labels(*labels.values()).dec() + + labels["status_code"] = request_span["status_code"] + label_values = [*labels.values()] + + if request_span["status_code"] >= HTTP_500_INTERNAL_SERVER_ERROR: + self.requests_error_count(labels).labels(*label_values).inc(**extra) + + self.request_count(labels).labels(*label_values).inc(**extra) + self.request_time(labels).labels(*label_values).observe(request_span["duration"], **extra) + + def _get_wrapped_send(self, send: Send, request_span: dict[str, float]) -> Callable: + @wraps(send) + async def wrapped_send(message: Message) -> None: + if message["type"] == "http.response.start": + request_span["status_code"] = message["status"] + + if message["type"] == "http.response.body": + end = time.perf_counter() + request_span["duration"] = end - request_span["start_time"] + request_span["end_time"] = end + await send(message) + + return wrapped_send diff --git a/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example.py b/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example.py index 569fd614ac..512537b7f0 100644 --- a/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example.py +++ b/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example.py @@ -4,7 +4,7 @@ from prometheus_client import REGISTRY from litestar import Controller, Litestar, Request, get -from litestar.contrib.prometheus import PrometheusMiddleware +from litestar.plugins.prometheus import PrometheusMiddleware from litestar.status_codes import HTTP_200_OK from litestar.testing import TestClient @@ -41,7 +41,7 @@ def clear_collectors() -> None: def test_prometheus_exporter_example( group_path: bool, route_path: str, route_template: str, expected_path: str ) -> None: - from docs.examples.contrib.prometheus.using_prometheus_exporter import create_app + from docs.examples.plugins.prometheus.using_prometheus_exporter import create_app app = create_app(group_path=group_path) diff --git a/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example_with_extra_config.py b/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example_with_extra_config.py index 3b85e69224..937cf3f2d6 100644 --- a/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example_with_extra_config.py +++ b/tests/examples/test_contrib/prometheus/test_prometheus_exporter_example_with_extra_config.py @@ -3,7 +3,7 @@ from prometheus_client import REGISTRY from litestar import get -from litestar.contrib.prometheus import PrometheusMiddleware +from litestar.plugins.prometheus import PrometheusMiddleware from litestar.status_codes import HTTP_200_OK from litestar.testing import TestClient @@ -17,7 +17,7 @@ def clear_collectors() -> None: def test_prometheus_exporter_with_extra_config_example() -> None: - from docs.examples.contrib.prometheus.using_prometheus_exporter_with_extra_configs import app + from docs.examples.plugins.prometheus.using_prometheus_exporter_with_extra_configs import app clear_collectors() diff --git a/tests/unit/test_contrib/test_prometheus.py b/tests/unit/test_contrib/test_prometheus.py index e9d82f7430..f6fc75c049 100644 --- a/tests/unit/test_contrib/test_prometheus.py +++ b/tests/unit/test_contrib/test_prometheus.py @@ -1,217 +1,63 @@ -import re -import time -from http.client import HTTPException +# ruff: noqa: TCH004, F401 +from __future__ import annotations + +import importlib +import sys +from importlib.util import cache_from_source from pathlib import Path -from typing import Any import pytest -from _pytest.monkeypatch import MonkeyPatch -from prometheus_client import REGISTRY -from pytest_mock import MockerFixture - -from litestar import get, post, websocket_listener -from litestar.contrib.prometheus import PrometheusConfig, PrometheusController, PrometheusMiddleware -from litestar.status_codes import HTTP_200_OK -from litestar.testing import create_test_client - - -def create_config(**kwargs: Any) -> PrometheusConfig: - collectors = list(REGISTRY._collector_to_names.keys()) - for collector in collectors: - REGISTRY.unregister(collector) - - PrometheusMiddleware._metrics = {} - return PrometheusConfig(**kwargs) - - -@pytest.mark.flaky(reruns=5) -def test_prometheus_exporter_metrics_with_http() -> None: - config = create_config() - - @get("/duration") - def duration_handler() -> dict: - time.sleep(0.1) - return {"hello": "world"} - - @get("/error") - def handler_error() -> dict: - raise HTTPException("Error Occurred") - - with create_test_client( - [duration_handler, handler_error, PrometheusController], middleware=[config.middleware] - ) as client: - client.get("/error") - client.get("/duration") - metrics_exporter_response = client.get("/metrics") - - assert metrics_exporter_response.status_code == HTTP_200_OK - metrics = metrics_exporter_response.content.decode() - - assert ( - """litestar_request_duration_seconds_sum{app_name="litestar",method="GET",path="/duration",status_code="200"}""" - in metrics - ) - - assert ( - """litestar_requests_error_total{app_name="litestar",method="GET",path="/error",status_code="500"} 1.0""" - in metrics - ) - - assert ( - """litestar_request_duration_seconds_bucket{app_name="litestar",le="0.005",method="GET",path="/error",status_code="500"} 1.0""" - in metrics - ) - - assert ( - """litestar_requests_in_progress{app_name="litestar",method="GET",path="/metrics",status_code="200"} 1.0""" - in metrics - ) - - assert ( - """litestar_requests_in_progress{app_name="litestar",method="GET",path="/duration",status_code="200"} 0.0""" - in metrics - ) - - duration_metric_matches = re.findall( - r"""litestar_request_duration_seconds_sum{app_name="litestar",method="GET",path="/duration",status_code="200"} (\d+\.\d+)""", - metrics, - ) - - assert duration_metric_matches != [] - assert round(float(duration_metric_matches[0]), 1) == 0.1 - - client.get("/duration") - metrics = client.get("/metrics").content.decode() - - assert ( - """litestar_requests_total{app_name="litestar",method="GET",path="/duration",status_code="200"} 2.0""" - in metrics - ) - - assert ( - """litestar_requests_in_progress{app_name="litestar",method="GET",path="/error",status_code="200"} 0.0""" - in metrics - ) - - assert ( - """litestar_requests_in_progress{app_name="litestar",method="GET",path="/metrics",status_code="200"} 1.0""" - in metrics - ) - - -def test_prometheus_middleware_configurations() -> None: - labels = {"foo": "bar", "baz": lambda a: "qux"} - - config = create_config( - app_name="litestar_test", - prefix="litestar_rocks", - labels=labels, - buckets=[0.1, 0.5, 1.0], - excluded_http_methods=["POST"], - ) - - @get("/test") - def test() -> dict: - return {"hello": "world"} - - @post("/ignore") - def ignore() -> dict: - return {"hello": "world"} - - with create_test_client([test, ignore, PrometheusController], middleware=[config.middleware]) as client: - client.get("/test") - client.post("/ignore") - metrics_exporter_response = client.get("/metrics") - - assert metrics_exporter_response.status_code == HTTP_200_OK - metrics = metrics_exporter_response.content.decode() - - assert ( - """litestar_rocks_requests_total{app_name="litestar_test",baz="qux",foo="bar",method="GET",path="/test",status_code="200"} 1.0""" - in metrics - ) - - assert ( - """litestar_rocks_requests_total{app_name="litestar_test",baz="qux",foo="bar",method="POST",path="/ignore",status_code="201"} 1.0""" - not in metrics - ) - - assert ( - """litestar_rocks_request_duration_seconds_bucket{app_name="litestar_test",baz="qux",foo="bar",le="0.1",method="GET",path="/test",status_code="200"} 1.0""" - in metrics - ) - - assert ( - """litestar_rocks_request_duration_seconds_bucket{app_name="litestar_test",baz="qux",foo="bar",le="0.5",method="GET",path="/test",status_code="200"} 1.0""" - in metrics - ) - - assert ( - """litestar_rocks_request_duration_seconds_bucket{app_name="litestar_test",baz="qux",foo="bar",le="1.0",method="GET",path="/test",status_code="200"} 1.0""" - in metrics - ) - - -def test_prometheus_controller_configurations() -> None: - config = create_config( - exemplars=lambda a: {"trace_id": "1234"}, - ) - - class CustomPrometheusController(PrometheusController): - path: str = "/metrics/custom" - openmetrics_format: bool = True - - @get("/test") - def test() -> dict: - return {"hello": "world"} - - with create_test_client([test, CustomPrometheusController], middleware=[config.middleware]) as client: - client.get("/test") - - metrics_exporter_response = client.get("/metrics/custom") - - assert metrics_exporter_response.status_code == HTTP_200_OK - metrics = metrics_exporter_response.content.decode() - - assert ( - """litestar_requests_total{app_name="litestar",method="GET",path="/test",status_code="200"} 1.0 # {trace_id="1234"} 1.0""" - in metrics - ) - - -def test_prometheus_with_websocket() -> None: - config = create_config() - - @websocket_listener("/test") - def test(data: str) -> dict: - return {"hello": data} - - with create_test_client([test, PrometheusController], middleware=[config.middleware]) as client: - with client.websocket_connect("/test") as websocket: - websocket.send_text("litestar") - websocket.receive_json() - - metrics_exporter_response = client.get("/metrics") - - assert metrics_exporter_response.status_code == HTTP_200_OK - metrics = metrics_exporter_response.content.decode() - - assert ( - """litestar_requests_total{app_name="litestar",method="websocket",path="/test",status_code="200"} 1.0""" - in metrics - ) - - -@pytest.mark.parametrize("env_var", ["PROMETHEUS_MULTIPROC_DIR", "prometheus_multiproc_dir"]) -def test_procdir(monkeypatch: MonkeyPatch, tmp_path: Path, mocker: MockerFixture, env_var: str) -> None: - proc_dir = tmp_path / "something" - proc_dir.mkdir() - monkeypatch.setenv(env_var, str(proc_dir)) - config = create_config() - mock_registry = mocker.patch("litestar.contrib.prometheus.controller.CollectorRegistry") - mock_collector = mocker.patch("litestar.contrib.prometheus.controller.multiprocess.MultiProcessCollector") - with create_test_client([PrometheusController], middleware=[config.middleware]) as client: - client.get("/metrics") - mock_collector.assert_called_once_with(mock_registry.return_value) +def purge_module(module_names: list[str], path: str | Path) -> None: + for name in module_names: + if name in sys.modules: + del sys.modules[name] + Path(cache_from_source(str(path))).unlink(missing_ok=True) + + +def test_deprecated_prometheus_imports() -> None: + purge_module(["litestar.contrib.prometheus"], __file__) + with pytest.warns( + DeprecationWarning, match="importing PrometheusMiddleware from 'litestar.contrib.prometheus' is deprecated" + ): + from litestar.contrib.prometheus import PrometheusMiddleware + + purge_module(["litestar.contrib.prometheus"], __file__) + with pytest.warns( + DeprecationWarning, match="importing PrometheusConfig from 'litestar.contrib.prometheus' is deprecated" + ): + from litestar.contrib.prometheus import PrometheusConfig + + purge_module(["litestar.contrib.prometheus"], __file__) + with pytest.warns( + DeprecationWarning, match="importing PrometheusController from 'litestar.contrib.prometheus' is deprecated" + ): + from litestar.contrib.prometheus import PrometheusController + + +def test_deprecated_prometheus_middleware_imports() -> None: + purge_module(["litestar.contrib.prometheus.middleware"], __file__) + with pytest.warns( + DeprecationWarning, + match="importing PrometheusMiddleware from 'litestar.contrib.prometheus.middleware' is deprecated", + ): + from litestar.contrib.prometheus.middleware import PrometheusMiddleware + + +def test_deprecated_prometheus_config_imports() -> None: + purge_module(["litestar.contrib.prometheus.config"], __file__) + with pytest.warns( + DeprecationWarning, + match="importing PrometheusConfig from 'litestar.contrib.prometheus.config' is deprecated", + ): + from litestar.contrib.prometheus.config import PrometheusConfig + + +def test_deprecated_prometheus_controller_imports() -> None: + purge_module(["litestar.contrib.prometheus.controller"], __file__) + with pytest.warns( + DeprecationWarning, + match="importing PrometheusController from 'litestar.contrib.prometheus.controller' is deprecated", + ): + from litestar.contrib.prometheus.controller import PrometheusController diff --git a/tests/unit/test_plugins/test_prometheus.py b/tests/unit/test_plugins/test_prometheus.py new file mode 100644 index 0000000000..894e19e17e --- /dev/null +++ b/tests/unit/test_plugins/test_prometheus.py @@ -0,0 +1,217 @@ +import re +import time +from http.client import HTTPException +from pathlib import Path +from typing import Any + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from prometheus_client import REGISTRY +from pytest_mock import MockerFixture + +from litestar import get, post, websocket_listener +from litestar.plugins.prometheus import PrometheusConfig, PrometheusController, PrometheusMiddleware +from litestar.status_codes import HTTP_200_OK +from litestar.testing import create_test_client + + +def create_config(**kwargs: Any) -> PrometheusConfig: + collectors = list(REGISTRY._collector_to_names.keys()) + for collector in collectors: + REGISTRY.unregister(collector) + + PrometheusMiddleware._metrics = {} + return PrometheusConfig(**kwargs) + + +@pytest.mark.flaky(reruns=5) +def test_prometheus_exporter_metrics_with_http() -> None: + config = create_config() + + @get("/duration") + def duration_handler() -> dict: + time.sleep(0.1) + return {"hello": "world"} + + @get("/error") + def handler_error() -> dict: + raise HTTPException("Error Occurred") + + with create_test_client( + [duration_handler, handler_error, PrometheusController], middleware=[config.middleware] + ) as client: + client.get("/error") + client.get("/duration") + metrics_exporter_response = client.get("/metrics") + + assert metrics_exporter_response.status_code == HTTP_200_OK + metrics = metrics_exporter_response.content.decode() + + assert ( + """litestar_request_duration_seconds_sum{app_name="litestar",method="GET",path="/duration",status_code="200"}""" + in metrics + ) + + assert ( + """litestar_requests_error_total{app_name="litestar",method="GET",path="/error",status_code="500"} 1.0""" + in metrics + ) + + assert ( + """litestar_request_duration_seconds_bucket{app_name="litestar",le="0.005",method="GET",path="/error",status_code="500"} 1.0""" + in metrics + ) + + assert ( + """litestar_requests_in_progress{app_name="litestar",method="GET",path="/metrics",status_code="200"} 1.0""" + in metrics + ) + + assert ( + """litestar_requests_in_progress{app_name="litestar",method="GET",path="/duration",status_code="200"} 0.0""" + in metrics + ) + + duration_metric_matches = re.findall( + r"""litestar_request_duration_seconds_sum{app_name="litestar",method="GET",path="/duration",status_code="200"} (\d+\.\d+)""", + metrics, + ) + + assert duration_metric_matches != [] + assert round(float(duration_metric_matches[0]), 1) == 0.1 + + client.get("/duration") + metrics = client.get("/metrics").content.decode() + + assert ( + """litestar_requests_total{app_name="litestar",method="GET",path="/duration",status_code="200"} 2.0""" + in metrics + ) + + assert ( + """litestar_requests_in_progress{app_name="litestar",method="GET",path="/error",status_code="200"} 0.0""" + in metrics + ) + + assert ( + """litestar_requests_in_progress{app_name="litestar",method="GET",path="/metrics",status_code="200"} 1.0""" + in metrics + ) + + +def test_prometheus_middleware_configurations() -> None: + labels = {"foo": "bar", "baz": lambda a: "qux"} + + config = create_config( + app_name="litestar_test", + prefix="litestar_rocks", + labels=labels, + buckets=[0.1, 0.5, 1.0], + excluded_http_methods=["POST"], + ) + + @get("/test") + def test() -> dict: + return {"hello": "world"} + + @post("/ignore") + def ignore() -> dict: + return {"hello": "world"} + + with create_test_client([test, ignore, PrometheusController], middleware=[config.middleware]) as client: + client.get("/test") + client.post("/ignore") + metrics_exporter_response = client.get("/metrics") + + assert metrics_exporter_response.status_code == HTTP_200_OK + metrics = metrics_exporter_response.content.decode() + + assert ( + """litestar_rocks_requests_total{app_name="litestar_test",baz="qux",foo="bar",method="GET",path="/test",status_code="200"} 1.0""" + in metrics + ) + + assert ( + """litestar_rocks_requests_total{app_name="litestar_test",baz="qux",foo="bar",method="POST",path="/ignore",status_code="201"} 1.0""" + not in metrics + ) + + assert ( + """litestar_rocks_request_duration_seconds_bucket{app_name="litestar_test",baz="qux",foo="bar",le="0.1",method="GET",path="/test",status_code="200"} 1.0""" + in metrics + ) + + assert ( + """litestar_rocks_request_duration_seconds_bucket{app_name="litestar_test",baz="qux",foo="bar",le="0.5",method="GET",path="/test",status_code="200"} 1.0""" + in metrics + ) + + assert ( + """litestar_rocks_request_duration_seconds_bucket{app_name="litestar_test",baz="qux",foo="bar",le="1.0",method="GET",path="/test",status_code="200"} 1.0""" + in metrics + ) + + +def test_prometheus_controller_configurations() -> None: + config = create_config( + exemplars=lambda a: {"trace_id": "1234"}, + ) + + class CustomPrometheusController(PrometheusController): + path: str = "/metrics/custom" + openmetrics_format: bool = True + + @get("/test") + def test() -> dict: + return {"hello": "world"} + + with create_test_client([test, CustomPrometheusController], middleware=[config.middleware]) as client: + client.get("/test") + + metrics_exporter_response = client.get("/metrics/custom") + + assert metrics_exporter_response.status_code == HTTP_200_OK + metrics = metrics_exporter_response.content.decode() + + assert ( + """litestar_requests_total{app_name="litestar",method="GET",path="/test",status_code="200"} 1.0 # {trace_id="1234"} 1.0""" + in metrics + ) + + +def test_prometheus_with_websocket() -> None: + config = create_config() + + @websocket_listener("/test") + def test(data: str) -> dict: + return {"hello": data} + + with create_test_client([test, PrometheusController], middleware=[config.middleware]) as client: + with client.websocket_connect("/test") as websocket: + websocket.send_text("litestar") + websocket.receive_json() + + metrics_exporter_response = client.get("/metrics") + + assert metrics_exporter_response.status_code == HTTP_200_OK + metrics = metrics_exporter_response.content.decode() + + assert ( + """litestar_requests_total{app_name="litestar",method="websocket",path="/test",status_code="200"} 1.0""" + in metrics + ) + + +@pytest.mark.parametrize("env_var", ["PROMETHEUS_MULTIPROC_DIR", "prometheus_multiproc_dir"]) +def test_procdir(monkeypatch: MonkeyPatch, tmp_path: Path, mocker: MockerFixture, env_var: str) -> None: + proc_dir = tmp_path / "something" + proc_dir.mkdir() + monkeypatch.setenv(env_var, str(proc_dir)) + config = create_config() + mock_registry = mocker.patch("litestar.plugins.prometheus.controller.CollectorRegistry") + mock_collector = mocker.patch("litestar.plugins.prometheus.controller.multiprocess.MultiProcessCollector") + + with create_test_client([PrometheusController], middleware=[config.middleware]) as client: + client.get("/metrics") + + mock_collector.assert_called_once_with(mock_registry.return_value)