From 2d21eefe3dcb987cf580bbdd7ca56ff49712aeb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= <25355197+provinzkraut@users.noreply.github.com> Date: Sat, 6 Jan 2024 12:49:37 +0100 Subject: [PATCH 01/14] Bump version to 2.6.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 51173302aa..655feceff1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ maintainers = [ name = "litestar" readme = "README.md" requires-python = ">=3.8,<4.0" -version = "2.5.5" +version = "2.6.0" [project.urls] Blog = "https://blog.litestar.dev" From 0af4c8d5e7bd62c186e42722da7355ee0c19a645 Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Sun, 7 Jan 2024 05:00:34 -0600 Subject: [PATCH 02/14] feat: `structlog` plugin & bug fixes (#2943) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(channels): Postgres backends (#2803) * wip Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * some debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * use a separate connection to publish/listen Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * reintroduce flaky Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add psycopg backend Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix backend issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Undo test debugging changes Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * mark groups Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Ensure channel names ar quoted Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * sleep debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * update docs Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix docs link Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing listener test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix test typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix some coverage issue Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * test skip sourcery Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * test(channels): Improve channels testing (#2838) * Improve channels testing --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * chore(typing): various pyright issues (#2897) Fix various pyright issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix(channels): Trailing messages after unsubscribes (#2894) Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(cli): Add ``--schema`` and ``--exclude`` option to route CLI. (#2886) * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * Update docs/usage/cli.rst * Update litestar/cli/_utils.py * fix malformed docs table. --------- Co-authored-by: Jacob Coffee * test(CLI): Fix xdist issue (#2931) Fix test for xdist Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(core): Replace `anyio.to_thread.run_sync` with native versions (#2937) Replace anyio.to_thread.run_sync with native versions Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: example app using structlog * fix: updated structlog with request logging * fix: lazy initialized structlog fix * feat: add structlog plugin * fix: adds `set_level` to all Logging configurations * fix: check that the object has the `setLevel` method before calling * feat: adds test for plugin * fix: parameter naming for `set_level` abstract method * feat(channels): Postgres backends (#2803) * wip Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * some debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * use a separate connection to publish/listen Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * reintroduce flaky Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add psycopg backend Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix backend issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Undo test debugging changes Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * mark groups Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Ensure channel names ar quoted Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * sleep debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * update docs Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix docs link Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing listener test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix test typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix some coverage issue Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * test skip sourcery Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * test(channels): Improve channels testing (#2838) * Improve channels testing --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * chore(typing): various pyright issues (#2897) Fix various pyright issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix(channels): Trailing messages after unsubscribes (#2894) Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(cli): Add ``--schema`` and ``--exclude`` option to route CLI. (#2886) * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * Update docs/usage/cli.rst * Update litestar/cli/_utils.py * fix malformed docs table. --------- Co-authored-by: Jacob Coffee * test(CLI): Fix xdist issue (#2931) Fix test for xdist Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(core): Replace `anyio.to_thread.run_sync` with native versions (#2937) Replace anyio.to_thread.run_sync with native versions Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(channels): Postgres backends (#2803) * wip Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * some debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * use a separate connection to publish/listen Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * reintroduce flaky Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add psycopg backend Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix backend issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Undo test debugging changes Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * mark groups Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Ensure channel names ar quoted Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * sleep debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * update docs Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix docs link Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing listener test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix test typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix some coverage issue Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * test skip sourcery Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * test(channels): Improve channels testing (#2838) * Improve channels testing --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * chore(typing): various pyright issues (#2897) Fix various pyright issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix(channels): Trailing messages after unsubscribes (#2894) Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(cli): Add ``--schema`` and ``--exclude`` option to route CLI. (#2886) * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * Update docs/usage/cli.rst * Update litestar/cli/_utils.py * fix malformed docs table. --------- Co-authored-by: Jacob Coffee * test(CLI): Fix xdist issue (#2931) Fix test for xdist Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(core): Replace `anyio.to_thread.run_sync` with native versions (#2937) Replace anyio.to_thread.run_sync with native versions Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: structlog detects tty by default * chore: linting fixes * fix: color code correction * fix: adjusted color code to be more visible * fix: additional config settings * feat: enable pretty-print in TTY * fix: apply rich configuration * fix: updated formatting to align with other messages * chore: trim whitespace --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: andrew do Co-authored-by: Jacob Coffee --- docs/usage/logging.rst | 8 +- litestar/app.py | 8 +- litestar/logging/config.py | 123 +++++++++++++++--- litestar/plugins/structlog.py | 61 +++++++++ litestar/response/sse.py | 3 +- pdm.lock | 48 +++---- test_apps/structlog_app/__init__.py | 0 test_apps/structlog_app/main.py | 26 ++++ .../test_logging/test_structlog_config.py | 15 +++ .../test_exception_handler_middleware.py | 2 +- 10 files changed, 244 insertions(+), 50 deletions(-) create mode 100644 litestar/plugins/structlog.py create mode 100644 test_apps/structlog_app/__init__.py create mode 100644 test_apps/structlog_app/main.py diff --git a/docs/usage/logging.rst b/docs/usage/logging.rst index ca35ea782d..c39861aea4 100644 --- a/docs/usage/logging.rst +++ b/docs/usage/logging.rst @@ -113,12 +113,12 @@ Using StructLog ^^^^^^^^^^^^^^^ `StructLog `_ is a powerful structured-logging library. Litestar ships with a dedicated -logging config for using it: +logging plugin and config for using it: .. code-block:: python from litestar import Litestar, Request, get - from litestar.logging import StructLoggingConfig + from litestar.plugins.structlog import StructlogPlugin @get("/") @@ -127,9 +127,9 @@ logging config for using it: return None - logging_config = StructLoggingConfig() + structlog_plugin = StructlogPlugin() - app = Litestar(route_handlers=[my_router_handler], logging_config=logging_config) + app = Litestar(route_handlers=[my_router_handler], plugins=[StructlogPlugin()]) Subclass Logging Configs ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/litestar/app.py b/litestar/app.py index 9ac90f2ddf..2d3decae3e 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -508,7 +508,13 @@ def debug(self) -> bool: @debug.setter def debug(self, value: bool) -> None: - if self.logger: + """Sets the debug logging level for the application. + + When possible, it calls the `self.logging_config.set_level` method. This allows for implementation specific code and APIs to be called. + """ + if self.logger and self.logging_config: + self.logging_config.set_level(self.logger, logging.DEBUG if value else logging.INFO) + elif self.logger and hasattr(self.logger, "setLevel"): self.logger.setLevel(logging.DEBUG if value else logging.INFO) if isinstance(self.logging_config, LoggingConfig): self.logging_config.loggers["litestar"]["level"] = "DEBUG" if value else "INFO" diff --git a/litestar/logging/config.py b/litestar/logging/config.py index 4b5e64d67d..6b2e0ecebb 100644 --- a/litestar/logging/config.py +++ b/litestar/logging/config.py @@ -100,7 +100,7 @@ def _default_exception_logging_handler(logger: Logger, scope: Scope, tb: list[st if is_struct_logger: logger.exception( - "uncaught exception", + "Uncaught Exception", connection_type=scope["type"], path=scope["path"], traceback="".join(tb[-traceback_line_limit:]), @@ -135,6 +135,11 @@ def configure(self) -> GetLogger: """ raise NotImplementedError("abstract method") + @staticmethod + def set_level(logger: Any, level: int) -> None: + """Provides a consistent interface to call `setLevel` for all loggers.""" + raise NotImplementedError("abstract method") + @dataclass class LoggingConfig(BaseLoggingConfig): @@ -234,12 +239,17 @@ def configure(self) -> GetLogger: config.dictConfig(values) return cast("Callable[[str], Logger]", getLogger) + @staticmethod + def set_level(logger: Logger, level: int) -> None: + """Provides a consistent interface to call `setLevel` for all loggers.""" + logger.setLevel(level) + def default_json_serializer(value: Any, default: Callable[[Any], Any] | None = None) -> bytes: return encode_json(value=value, serializer=default) -def default_structlog_processors() -> list[Processor] | None: # pyright: ignore +def default_structlog_processors(as_json: bool = True) -> list[Processor]: # pyright: ignore """Set the default processors for structlog. Returns: @@ -247,19 +257,57 @@ def default_structlog_processors() -> list[Processor] | None: # pyright: ignore """ try: import structlog - + from structlog.dev import RichTracebackFormatter + + if as_json: + return [ + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.format_exc_info, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.JSONRenderer(serializer=default_json_serializer), + ] return [ structlog.contextvars.merge_contextvars, structlog.processors.add_log_level, - structlog.processors.format_exc_info, structlog.processors.TimeStamper(fmt="iso"), - structlog.processors.JSONRenderer(serializer=default_json_serializer), + structlog.dev.ConsoleRenderer( + colors=True, exception_formatter=RichTracebackFormatter(max_frames=1, show_locals=False, width=80) + ), ] + except ImportError: - return None + return [] + + +def default_structlog_standard_lib_processors(as_json: bool = True) -> list[Processor]: # pyright: ignore + """Set the default processors for structlog stdlib. + + Returns: + An optional list of processors. + """ + try: + import structlog + from structlog.dev import RichTracebackFormatter + + if as_json: + return [ + structlog.stdlib.add_log_level, + structlog.stdlib.ExtraAdder(), + structlog.processors.JSONRenderer(serializer=default_json_serializer), + ] + return [ + structlog.stdlib.add_log_level, + structlog.stdlib.ExtraAdder(), + structlog.dev.ConsoleRenderer( + colors=True, exception_formatter=RichTracebackFormatter(max_frames=1, show_locals=False, width=80) + ), + ] + except ImportError: + return [] -def default_wrapper_class() -> type[BindableLogger] | None: # pyright: ignore +def default_wrapper_class(log_level: int = INFO) -> type[BindableLogger] | None: # pyright: ignore """Set the default wrapper class for structlog. Returns: @@ -269,12 +317,12 @@ def default_wrapper_class() -> type[BindableLogger] | None: # pyright: ignore try: import structlog - return structlog.make_filtering_bound_logger(INFO) + return structlog.make_filtering_bound_logger(log_level) except ImportError: return None -def default_logger_factory() -> Callable[..., WrappedLogger] | None: +def default_logger_factory(as_json: bool = True) -> Callable[..., WrappedLogger] | None: """Set the default logger factory for structlog. Returns: @@ -283,7 +331,9 @@ def default_logger_factory() -> Callable[..., WrappedLogger] | None: try: import structlog - return structlog.BytesLoggerFactory() + if as_json: + return structlog.BytesLoggerFactory() + return structlog.WriteLoggerFactory() except ImportError: return None @@ -296,13 +346,18 @@ class StructLoggingConfig(BaseLoggingConfig): - requires ``structlog`` to be installed. """ - processors: list[Processor] | None = field(default_factory=default_structlog_processors) # pyright: ignore + processors: list[Processor] | None = field(default=None) # pyright: ignore """Iterable of structlog logging processors.""" - wrapper_class: type[BindableLogger] | None = field(default_factory=default_wrapper_class) # pyright: ignore + standard_lib_logging_config: LoggingConfig | None = field(default=None) # pyright: ignore + """Optional customized standard logging configuration. + + Use this when you need to modify the standard library outside of the Structlog pre-configured implementation. + """ + wrapper_class: type[BindableLogger] | None = field(default=None) # pyright: ignore """Structlog bindable logger.""" context_class: dict[str, Any] | None = None """Context class (a 'contextvar' context) for the logger.""" - logger_factory: Callable[..., WrappedLogger] | None = field(default_factory=default_logger_factory) + logger_factory: Callable[..., WrappedLogger] | None = field(default=None) # pyright: ignore """Logger factory to use.""" cache_logger_on_first_use: bool = field(default=True) """Whether to cache the logger configuration and reuse.""" @@ -312,12 +367,34 @@ class StructLoggingConfig(BaseLoggingConfig): """Max number of lines to print for exception traceback""" exception_logging_handler: ExceptionLoggingHandler | None = field(default=None) """Handler function for logging exceptions.""" + pretty_print_tty: bool = field(default=True) + """Pretty print log output when run from an interactive terminal.""" def __post_init__(self) -> None: + if self.processors is None: + self.processors = default_structlog_processors(not sys.stderr.isatty() and self.pretty_print_tty) + if self.logger_factory is None: + self.logger_factory = default_logger_factory(not sys.stderr.isatty() and self.pretty_print_tty) if self.log_exceptions != "never" and self.exception_logging_handler is None: self.exception_logging_handler = _default_exception_logging_handler_factory( is_struct_logger=True, traceback_line_limit=self.traceback_line_limit ) + try: + import structlog + + if self.standard_lib_logging_config is None: + self.standard_lib_logging_config = LoggingConfig( + formatters={ + "standard": { + "()": structlog.stdlib.ProcessorFormatter, + "processors": default_structlog_standard_lib_processors( + as_json=not sys.stderr.isatty() and self.pretty_print_tty + ), + } + } + ) + except ImportError: + self.standard_lib_logging_config = LoggingConfig() def configure(self) -> GetLogger: """Return logger with the given configuration. @@ -326,13 +403,11 @@ def configure(self) -> GetLogger: A 'logging.getLogger' like function. """ try: - import structlog # noqa: F401 + import structlog except ImportError as e: raise MissingDependencyException("structlog") from e - from structlog import configure, get_logger - - configure( + structlog.configure( **{ k: v for k, v in asdict(self).items() @@ -342,7 +417,19 @@ def configure(self) -> GetLogger: "log_exceptions", "traceback_line_limit", "exception_logging_handler", + "pretty_print_tty", ) } ) - return get_logger + return structlog.get_logger + + @staticmethod + def set_level(logger: Logger, level: int) -> None: + """Provides a consistent interface to call `setLevel` for all loggers.""" + + try: + import structlog + + structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(level)) + except ImportError: + """""" diff --git a/litestar/plugins/structlog.py b/litestar/plugins/structlog.py new file mode 100644 index 0000000000..0ff53fbfd4 --- /dev/null +++ b/litestar/plugins/structlog.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from litestar.cli._utils import console +from litestar.logging.config import StructLoggingConfig +from litestar.middleware.logging import LoggingMiddlewareConfig +from litestar.plugins import CLIPluginProtocol, InitPluginProtocol + +if TYPE_CHECKING: + from click import Group + + from litestar.config.app import AppConfig + + +@dataclass +class StructlogConfig: + structlog_logging_config: StructLoggingConfig = field(default_factory=StructLoggingConfig) + """Structlog Logging configuration for Litestar. See ``litestar.logging.config.StructLoggingConfig``` for details.""" + middleware_logging_config: LoggingMiddlewareConfig = field(default_factory=LoggingMiddlewareConfig) + """Middleware logging config.""" + enable_middleware_logging: bool = True + """Enable request logging.""" + + +class StructlogPlugin(InitPluginProtocol, CLIPluginProtocol): + """Structlog Plugin.""" + + __slots__ = ("_config",) + + def __init__(self, config: StructlogConfig | None = None) -> None: + if config is None: + config = StructlogConfig() + self._config = config + super().__init__() + + def on_cli_init(self, cli: Group) -> None: + return super().on_cli_init(cli) + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Structlog Plugin + + Args: + app_config: The :class:`AppConfig ` instance. + + Returns: + The app config object. + """ + if app_config.logging_config is not None and isinstance(app_config.logging_config, StructLoggingConfig): + console.print( + "[red dim]* Found pre-configured `StructLoggingConfig` on the `app` instance. Skipping configuration.[/]", + ) + else: + app_config.logging_config = self._config.structlog_logging_config + app_config.logging_config.configure() + if self._config.structlog_logging_config.standard_lib_logging_config is not None: + self._config.structlog_logging_config.standard_lib_logging_config.configure() + if self._config.enable_middleware_logging: + app_config.middleware.append(self._config.middleware_logging_config.middleware) + return app_config # pragma: no cover diff --git a/litestar/response/sse.py b/litestar/response/sse.py index 7770929f9b..59576c5208 100644 --- a/litestar/response/sse.py +++ b/litestar/response/sse.py @@ -85,8 +85,7 @@ async def _async_generator(self) -> AsyncGenerator[bytes, None]: yield await sync_to_thread(self._call_next) except ValueError: async for value in self.content_async_iterator: - data = self.ensure_bytes(value, DEFAULT_SEPARATOR) - yield data + yield self.ensure_bytes(value, DEFAULT_SEPARATOR) break diff --git a/pdm.lock b/pdm.lock index 2876d76103..b5d0c25969 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1087,12 +1087,12 @@ files = [ [[package]] name = "fsspec" -version = "2023.12.0" +version = "2023.12.1" requires_python = ">=3.8" summary = "File-system specification" files = [ - {file = "fsspec-2023.12.0-py3-none-any.whl", hash = "sha256:f807252ee2018f2223760315beb87a2166c2b9532786eeca9e6548dfcf2cfac9"}, - {file = "fsspec-2023.12.0.tar.gz", hash = "sha256:8e0bb2db2a94082968483b7ba2eaebf3949835e2dfdf09243dda387539464b31"}, + {file = "fsspec-2023.12.1-py3-none-any.whl", hash = "sha256:6271f1d3075a378bfe432f6f42bf7e1d2a6ba74f78dd9b512385474c579146a0"}, + {file = "fsspec-2023.12.1.tar.gz", hash = "sha256:c4da01a35ac65c853f833e43f67802c25213f560820d54ddf248f92eddd5e990"}, ] [[package]] @@ -1426,12 +1426,12 @@ files = [ [[package]] name = "identify" -version = "2.5.32" +version = "2.5.33" requires_python = ">=3.8" summary = "File identification library for Python" files = [ - {file = "identify-2.5.32-py2.py3-none-any.whl", hash = "sha256:0b7656ef6cba81664b783352c73f8c24b39cf82f926f78f4550eda928e5e0545"}, - {file = "identify-2.5.32.tar.gz", hash = "sha256:5d9979348ec1a21c768ae07e0a652924538e8bce67313a73cb0f681cf08ba407"}, + {file = "identify-2.5.33-py2.py3-none-any.whl", hash = "sha256:d40ce5fcd762817627670da8a7d8d8e65f24342d14539c59488dc603bf662e34"}, + {file = "identify-2.5.33.tar.gz", hash = "sha256:161558f9fe4559e1557e1bff323e8631f6a0e4837f7497767c1782832f16b62d"}, ] [[package]] @@ -3051,27 +3051,27 @@ files = [ [[package]] name = "ruff" -version = "0.1.6" +version = "0.1.7" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." files = [ - {file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:88b8cdf6abf98130991cbc9f6438f35f6e8d41a02622cc5ee130a02a0ed28703"}, - {file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c549ed437680b6105a1299d2cd30e4964211606eeb48a0ff7a93ef70b902248"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cf5f701062e294f2167e66d11b092bba7af6a057668ed618a9253e1e90cfd76"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05991ee20d4ac4bb78385360c684e4b417edd971030ab12a4fbd075ff535050e"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87455a0c1f739b3c069e2f4c43b66479a54dea0276dd5d4d67b091265f6fd1dc"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:683aa5bdda5a48cb8266fcde8eea2a6af4e5700a392c56ea5fb5f0d4bfdc0240"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:137852105586dcbf80c1717facb6781555c4e99f520c9c827bd414fac67ddfb6"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd98138a98d48a1c36c394fd6b84cd943ac92a08278aa8ac8c0fdefcf7138f35"}, - {file = "ruff-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a0cd909d25f227ac5c36d4e7e681577275fb74ba3b11d288aff7ec47e3ae745"}, - {file = "ruff-0.1.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e8fd1c62a47aa88a02707b5dd20c5ff20d035d634aa74826b42a1da77861b5ff"}, - {file = "ruff-0.1.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd89b45d374935829134a082617954120d7a1470a9f0ec0e7f3ead983edc48cc"}, - {file = "ruff-0.1.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:491262006e92f825b145cd1e52948073c56560243b55fb3b4ecb142f6f0e9543"}, - {file = "ruff-0.1.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea284789861b8b5ca9d5443591a92a397ac183d4351882ab52f6296b4fdd5462"}, - {file = "ruff-0.1.6-py3-none-win32.whl", hash = "sha256:1610e14750826dfc207ccbcdd7331b6bd285607d4181df9c1c6ae26646d6848a"}, - {file = "ruff-0.1.6-py3-none-win_amd64.whl", hash = "sha256:4558b3e178145491e9bc3b2ee3c4b42f19d19384eaa5c59d10acf6e8f8b57e33"}, - {file = "ruff-0.1.6-py3-none-win_arm64.whl", hash = "sha256:03910e81df0d8db0e30050725a5802441c2022ea3ae4fe0609b76081731accbc"}, - {file = "ruff-0.1.6.tar.gz", hash = "sha256:1b09f29b16c6ead5ea6b097ef2764b42372aebe363722f1605ecbcd2b9207184"}, + {file = "ruff-0.1.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7f80496854fdc65b6659c271d2c26e90d4d401e6a4a31908e7e334fab4645aac"}, + {file = "ruff-0.1.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:1ea109bdb23c2a4413f397ebd8ac32cb498bee234d4191ae1a310af760e5d287"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b0c2de9dd9daf5e07624c24add25c3a490dbf74b0e9bca4145c632457b3b42a"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:69a4bed13bc1d5dabf3902522b5a2aadfebe28226c6269694283c3b0cecb45fd"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de02ca331f2143195a712983a57137c5ec0f10acc4aa81f7c1f86519e52b92a1"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45b38c3f8788a65e6a2cab02e0f7adfa88872696839d9882c13b7e2f35d64c5f"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c64cb67b2025b1ac6d58e5ffca8f7b3f7fd921f35e78198411237e4f0db8e73"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9dcc6bb2f4df59cb5b4b40ff14be7d57012179d69c6565c1da0d1f013d29951b"}, + {file = "ruff-0.1.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df2bb4bb6bbe921f6b4f5b6fdd8d8468c940731cb9406f274ae8c5ed7a78c478"}, + {file = "ruff-0.1.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:276a89bcb149b3d8c1b11d91aa81898fe698900ed553a08129b38d9d6570e717"}, + {file = "ruff-0.1.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:90c958fe950735041f1c80d21b42184f1072cc3975d05e736e8d66fc377119ea"}, + {file = "ruff-0.1.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b05e3b123f93bb4146a761b7a7d57af8cb7384ccb2502d29d736eaade0db519"}, + {file = "ruff-0.1.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:290ecab680dce94affebefe0bbca2322a6277e83d4f29234627e0f8f6b4fa9ce"}, + {file = "ruff-0.1.7-py3-none-win32.whl", hash = "sha256:416dfd0bd45d1a2baa3b1b07b1b9758e7d993c256d3e51dc6e03a5e7901c7d80"}, + {file = "ruff-0.1.7-py3-none-win_amd64.whl", hash = "sha256:4af95fd1d3b001fc41325064336db36e3d27d2004cdb6d21fd617d45a172dd96"}, + {file = "ruff-0.1.7-py3-none-win_arm64.whl", hash = "sha256:0683b7bfbb95e6df3c7c04fe9d78f631f8e8ba4868dfc932d43d690698057e2e"}, + {file = "ruff-0.1.7.tar.gz", hash = "sha256:dffd699d07abf54833e5f6cc50b85a6ff043715da8788c4a79bcd4ab4734d306"}, ] [[package]] diff --git a/test_apps/structlog_app/__init__.py b/test_apps/structlog_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test_apps/structlog_app/main.py b/test_apps/structlog_app/main.py new file mode 100644 index 0000000000..be98ae2150 --- /dev/null +++ b/test_apps/structlog_app/main.py @@ -0,0 +1,26 @@ +from typing import Dict + +from litestar import Litestar, Request, get +from litestar.logging.config import StructLoggingConfig +from litestar.middleware.logging import LoggingMiddlewareConfig + + +@get("/") +async def handler(request: Request) -> Dict[str, str]: + request.logger.info("Logging in the handler") + return {"hello": "world"} + + +logging_middleware_config = LoggingMiddlewareConfig() + +app = Litestar( + route_handlers=[handler], + logging_config=StructLoggingConfig(log_exceptions="always"), + middleware=[logging_middleware_config.middleware], +) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app) diff --git a/tests/unit/test_logging/test_structlog_config.py b/tests/unit/test_logging/test_structlog_config.py index 4983a8f92b..62f53f0f9a 100644 --- a/tests/unit/test_logging/test_structlog_config.py +++ b/tests/unit/test_logging/test_structlog_config.py @@ -3,6 +3,7 @@ from structlog.types import BindableLogger from litestar.logging.config import StructLoggingConfig, default_json_serializer +from litestar.plugins.structlog import StructlogPlugin from litestar.serialization import decode_json from litestar.testing import create_test_client @@ -10,6 +11,20 @@ # Because we want to test processors, use capsys instead +def test_structlog_plugin(capsys: CaptureFixture) -> None: + with create_test_client([], plugins=[StructlogPlugin()]) as client: + assert client.app.logger + assert isinstance(client.app.logger.bind(), BindableLogger) + client.app.logger.info("message", key="value") + + log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] + assert len(log_messages) == 1 + + # Format should be: {event: message, key: value, level: info, timestamp: isoformat} + log_messages[0].pop("timestamp") # Assume structlog formats timestamp correctly + assert log_messages[0] == {"event": "message", "key": "value", "level": "info"} + + def test_structlog_config_default(capsys: CaptureFixture) -> None: with create_test_client([], logging_config=StructLoggingConfig()) as client: assert client.app.logger diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index cb53adc88a..85cbe4c84b 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -233,7 +233,7 @@ def handler() -> None: assert cap_logs[0].get("connection_type") == "http" assert cap_logs[0].get("path") == "/test" assert cap_logs[0].get("traceback") - assert cap_logs[0].get("event") == "uncaught exception" + assert cap_logs[0].get("event") == "Uncaught Exception" assert cap_logs[0].get("log_level") == "error" else: assert not cap_logs From 3ec3ba7bbeb938e5e16884dfa54a5bfe501fbbe2 Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Mon, 8 Jan 2024 18:00:55 +0530 Subject: [PATCH 03/14] feat: allow using custom `CompressionFacade` implementations (#2952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: move the different compressions into their own implementations * feat: check if encoding given by the facade is accepted * Bump version to 2.6.0 Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: allow use of custom CompressionFacade implementations * feat: allow gzip to be used as fallback for any backend * docs: clarify 'backend' * test: rename test to a clearer name * docs: add docstring for 'CompressionFacade.encoding' * fix: explicitly specify facade type * fix: only import BrotliCompression if backend is brotli If the backend is not brotli, then the user may not have installed brotli which would result in an incorrect MissingDependency exception. --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> --- litestar/config/compression.py | 38 +++++-- litestar/middleware/compression/__init__.py | 4 + .../middleware/compression/brotli_facade.py | 51 ++++++++++ litestar/middleware/compression/facade.py | 47 +++++++++ .../middleware/compression/gzip_facade.py | 32 ++++++ .../middleware.py} | 98 ++++--------------- .../test_compression_middleware.py | 39 +++++++- 7 files changed, 216 insertions(+), 93 deletions(-) create mode 100644 litestar/middleware/compression/__init__.py create mode 100644 litestar/middleware/compression/brotli_facade.py create mode 100644 litestar/middleware/compression/facade.py create mode 100644 litestar/middleware/compression/gzip_facade.py rename litestar/middleware/{compression.py => compression/middleware.py} (65%) diff --git a/litestar/config/compression.py b/litestar/config/compression.py index 2d6ecf1491..c339329144 100644 --- a/litestar/config/compression.py +++ b/litestar/config/compression.py @@ -1,10 +1,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Literal +from typing import TYPE_CHECKING, Any, Literal from litestar.exceptions import ImproperlyConfiguredException from litestar.middleware.compression import CompressionMiddleware +from litestar.middleware.compression.gzip_facade import GzipCompression + +if TYPE_CHECKING: + from litestar.middleware.compression.facade import CompressionFacade __all__ = ("CompressionConfig",) @@ -17,8 +21,11 @@ class CompressionConfig: using the ``compression_config`` key. """ - backend: Literal["gzip", "brotli"] - """Literal of "gzip" or "brotli".""" + backend: Literal["gzip", "brotli"] | str + """The backend to use. + + If the value given is `gzip` or `brotli`, then the builtin gzip and brotli compression is used. + """ minimum_size: int = field(default=500) """Minimum response size (bytes) to enable compression, affects all backends.""" gzip_compress_level: int = field(default=9) @@ -48,16 +55,29 @@ class CompressionConfig: """A pattern or list of patterns to skip in the compression middleware.""" exclude_opt_key: str | None = None """An identifier to use on routes to disable compression for a particular route.""" + compression_facade: type[CompressionFacade] = GzipCompression + """The compression facade to use for the actual compression.""" + backend_config: Any = None + """Configuration specific to the backend.""" + gzip_fallback: bool = True + """Use GZIP as a fallback if the provided backend is not supported by the client.""" def __post_init__(self) -> None: if self.minimum_size <= 0: raise ImproperlyConfiguredException("minimum_size must be greater than 0") - if self.gzip_compress_level < 0 or self.gzip_compress_level > 9: - raise ImproperlyConfiguredException("gzip_compress_level must be a value between 0 and 9") + if self.backend == "gzip": + if self.gzip_compress_level < 0 or self.gzip_compress_level > 9: + raise ImproperlyConfiguredException("gzip_compress_level must be a value between 0 and 9") + elif self.backend == "brotli": + # Brotli is not guaranteed to be installed. + from litestar.middleware.compression.brotli_facade import BrotliCompression + + if self.brotli_quality < 0 or self.brotli_quality > 11: + raise ImproperlyConfiguredException("brotli_quality must be a value between 0 and 11") - if self.brotli_quality < 0 or self.brotli_quality > 11: - raise ImproperlyConfiguredException("brotli_quality must be a value between 0 and 11") + if self.brotli_lgwin < 10 or self.brotli_lgwin > 24: + raise ImproperlyConfiguredException("brotli_lgwin must be a value between 10 and 24") - if self.brotli_lgwin < 10 or self.brotli_lgwin > 24: - raise ImproperlyConfiguredException("brotli_lgwin must be a value between 10 and 24") + self.gzip_fallback = self.brotli_gzip_fallback + self.compression_facade = BrotliCompression diff --git a/litestar/middleware/compression/__init__.py b/litestar/middleware/compression/__init__.py new file mode 100644 index 0000000000..0885932dd0 --- /dev/null +++ b/litestar/middleware/compression/__init__.py @@ -0,0 +1,4 @@ +from litestar.middleware.compression.facade import CompressionFacade +from litestar.middleware.compression.middleware import CompressionMiddleware + +__all__ = ("CompressionMiddleware", "CompressionFacade") diff --git a/litestar/middleware/compression/brotli_facade.py b/litestar/middleware/compression/brotli_facade.py new file mode 100644 index 0000000000..3d01950a45 --- /dev/null +++ b/litestar/middleware/compression/brotli_facade.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from litestar.enums import CompressionEncoding +from litestar.exceptions import MissingDependencyException +from litestar.middleware.compression.facade import CompressionFacade + +try: + from brotli import MODE_FONT, MODE_GENERIC, MODE_TEXT, Compressor +except ImportError as e: + raise MissingDependencyException("brotli") from e + + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + + +class BrotliCompression(CompressionFacade): + __slots__ = ("compressor", "buffer", "compression_encoding") + + encoding = CompressionEncoding.BROTLI + + def __init__( + self, + buffer: BytesIO, + compression_encoding: Literal[CompressionEncoding.BROTLI] | str, + config: CompressionConfig, + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + modes: dict[Literal["generic", "text", "font"], int] = { + "text": int(MODE_TEXT), + "font": int(MODE_FONT), + "generic": int(MODE_GENERIC), + } + self.compressor = Compressor( + quality=config.brotli_quality, + mode=modes[config.brotli_mode], + lgwin=config.brotli_lgwin, + lgblock=config.brotli_lgblock, + ) + + def write(self, body: bytes) -> None: + self.buffer.write(self.compressor.process(body)) + self.buffer.write(self.compressor.flush()) + + def close(self) -> None: + self.buffer.write(self.compressor.finish()) diff --git a/litestar/middleware/compression/facade.py b/litestar/middleware/compression/facade.py new file mode 100644 index 0000000000..0074b57419 --- /dev/null +++ b/litestar/middleware/compression/facade.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Protocol + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + from litestar.enums import CompressionEncoding + + +class CompressionFacade(Protocol): + """A unified facade offering a uniform interface for different compression libraries.""" + + encoding: ClassVar[str] + """The encoding of the compression.""" + + def __init__( + self, buffer: BytesIO, compression_encoding: CompressionEncoding | str, config: CompressionConfig + ) -> None: + """Initialize ``CompressionFacade``. + + Args: + buffer: A bytes IO buffer to write the compressed data into. + compression_encoding: The compression encoding used. + config: The app compression config. + """ + ... + + def write(self, body: bytes) -> None: + """Write compressed bytes. + + Args: + body: Message body to process + + Returns: + None + """ + ... + + def close(self) -> None: + """Close the compression stream. + + Returns: + None + """ + ... diff --git a/litestar/middleware/compression/gzip_facade.py b/litestar/middleware/compression/gzip_facade.py new file mode 100644 index 0000000000..b10ef73991 --- /dev/null +++ b/litestar/middleware/compression/gzip_facade.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from gzip import GzipFile +from typing import TYPE_CHECKING, Literal + +from litestar.enums import CompressionEncoding +from litestar.middleware.compression.facade import CompressionFacade + +if TYPE_CHECKING: + from io import BytesIO + + from litestar.config.compression import CompressionConfig + + +class GzipCompression(CompressionFacade): + __slots__ = ("compressor", "buffer", "compression_encoding") + + encoding = CompressionEncoding.GZIP + + def __init__( + self, buffer: BytesIO, compression_encoding: Literal[CompressionEncoding.GZIP] | str, config: CompressionConfig + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + self.compressor = GzipFile(mode="wb", fileobj=buffer, compresslevel=config.gzip_compress_level) + + def write(self, body: bytes) -> None: + self.compressor.write(body) + self.compressor.flush() + + def close(self) -> None: + self.compressor.close() diff --git a/litestar/middleware/compression.py b/litestar/middleware/compression/middleware.py similarity index 65% rename from litestar/middleware/compression.py rename to litestar/middleware/compression/middleware.py index c5cd860dda..7ea7853b08 100644 --- a/litestar/middleware/compression.py +++ b/litestar/middleware/compression/middleware.py @@ -1,18 +1,18 @@ from __future__ import annotations -from gzip import GzipFile from io import BytesIO from typing import TYPE_CHECKING, Any, Literal from litestar.datastructures import Headers, MutableScopeHeaders from litestar.enums import CompressionEncoding, ScopeType -from litestar.exceptions import MissingDependencyException from litestar.middleware.base import AbstractMiddleware +from litestar.middleware.compression.gzip_facade import GzipCompression from litestar.utils.empty import value_or_default from litestar.utils.scope.state import ScopeState if TYPE_CHECKING: from litestar.config.compression import CompressionConfig + from litestar.middleware.compression.facade import CompressionFacade from litestar.types import ( ASGIApp, HTTPResponseStartEvent, @@ -27,76 +27,6 @@ except ImportError: Compressor = Any -__all__ = ("CompressionFacade", "CompressionMiddleware") - - -class CompressionFacade: - """A unified facade offering a uniform interface for different compression libraries.""" - - __slots__ = ("compressor", "buffer", "compression_encoding") - - compressor: GzipFile | Compressor # pyright: ignore - - def __init__(self, buffer: BytesIO, compression_encoding: CompressionEncoding, config: CompressionConfig) -> None: - """Initialize ``CompressionFacade``. - - Args: - buffer: A bytes IO buffer to write the compressed data into. - compression_encoding: The compression encoding used. - config: The app compression config. - """ - self.buffer = buffer - self.compression_encoding = compression_encoding - - if compression_encoding == CompressionEncoding.BROTLI: - try: - import brotli # noqa: F401 - except ImportError as e: - raise MissingDependencyException("brotli") from e - - from brotli import MODE_FONT, MODE_GENERIC, MODE_TEXT, Compressor - - modes: dict[Literal["generic", "text", "font"], int] = { - "text": int(MODE_TEXT), - "font": int(MODE_FONT), - "generic": int(MODE_GENERIC), - } - self.compressor = Compressor( - quality=config.brotli_quality, - mode=modes[config.brotli_mode], - lgwin=config.brotli_lgwin, - lgblock=config.brotli_lgblock, - ) - else: - self.compressor = GzipFile(mode="wb", fileobj=buffer, compresslevel=config.gzip_compress_level) - - def write(self, body: bytes) -> None: - """Write compressed bytes. - - Args: - body: Message body to process - - Returns: - None - """ - - if self.compression_encoding == CompressionEncoding.BROTLI: - self.buffer.write(self.compressor.process(body) + self.compressor.flush()) # type: ignore - else: - self.compressor.write(body) - self.compressor.flush() - - def close(self) -> None: - """Close the compression stream. - - Returns: - None - """ - if self.compression_encoding == CompressionEncoding.BROTLI: - self.buffer.write(self.compressor.finish()) # type: ignore - else: - self.compressor.close() - class CompressionMiddleware(AbstractMiddleware): """Compression Middleware Wrapper. @@ -128,20 +58,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: None """ accept_encoding = Headers.from_scope(scope).get("accept-encoding", "") + config = self.config - if CompressionEncoding.BROTLI in accept_encoding and self.config.backend == "brotli": + if config.compression_facade.encoding in accept_encoding: await self.app( scope, receive, self.create_compression_send_wrapper( - send=send, compression_encoding=CompressionEncoding.BROTLI, scope=scope + send=send, compression_encoding=config.compression_facade.encoding, scope=scope ), ) return - if CompressionEncoding.GZIP in accept_encoding and ( - self.config.backend == "gzip" or self.config.brotli_gzip_fallback - ): + if config.gzip_fallback and CompressionEncoding.GZIP in accept_encoding: await self.app( scope, receive, @@ -156,7 +85,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def create_compression_send_wrapper( self, send: Send, - compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP], + compression_encoding: Literal[CompressionEncoding.BROTLI, CompressionEncoding.GZIP] | str, scope: Scope, ) -> Send: """Wrap ``send`` to handle brotli compression. @@ -170,13 +99,20 @@ def create_compression_send_wrapper( An ASGI send function. """ bytes_buffer = BytesIO() - facade = CompressionFacade(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config) + + facade: CompressionFacade + # We can't use `self.config.compression_facade` directly if the compression is `gzip` since + # it may be being used as a fallback. + if compression_encoding == CompressionEncoding.GZIP: + facade = GzipCompression(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config) + else: + facade = self.config.compression_facade( + buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config + ) initial_message: HTTPResponseStartEvent | None = None started = False - _own_encoding = compression_encoding.encode("latin-1") - connection_state = ScopeState.from_scope(scope) async def send_wrapper(message: Message) -> None: diff --git a/tests/unit/test_middleware/test_compression_middleware.py b/tests/unit/test_middleware/test_compression_middleware.py index e1c3ec08bb..c71784f81e 100644 --- a/tests/unit/test_middleware/test_compression_middleware.py +++ b/tests/unit/test_middleware/test_compression_middleware.py @@ -1,4 +1,6 @@ -from typing import AsyncIterator, Callable, Literal +import zlib +from io import BytesIO +from typing import AsyncIterator, Callable, Literal, Union from unittest.mock import MagicMock import pytest @@ -9,6 +11,7 @@ from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import HTTPRouteHandler from litestar.middleware.compression import CompressionMiddleware +from litestar.middleware.compression.facade import CompressionFacade from litestar.response.streaming import Stream from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client @@ -146,9 +149,9 @@ def test_config_minimum_size_validation(minimum_size: int, should_raise: bool) - def test_config_gzip_compress_level_validation(gzip_compress_level: int, should_raise: bool) -> None: if should_raise: with pytest.raises(ImproperlyConfiguredException): - CompressionConfig(backend="brotli", brotli_gzip_fallback=False, gzip_compress_level=gzip_compress_level) + CompressionConfig(backend="gzip", brotli_gzip_fallback=False, gzip_compress_level=gzip_compress_level) else: - CompressionConfig(backend="brotli", brotli_gzip_fallback=False, gzip_compress_level=gzip_compress_level) + CompressionConfig(backend="gzip", brotli_gzip_fallback=False, gzip_compress_level=gzip_compress_level) @pytest.mark.parametrize("brotli_quality, should_raise", ((0, False), (1, False), (-1, True), (12, True), (11, False))) @@ -216,3 +219,33 @@ def handler_fn() -> str: assert response.text == "_litestar_" * 4000 assert response.headers["Content-Encoding"] == compression_encoding assert int(response.headers["Content-Length"]) < 40000 + + +def test_compression_with_custom_backend(handler: HTTPRouteHandler) -> None: + class ZlibCompression(CompressionFacade): + encoding = "deflate" + + def __init__( + self, + buffer: BytesIO, + compression_encoding: Union[Literal[CompressionEncoding.GZIP], str], + config: CompressionConfig, + ) -> None: + self.buffer = buffer + self.compression_encoding = compression_encoding + self.config = config + + def write(self, body: bytes) -> None: + self.buffer.write(zlib.compress(body, level=self.config.backend_config["level"])) + + def close(self) -> None: + ... + + zlib_config = {"level": 9} + config = CompressionConfig(backend="deflate", compression_facade=ZlibCompression, backend_config=zlib_config) + with create_test_client([handler], compression_config=config) as client: + response = client.get("/", headers={"Accept-Encoding": "deflate"}) + assert response.status_code == HTTP_200_OK + assert response.text == "_litestar_" * 4000 + assert response.headers["Content-Encoding"] == "deflate" + assert int(response.headers["Content-Length"]) < 40000 From 283ef423e968e4932bcf596e2f2dbe02a88ad149 Mon Sep 17 00:00:00 2001 From: FergusMok Date: Fri, 12 Jan 2024 08:21:56 +0800 Subject: [PATCH 04/14] feat: Add `reload-include` and `reload-exclude` from uvicorn to CLI (#2973) * Add reload-include and reload-exclude feature * Update documentation * Fix tests, update documentation --- docs/usage/cli.rst | 38 ++++++++++++++++++++ litestar/cli/_utils.py | 6 ++++ litestar/cli/commands/core.py | 20 ++++++++++- tests/unit/test_cli/test_core_commands.py | 43 +++++++++++++++++------ 4 files changed, 95 insertions(+), 12 deletions(-) diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index 2360962691..b52463d315 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -107,6 +107,10 @@ Options +-------------------------------------------+----------------------------------------------+----------------------------------------------------------------------------------+ | ``-R``\ , ``--reload-dir`` | ``LITESTAR_RELOAD_DIRS`` | Specify directories to watch for reload. | +-------------------------------------------+----------------------------------------------+----------------------------------------------------------------------------------+ +| ``-I``\ , ``--reload-include`` | ``LITESTAR_RELOAD_INCLUDES`` | Specify glob patterns for files to include when watching for reload. | ++-------------------------------------------+----------------------------------------------+----------------------------------------------------------------------------------+ +| ``-E``\ , ``--reload-exclude`` | ``LITESTAR_RELOAD_EXCLUDES`` | Specify glob patterns for files to exclude when watching for reload. | ++-------------------------------------------+----------------------------------------------+----------------------------------------------------------------------------------+ | ``-p``\ , ``--port`` | ``LITESTAR_PORT`` | Bind the server to this port [default: 8000] | +-------------------------------------------+----------------------------------------------+----------------------------------------------------------------------------------+ | ``--wc``\ , ``--web-concurrency`` | ``WEB_CONCURRENCY`` | The number of concurrent web workers to start [default: 1] | @@ -143,6 +147,40 @@ To set multiple directories via an environment variable, use a comma-separated l LITESTAR_RELOAD_DIRS=.,../other-library/src +--reload-include +++++++++++++++++ + +The ``--reload-include`` flag allows you to specify glob patterns to include when watching for file changes. If you specify this flag, the ``--reload`` flag is implied. Furthermore, ``.py`` files are included implicitly by default. + +You can specify multiple glob patterns by passing the flag multiple times: + +.. code-block:: shell + + litestar run --reload-include="*.rst" --reload-include="*.yml" + +To set multiple directories via an environment variable, use a comma-separated list: + +.. code-block:: shell + + LITESTAR_RELOAD_INCLUDES=*.rst,*.yml + +--reload-exclude +++++++++++++++++ + +The ``--reload-exclude`` flag allows you to specify glob patterns to exclude when watching for file changes. If you specify this flag, the ``--reload`` flag is implied. + +You can specify multiple glob patterns by passing the flag multiple times: + +.. code-block:: shell + + litestar run --reload-exclude="*.py" --reload-exclude="*.yml" + +To set multiple directories via an environment variable, use a comma-separated list: + +.. code-block:: shell + + LITESTAR_RELOAD_EXCLUDES=*.py,*.yml + SSL +++ diff --git a/litestar/cli/_utils.py b/litestar/cli/_utils.py index 4559f53a47..c3f35689d6 100644 --- a/litestar/cli/_utils.py +++ b/litestar/cli/_utils.py @@ -85,6 +85,8 @@ class LitestarEnv: uds: str | None = None reload: bool | None = None reload_dirs: tuple[str, ...] | None = None + reload_include: tuple[str, ...] | None = None + reload_exclude: tuple[str, ...] | None = None web_concurrency: int | None = None is_app_factory: bool = False certfile_path: str | None = None @@ -120,6 +122,8 @@ def from_env(cls, app_path: str | None, app_dir: Path | None = None) -> Litestar uds = getenv("LITESTAR_UNIX_DOMAIN_SOCKET") fd = getenv("LITESTAR_FILE_DESCRIPTOR") reload_dirs = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_DIRS", "").split(",") if s) or None + reload_include = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_INCLUDES", "").split(",") if s) or None + reload_exclude = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_EXCLUDES", "").split(",") if s) or None return cls( app_path=loaded_app.app_path, @@ -131,6 +135,8 @@ def from_env(cls, app_path: str | None, app_dir: Path | None = None) -> Litestar fd=int(fd) if fd else None, reload=_bool_from_env("LITESTAR_RELOAD"), reload_dirs=reload_dirs, + reload_include=reload_include, + reload_exclude=reload_exclude, web_concurrency=int(web_concurrency) if web_concurrency else None, is_app_factory=loaded_app.is_factory, cwd=cwd, diff --git a/litestar/cli/commands/core.py b/litestar/cli/commands/core.py index a77abe053b..5b552533b1 100644 --- a/litestar/cli/commands/core.py +++ b/litestar/cli/commands/core.py @@ -69,6 +69,8 @@ def _run_uvicorn_in_subprocess( workers: int | None, reload: bool, reload_dirs: tuple[str, ...] | None, + reload_include: tuple[str, ...] | None, + reload_exclude: tuple[str, ...] | None, fd: int | None, uds: str | None, certfile_path: str | None, @@ -87,6 +89,10 @@ def _run_uvicorn_in_subprocess( process_args["uds"] = uds if reload_dirs: process_args["reload-dir"] = reload_dirs + if reload_include: + process_args["reload-include"] = reload_include + if reload_exclude: + process_args["reload-exclude"] = reload_exclude if certfile_path is not None: process_args["ssl-certfile"] = certfile_path if keyfile_path is not None: @@ -116,6 +122,12 @@ def info_command(app: Litestar) -> None: @command(name="run") @option("-r", "--reload", help="Reload server on changes", default=False, is_flag=True) @option("-R", "--reload-dir", help="Directories to watch for file changes", multiple=True) +@option( + "-I", "--reload-include", help="Glob patterns for files to include when watching for file changes", multiple=True +) +@option( + "-E", "--reload-exclude", help="Glob patterns for files to exclude when watching for file changes", multiple=True +) @option("-p", "--port", help="Serve under this port", type=int, default=8000, show_default=True) @option( "-W", @@ -155,6 +167,8 @@ def run_command( uds: str | None, debug: bool, reload_dir: tuple[str, ...], + reload_include: tuple[str, ...], + reload_exclude: tuple[str, ...], pdb: bool, ssl_certfile: str | None, ssl_keyfile: str | None, @@ -194,12 +208,14 @@ def run_command( app = env.app reload_dirs = env.reload_dirs or reload_dir + reload_include = env.reload_include or reload_include + reload_exclude = env.reload_exclude or reload_exclude host = env.host or host port = env.port if env.port is not None else port fd = env.fd if env.fd is not None else fd uds = env.uds or uds - reload = env.reload or reload or bool(reload_dirs) + reload = env.reload or reload or bool(reload_dirs) or bool(reload_include) or bool(reload_exclude) workers = env.web_concurrency or wc ssl_certfile = ssl_certfile or env.certfile_path @@ -248,6 +264,8 @@ def run_command( workers=workers, reload=reload, reload_dirs=reload_dirs, + reload_include=reload_include, + reload_exclude=reload_exclude, fd=fd, uds=uds, certfile_path=certfile_path, diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index 8330604aee..46c89d68d1 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -40,16 +40,19 @@ def mock_show_app_info(mocker: MockerFixture) -> MagicMock: @pytest.mark.parametrize("custom_app_file,", [Path("my_app.py"), None]) @pytest.mark.parametrize("app_dir", ["custom_subfolder", None]) @pytest.mark.parametrize( - "reload, reload_dir, web_concurrency", + "reload, reload_dir, reload_include, reload_exclude, web_concurrency", [ - (None, None, None), - (True, None, None), - (False, None, None), - (True, [".", "../somewhere_else"], None), - (False, [".", "../somewhere_else"], None), - (None, None, 2), - (True, None, 2), - (False, None, 2), + (None, None, None, None, None), + (True, None, None, None, None), + (False, None, None, None, None), + (True, [".", "../somewhere_else"], None, None, None), + (False, [".", "../somewhere_else"], None, None, None), + (True, None, ["*.rst", "*.yml"], None, None), + (False, None, None, ["*.py"], None), + (False, None, ["*.yml", "*.rst"], None, None), + (None, None, None, None, 2), + (True, None, None, None, 2), + (False, None, None, None, 2), ], ) def test_run_command( @@ -64,6 +67,8 @@ def test_run_command( web_concurrency: Optional[int], app_dir: Optional[str], reload_dir: Optional[List[str]], + reload_include: Optional[List[str]], + reload_exclude: Optional[List[str]], custom_app_file: Optional[Path], create_app_file: CreateAppFileFixture, set_in_env: bool, @@ -131,6 +136,18 @@ def test_run_command( else: args.extend([f"--reload-dir={s}" for s in reload_dir]) + if reload_include is not None: + if set_in_env: + monkeypatch.setenv("LITESTAR_RELOAD_INCLUDES", ",".join(reload_include)) + else: + args.extend([f"--reload-include={s}" for s in reload_include]) + + if reload_exclude is not None: + if set_in_env: + monkeypatch.setenv("LITESTAR_RELOAD_EXCLUDES", ",".join(reload_exclude)) + else: + args.extend([f"--reload-exclude={s}" for s in reload_exclude]) + path = create_app_file(custom_app_file or "app.py", directory=app_dir) result = runner.invoke(cli_command, args) @@ -138,7 +155,7 @@ def test_run_command( assert result.exception is None assert result.exit_code == 0 - if reload or reload_dir or web_concurrency > 1: + if reload or reload_dir or reload_include or reload_exclude or web_concurrency > 1: expected_args = [ sys.executable, "-m", @@ -151,12 +168,16 @@ def test_run_command( expected_args.append(f"--fd={fd}") if uds is not None: expected_args.append(f"--uds={uds}") - if reload or reload_dir: + if reload or reload_dir or reload_include or reload_exclude: expected_args.append("--reload") if web_concurrency: expected_args.append(f"--workers={web_concurrency}") if reload_dir: expected_args.extend([f"--reload-dir={s}" for s in reload_dir]) + if reload_include: + expected_args.extend([f"--reload-include={s}" for s in reload_include]) + if reload_exclude: + expected_args.extend([f"--reload-exclude={s}" for s in reload_exclude]) mock_subprocess_run.assert_called_once() assert sorted(mock_subprocess_run.call_args_list[0].args[0]) == sorted(expected_args) else: From 1c28ce718ae5994f203139f1904f44f2f224da7c Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Sat, 13 Jan 2024 12:36:35 -0600 Subject: [PATCH 05/14] feat: allow `root` logger configuration to be disabled (#2969) --- litestar/logging/config.py | 17 +++++++++----- .../unit/test_logging/test_logging_config.py | 22 ++++++++++++++++++- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/litestar/logging/config.py b/litestar/logging/config.py index 6b2e0ecebb..5d17dcb756 100644 --- a/litestar/logging/config.py +++ b/litestar/logging/config.py @@ -192,6 +192,8 @@ class LoggingConfig(BaseLoggingConfig): Processing of the configuration will be as for any logger, except that the propagate setting will not be applicable. """ + configure_root_logger: bool = field(default=True) + """Should the root logger be configured, defaults to True for ease of configuration.""" log_exceptions: Literal["always", "debug", "never"] = field(default="debug") """Should exceptions be logged, defaults to log exceptions when 'app.debug == True'""" traceback_line_limit: int = field(default=20) @@ -224,18 +226,21 @@ def configure(self) -> GetLogger: if "picologging" in str(encode_json(self.handlers)): try: - import picologging # noqa: F401 + from picologging import config, getLogger except ImportError as e: raise MissingDependencyException("picologging") from e - from picologging import config, getLogger - - values = {k: v for k, v in asdict(self).items() if v is not None and k != "incremental"} + values = { + k: v + for k, v in asdict(self).items() + if v is not None and k not in ("incremental", "configure_root_logger") + } else: from logging import config, getLogger # type: ignore[no-redef, assignment] - values = {k: v for k, v in asdict(self).items() if v is not None} - + values = {k: v for k, v in asdict(self).items() if v is not None and k not in ("configure_root_logger",)} + if not self.configure_root_logger: + values.pop("root") config.dictConfig(values) return cast("Callable[[str], Logger]", getLogger) diff --git a/tests/unit/test_logging/test_logging_config.py b/tests/unit/test_logging/test_logging_config.py index f86ad0809c..e850305d5a 100644 --- a/tests/unit/test_logging/test_logging_config.py +++ b/tests/unit/test_logging/test_logging_config.py @@ -144,6 +144,27 @@ def test_root_logger(handlers: Any, listener: Any) -> None: assert isinstance(root_logger.handlers[0], listener) # type: ignore +@pytest.mark.parametrize( + "handlers, listener", + [ + [default_handlers, StandardQueueListenerHandler], + [default_picologging_handlers, PicologgingQueueListenerHandler], + ], +) +def test_root_logger_no_config(handlers: Any, listener: Any) -> None: + logging_config = LoggingConfig(handlers=handlers, configure_root_logger=False) + get_logger = logging_config.configure() + root_logger = get_logger() + for handler in root_logger.handlers: # type: ignore[attr-defined] + root_logger.removeHandler(handler) # type: ignore[attr-defined] + get_logger = logging_config.configure() + root_logger = get_logger() + if handlers["console"]["class"] == "logging.StreamHandler": + assert not isinstance(root_logger.handlers[0], listener) # type: ignore[attr-defined] + else: + assert len(root_logger.handlers) < 1 # type: ignore[attr-defined] + + @pytest.mark.parametrize( "handlers, listener", [ @@ -159,7 +180,6 @@ def test_root_logger(handlers: Any, listener: Any) -> None: ) def test_customizing_handler(handlers: Any, listener: Any, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(handlers["queue_listener"], "handlers", ["cfg://handlers.console"]) - logging_config = LoggingConfig(handlers=handlers) get_logger = logging_config.configure() root_logger = get_logger() From d501aa75381485131e74a6de6e8e7ca7851bce9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sat, 13 Jan 2024 19:59:55 +0100 Subject: [PATCH 06/14] docs: Fix reference error (#2983) Fix doc reference error --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index e9a00e4fb0..2bcb2c679e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -178,6 +178,7 @@ ("py:exc", "InternalServerError"), ("py:exc", "HTTPExceptions"), (PY_CLASS, "litestar.template.Template"), + (PY_CLASS, "litestar.middleware.compression.gzip_facade.GzipCompression"), ] nitpick_ignore_regex = [ From e1e6c7c42ae29f0b1b09168f13e9d478ac89bd5b Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Sat, 20 Jan 2024 08:50:06 -0800 Subject: [PATCH 07/14] fix: additional `structlog` fixes (#2985) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bump version to 2.6.0 Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: `structlog` plugin & bug fixes (#2943) * feat(channels): Postgres backends (#2803) * wip Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * some debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * use a separate connection to publish/listen Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * reintroduce flaky Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add psycopg backend Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix backend issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Undo test debugging changes Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * mark groups Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Ensure channel names ar quoted Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * sleep debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * update docs Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix docs link Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing listener test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix test typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix some coverage issue Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * test skip sourcery Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * test(channels): Improve channels testing (#2838) * Improve channels testing --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * chore(typing): various pyright issues (#2897) Fix various pyright issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix(channels): Trailing messages after unsubscribes (#2894) Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(cli): Add ``--schema`` and ``--exclude`` option to route CLI. (#2886) * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * Update docs/usage/cli.rst * Update litestar/cli/_utils.py * fix malformed docs table. --------- Co-authored-by: Jacob Coffee * test(CLI): Fix xdist issue (#2931) Fix test for xdist Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(core): Replace `anyio.to_thread.run_sync` with native versions (#2937) Replace anyio.to_thread.run_sync with native versions Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: example app using structlog * fix: updated structlog with request logging * fix: lazy initialized structlog fix * feat: add structlog plugin * fix: adds `set_level` to all Logging configurations * fix: check that the object has the `setLevel` method before calling * feat: adds test for plugin * fix: parameter naming for `set_level` abstract method * feat(channels): Postgres backends (#2803) * wip Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * some debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * use a separate connection to publish/listen Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * reintroduce flaky Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add psycopg backend Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix backend issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Undo test debugging changes Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * mark groups Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Ensure channel names ar quoted Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * sleep debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * update docs Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix docs link Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing listener test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix test typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix some coverage issue Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * test skip sourcery Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * test(channels): Improve channels testing (#2838) * Improve channels testing --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * chore(typing): various pyright issues (#2897) Fix various pyright issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix(channels): Trailing messages after unsubscribes (#2894) Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(cli): Add ``--schema`` and ``--exclude`` option to route CLI. (#2886) * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * Update docs/usage/cli.rst * Update litestar/cli/_utils.py * fix malformed docs table. --------- Co-authored-by: Jacob Coffee * test(CLI): Fix xdist issue (#2931) Fix test for xdist Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(core): Replace `anyio.to_thread.run_sync` with native versions (#2937) Replace anyio.to_thread.run_sync with native versions Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(channels): Postgres backends (#2803) * wip Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * some debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * use a separate connection to publish/listen Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * reintroduce flaky Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add psycopg backend Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix backend issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Undo test debugging changes Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * mark groups Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Ensure channel names ar quoted Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * sleep debugging Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * update docs Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix docs link Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Add missing listener test Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Formatting Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix test typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * Fix some coverage issue Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * test skip sourcery Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * test(channels): Improve channels testing (#2838) * Improve channels testing --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * chore(typing): various pyright issues (#2897) Fix various pyright issues Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix(channels): Trailing messages after unsubscribes (#2894) Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(cli): Add ``--schema`` and ``--exclude`` option to route CLI. (#2886) * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * add exclude and schema cli options to route. * updates per linting, mypy, and etc. * fix some more mypy stuff. * fix issue with linting. * add doc for route cli options. * fix issue with python3.8 not liking dict type. * Update docs/usage/cli.rst * Update litestar/cli/_utils.py * fix malformed docs table. --------- Co-authored-by: Jacob Coffee * test(CLI): Fix xdist issue (#2931) Fix test for xdist Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat(core): Replace `anyio.to_thread.run_sync` with native versions (#2937) Replace anyio.to_thread.run_sync with native versions Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: structlog detects tty by default * chore: linting fixes * fix: color code correction * fix: adjusted color code to be more visible * fix: additional config settings * feat: enable pretty-print in TTY * fix: apply rich configuration * fix: updated formatting to align with other messages * chore: trim whitespace --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: andrew do Co-authored-by: Jacob Coffee * feat: allow using custom `CompressionFacade` implementations (#2952) * refactor: move the different compressions into their own implementations * feat: check if encoding given by the facade is accepted * Bump version to 2.6.0 Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * feat: allow use of custom CompressionFacade implementations * feat: allow gzip to be used as fallback for any backend * docs: clarify 'backend' * test: rename test to a clearer name * docs: add docstring for 'CompressionFacade.encoding' * fix: explicitly specify facade type * fix: only import BrotliCompression if backend is brotli If the backend is not brotli, then the user may not have installed brotli which would result in an incorrect MissingDependency exception. --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com> * feat: Add `reload-include` and `reload-exclude` from uvicorn to CLI (#2973) * Add reload-include and reload-exclude feature * Update documentation * Fix tests, update documentation * feat: allow `root` logger configuration to be disabled (#2969) * docs: Fix reference error (#2983) Fix doc reference error * fix: correctly render stdlib logs as string instead of bytes * feat: add missing timestamper to standard logging for structlog * feat: filter out `color_message` by default * feat: add nocover for dev logger * feat: adds test for `TTY` config of structlog * feat: increased coverage * fix: remove incorrect call to `get` a plugin * feat: additional coverage * fix: remove unnecessary mixin * feat: additional coverage * fix: add ignore on lines that are actually covered * feat: add deprecated function --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Janek Nouvertné Co-authored-by: andrew do Co-authored-by: Jacob Coffee Co-authored-by: guacs <126393040+guacs@users.noreply.github.com> Co-authored-by: FergusMok --- litestar/logging/config.py | 89 +++++-- litestar/plugins/structlog.py | 13 +- .../test_logging/test_structlog_config.py | 240 ++++++++++++++---- 3 files changed, 259 insertions(+), 83 deletions(-) diff --git a/litestar/logging/config.py b/litestar/logging/config.py index 5d17dcb756..d7733ac39d 100644 --- a/litestar/logging/config.py +++ b/litestar/logging/config.py @@ -9,15 +9,19 @@ from litestar.exceptions import ImproperlyConfiguredException, MissingDependencyException from litestar.serialization import encode_json +from litestar.serialization.msgspec_hooks import _msgspec_json_encoder +from litestar.utils.deprecation import deprecated __all__ = ("BaseLoggingConfig", "LoggingConfig", "StructLoggingConfig") if TYPE_CHECKING: + from collections.abc import Iterable from typing import NoReturn # these imports are duplicated on purpose so sphinx autodoc can find and link them from structlog.types import BindableLogger, Processor, WrappedLogger + from structlog.typing import EventDict from litestar.types import Logger, Scope from litestar.types.callable_types import ExceptionLoggingHandler, GetLogger @@ -250,8 +254,55 @@ def set_level(logger: Logger, level: int) -> None: logger.setLevel(level) -def default_json_serializer(value: Any, default: Callable[[Any], Any] | None = None) -> bytes: - return encode_json(value=value, serializer=default) +class StructlogEventFilter: + """Remove keys from the log event. + + Add an instance to the processor chain. + + .. code-block:: python + :caption: Examples + + structlog.configure( + ..., + processors=[ + ..., + EventFilter(["color_message"]), + ..., + ], + ) + + """ + + def __init__(self, filter_keys: Iterable[str]) -> None: + """Initialize the EventFilter. + + Args: + filter_keys: Iterable of string keys to be excluded from the log event. + """ + self.filter_keys = filter_keys + + def __call__(self, _: WrappedLogger, __: str, event_dict: EventDict) -> EventDict: + """Receive the log event, and filter keys. + + Args: + _ (): + __ (): + event_dict (): The data to be logged. + + Returns: + The log event with any key in `self.filter_keys` removed. + """ + for key in self.filter_keys: + event_dict.pop(key, None) + return event_dict + + +def default_json_serializer(value: EventDict, **_: Any) -> bytes: + return _msgspec_json_encoder.encode(value) + + +def stdlib_json_serializer(value: EventDict, **_: Any) -> str: # pragma: no cover + return _msgspec_json_encoder.encode(value).decode("utf-8") def default_structlog_processors(as_json: bool = True) -> list[Processor]: # pyright: ignore @@ -297,13 +348,19 @@ def default_structlog_standard_lib_processors(as_json: bool = True) -> list[Proc if as_json: return [ + structlog.processors.TimeStamper(fmt="iso"), structlog.stdlib.add_log_level, structlog.stdlib.ExtraAdder(), - structlog.processors.JSONRenderer(serializer=default_json_serializer), + StructlogEventFilter(["color_message"]), + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.processors.JSONRenderer(serializer=stdlib_json_serializer), ] return [ + structlog.processors.TimeStamper(fmt="iso"), structlog.stdlib.add_log_level, structlog.stdlib.ExtraAdder(), + StructlogEventFilter(["color_message"]), + structlog.stdlib.ProcessorFormatter.remove_processors_meta, structlog.dev.ConsoleRenderer( colors=True, exception_formatter=RichTracebackFormatter(max_frames=1, show_locals=False, width=80) ), @@ -312,21 +369,6 @@ def default_structlog_standard_lib_processors(as_json: bool = True) -> list[Proc return [] -def default_wrapper_class(log_level: int = INFO) -> type[BindableLogger] | None: # pyright: ignore - """Set the default wrapper class for structlog. - - Returns: - An optional wrapper class. - """ - - try: - import structlog - - return structlog.make_filtering_bound_logger(log_level) - except ImportError: - return None - - def default_logger_factory(as_json: bool = True) -> Callable[..., WrappedLogger] | None: """Set the default logger factory for structlog. @@ -438,3 +480,14 @@ def set_level(logger: Logger, level: int) -> None: structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(level)) except ImportError: """""" + return + + +@deprecated(version="2.6.0", removal_in="3.0.0", alternative="`StructLoggingConfig.set_level`") +def default_wrapper_class(log_level: int = INFO) -> type[BindableLogger] | None: # pragma: no cover # pyright: ignore + try: # pragma: no cover + import structlog + + return structlog.make_filtering_bound_logger(log_level) + except ImportError: + return None diff --git a/litestar/plugins/structlog.py b/litestar/plugins/structlog.py index 0ff53fbfd4..fafa3dde8f 100644 --- a/litestar/plugins/structlog.py +++ b/litestar/plugins/structlog.py @@ -6,11 +6,9 @@ from litestar.cli._utils import console from litestar.logging.config import StructLoggingConfig from litestar.middleware.logging import LoggingMiddlewareConfig -from litestar.plugins import CLIPluginProtocol, InitPluginProtocol +from litestar.plugins import InitPluginProtocol if TYPE_CHECKING: - from click import Group - from litestar.config.app import AppConfig @@ -24,7 +22,7 @@ class StructlogConfig: """Enable request logging.""" -class StructlogPlugin(InitPluginProtocol, CLIPluginProtocol): +class StructlogPlugin(InitPluginProtocol): """Structlog Plugin.""" __slots__ = ("_config",) @@ -35,9 +33,6 @@ def __init__(self, config: StructlogConfig | None = None) -> None: self._config = config super().__init__() - def on_cli_init(self, cli: Group) -> None: - return super().on_cli_init(cli) - def on_app_init(self, app_config: AppConfig) -> AppConfig: """Structlog Plugin @@ -54,8 +49,8 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: else: app_config.logging_config = self._config.structlog_logging_config app_config.logging_config.configure() - if self._config.structlog_logging_config.standard_lib_logging_config is not None: - self._config.structlog_logging_config.standard_lib_logging_config.configure() + if self._config.structlog_logging_config.standard_lib_logging_config is not None: # pragma: no cover + self._config.structlog_logging_config.standard_lib_logging_config.configure() # pragma: no cover if self._config.enable_middleware_logging: app_config.middleware.append(self._config.middleware_logging_config.middleware) return app_config # pragma: no cover diff --git a/tests/unit/test_logging/test_structlog_config.py b/tests/unit/test_logging/test_structlog_config.py index 62f53f0f9a..e850305d5a 100644 --- a/tests/unit/test_logging/test_structlog_config.py +++ b/tests/unit/test_logging/test_structlog_config.py @@ -1,58 +1,186 @@ -from pytest import CaptureFixture -from structlog.processors import JSONRenderer -from structlog.types import BindableLogger - -from litestar.logging.config import StructLoggingConfig, default_json_serializer -from litestar.plugins.structlog import StructlogPlugin -from litestar.serialization import decode_json +import logging +import sys +from typing import TYPE_CHECKING, Any, Dict +from unittest.mock import Mock, patch + +import pytest + +from litestar import Request, get +from litestar.exceptions import ImproperlyConfiguredException +from litestar.logging.config import LoggingConfig, _get_default_handlers, default_handlers, default_picologging_handlers +from litestar.logging.picologging import QueueListenerHandler as PicologgingQueueListenerHandler +from litestar.logging.standard import QueueListenerHandler as StandardQueueListenerHandler +from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client -# structlog.testing.capture_logs changes the processors -# Because we want to test processors, use capsys instead - - -def test_structlog_plugin(capsys: CaptureFixture) -> None: - with create_test_client([], plugins=[StructlogPlugin()]) as client: - assert client.app.logger - assert isinstance(client.app.logger.bind(), BindableLogger) - client.app.logger.info("message", key="value") - - log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] - assert len(log_messages) == 1 - - # Format should be: {event: message, key: value, level: info, timestamp: isoformat} - log_messages[0].pop("timestamp") # Assume structlog formats timestamp correctly - assert log_messages[0] == {"event": "message", "key": "value", "level": "info"} - - -def test_structlog_config_default(capsys: CaptureFixture) -> None: - with create_test_client([], logging_config=StructLoggingConfig()) as client: - assert client.app.logger - assert isinstance(client.app.logger.bind(), BindableLogger) - client.app.logger.info("message", key="value") - - log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] - assert len(log_messages) == 1 - - # Format should be: {event: message, key: value, level: info, timestamp: isoformat} - log_messages[0].pop("timestamp") # Assume structlog formats timestamp correctly - assert log_messages[0] == {"event": "message", "key": "value", "level": "info"} - - -def test_structlog_config_specify_processors(capsys: CaptureFixture) -> None: - logging_config = StructLoggingConfig(processors=[JSONRenderer(serializer=default_json_serializer)]) - - with create_test_client([], logging_config=logging_config) as client: - assert client.app.logger - assert isinstance(client.app.logger.bind(), BindableLogger) - - client.app.logger.info("message1", key="value1") - # Log twice to make sure issue #882 doesn't appear again - client.app.logger.info("message2", key="value2") - - log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] - - assert log_messages == [ - {"key": "value1", "event": "message1"}, - {"key": "value2", "event": "message2"}, - ] +if TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture + + +@pytest.mark.parametrize( + "dict_config_class, handlers, expected_called", + [ + ["logging.config.dictConfig", default_handlers, True], + ["logging.config.dictConfig", default_picologging_handlers, False], + ["picologging.config.dictConfig", default_handlers, False], + ["picologging.config.dictConfig", default_picologging_handlers, True], + ], +) +def test_correct_dict_config_called( + dict_config_class: str, handlers: Dict[str, Dict[str, Any]], expected_called: bool +) -> None: + with patch(dict_config_class) as dict_config_mock: + log_config = LoggingConfig(handlers=handlers) + log_config.configure() + if expected_called: + assert dict_config_mock.called + else: + assert not dict_config_mock.called + + +@pytest.mark.parametrize("picologging_exists", [True, False]) +def test_correct_default_handlers_set(picologging_exists: bool) -> None: + with patch("litestar.logging.config.find_spec") as find_spec_mock: + find_spec_mock.return_value = picologging_exists + log_config = LoggingConfig() + + if picologging_exists: + assert log_config.handlers == default_picologging_handlers + else: + assert log_config.handlers == default_handlers + + +@pytest.mark.parametrize( + "dict_config_class, handlers", + [ + ["logging.config.dictConfig", default_handlers], + ["picologging.config.dictConfig", default_picologging_handlers], + ], +) +def test_dictconfig_startup(dict_config_class: str, handlers: Any) -> None: + with patch(dict_config_class) as dict_config_mock: + test_logger = LoggingConfig( + handlers=handlers, + ) + with create_test_client([], on_startup=[test_logger.configure]): + assert dict_config_mock.called + + +def test_standard_queue_listener_logger(caplog: "LogCaptureFixture") -> None: + with caplog.at_level("INFO", logger="test_logger"): + logger = logging.getLogger("test_logger") + logger.info("Testing now!") + assert "Testing now!" in caplog.text + var = "test_var" + logger.info("%s", var) + assert var in caplog.text + + +@patch("picologging.config.dictConfig") +def test_picologging_dictconfig_when_disabled(dict_config_mock: Mock) -> None: + test_logger = LoggingConfig(loggers={"app": {"level": "INFO", "handlers": ["console"]}}, handlers=default_handlers) + with create_test_client([], on_startup=[test_logger.configure], logging_config=None): + assert not dict_config_mock.called + + +def test_get_logger_without_logging_config() -> None: + with create_test_client(logging_config=None) as client: + with pytest.raises( + ImproperlyConfiguredException, + match="cannot call '.get_logger' without passing 'logging_config' to the Litestar constructor first", + ): + client.app.get_logger() + + +def test_get_default_logger() -> None: + with create_test_client(logging_config=LoggingConfig(handlers=default_handlers)) as client: + assert isinstance(client.app.logger.handlers[0], StandardQueueListenerHandler) + new_logger = client.app.get_logger() + assert isinstance(new_logger.handlers[0], StandardQueueListenerHandler) + + +def test_get_picologging_logger() -> None: + with create_test_client(logging_config=LoggingConfig(handlers=default_picologging_handlers)) as client: + assert isinstance(client.app.logger.handlers[0], PicologgingQueueListenerHandler) + new_logger = client.app.get_logger() + assert isinstance(new_logger.handlers[0], PicologgingQueueListenerHandler) + + +@pytest.mark.parametrize( + "handlers, listener", + [ + [default_handlers, StandardQueueListenerHandler], + [default_picologging_handlers, PicologgingQueueListenerHandler], + ], +) +def test_connection_logger(handlers: Any, listener: Any) -> None: + @get("/") + def handler(request: Request) -> Dict[str, bool]: + return {"isinstance": isinstance(request.logger.handlers[0], listener)} # type: ignore + + with create_test_client(route_handlers=[handler], logging_config=LoggingConfig(handlers=handlers)) as client: + response = client.get("/") + assert response.status_code == HTTP_200_OK + assert response.json()["isinstance"] + + +def test_validation() -> None: + logging_config = LoggingConfig(handlers={}, loggers={}) + assert logging_config.handlers["queue_listener"] == _get_default_handlers()["queue_listener"] + assert logging_config.loggers["litestar"] + + +@pytest.mark.parametrize( + "handlers, listener", + [ + [default_handlers, StandardQueueListenerHandler], + [default_picologging_handlers, PicologgingQueueListenerHandler], + ], +) +def test_root_logger(handlers: Any, listener: Any) -> None: + logging_config = LoggingConfig(handlers=handlers) + get_logger = logging_config.configure() + root_logger = get_logger() + assert isinstance(root_logger.handlers[0], listener) # type: ignore + + +@pytest.mark.parametrize( + "handlers, listener", + [ + [default_handlers, StandardQueueListenerHandler], + [default_picologging_handlers, PicologgingQueueListenerHandler], + ], +) +def test_root_logger_no_config(handlers: Any, listener: Any) -> None: + logging_config = LoggingConfig(handlers=handlers, configure_root_logger=False) + get_logger = logging_config.configure() + root_logger = get_logger() + for handler in root_logger.handlers: # type: ignore[attr-defined] + root_logger.removeHandler(handler) # type: ignore[attr-defined] + get_logger = logging_config.configure() + root_logger = get_logger() + if handlers["console"]["class"] == "logging.StreamHandler": + assert not isinstance(root_logger.handlers[0], listener) # type: ignore[attr-defined] + else: + assert len(root_logger.handlers) < 1 # type: ignore[attr-defined] + + +@pytest.mark.parametrize( + "handlers, listener", + [ + pytest.param( + default_handlers, + StandardQueueListenerHandler, + marks=pytest.mark.xfail( + condition=sys.version_info >= (3, 12), reason="change to QueueHandler/QueueListener config in 3.12" + ), + ), + [default_picologging_handlers, PicologgingQueueListenerHandler], + ], +) +def test_customizing_handler(handlers: Any, listener: Any, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(handlers["queue_listener"], "handlers", ["cfg://handlers.console"]) + logging_config = LoggingConfig(handlers=handlers) + get_logger = logging_config.configure() + root_logger = get_logger() + assert isinstance(root_logger.handlers[0], listener) # type: ignore From 53e872cd3971c9f1621774930c359e8b1bb08759 Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:42:47 -0600 Subject: [PATCH 08/14] fix: re-add `structlog` tests (#3001) * fix: re-add structlog tests * fix: re-add hint on unreachable section --- litestar/app.py | 4 +- .../test_logging/test_structlog_config.py | 328 ++++++++---------- 2 files changed, 149 insertions(+), 183 deletions(-) diff --git a/litestar/app.py b/litestar/app.py index 2d3decae3e..cf1fc21ddd 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -514,8 +514,8 @@ def debug(self, value: bool) -> None: """ if self.logger and self.logging_config: self.logging_config.set_level(self.logger, logging.DEBUG if value else logging.INFO) - elif self.logger and hasattr(self.logger, "setLevel"): - self.logger.setLevel(logging.DEBUG if value else logging.INFO) + elif self.logger and hasattr(self.logger, "setLevel"): # pragma: no cover + self.logger.setLevel(logging.DEBUG if value else logging.INFO) # pragma: no cover if isinstance(self.logging_config, LoggingConfig): self.logging_config.loggers["litestar"]["level"] = "DEBUG" if value else "INFO" self._debug = value diff --git a/tests/unit/test_logging/test_structlog_config.py b/tests/unit/test_logging/test_structlog_config.py index e850305d5a..39adcc1e1a 100644 --- a/tests/unit/test_logging/test_structlog_config.py +++ b/tests/unit/test_logging/test_structlog_config.py @@ -1,186 +1,152 @@ -import logging -import sys -from typing import TYPE_CHECKING, Any, Dict -from unittest.mock import Mock, patch +from typing import Callable import pytest - -from litestar import Request, get -from litestar.exceptions import ImproperlyConfiguredException -from litestar.logging.config import LoggingConfig, _get_default_handlers, default_handlers, default_picologging_handlers -from litestar.logging.picologging import QueueListenerHandler as PicologgingQueueListenerHandler -from litestar.logging.standard import QueueListenerHandler as StandardQueueListenerHandler -from litestar.status_codes import HTTP_200_OK +import structlog +from pytest import CaptureFixture +from structlog import BytesLoggerFactory, get_logger +from structlog.processors import JSONRenderer +from structlog.types import BindableLogger, WrappedLogger + +from litestar.logging.config import LoggingConfig, StructlogEventFilter, StructLoggingConfig, default_json_serializer +from litestar.plugins.structlog import StructlogConfig, StructlogPlugin +from litestar.serialization import decode_json from litestar.testing import create_test_client -if TYPE_CHECKING: - from _pytest.logging import LogCaptureFixture - - -@pytest.mark.parametrize( - "dict_config_class, handlers, expected_called", - [ - ["logging.config.dictConfig", default_handlers, True], - ["logging.config.dictConfig", default_picologging_handlers, False], - ["picologging.config.dictConfig", default_handlers, False], - ["picologging.config.dictConfig", default_picologging_handlers, True], - ], -) -def test_correct_dict_config_called( - dict_config_class: str, handlers: Dict[str, Dict[str, Any]], expected_called: bool -) -> None: - with patch(dict_config_class) as dict_config_mock: - log_config = LoggingConfig(handlers=handlers) - log_config.configure() - if expected_called: - assert dict_config_mock.called - else: - assert not dict_config_mock.called - - -@pytest.mark.parametrize("picologging_exists", [True, False]) -def test_correct_default_handlers_set(picologging_exists: bool) -> None: - with patch("litestar.logging.config.find_spec") as find_spec_mock: - find_spec_mock.return_value = picologging_exists - log_config = LoggingConfig() - - if picologging_exists: - assert log_config.handlers == default_picologging_handlers - else: - assert log_config.handlers == default_handlers - - -@pytest.mark.parametrize( - "dict_config_class, handlers", - [ - ["logging.config.dictConfig", default_handlers], - ["picologging.config.dictConfig", default_picologging_handlers], - ], -) -def test_dictconfig_startup(dict_config_class: str, handlers: Any) -> None: - with patch(dict_config_class) as dict_config_mock: - test_logger = LoggingConfig( - handlers=handlers, +# structlog.testing.capture_logs changes the processors +# Because we want to test processors, use capsys instead + + +def test_event_filter() -> None: + """Functionality test for the event filter processor.""" + event_filter = StructlogEventFilter(["a_key"]) + log_event = {"a_key": "a_val", "b_key": "b_val"} + log_event = event_filter(..., "", log_event) # type:ignore[assignment] + assert log_event == {"b_key": "b_val"} + + +def test_set_level_custom_logger_factory() -> None: + """Functionality test for the event filter processor.""" + + def custom_logger_factory() -> Callable[..., WrappedLogger]: + """Set the default logger factory for structlog. + + Returns: + An optional logger factory. + """ + return BytesLoggerFactory() + + log_config = StructLoggingConfig(logger_factory=custom_logger_factory, wrapper_class=structlog.stdlib.BoundLogger) + logger = get_logger() + assert logger.bind().__class__.__name__ != "BoundLoggerFilteringAtDebug" + log_config.set_level(logger, 10) + logger.info("a message") + assert logger.bind().__class__.__name__ == "BoundLoggerFilteringAtDebug" + + +def test_structlog_plugin(capsys: CaptureFixture) -> None: + with create_test_client([], plugins=[StructlogPlugin()]) as client: + assert client.app.logger + assert isinstance(client.app.logger.bind(), BindableLogger) + client.app.logger.info("message", key="value") + + log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] + assert len(log_messages) == 1 + + # Format should be: {event: message, key: value, level: info, timestamp: isoformat} + log_messages[0].pop("timestamp") # Assume structlog formats timestamp correctly + assert log_messages[0] == {"event": "message", "key": "value", "level": "info"} + + +def test_structlog_plugin_config(capsys: CaptureFixture) -> None: + config = StructlogConfig() + with create_test_client([], plugins=[StructlogPlugin(config=config)]) as client: + assert client.app.logger + assert isinstance(client.app.logger.bind(), BindableLogger) + client.app.logger.info("message", key="value") + + log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] + assert len(log_messages) == 1 + assert client.app.plugins.get(StructlogPlugin)._config == config + + +def test_structlog_plugin_config_custom_standard_logger() -> None: + standard_logging_config = LoggingConfig() + structlog_logging_config = StructLoggingConfig(standard_lib_logging_config=standard_logging_config) + config = StructlogConfig(structlog_logging_config=structlog_logging_config) + with create_test_client([], plugins=[StructlogPlugin(config=config)]) as client: + assert client.app.plugins.get(StructlogPlugin)._config == config + assert ( + client.app.plugins.get(StructlogPlugin)._config.structlog_logging_config.standard_lib_logging_config + == standard_logging_config + ) + + +def test_structlog_plugin_config_custom() -> None: + structlog_logging_config = StructLoggingConfig(standard_lib_logging_config=None) + config = StructlogConfig(structlog_logging_config=structlog_logging_config) + with create_test_client([], plugins=[StructlogPlugin(config=config)]) as client: + assert client.app.plugins.get(StructlogPlugin)._config == config + assert client.app.plugins.get(StructlogPlugin)._config.structlog_logging_config == structlog_logging_config + assert ( + client.app.plugins.get(StructlogPlugin)._config.structlog_logging_config.standard_lib_logging_config + is not None ) - with create_test_client([], on_startup=[test_logger.configure]): - assert dict_config_mock.called - - -def test_standard_queue_listener_logger(caplog: "LogCaptureFixture") -> None: - with caplog.at_level("INFO", logger="test_logger"): - logger = logging.getLogger("test_logger") - logger.info("Testing now!") - assert "Testing now!" in caplog.text - var = "test_var" - logger.info("%s", var) - assert var in caplog.text - - -@patch("picologging.config.dictConfig") -def test_picologging_dictconfig_when_disabled(dict_config_mock: Mock) -> None: - test_logger = LoggingConfig(loggers={"app": {"level": "INFO", "handlers": ["console"]}}, handlers=default_handlers) - with create_test_client([], on_startup=[test_logger.configure], logging_config=None): - assert not dict_config_mock.called - - -def test_get_logger_without_logging_config() -> None: - with create_test_client(logging_config=None) as client: - with pytest.raises( - ImproperlyConfiguredException, - match="cannot call '.get_logger' without passing 'logging_config' to the Litestar constructor first", - ): - client.app.get_logger() - - -def test_get_default_logger() -> None: - with create_test_client(logging_config=LoggingConfig(handlers=default_handlers)) as client: - assert isinstance(client.app.logger.handlers[0], StandardQueueListenerHandler) - new_logger = client.app.get_logger() - assert isinstance(new_logger.handlers[0], StandardQueueListenerHandler) - - -def test_get_picologging_logger() -> None: - with create_test_client(logging_config=LoggingConfig(handlers=default_picologging_handlers)) as client: - assert isinstance(client.app.logger.handlers[0], PicologgingQueueListenerHandler) - new_logger = client.app.get_logger() - assert isinstance(new_logger.handlers[0], PicologgingQueueListenerHandler) - - -@pytest.mark.parametrize( - "handlers, listener", - [ - [default_handlers, StandardQueueListenerHandler], - [default_picologging_handlers, PicologgingQueueListenerHandler], - ], -) -def test_connection_logger(handlers: Any, listener: Any) -> None: - @get("/") - def handler(request: Request) -> Dict[str, bool]: - return {"isinstance": isinstance(request.logger.handlers[0], listener)} # type: ignore - - with create_test_client(route_handlers=[handler], logging_config=LoggingConfig(handlers=handlers)) as client: - response = client.get("/") - assert response.status_code == HTTP_200_OK - assert response.json()["isinstance"] - - -def test_validation() -> None: - logging_config = LoggingConfig(handlers={}, loggers={}) - assert logging_config.handlers["queue_listener"] == _get_default_handlers()["queue_listener"] - assert logging_config.loggers["litestar"] - - -@pytest.mark.parametrize( - "handlers, listener", - [ - [default_handlers, StandardQueueListenerHandler], - [default_picologging_handlers, PicologgingQueueListenerHandler], - ], -) -def test_root_logger(handlers: Any, listener: Any) -> None: - logging_config = LoggingConfig(handlers=handlers) - get_logger = logging_config.configure() - root_logger = get_logger() - assert isinstance(root_logger.handlers[0], listener) # type: ignore - - -@pytest.mark.parametrize( - "handlers, listener", - [ - [default_handlers, StandardQueueListenerHandler], - [default_picologging_handlers, PicologgingQueueListenerHandler], - ], -) -def test_root_logger_no_config(handlers: Any, listener: Any) -> None: - logging_config = LoggingConfig(handlers=handlers, configure_root_logger=False) - get_logger = logging_config.configure() - root_logger = get_logger() - for handler in root_logger.handlers: # type: ignore[attr-defined] - root_logger.removeHandler(handler) # type: ignore[attr-defined] - get_logger = logging_config.configure() - root_logger = get_logger() - if handlers["console"]["class"] == "logging.StreamHandler": - assert not isinstance(root_logger.handlers[0], listener) # type: ignore[attr-defined] - else: - assert len(root_logger.handlers) < 1 # type: ignore[attr-defined] - - -@pytest.mark.parametrize( - "handlers, listener", - [ - pytest.param( - default_handlers, - StandardQueueListenerHandler, - marks=pytest.mark.xfail( - condition=sys.version_info >= (3, 12), reason="change to QueueHandler/QueueListener config in 3.12" - ), - ), - [default_picologging_handlers, PicologgingQueueListenerHandler], - ], -) -def test_customizing_handler(handlers: Any, listener: Any, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(handlers["queue_listener"], "handlers", ["cfg://handlers.console"]) - logging_config = LoggingConfig(handlers=handlers) - get_logger = logging_config.configure() - root_logger = get_logger() - assert isinstance(root_logger.handlers[0], listener) # type: ignore + + +def test_structlog_plugin_config_with_existing_logging_config(capsys: CaptureFixture) -> None: + existing_log_config = StructLoggingConfig() + standard_logging_config = LoggingConfig() + structlog_logging_config = StructLoggingConfig(standard_lib_logging_config=standard_logging_config) + config = StructlogConfig(structlog_logging_config=structlog_logging_config) + with create_test_client([], logging_config=existing_log_config, plugins=[StructlogPlugin(config=config)]) as client: + assert client.app.plugins.get(StructlogPlugin)._config == config + assert "Found pre-configured" in capsys.readouterr().out + + +def test_structlog_config_no_tty_default(capsys: CaptureFixture) -> None: + with create_test_client([], logging_config=StructLoggingConfig()) as client: + assert client.app.logger + assert isinstance(client.app.logger.bind(), BindableLogger) + client.app.logger.info("message", key="value") + + log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] + assert len(log_messages) == 1 + + # Format should be: {event: message, key: value, level: info, timestamp: isoformat} + log_messages[0].pop("timestamp") # Assume structlog formats timestamp correctly + assert log_messages[0] == {"event": "message", "key": "value", "level": "info"} + + +def test_structlog_config_tty_default(capsys: CaptureFixture, monkeypatch: pytest.MonkeyPatch) -> None: + from sys import stderr + + monkeypatch.setattr(stderr, "isatty", lambda: True) + + with create_test_client([], logging_config=StructLoggingConfig()) as client: + assert client.app.logger + assert isinstance(client.app.logger.bind(), BindableLogger) + client.app.logger.info("message", key="value") + + log_messages = capsys.readouterr().out.splitlines() + assert len(log_messages) == 1 + + assert log_messages[0].startswith("\x1b[") + + +def test_structlog_config_specify_processors(capsys: CaptureFixture) -> None: + logging_config = StructLoggingConfig(processors=[JSONRenderer(serializer=default_json_serializer)]) + + with create_test_client([], logging_config=logging_config) as client: + assert client.app.logger + assert isinstance(client.app.logger.bind(), BindableLogger) + + client.app.logger.info("message1", key="value1") + # Log twice to make sure issue #882 doesn't appear again + client.app.logger.info("message2", key="value2") + + log_messages = [decode_json(value=x) for x in capsys.readouterr().out.splitlines()] + + assert log_messages == [ + {"key": "value1", "event": "message1"}, + {"key": "value2", "event": "message2"}, + ] From a62516f874e75c94858b5fcec9941677447c8d95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Tue, 23 Jan 2024 20:18:50 +0100 Subject: [PATCH 09/14] feat: Implement static file serving with regular route handlers (#2960) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement static files with regular handlers * Add more router options * Add docs * Deprecate StaticFilesConfig --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- docs/examples/static_files/__init__.py | 0 docs/examples/static_files/custom_router.py | 18 ++ docs/examples/static_files/file_system.py | 14 ++ docs/examples/static_files/full_example.py | 22 ++ docs/examples/static_files/html_mode.py | 29 +++ docs/examples/static_files/passing_options.py | 14 ++ docs/examples/static_files/route_reverse.py | 11 + .../static_files/send_as_attachment.py | 12 ++ .../static_files/upgrade_from_static_1.py | 8 + .../static_files/upgrade_from_static_2.py | 8 + docs/usage/static-files.rst | 161 ++++++++------ litestar/app.py | 12 +- litestar/static_files/__init__.py | 4 +- litestar/static_files/base.py | 24 ++- litestar/static_files/config.py | 161 ++++++++++++-- tests/examples/test_static_files.py | 69 ++++++ tests/unit/test_static_files/conftest.py | 31 +++ .../test_create_static_router.py | 73 +++++++ .../test_file_serving_resolution.py | 202 +++++++++++------- .../unit/test_static_files/test_html_mode.py | 39 ++-- .../test_static_files_validation.py | 51 +++-- 21 files changed, 752 insertions(+), 211 deletions(-) create mode 100644 docs/examples/static_files/__init__.py create mode 100644 docs/examples/static_files/custom_router.py create mode 100644 docs/examples/static_files/file_system.py create mode 100644 docs/examples/static_files/full_example.py create mode 100644 docs/examples/static_files/html_mode.py create mode 100644 docs/examples/static_files/passing_options.py create mode 100644 docs/examples/static_files/route_reverse.py create mode 100644 docs/examples/static_files/send_as_attachment.py create mode 100644 docs/examples/static_files/upgrade_from_static_1.py create mode 100644 docs/examples/static_files/upgrade_from_static_2.py create mode 100644 tests/examples/test_static_files.py create mode 100644 tests/unit/test_static_files/conftest.py create mode 100644 tests/unit/test_static_files/test_create_static_router.py diff --git a/docs/examples/static_files/__init__.py b/docs/examples/static_files/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/static_files/custom_router.py b/docs/examples/static_files/custom_router.py new file mode 100644 index 0000000000..cce97ac9c1 --- /dev/null +++ b/docs/examples/static_files/custom_router.py @@ -0,0 +1,18 @@ +from litestar import Litestar +from litestar.router import Router +from litestar.static_files import create_static_files_router + + +class MyRouter(Router): + pass + + +app = Litestar( + route_handlers=[ + create_static_files_router( + path="/static", + directories=["assets"], + router_class=MyRouter, + ) + ] +) diff --git a/docs/examples/static_files/file_system.py b/docs/examples/static_files/file_system.py new file mode 100644 index 0000000000..e98c1f459a --- /dev/null +++ b/docs/examples/static_files/file_system.py @@ -0,0 +1,14 @@ +from fsspec.implementations.ftp import FTPFileSystem + +from litestar import Litestar +from litestar.static_files import create_static_files_router + +app = Litestar( + route_handlers=[ + create_static_files_router( + path="/static", + directories=["assets"], + file_system=FTPFileSystem(host="127.0.0.1"), + ), + ] +) diff --git a/docs/examples/static_files/full_example.py b/docs/examples/static_files/full_example.py new file mode 100644 index 0000000000..f2fb443f6a --- /dev/null +++ b/docs/examples/static_files/full_example.py @@ -0,0 +1,22 @@ +from pathlib import Path + +from litestar import Litestar +from litestar.static_files import create_static_files_router + +ASSETS_DIR = Path("assets") + + +def on_startup(): + ASSETS_DIR.mkdir(exist_ok=True) + ASSETS_DIR.joinpath("hello.txt").write_text("Hello, world!") + + +app = Litestar( + route_handlers=[ + create_static_files_router(path="/static", directories=["assets"]), + ], + on_startup=[on_startup], +) + + +# run: /static/hello.txt diff --git a/docs/examples/static_files/html_mode.py b/docs/examples/static_files/html_mode.py new file mode 100644 index 0000000000..6b6ed95b65 --- /dev/null +++ b/docs/examples/static_files/html_mode.py @@ -0,0 +1,29 @@ +from pathlib import Path + +from litestar import Litestar +from litestar.static_files import create_static_files_router + +HTML_DIR = Path("html") + + +def on_startup() -> None: + HTML_DIR.mkdir(exist_ok=True) + HTML_DIR.joinpath("index.html").write_text("Hello, world!") + HTML_DIR.joinpath("404.html").write_text("

Not found

") + + +app = Litestar( + route_handlers=[ + create_static_files_router( + path="/", + directories=["html"], + html_mode=True, + ) + ], + on_startup=[on_startup], +) + + +# run: / +# run: /index.html +# run: /something diff --git a/docs/examples/static_files/passing_options.py b/docs/examples/static_files/passing_options.py new file mode 100644 index 0000000000..63f4449c17 --- /dev/null +++ b/docs/examples/static_files/passing_options.py @@ -0,0 +1,14 @@ +from litestar import Litestar +from litestar.static_files import create_static_files_router + +app = Litestar( + route_handlers=[ + create_static_files_router( + path="/", + directories=["assets"], + opt={"some": True}, + include_in_schema=False, + tags=["static"], + ) + ] +) diff --git a/docs/examples/static_files/route_reverse.py b/docs/examples/static_files/route_reverse.py new file mode 100644 index 0000000000..0276c51974 --- /dev/null +++ b/docs/examples/static_files/route_reverse.py @@ -0,0 +1,11 @@ +from litestar import Litestar +from litestar.static_files import create_static_files_router + +app = Litestar( + route_handlers=[ + create_static_files_router(path="/static", directories=["assets"]), + ] +) + + +print(app.route_reverse(name="static", file_path="/some_file.txt")) # /static/some_file.txt diff --git a/docs/examples/static_files/send_as_attachment.py b/docs/examples/static_files/send_as_attachment.py new file mode 100644 index 0000000000..e13ae69c05 --- /dev/null +++ b/docs/examples/static_files/send_as_attachment.py @@ -0,0 +1,12 @@ +from litestar import Litestar +from litestar.static_files import create_static_files_router + +app = Litestar( + route_handlers=[ + create_static_files_router( + path="/static", + directories=["assets"], + send_as_attachment=True, + ) + ] +) diff --git a/docs/examples/static_files/upgrade_from_static_1.py b/docs/examples/static_files/upgrade_from_static_1.py new file mode 100644 index 0000000000..ad0b8aa61d --- /dev/null +++ b/docs/examples/static_files/upgrade_from_static_1.py @@ -0,0 +1,8 @@ +from litestar import Litestar +from litestar.static_files.config import StaticFilesConfig + +app = Litestar( + static_files_config=[ + StaticFilesConfig(directories=["assets"], path="/static"), + ], +) diff --git a/docs/examples/static_files/upgrade_from_static_2.py b/docs/examples/static_files/upgrade_from_static_2.py new file mode 100644 index 0000000000..af4578f333 --- /dev/null +++ b/docs/examples/static_files/upgrade_from_static_2.py @@ -0,0 +1,8 @@ +from litestar import Litestar +from litestar.static_files import create_static_files_router + +app = Litestar( + route_handlers=[ + create_static_files_router(directories=["assets"], path="/static"), + ], +) diff --git a/docs/usage/static-files.rst b/docs/usage/static-files.rst index fb1b8f2d47..8216a37818 100644 --- a/docs/usage/static-files.rst +++ b/docs/usage/static-files.rst @@ -1,89 +1,116 @@ -Static Files +Static files ============ -Static files are served by the app from predefined locations. To configure static file serving, either pass an -instance of :class:`StaticFilesConfig <.static_files.config.StaticFilesConfig>` or a list -thereof to :class:`Litestar <.app.Litestar>` using the ``static_files_config`` kwarg. +To serve static files (i.e. serve arbitrary files from a given directory), the +:func:`~litestar.static_files.create_static_files_router` can be used to create a +:class:`Router ` to handle this task. -For example, lets say our Litestar app is going to serve **regular files** from the ``my_app/static`` folder and **html -documents** from the ``my_app/html`` folder, and we would like to serve the **static files** on the ``/files`` path, -and the **html files** on the ``/html`` path: +.. literalinclude:: /examples/static_files/full_example.py + :language: python -.. code-block:: python +In this example, files from the directory ``assets`` will be served on the path +``/static``. A file ``assets/hello.txt`` would now be available on ``/static/hello.txt`` - from litestar import Litestar - from litestar.static_files.config import StaticFilesConfig +.. attention:: + Directories are interpreted as relative to the working directory from which the + application is started - app = Litestar( - route_handlers=[...], - static_files_config=[ - StaticFilesConfig(directories=["static"], path="/files"), - StaticFilesConfig(directories=["html"], path="/html", html_mode=True), - ], - ) -Matching is done based on filename, for example, assume we have a request that is trying to retrieve the path -``/files/file.txt``\ , the **directory for the base path** ``/files`` **will be searched** for the file ``file.txt``. If it is -found, the file will be sent, otherwise a **404 response** will be sent. +Sending files as attachments +---------------------------- -If ``html_mode`` is enabled and no specific file is requested, the application will fall back to serving ``index.html``. If -no file is found the application will look for a ``404.html`` file in order to render a response, otherwise a 404 -:class:`NotFoundException <.exceptions.http_exceptions.NotFoundException>` will be returned. +By default, files are sent "inline", meaning they will have a +``Content-Disposition: inline`` header. Setting ``send_as_attachment=True`` flag will +send them with a ``Content-Disposition: attachment`` instead: -You can provide a ``name`` parameter to ``StaticFilesConfig`` to identify the given config and generate links to files in -folders belonging to that config. ``name`` should be a unique string across all static configs and -:doc:`/usage/routing/handlers`. +.. literalinclude:: /examples/static_files/send_as_attachment.py + :language: python -.. code-block:: python - from litestar import Litestar - from litestar.static_files.config import StaticFilesConfig +HTML mode +--------- - app = Litestar( - route_handlers=[...], - static_files_config=[ - StaticFilesConfig( - directories=["static"], path="/some_folder/static/path", name="static" - ), - ], - ) +"HTML mode" can be enabled by setting ``html_mode=True``. This will: - url_path = app.url_for_static_asset("static", "file.pdf") - # /some_folder/static/path/file.pdf +- Serve and ``/index.html`` when the path ``/`` is requested +- Attempt to serve ``/404.html`` when a requested file is not found + + +.. literalinclude:: /examples/static_files/html_mode.py + :language: python + + +Passing options to the generated router +--------------------------------------- + +Options available on :class:`~litestar.router.Router` can be passed to directly +:func:`~litestar.static_files.create_static_files_router`: + +.. literalinclude:: /examples/static_files/passing_options.py + :language: python + + +Using a custom router class +--------------------------- + +The router class used can be customized with the ``router_class`` parameter: + +.. literalinclude:: /examples/static_files/custom_router.py + :language: python + + + +Retrieving paths to static files +-------------------------------- + +:meth:`~litestar.app.Litestar.route_reverse` and +:meth:`~litestar.connection.ASGIConnection.url_for` can be used to retrieve the path +under which a specific file will be available: + +.. literalinclude:: /examples/static_files/route_reverse.py + :language: python + +.. tip:: + + The ``name`` parameter has to match the ``name`` parameter passed to + :func:`create_static_files_router`, which defaults to ``static``. + + +(Remote) file systems +--------------------- + +To customize how Litestar interacts with the file system, a class implementing the +:class:`~litestar.types.FileSystemProtocol` can be passed to ``file_system``. An example +of this are the file systems provided by +`fsspec `_, which includes support +for FTP, SFTP, Hadoop, SMB, GitHub and +`many more `_, +with support for popular cloud providers available via 3rd party implementations such as + +- S3 via `S3FS `_ +- Google Cloud Storage via `GCFS `_ +- Azure Blob Storage via `adlfs `_ + + +.. literalinclude:: /examples/static_files/file_system.py + :language: python -Sending files as attachments ----------------------------- -By default, files are sent "inline", meaning they will have a ``Content-Disposition: inline`` header. -To send them as attachments, use the ``send_as_attachment=True`` flag, which will add a -``Content-Disposition: attachment`` header: +Upgrading from legacy StaticFilesConfig +--------------------------------------- -.. code-block:: python +.. important:: Info + :class:`StaticFilesConfig` is deprecated and will be removed in Litestar 3.0 - from litestar import Litestar - from litestar.static_files.config import StaticFilesConfig - app = Litestar( - route_handlers=[...], - static_files_config=[ - StaticFilesConfig( - directories=["static"], - path="/some_folder/static/path", - name="static", - send_as_attachment=True, - ), - ], - ) +Existing code can be upgraded to :func:`create_static_files_router` by replacing +:class:`StaticFilesConfig` instances with this function call and passing the result to +``route_handlers`` instead of ``static_files_config``: -File System support and Cloud Files ------------------------------------ -The :class:`StaticFilesConfig <.static_files.StaticFilesConfig>` class accepts a value called ``file_system``, -which can be any class adhering to the Litestar :class:`FileSystemProtocol `. +.. literalinclude:: /examples/static_files/upgrade_from_static_1.py + :language: python -This protocol is similar to the file systems defined by `fsspec `_, -which cover all major cloud providers and a wide range of other use cases (e.g. HTTP based file service, ``ftp``, etc.). -In order to use any file system, simply use `fsspec `_ or one of -the libraries based upon it, or provide a custom implementation adhering to the -:class:`FileSystemProtocol `. +.. literalinclude:: /examples/static_files/upgrade_from_static_2.py + :language: python diff --git a/litestar/app.py b/litestar/app.py index cf1fc21ddd..1b62113f9f 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -137,6 +137,7 @@ class Litestar(Router): "_server_lifespan_managers", "_debug", "_openapi_schema", + "_static_files_config", "plugins", "after_exception", "allowed_hosts", @@ -160,7 +161,6 @@ class Litestar(Router): "route_map", "signature_namespace", "state", - "static_files_config", "stores", "template_engine", "websocket_class", @@ -410,7 +410,7 @@ def __init__( self.request_class = config.request_class or Request self.response_cache_config = config.response_cache_config self.state = config.state - self.static_files_config = config.static_files_config + self._static_files_config = config.static_files_config self.template_engine = config.template_config.engine_instance if config.template_config else None self.websocket_class = config.websocket_class or WebSocket self.debug = config.debug @@ -465,6 +465,11 @@ def __init__( config.stores if isinstance(config.stores, StoreRegistry) else StoreRegistry(config.stores) ) + @property + @deprecated(version="2.6.0", kind="property", info="Use create_static_files router instead") + def static_files_config(self) -> list[StaticFilesConfig]: + return self._static_files_config + @property @deprecated(version="2.0", alternative="Litestar.plugins.cli", kind="property") def cli_plugins(self) -> list[CLIPluginProtocol]: @@ -739,6 +744,9 @@ def get_membership_details(group_id: int, user_id: int) -> None: return join_paths(output) + @deprecated( + "2.6.0", info="Use create_static_files router instead of StaticFilesConfig, which works with route_reverse" + ) def url_for_static_asset(self, name: str, file_path: str) -> str: """Receives a static files handler name, an asset file path and returns resolved url path to the asset. diff --git a/litestar/static_files/__init__.py b/litestar/static_files/__init__.py index 22bab52f8f..3cd45945f7 100644 --- a/litestar/static_files/__init__.py +++ b/litestar/static_files/__init__.py @@ -1,4 +1,4 @@ from litestar.static_files.base import StaticFiles -from litestar.static_files.config import StaticFilesConfig +from litestar.static_files.config import StaticFilesConfig, create_static_files_router -__all__ = ("StaticFiles", "StaticFilesConfig") +__all__ = ("StaticFiles", "StaticFilesConfig", "create_static_files_router") diff --git a/litestar/static_files/base.py b/litestar/static_files/base.py index 889d7020d0..12f197f987 100644 --- a/litestar/static_files/base.py +++ b/litestar/static_files/base.py @@ -30,6 +30,7 @@ def __init__( directories: Sequence[PathType], file_system: FileSystemProtocol, send_as_attachment: bool = False, + resolve_symlinks: bool = True, ) -> None: """Initialize the Application. @@ -39,9 +40,10 @@ def __init__( file_system: The file_system spec to use for serving files. send_as_attachment: Whether to send the file with a ``content-disposition`` header of ``attachment`` or ``inline`` + resolve_symlinks: Resolve symlinks to the directories """ self.adapter = FileSystemAdapter(file_system) - self.directories = tuple(Path(p).resolve() for p in directories) + self.directories = tuple(Path(p).resolve() if resolve_symlinks else Path(p) for p in directories) self.is_html_mode = is_html_mode self.send_as_attachment = send_as_attachment @@ -82,7 +84,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != ScopeType.HTTP or scope["method"] not in {"GET", "HEAD"}: raise MethodNotAllowedException() - split_path = scope["path"].split("/") + res = await self.handle(path=scope["path"], is_head_response=scope["method"] == "HEAD") + await res(scope=scope, receive=receive, send=send) + + async def handle(self, path: str, is_head_response: bool) -> ASGIFileResponse: + split_path = path.split("/") filename = split_path[-1] joined_path = Path(*split_path) resolved_path, fs_info = await self.get_fs_info(directories=self.directories, file_path=joined_path) @@ -98,15 +104,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) if fs_info and fs_info["type"] == "file": - await ASGIFileResponse( + return ASGIFileResponse( file_path=resolved_path or joined_path, file_info=fs_info, file_system=self.adapter.file_system, filename=filename, content_disposition_type=content_disposition_type, - is_head_response=scope["method"] == "HEAD", - )(scope, receive, send) - return + is_head_response=is_head_response, + ) if self.is_html_mode: # for some reason coverage doesn't catch these two lines @@ -116,16 +121,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) if fs_info and fs_info["type"] == "file": - await ASGIFileResponse( + return ASGIFileResponse( file_path=resolved_path or joined_path, file_info=fs_info, file_system=self.adapter.file_system, filename=filename, status_code=HTTP_404_NOT_FOUND, content_disposition_type=content_disposition_type, - is_head_response=scope["method"] == "HEAD", - )(scope, receive, send) - return + is_head_response=is_head_response, + ) raise NotFoundException( f"no file or directory match the path {resolved_path or joined_path} was found" diff --git a/litestar/static_files/config.py b/litestar/static_files/config.py index 9cca30fa79..283d095665 100644 --- a/litestar/static_files/config.py +++ b/litestar/static_files/config.py @@ -1,20 +1,34 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from pathlib import PurePath # noqa: TCH003 +from typing import TYPE_CHECKING, Any, Sequence from litestar.exceptions import ImproperlyConfiguredException from litestar.file_system import BaseLocalFileSystem -from litestar.handlers import asgi +from litestar.handlers import asgi, get, head +from litestar.response.file import ASGIFileResponse # noqa: TCH001 +from litestar.router import Router from litestar.static_files.base import StaticFiles -from litestar.utils import normalize_path +from litestar.types import Empty +from litestar.utils import normalize_path, warn_deprecation __all__ = ("StaticFilesConfig",) - if TYPE_CHECKING: + from litestar.datastructures import CacheControlHeader from litestar.handlers.asgi_handlers import ASGIRouteHandler - from litestar.types import ExceptionHandlersMap, Guard, PathType + from litestar.openapi.spec import SecurityRequirement + from litestar.types import ( + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeRequestHookHandler, + EmptyType, + ExceptionHandlersMap, + Guard, + Middleware, + PathType, + ) @dataclass @@ -58,21 +72,17 @@ class StaticFilesConfig: """Whether to send the file as an attachment.""" def __post_init__(self) -> None: - if not self.path: - raise ImproperlyConfiguredException("path must be a non-zero length string,") - - if not self.directories or not any(bool(d) for d in self.directories): - raise ImproperlyConfiguredException("directories must include at least one path.") - - if "{" in self.path: - raise ImproperlyConfiguredException("path parameters are not supported for static files") - - if not ( - callable(getattr(self.file_system, "info", None)) and callable(getattr(self.file_system, "open", None)) - ): - raise ImproperlyConfiguredException("file_system must adhere to the FileSystemProtocol type") - + _validate_config(path=self.path, directories=self.directories, file_system=self.file_system) self.path = normalize_path(self.path) + warn_deprecation( + "2.6.0", + kind="class", + deprecated_name="StaticFilesConfig", + removal_in="3.0", + alternative="create_static_files_router", + info='Replace static_files_config=[StaticFilesConfig(path="/static", directories=["assets"])] with ' + 'route_handlers=[..., create_static_files_router(path="/static", directories=["assets"])]', + ) def to_static_files_app(self) -> ASGIRouteHandler: """Return an ASGI app serving static files based on the config. @@ -94,3 +104,116 @@ def to_static_files_app(self) -> ASGIRouteHandler: guards=self.guards, exception_handlers=self.exception_handlers, )(static_files) + + +def create_static_files_router( + path: str, + directories: list[PathType], + file_system: Any = None, + send_as_attachment: bool = False, + html_mode: bool = False, + name: str = "static", + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + before_request: BeforeRequestHookHandler | None = None, + cache_control: CacheControlHeader | None = None, + exception_handlers: ExceptionHandlersMap | None = None, + guards: list[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + middleware: Sequence[Middleware] | None = None, + opt: dict[str, Any] | None = None, + security: Sequence[SecurityRequirement] | None = None, + tags: Sequence[str] | None = None, + router_class: type[Router] = Router, + resolve_symlinks: bool = True, +) -> Router: + """Create a router with handlers to serve static files. + + Args: + path: Path to serve static files under + directories: Directories to serve static files from + file_system: A *file system* implementing + :class:`~litestar.types.FileSystemProtocol`. + `fsspec `_ can be passed + here as well + send_as_attachment: Whether to send the file as an attachment + html_mode: When in HTML: + - Serve an ``index.html`` file from ``/`` + - Serve ``404.html`` when a file could not be found + name: Name to pass to the generated handlers + after_request: ``after_request`` handlers passed to the router + after_response: ``after_response`` handlers passed to the router + before_request: ``before_request`` handlers passed to the router + cache_control: ``cache_control`` passed to the router + exception_handlers: Exception handlers passed to the router + guards: Guards passed to the router + include_in_schema: Include the routes / router in the OpenAPI schema + middleware: Middlewares passed to the router + opt: Opts passed to the router + security: Security options passed to the router + tags: ``tags`` passed to the router + router_class: The class used to construct a router from + resolve_symlinks: Resolve symlinks of ``directories`` + """ + + if file_system is None: + file_system = BaseLocalFileSystem() + + _validate_config(path=path, directories=directories, file_system=file_system) + path = normalize_path(path) + + static_files = StaticFiles( + is_html_mode=html_mode, + directories=directories, + file_system=file_system, + send_as_attachment=send_as_attachment, + resolve_symlinks=resolve_symlinks, + ) + + @get("{file_path:path}", name=name) + async def get_handler(file_path: PurePath) -> ASGIFileResponse: + return await static_files.handle(path=str(file_path), is_head_response=False) + + @head("/{file_path:path}", name=f"{name}/head") + async def head_handler(file_path: PurePath) -> ASGIFileResponse: + return await static_files.handle(path=str(file_path), is_head_response=True) + + handlers = [get_handler, head_handler] + + if html_mode: + + @get("/", name=f"{name}/index") + async def index_handler() -> ASGIFileResponse: + return await static_files.handle(path="/", is_head_response=False) + + handlers.append(index_handler) + + return router_class( + after_request=after_request, + after_response=after_response, + before_request=before_request, + cache_control=cache_control, + exception_handlers=exception_handlers, + guards=guards, + include_in_schema=include_in_schema, + middleware=middleware, + opt=opt, + path=path, + route_handlers=handlers, + security=security, + tags=tags, + ) + + +def _validate_config(path: str, directories: list[PathType], file_system: Any) -> None: + if not path: + raise ImproperlyConfiguredException("path must be a non-zero length string,") + + if not directories or not any(bool(d) for d in directories): + raise ImproperlyConfiguredException("directories must include at least one path.") + + if "{" in path: + raise ImproperlyConfiguredException("path parameters are not supported for static files") + + if not (callable(getattr(file_system, "info", None)) and callable(getattr(file_system, "open", None))): + raise ImproperlyConfiguredException("file_system must adhere to the FileSystemProtocol type") diff --git a/tests/examples/test_static_files.py b/tests/examples/test_static_files.py new file mode 100644 index 0000000000..7d48603992 --- /dev/null +++ b/tests/examples/test_static_files.py @@ -0,0 +1,69 @@ +import secrets +from pathlib import Path + +import pytest +from _pytest.monkeypatch import MonkeyPatch + +from litestar.testing import TestClient + + +@pytest.fixture(autouse=True) +def _chdir(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + +@pytest.fixture() +def assets_file(tmp_path: Path) -> str: + content = secrets.token_hex() + assets_path = tmp_path / "assets" + assets_path.mkdir() + assets_path.joinpath("test.txt").write_text(content) + return content + + +def test_custom_router() -> None: + from docs.examples.static_files import custom_router # noqa: F401 + + +def test_full_example() -> None: + from docs.examples.static_files import full_example + + with TestClient(full_example.app) as client: + assert client.get("/static/hello.txt").text == "Hello, world!" + + +def test_html_mode() -> None: + from docs.examples.static_files import html_mode + + with TestClient(html_mode.app) as client: + assert client.get("/").text == "Hello, world!" + assert client.get("/index.html").text == "Hello, world!" + assert client.get("/something").text == "

Not found

" + + +def test_passing_options() -> None: + from docs.examples.static_files import passing_options # noqa: F401 + + +def test_route_reverse(capsys) -> None: + from docs.examples.static_files import route_reverse # noqa: F401 + + assert capsys.readouterr().out.strip() == "/static/some_file.txt" + + +def test_send_as_attachment(tmp_path: Path, assets_file: str) -> None: + from docs.examples.static_files import send_as_attachment + + with TestClient(send_as_attachment.app) as client: + res = client.get("/static/test.txt") + assert res.text == assets_file + assert res.headers["content-disposition"].startswith("attachment") + + +def test_upgrade_from_static(tmp_path: Path, assets_file: str) -> None: + from docs.examples.static_files import upgrade_from_static_1, upgrade_from_static_2 + + for app in [upgrade_from_static_1.app, upgrade_from_static_2.app]: + with TestClient(app) as client: + res = client.get("/static/test.txt") + assert res.text == assets_file diff --git a/tests/unit/test_static_files/conftest.py b/tests/unit/test_static_files/conftest.py new file mode 100644 index 0000000000..01d70d2bb1 --- /dev/null +++ b/tests/unit/test_static_files/conftest.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from dataclasses import asdict +from typing import Callable + +import pytest +from _pytest.fixtures import FixtureRequest +from fsspec.implementations.local import LocalFileSystem +from typing_extensions import TypeAlias + +from litestar import Router +from litestar.file_system import BaseLocalFileSystem +from litestar.static_files import StaticFilesConfig, create_static_files_router +from litestar.types import FileSystemProtocol + +MakeConfig: TypeAlias = "Callable[[StaticFilesConfig], tuple[list[StaticFilesConfig], list[Router]]]" + + +@pytest.fixture(params=["config", "handlers"]) +def make_config(request: FixtureRequest) -> MakeConfig: + def make(config: StaticFilesConfig) -> tuple[list[StaticFilesConfig], list[Router]]: + if request.param == "config": + return [config], [] + return [], [create_static_files_router(**asdict(config))] + + return make + + +@pytest.fixture(params=[BaseLocalFileSystem(), LocalFileSystem()]) +def file_system(request: FixtureRequest) -> FileSystemProtocol: + return request.param # type: ignore[no-any-return] diff --git a/tests/unit/test_static_files/test_create_static_router.py b/tests/unit/test_static_files/test_create_static_router.py new file mode 100644 index 0000000000..08cc71653a --- /dev/null +++ b/tests/unit/test_static_files/test_create_static_router.py @@ -0,0 +1,73 @@ +from typing import Any + +from litestar import Litestar, Request, Response, Router +from litestar.connection import ASGIConnection +from litestar.datastructures import CacheControlHeader +from litestar.exceptions import ValidationException +from litestar.handlers import BaseRouteHandler +from litestar.static_files import create_static_files_router + + +def test_route_reverse() -> None: + app = Litestar( + route_handlers=[create_static_files_router(path="/static", directories=["something"], name="static")] + ) + + assert app.route_reverse("static", file_path="foo.py") == "/static/foo.py" + + +def test_pass_options() -> None: + def guard(connection: ASGIConnection, handler: BaseRouteHandler) -> None: + pass + + def handle(request: Request, exception: Any) -> Response: + return Response(b"") + + async def after_request(response: Response) -> Response: + return Response(b"") + + async def after_response(request: Request) -> None: + pass + + async def before_request(request: Request) -> Any: + pass + + exception_handlers = {ValidationException: handle} + opts = {"foo": "bar"} + cache_control = CacheControlHeader() + security = [{"foo": ["bar"]}] + tags = ["static", "random"] + + router = create_static_files_router( + path="/", + directories=["something"], + guards=[guard], + exception_handlers=exception_handlers, # type: ignore[arg-type] + opt=opts, + after_request=after_request, + after_response=after_response, + before_request=before_request, + cache_control=cache_control, + include_in_schema=False, + security=security, + tags=tags, + ) + + assert router.guards == [guard] + assert router.exception_handlers == exception_handlers + assert router.opt == opts + assert router.after_request is after_request + assert router.after_response is after_response + assert router.before_request is before_request + assert router.cache_control is cache_control + assert router.include_in_schema is False + assert router.security == security + assert router.tags == tags + + +def test_custom_router_class() -> None: + class MyRouter(Router): + pass + + router = create_static_files_router("/", directories=["some"], router_class=MyRouter) + assert isinstance(router, MyRouter) diff --git a/tests/unit/test_static_files/test_file_serving_resolution.py b/tests/unit/test_static_files/test_file_serving_resolution.py index 564799a758..3bc9c70d6a 100644 --- a/tests/unit/test_static_files/test_file_serving_resolution.py +++ b/tests/unit/test_static_files/test_file_serving_resolution.py @@ -1,125 +1,145 @@ +from __future__ import annotations + import gzip import mimetypes from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import brotli import pytest -from fsspec.implementations.local import LocalFileSystem +from typing_extensions import TypeAlias -from litestar import MediaType, get -from litestar.file_system import BaseLocalFileSystem -from litestar.static_files.config import StaticFilesConfig +from litestar import MediaType, Router, get +from litestar.static_files import StaticFilesConfig, create_static_files_router from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client +from tests.unit.test_static_files.conftest import MakeConfig if TYPE_CHECKING: from litestar.types import FileSystemProtocol -def test_default_static_files_config(tmpdir: "Path") -> None: +def test_default_static_files_config(tmpdir: Path, make_config: MakeConfig) -> None: path = tmpdir / "test.txt" path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) + static_files_config, router = make_config(StaticFilesConfig(path="/static", directories=[tmpdir])) - with create_test_client([], static_files_config=[static_files_config]) as client: + with create_test_client(router, static_files_config=static_files_config) as client: response = client.get("/static/test.txt") assert response.status_code == HTTP_200_OK, response.text assert response.text == "content" -def test_multiple_static_files_configs(tmpdir: "Path") -> None: - root1 = tmpdir.mkdir("1") # type: ignore - root2 = tmpdir.mkdir("2") # type: ignore - path1 = root1 / "test.txt" # pyright: ignore - path1.write_text("content1", "utf-8") - path2 = root2 / "test.txt" # pyright: ignore - path2.write_text("content2", "utf-8") +@pytest.fixture() +def setup_dirs(tmpdir: Path) -> tuple[Path, Path]: + paths = [] + for i in range(1, 3): + root = tmpdir / str(i) + root.mkdir() + file_path = root / f"test_{i}.txt" + file_path.write_text(f"content{i}", "utf-8") + paths.append(root) + + return paths[0], paths[1] + + +MakeConfigs: TypeAlias = ( + "Callable[[StaticFilesConfig, StaticFilesConfig], tuple[list[StaticFilesConfig], list[Router]]]" +) + - static_files_config = [ +@pytest.fixture() +def make_configs(make_config: MakeConfig) -> MakeConfigs: + def make( + first_config: StaticFilesConfig, second_config: StaticFilesConfig + ) -> tuple[list[StaticFilesConfig], list[Router]]: + configs_1, routers_1 = make_config(first_config) + configs_2, routers_2 = make_config(second_config) + return [*configs_1, *configs_2], [*routers_1, *routers_2] + + return make + + +def test_multiple_static_files_configs(setup_dirs: tuple[Path, Path], make_configs: MakeConfigs) -> None: + root1, root2 = setup_dirs + + configs, handlers = make_configs( StaticFilesConfig(path="/static_first", directories=[root1]), # pyright: ignore StaticFilesConfig(path="/static_second", directories=[root2]), # pyright: ignore - ] - with create_test_client([], static_files_config=static_files_config) as client: - response = client.get("/static_first/test.txt") + ) + with create_test_client(handlers, static_files_config=configs) as client: + response = client.get("/static_first/test_1.txt") assert response.status_code == HTTP_200_OK assert response.text == "content1" - response = client.get("/static_second/test.txt") + response = client.get("/static_second/test_2.txt") assert response.status_code == HTTP_200_OK assert response.text == "content2" -@pytest.mark.parametrize("file_system", (BaseLocalFileSystem(), LocalFileSystem())) -def test_static_files_configs_with_mixed_file_systems(tmpdir: "Path", file_system: "FileSystemProtocol") -> None: - root1 = tmpdir.mkdir("1") # type: ignore - root2 = tmpdir.mkdir("2") # type: ignore - path1 = root1 / "test.txt" # pyright: ignore - path1.write_text("content1", "utf-8") - path2 = root2 / "test.txt" # pyright: ignore - path2.write_text("content2", "utf-8") +def test_static_files_configs_with_mixed_file_systems( + file_system: FileSystemProtocol, setup_dirs: tuple[Path, Path], make_configs: MakeConfigs +) -> None: + root1, root2 = setup_dirs - static_files_config = [ + configs, handlers = make_configs( StaticFilesConfig(path="/static_first", directories=[root1], file_system=file_system), # pyright: ignore StaticFilesConfig(path="/static_second", directories=[root2]), # pyright: ignore - ] - with create_test_client([], static_files_config=static_files_config) as client: - response = client.get("/static_first/test.txt") + ) + + with create_test_client(handlers, static_files_config=configs) as client: + response = client.get("/static_first/test_1.txt") assert response.status_code == HTTP_200_OK assert response.text == "content1" - response = client.get("/static_second/test.txt") + response = client.get("/static_second/test_2.txt") assert response.status_code == HTTP_200_OK assert response.text == "content2" -@pytest.mark.parametrize("file_system", (BaseLocalFileSystem(), LocalFileSystem())) -def test_static_files_config_with_multiple_directories(tmpdir: "Path", file_system: "FileSystemProtocol") -> None: - root1 = tmpdir.mkdir("first") # type: ignore - root2 = tmpdir.mkdir("second") # type: ignore - path1 = root1 / "test1.txt" # pyright: ignore - path1.write_text("content1", "utf-8") - path2 = root2 / "test2.txt" # pyright: ignore - path2.write_text("content2", "utf-8") +def test_static_files_config_with_multiple_directories( + file_system: FileSystemProtocol, setup_dirs: tuple[Path, Path], make_config: MakeConfig +) -> None: + root1, root2 = setup_dirs + configs, handlers = make_config( + StaticFilesConfig(path="/static", directories=[root1, root2], file_system=file_system) + ) - with create_test_client( - [], - static_files_config=[ - StaticFilesConfig(path="/static", directories=[root1, root2], file_system=file_system) # pyright: ignore - ], - ) as client: - response = client.get("/static/test1.txt") + with create_test_client(handlers, static_files_config=configs) as client: + response = client.get("/static/test_1.txt") assert response.status_code == HTTP_200_OK assert response.text == "content1" - response = client.get("/static/test2.txt") + response = client.get("/static/test_2.txt") assert response.status_code == HTTP_200_OK assert response.text == "content2" -def test_staticfiles_for_slash_path_regular_mode(tmpdir: "Path") -> None: +def test_staticfiles_for_slash_path_regular_mode(tmpdir: Path, make_config: MakeConfig) -> None: path = tmpdir / "text.txt" path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/", directories=[tmpdir]) - with create_test_client([], static_files_config=[static_files_config]) as client: + configs, handlers = make_config(StaticFilesConfig(path="/", directories=[tmpdir])) + + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/text.txt") assert response.status_code == HTTP_200_OK assert response.text == "content" -def test_staticfiles_for_slash_path_html_mode(tmpdir: "Path") -> None: +def test_staticfiles_for_slash_path_html_mode(tmpdir: Path, make_config: MakeConfig) -> None: path = tmpdir / "index.html" path.write_text("", "utf-8") - static_files_config = StaticFilesConfig(path="/", directories=[tmpdir], html_mode=True) - with create_test_client([], static_files_config=[static_files_config]) as client: + configs, handlers = make_config(StaticFilesConfig(path="/", directories=[tmpdir], html_mode=True)) + + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/") assert response.status_code == HTTP_200_OK assert response.text == "" -def test_sub_path_under_static_path(tmpdir: "Path") -> None: +def test_sub_path_under_static_path(tmpdir: Path, make_config: MakeConfig) -> None: path = tmpdir / "test.txt" path.write_text("content", "utf-8") @@ -127,9 +147,10 @@ def test_sub_path_under_static_path(tmpdir: "Path") -> None: def handler(f: str) -> str: return f - with create_test_client( - handler, static_files_config=[StaticFilesConfig(path="/static", directories=[tmpdir])] - ) as client: + configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[tmpdir])) + handlers.append(handler) # type: ignore[arg-type] + + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/static/test.txt") assert response.status_code == HTTP_200_OK @@ -137,26 +158,26 @@ def handler(f: str) -> str: assert response.status_code == HTTP_200_OK -def test_static_substring_of_self(tmpdir: "Path") -> None: +def test_static_substring_of_self(tmpdir: Path, make_config: MakeConfig) -> None: path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) - with create_test_client([], static_files_config=[static_files_config]) as client: + configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[tmpdir])) + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/static/static_part/static/test.txt") assert response.status_code == HTTP_200_OK assert response.text == "content" @pytest.mark.parametrize("extension", ["css", "js", "html", "json"]) -def test_static_files_response_mimetype(tmpdir: "Path", extension: str) -> None: +def test_static_files_response_mimetype(tmpdir: Path, extension: str, make_config: MakeConfig) -> None: fn = f"test.{extension}" path = tmpdir / fn path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) + configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[tmpdir])) expected_mime_type = mimetypes.guess_type(fn)[0] - with create_test_client([], static_files_config=[static_files_config]) as client: + with create_test_client(handlers, static_files_config=configs) as client: response = client.get(f"/static/{fn}") assert expected_mime_type assert response.status_code == HTTP_200_OK @@ -164,7 +185,7 @@ def test_static_files_response_mimetype(tmpdir: "Path", extension: str) -> None: @pytest.mark.parametrize("extension", ["gz", "br"]) -def test_static_files_response_encoding(tmp_path: "Path", extension: str) -> None: +def test_static_files_response_encoding(tmp_path: Path, extension: str, make_config: MakeConfig) -> None: fn = f"test.js.{extension}" path = tmp_path / fn compressed_data = None @@ -173,10 +194,11 @@ def test_static_files_response_encoding(tmp_path: "Path", extension: str) -> Non elif extension == "gz": compressed_data = gzip.compress(b"content") path.write_bytes(compressed_data) # type: ignore[arg-type] - static_files_config = StaticFilesConfig(path="/static", directories=[tmp_path]) expected_encoding_type = mimetypes.guess_type(fn)[1] - with create_test_client([], static_files_config=[static_files_config]) as client: + configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[tmp_path])) + + with create_test_client(handlers, static_files_config=configs) as client: response = client.get(f"/static/{fn}") assert expected_encoding_type assert response.status_code == HTTP_200_OK @@ -184,45 +206,51 @@ def test_static_files_response_encoding(tmp_path: "Path", extension: str) -> Non @pytest.mark.parametrize("send_as_attachment,disposition", [(True, "attachment"), (False, "inline")]) -def test_static_files_content_disposition(tmpdir: "Path", send_as_attachment: bool, disposition: str) -> None: +def test_static_files_content_disposition( + tmpdir: Path, send_as_attachment: bool, disposition: str, make_config: MakeConfig +) -> None: path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir], send_as_attachment=send_as_attachment) + configs, handlers = make_config( + StaticFilesConfig(path="/static", directories=[tmpdir], send_as_attachment=send_as_attachment) + ) - with create_test_client([], static_files_config=[static_files_config]) as client: + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/static/static_part/static/test.txt") assert response.status_code == HTTP_200_OK assert response.headers["content-disposition"].startswith(disposition) -def test_service_from_relative_path_using_string(tmpdir: "Path") -> None: +def test_service_from_relative_path_using_string(tmpdir: Path, make_config: MakeConfig) -> None: sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore path = tmpdir / "test.txt" path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/static", directories=[f"{sub_dir}/.."]) - with create_test_client([], static_files_config=[static_files_config]) as client: + configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[f"{sub_dir}/.."])) + + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/static/test.txt") assert response.status_code == HTTP_200_OK assert response.text == "content" -def test_service_from_relative_path_using_path(tmpdir: "Path") -> None: +def test_service_from_relative_path_using_path(tmpdir: Path, make_config: MakeConfig) -> None: sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore path = tmpdir / "test.txt" path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig(path="/static", directories=[Path(f"{sub_dir}/..")]) - with create_test_client([], static_files_config=[static_files_config]) as client: + configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[Path(f"{sub_dir}/..")])) + + with create_test_client(handlers, static_files_config=configs) as client: response = client.get("/static/test.txt") assert response.status_code == HTTP_200_OK assert response.text == "content" -def test_service_from_base_path_using_string(tmpdir: "Path") -> None: +def test_service_from_base_path_using_string(tmpdir: Path) -> None: sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore path = tmpdir / "test.txt" @@ -249,3 +277,21 @@ def sub_handler() -> dict: response = client.get("/sub") assert response.status_code == HTTP_200_OK assert response.json() == {"hello": "world"} + + +@pytest.mark.parametrize("resolve", [True, False]) +def test_resolve_symlinks(tmp_path: Path, resolve: bool) -> None: + source_dir = tmp_path / "foo" + source_dir.mkdir() + linked_dir = tmp_path / "bar" + linked_dir.symlink_to(source_dir, target_is_directory=True) + source_dir.joinpath("test.txt").write_text("hello") + + router = create_static_files_router(path="/", directories=[linked_dir], resolve_symlinks=resolve) + + with create_test_client(router) as client: + if not resolve: + linked_dir.unlink() + assert client.get("/test.txt").status_code == 404 + else: + assert client.get("/test.txt").status_code == 200 diff --git a/tests/unit/test_static_files/test_html_mode.py b/tests/unit/test_static_files/test_html_mode.py index fa58044da8..c1f792d55a 100644 --- a/tests/unit/test_static_files/test_html_mode.py +++ b/tests/unit/test_static_files/test_html_mode.py @@ -1,12 +1,11 @@ -from typing import TYPE_CHECKING +from __future__ import annotations -import pytest -from fsspec.implementations.local import LocalFileSystem +from typing import TYPE_CHECKING -from litestar.file_system import BaseLocalFileSystem -from litestar.static_files.config import StaticFilesConfig +from litestar.static_files import StaticFilesConfig from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND from litestar.testing import create_test_client +from tests.unit.test_static_files.conftest import MakeConfig if TYPE_CHECKING: from pathlib import Path @@ -14,14 +13,14 @@ from litestar.types import FileSystemProtocol -@pytest.mark.parametrize("file_system", (BaseLocalFileSystem(), LocalFileSystem())) -def test_staticfiles_is_html_mode(tmpdir: "Path", file_system: "FileSystemProtocol") -> None: +def test_staticfiles_is_html_mode(tmpdir: Path, file_system: FileSystemProtocol, make_config: MakeConfig) -> None: path = tmpdir / "index.html" path.write_text("content", "utf-8") - static_files_config = StaticFilesConfig( - path="/static", directories=[tmpdir], html_mode=True, file_system=file_system + static_files_config, handlers = make_config( + StaticFilesConfig(path="/static", directories=[tmpdir], html_mode=True, file_system=file_system) ) - with create_test_client([], static_files_config=[static_files_config]) as client: + + with create_test_client(handlers, static_files_config=static_files_config) as client: response = client.get("/static") assert response.status_code == HTTP_200_OK assert response.text == "content" @@ -29,28 +28,28 @@ def test_staticfiles_is_html_mode(tmpdir: "Path", file_system: "FileSystemProtoc assert response.headers["content-disposition"].startswith("inline") -@pytest.mark.parametrize("file_system", (BaseLocalFileSystem(), LocalFileSystem())) -def test_staticfiles_is_html_mode_serves_404_when_present(tmpdir: "Path", file_system: "FileSystemProtocol") -> None: +def test_staticfiles_is_html_mode_serves_404_when_present( + tmpdir: Path, file_system: FileSystemProtocol, make_config: MakeConfig +) -> None: path = tmpdir / "404.html" path.write_text("not found", "utf-8") - static_files_config = StaticFilesConfig( - path="/static", directories=[tmpdir], html_mode=True, file_system=file_system + static_files_config, handlers = make_config( + StaticFilesConfig(path="/static", directories=[tmpdir], html_mode=True, file_system=file_system) ) - with create_test_client([], static_files_config=[static_files_config]) as client: + with create_test_client(handlers, static_files_config=static_files_config) as client: response = client.get("/static") assert response.status_code == HTTP_404_NOT_FOUND assert response.text == "not found" assert response.headers["content-type"] == "text/html; charset=utf-8" -@pytest.mark.parametrize("file_system", (BaseLocalFileSystem(), LocalFileSystem())) def test_staticfiles_is_html_mode_raises_exception_when_no_404_html_is_present( - tmpdir: "Path", file_system: "FileSystemProtocol" + tmpdir: Path, file_system: FileSystemProtocol, make_config: MakeConfig ) -> None: - static_files_config = StaticFilesConfig( - path="/static", directories=[tmpdir], html_mode=True, file_system=file_system + static_files_config, handlers = make_config( + StaticFilesConfig(path="/static", directories=[tmpdir], html_mode=True, file_system=file_system) ) - with create_test_client([], static_files_config=[static_files_config]) as client: + with create_test_client(handlers, static_files_config=static_files_config) as client: response = client.get("/static") assert response.status_code == HTTP_404_NOT_FOUND assert response.json() == {"status_code": 404, "detail": "no file or directory match the path . was found"} diff --git a/tests/unit/test_static_files/test_static_files_validation.py b/tests/unit/test_static_files/test_static_files_validation.py index 075b93d47d..56eeda3c19 100644 --- a/tests/unit/test_static_files/test_static_files_validation.py +++ b/tests/unit/test_static_files/test_static_files_validation.py @@ -1,47 +1,51 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List import pytest from litestar import HttpMethod, Litestar, MediaType, get from litestar.exceptions import ImproperlyConfiguredException -from litestar.static_files.config import StaticFilesConfig -from litestar.status_codes import HTTP_200_OK, HTTP_405_METHOD_NOT_ALLOWED +from litestar.static_files import StaticFilesConfig, create_static_files_router +from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_405_METHOD_NOT_ALLOWED from litestar.testing import create_test_client if TYPE_CHECKING: from pathlib import Path -def test_config_validation_of_directories() -> None: +@pytest.mark.parametrize("directories", [[], [""]]) +@pytest.mark.parametrize("func", [StaticFilesConfig, create_static_files_router]) +def test_config_validation_of_directories(func: Any, directories: List[str]) -> None: with pytest.raises(ImproperlyConfiguredException): - StaticFilesConfig(path="/static", directories=[]) + func(path="/static", directories=directories) -def test_config_validation_of_path(tmpdir: "Path") -> None: +@pytest.mark.parametrize("func", [StaticFilesConfig, create_static_files_router]) +def test_config_validation_of_path(tmpdir: "Path", func: Any) -> None: path = tmpdir / "text.txt" path.write_text("content", "utf-8") with pytest.raises(ImproperlyConfiguredException): - StaticFilesConfig(path="", directories=[tmpdir]) + func(path="", directories=[tmpdir]) with pytest.raises(ImproperlyConfiguredException): - StaticFilesConfig(path="/{param:int}", directories=[tmpdir]) + func(path="/{param:int}", directories=[tmpdir]) -def test_config_validation_of_file_system(tmpdir: "Path") -> None: +@pytest.mark.parametrize("func", [StaticFilesConfig, create_static_files_router]) +def test_config_validation_of_file_system(tmpdir: "Path", func: Any) -> None: class FSWithoutOpen: def info(self) -> None: return with pytest.raises(ImproperlyConfiguredException): - StaticFilesConfig(path="/static", directories=[tmpdir], file_system=FSWithoutOpen()) + func(path="/static", directories=[tmpdir], file_system=FSWithoutOpen()) class FSWithoutInfo: def open(self) -> None: return with pytest.raises(ImproperlyConfiguredException): - StaticFilesConfig(path="/static", directories=[tmpdir], file_system=FSWithoutInfo()) + func(path="/static", directories=[tmpdir], file_system=FSWithoutInfo()) class ImplementedFS: def info(self) -> None: @@ -50,7 +54,7 @@ def info(self) -> None: def open(self) -> None: return - assert StaticFilesConfig(path="/static", directories=[tmpdir], file_system=ImplementedFS()) + assert func(path="/static", directories=[tmpdir], file_system=ImplementedFS()) def test_runtime_validation_of_static_path_and_path_parameter(tmpdir: "Path") -> None: @@ -79,7 +83,7 @@ def handler(f: str) -> str: (HttpMethod.OPTIONS, HTTP_405_METHOD_NOT_ALLOWED), ), ) -def test_runtime_validation_of_request_method(tmpdir: "Path", method: HttpMethod, expected: int) -> None: +def test_runtime_validation_of_request_method_legacy_config(tmpdir: "Path", method: HttpMethod, expected: int) -> None: path = tmpdir / "test.txt" path.write_text("content", "utf-8") @@ -88,3 +92,24 @@ def test_runtime_validation_of_request_method(tmpdir: "Path", method: HttpMethod ) as client: response = client.request(method, "/static/test.txt") assert response.status_code == expected + + +@pytest.mark.parametrize( + "method, expected", + ( + (HttpMethod.GET, HTTP_200_OK), + (HttpMethod.HEAD, HTTP_200_OK), + (HttpMethod.OPTIONS, HTTP_204_NO_CONTENT), + (HttpMethod.PUT, HTTP_405_METHOD_NOT_ALLOWED), + (HttpMethod.PATCH, HTTP_405_METHOD_NOT_ALLOWED), + (HttpMethod.POST, HTTP_405_METHOD_NOT_ALLOWED), + (HttpMethod.DELETE, HTTP_405_METHOD_NOT_ALLOWED), + ), +) +def test_runtime_validation_of_request_method_create_handler(tmpdir: "Path", method: HttpMethod, expected: int) -> None: + path = tmpdir / "test.txt" + path.write_text("content", "utf-8") + + with create_test_client(create_static_files_router(path="/static", directories=[tmpdir])) as client: + response = client.request(method, "/static/test.txt") + assert response.status_code == expected From 842aca024bad2fb11e646cd7c31b16dede8debe9 Mon Sep 17 00:00:00 2001 From: Patrick Neise Date: Fri, 26 Jan 2024 10:51:10 -0500 Subject: [PATCH 10/14] feat: sanitize Piccolo columns with secret=True from PydanticDTO output (#3030) * feature(piccolo): sanitize columns with from PydanticDTO output * feature(piccolo): modify logic to assign to secret columns --------- Co-authored-by: Patrick Neise --- litestar/contrib/piccolo.py | 3 ++- .../test_piccolo_orm/test_piccolo_orm_dto.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/litestar/contrib/piccolo.py b/litestar/contrib/piccolo.py index 314eb328e2..50648e7537 100644 --- a/litestar/contrib/piccolo.py +++ b/litestar/contrib/piccolo.py @@ -79,10 +79,11 @@ class PiccoloDTO(AbstractDTO[T], Generic[T]): @classmethod def generate_field_definitions(cls, model_type: type[Table]) -> Generator[DTOFieldDefinition, None, None]: for column in model_type._meta.columns: + mark = Mark.WRITE_ONLY if column._meta.secret else Mark.READ_ONLY if column._meta.primary_key else None yield replace( DTOFieldDefinition.from_field_definition( field_definition=_parse_piccolo_type(column, _create_column_extra(column)), - dto_field=DTOField(mark=Mark.READ_ONLY if column._meta.primary_key else None), + dto_field=DTOField(mark=mark), model_name=model_type.__name__, default_factory=None, ), diff --git a/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py b/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py index 2a8c897642..6df59aaedb 100644 --- a/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py +++ b/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py @@ -58,8 +58,20 @@ def test_serializing_single_piccolo_table(scaffold_piccolo: Callable) -> None: def test_serializing_multiple_piccolo_tables(scaffold_piccolo: Callable) -> None: with create_test_client(route_handlers=[retrieve_venues]) as client: response = client.get("/venues") + + sanitized_venues = [] + for v in venues: + non_secret_data = { + column._meta.db_column_name: v[column._meta.db_column_name] + for column in v.all_columns() + if not column._meta.secret + } + sanitized_venues.append(Venue(**non_secret_data)) + assert response.status_code == HTTP_200_OK - assert [str(Venue(**value).querystring) for value in response.json()] == [str(v.querystring) for v in venues] + assert [str(Venue(**value).querystring) for value in response.json()] == [ + str(v.querystring) for v in sanitized_venues + ] @pytest.mark.parametrize( @@ -154,7 +166,6 @@ def test_piccolo_dto_openapi_spec_generation() -> None: assert venue_schema assert venue_schema.to_schema() == { "properties": { - "capacity": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, "id": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, "name": {"oneOf": [{"type": "null"}, {"type": "string"}]}, }, From f81990b1a3e2201ad4228b1cea1136d4da8228c7 Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Fri, 26 Jan 2024 15:05:09 -0600 Subject: [PATCH 11/14] feat: allow plugins to be found by dotted path string (#3027) feat: allow plugins to be found by stringified dotted path --- litestar/plugins/base.py | 14 ++- litestar/utils/module_loader.py | 97 +++++++++++++++++++++ tests/unit/test_plugins/test_base.py | 21 +++++ tests/unit/test_utils/test_module_loader.py | 23 +++++ 4 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 litestar/utils/module_loader.py create mode 100644 tests/unit/test_utils/test_module_loader.py diff --git a/litestar/plugins/base.py b/litestar/plugins/base.py index 254c94ab55..a6b83537b8 100644 --- a/litestar/plugins/base.py +++ b/litestar/plugins/base.py @@ -267,11 +267,23 @@ def __init__(self, plugins: list[PluginProtocol]) -> None: self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPluginProtocol)) self.cli = tuple(p for p in plugins if isinstance(p, CLIPluginProtocol)) - def get(self, type_: type[PluginT]) -> PluginT: + def get(self, type_: type[PluginT] | str) -> PluginT: """Return the registered plugin of ``type_``. This should be used with subclasses of the plugin protocols. """ + if isinstance(type_, str): + for plugin in self._plugins: + _name = plugin.__class__.__name__ + _module = plugin.__class__.__module__ + _qualname = ( + f"{_module}.{plugin.__class__.__qualname__}" + if _module is not None and _module != "__builtin__" + else plugin.__class__.__qualname__ + ) + if type_ in {_name, _qualname}: + return cast(PluginT, plugin) + raise KeyError(f"No plugin of type {type_!r} registered") try: return cast(PluginT, self._plugins_by_type[type_]) # type: ignore[index] except KeyError as e: diff --git a/litestar/utils/module_loader.py b/litestar/utils/module_loader.py new file mode 100644 index 0000000000..d466221a0c --- /dev/null +++ b/litestar/utils/module_loader.py @@ -0,0 +1,97 @@ +"""General utility functions.""" + +from __future__ import annotations + +import platform +import sys +from importlib import import_module +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import ModuleType + +__all__ = ( + "import_string", + "module_to_os_path", +) + +PYTHON_38 = sys.version_info < (3, 9, 0) + + +def module_to_os_path(dotted_path: str = "app") -> Path: + """Find Module to OS Path. + + Return a path to the base directory of the project or the module + specified by `dotted_path`. + + Args: + dotted_path (str, optional): The path to the module. Defaults to "app". + + Raises: + TypeError: The module could not be found. + + Returns: + Path: The path to the module. + """ + try: + src = find_spec(dotted_path) + except ModuleNotFoundError as e: + msg = "Couldn't find the path for %s" + raise TypeError(msg, dotted_path) from e + path_separator = "\\" if platform.system() == "Windows" else "/" + if PYTHON_38: + suffix = f"{path_separator}__init__.py" + return Path(str(src.origin)[: (-1 * len(suffix))] if src.origin.endswith(suffix) else src.origin) # type: ignore[arg-type, union-attr] + return Path(str(src.origin).removesuffix(f"{path_separator}__init__.py")) # type: ignore + + +def import_string(dotted_path: str) -> Any: + """Dotted Path Import. + + Import a dotted module path and return the attribute/class designated by the + last name in the path. Raise ImportError if the import failed. + + Args: + dotted_path (str): The path of the module to import. + + Raises: + ImportError: Could not import the module. + + Returns: + object: The imported object. + """ + + def _is_loaded(module: ModuleType | None) -> bool: + spec = getattr(module, "__spec__", None) + initializing = getattr(spec, "_initializing", False) + return bool(module and spec and not initializing) + + def _cached_import(module_path: str, class_name: str) -> Any: + """Import and cache a class from a module. + + Args: + module_path (str): dotted path to module. + class_name (str): Class or function name. + + Returns: + object: The imported class or function + """ + # Check whether module is loaded and fully initialized. + module = sys.modules.get(module_path) + if not _is_loaded(module): + module = import_module(module_path) + return getattr(module, class_name) + + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as e: + msg = "%s doesn't look like a module path" + raise ImportError(msg, dotted_path) from e + + try: + return _cached_import(module_path, class_name) + except AttributeError as e: + msg = "Module '%s' does not define a '%s' attribute/class" + raise ImportError(msg, module_path, class_name) from e diff --git a/tests/unit/test_plugins/test_base.py b/tests/unit/test_plugins/test_base.py index d0853421d0..f1d146f1d1 100644 --- a/tests/unit/test_plugins/test_base.py +++ b/tests/unit/test_plugins/test_base.py @@ -81,6 +81,27 @@ def on_cli_init(self, cli: Group) -> None: assert PluginRegistry([cli_plugin]).get(CLIPlugin) is cli_plugin +def test_plugin_registry_stringified_get() -> None: + class CLIPlugin(CLIPluginProtocol): + def on_cli_init(self, cli: Group) -> None: + pass + + cli_plugin = CLIPlugin() + pydantic_plugin = PydanticPlugin() + with pytest.raises(KeyError): + PluginRegistry([CLIPlugin()]).get( + "litestar2.contrib.pydantic.PydanticPlugin" + ) # not a fqdn. should fail # type: ignore[list-item] + PluginRegistry([]).get("CLIPlugin") # not a fqdn. should fail # type: ignore[list-item] + + assert PluginRegistry([cli_plugin, pydantic_plugin]).get(CLIPlugin) is cli_plugin + assert PluginRegistry([cli_plugin, pydantic_plugin]).get(PydanticPlugin) is pydantic_plugin + assert PluginRegistry([cli_plugin, pydantic_plugin]).get("PydanticPlugin") is pydantic_plugin + assert ( + PluginRegistry([cli_plugin, pydantic_plugin]).get("litestar.contrib.pydantic.PydanticPlugin") is pydantic_plugin + ) + + def test_openapi_schema_plugin_is_constrained_field() -> None: assert OpenAPISchemaPlugin.is_constrained_field(FieldDefinition.from_annotation(str)) is False diff --git a/tests/unit/test_utils/test_module_loader.py b/tests/unit/test_utils/test_module_loader.py new file mode 100644 index 0000000000..eee27bfe25 --- /dev/null +++ b/tests/unit/test_utils/test_module_loader.py @@ -0,0 +1,23 @@ +import pytest + +from litestar.config.compression import CompressionConfig +from litestar.utils.module_loader import import_string, module_to_os_path + + +def test_import_string() -> None: + cls = import_string("litestar.config.compression.CompressionConfig") + assert type(cls) == type(CompressionConfig) + + with pytest.raises(ImportError): + _ = import_string("CompressionConfigNew") + _ = import_string("litestar.config.compression.CompressionConfigNew") + _ = import_string("imaginary_module_that_doesnt_exist.Config") # a random nonexistent class + + +def test_module_path() -> None: + the_path = module_to_os_path("litestar.config.compression") + assert the_path.exists() + + with pytest.raises(TypeError): + _ = module_to_os_path("litestar.config.compression.Config") + _ = module_to_os_path("litestar.config.compression.extra.module") From 1966c4d24aadee17a61db9d720f4dc8d2fa50a95 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Sat, 3 Feb 2024 21:28:26 +0200 Subject: [PATCH 12/14] fix: align layered annotations (#2913) * Align route handler decorator annotations with respective app/router/controller ones * Align Controller annotations with Router * Standardize the response_headers annotation * Remove ResponseType type alias --- litestar/app.py | 9 +- litestar/config/app.py | 12 +- litestar/handlers/http_handlers/_utils.py | 11 +- litestar/handlers/http_handlers/base.py | 3 +- litestar/handlers/http_handlers/decorators.py | 112 +++++++++--------- litestar/router.py | 4 +- litestar/testing/helpers.py | 14 +-- litestar/types/__init__.py | 9 +- litestar/types/internal_types.py | 3 - .../test_http_handlers/test_kwarg_handling.py | 5 +- 10 files changed, 82 insertions(+), 100 deletions(-) diff --git a/litestar/app.py b/litestar/app.py index 1b62113f9f..4f9f002fd1 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -59,12 +59,13 @@ from litestar.config.compression import CompressionConfig from litestar.config.cors import CORSConfig from litestar.config.csrf import CSRFConfig - from litestar.datastructures import CacheControlHeader, ETag, ResponseHeader + from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.events.listener import EventListener from litestar.logging.config import BaseLoggingConfig from litestar.openapi.spec import SecurityRequirement from litestar.openapi.spec.open_api import OpenAPI + from litestar.response import Response from litestar.static_files.config import StaticFilesConfig from litestar.stores.base import Store from litestar.types import ( @@ -91,7 +92,7 @@ ParametersMap, Receive, ResponseCookies, - ResponseType, + ResponseHeaders, RouteHandlerType, Scope, Send, @@ -203,9 +204,9 @@ def __init__( plugins: Sequence[PluginProtocol] | None = None, request_class: type[Request] | None = None, response_cache_config: ResponseCacheConfig | None = None, - response_class: ResponseType | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, - response_headers: Sequence[ResponseHeader] | None = None, + response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, security: Sequence[SecurityRequirement] | None = None, signature_namespace: Mapping[str, Any] | None = None, diff --git a/litestar/config/app.py b/litestar/config/app.py index ff12a05d97..859999bfd6 100644 --- a/litestar/config/app.py +++ b/litestar/config/app.py @@ -2,7 +2,7 @@ import enum from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable from litestar.config.allowed_hosts import AllowedHostsConfig from litestar.config.response_cache import ResponseCacheConfig @@ -13,12 +13,12 @@ if TYPE_CHECKING: from contextlib import AbstractAsyncContextManager - from litestar import Litestar + from litestar import Litestar, Response from litestar.config.compression import CompressionConfig from litestar.config.cors import CORSConfig from litestar.config.csrf import CSRFConfig from litestar.connection import Request, WebSocket - from litestar.datastructures import CacheControlHeader, ETag, ResponseHeader + from litestar.datastructures import CacheControlHeader, ETag from litestar.di import Provide from litestar.dto import AbstractDTO from litestar.events.emitter import BaseEventEmitterBackend @@ -43,7 +43,7 @@ Middleware, ParametersMap, ResponseCookies, - ResponseType, + ResponseHeaders, TypeEncodersMap, ) from litestar.types.callable_types import LifespanHook @@ -157,11 +157,11 @@ class AppConfig: """List of :class:`SerializationPluginProtocol <.plugins.SerializationPluginProtocol>`.""" request_class: type[Request] | None = field(default=None) """An optional subclass of :class:`Request <.connection.Request>` to use for http connections.""" - response_class: ResponseType | None = field(default=None) + response_class: type[Response] | None = field(default=None) """A custom subclass of :class:`Response <.response.Response>` to be used as the app's default response.""" response_cookies: ResponseCookies = field(default_factory=list) """A list of :class:`Cookie <.datastructures.Cookie>`.""" - response_headers: Sequence[ResponseHeader] = field(default_factory=list) + response_headers: ResponseHeaders = field(default_factory=list) """A string keyed dictionary mapping :class:`ResponseHeader <.datastructures.ResponseHeader>`.""" response_cache_config: ResponseCacheConfig = field(default_factory=ResponseCacheConfig) """Configures caching behavior of the application.""" diff --git a/litestar/handlers/http_handlers/_utils.py b/litestar/handlers/http_handlers/_utils.py index e840383cba..2df6717be6 100644 --- a/litestar/handlers/http_handlers/_utils.py +++ b/litestar/handlers/http_handlers/_utils.py @@ -15,14 +15,7 @@ from litestar.background_tasks import BackgroundTask, BackgroundTasks from litestar.connection import Request from litestar.datastructures import Cookie, ResponseHeader - from litestar.types import ( - AfterRequestHookHandler, - ASGIApp, - AsyncAnyCallable, - Method, - ResponseType, - TypeEncodersMap, - ) + from litestar.types import AfterRequestHookHandler, ASGIApp, AsyncAnyCallable, Method, TypeEncodersMap from litestar.typing import FieldDefinition __all__ = ( @@ -42,7 +35,7 @@ def create_data_handler( cookies: frozenset[Cookie], headers: frozenset[ResponseHeader], media_type: str, - response_class: ResponseType, + response_class: type[Response], status_code: int, type_encoders: TypeEncodersMap | None, ) -> AsyncAnyCallable: diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index b4c3862120..d234a24cca 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -39,7 +39,6 @@ Middleware, ResponseCookies, ResponseHeaders, - ResponseType, TypeEncodersMap, ) from litestar.utils import ensure_async_callable @@ -135,7 +134,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, - response_class: ResponseType | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index 045534fe07..f78168efef 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -13,7 +13,7 @@ from .base import HTTPRouteHandler if TYPE_CHECKING: - from typing import Any, Mapping + from typing import Any, Mapping, Sequence from litestar.background_tasks import BackgroundTask, BackgroundTasks from litestar.config.response_cache import CACHE_FOREVER @@ -21,6 +21,7 @@ from litestar.dto import AbstractDTO from litestar.openapi.datastructures import ResponseSpec from litestar.openapi.spec import SecurityRequirement + from litestar.response import Response from litestar.types import ( AfterRequestHookHandler, AfterResponseHookHandler, @@ -33,7 +34,6 @@ Middleware, ResponseCookies, ResponseHeaders, - ResponseType, TypeEncodersMap, ) from litestar.types.callable_types import OperationIDCreator @@ -52,7 +52,7 @@ class delete(HTTPRouteHandler): def __init__( self, - path: str | list[str] | None = None, + path: str | None | Sequence[str] = None, *, after_request: AfterRequestHookHandler | None = None, after_response: AfterResponseHookHandler | None = None, @@ -65,12 +65,12 @@ def __init__( dto: type[AbstractDTO] | None | EmptyType = Empty, etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, - guards: list[Guard] | None = None, + guards: Sequence[Guard] | None = None, media_type: MediaType | str | None = None, - middleware: list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, - opt: dict[str, Any] | None = None, - response_class: ResponseType | None = None, + opt: Mapping[str, Any] | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, @@ -85,12 +85,12 @@ def __init__( include_in_schema: bool | EmptyType = Empty, operation_class: type[Operation] = Operation, operation_id: str | OperationIDCreator | None = None, - raises: list[type[HTTPException]] | None = None, + raises: Sequence[type[HTTPException]] | None = None, response_description: str | None = None, - responses: dict[int, ResponseSpec] | None = None, - security: list[SecurityRequirement] | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, - tags: list[str] | None = None, + tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -216,7 +216,7 @@ class get(HTTPRouteHandler): def __init__( self, - path: str | list[str] | None = None, + path: str | None | Sequence[str] = None, *, after_request: AfterRequestHookHandler | None = None, after_response: AfterResponseHookHandler | None = None, @@ -229,12 +229,12 @@ def __init__( dto: type[AbstractDTO] | None | EmptyType = Empty, etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, - guards: list[Guard] | None = None, + guards: Sequence[Guard] | None = None, media_type: MediaType | str | None = None, - middleware: list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, - opt: dict[str, Any] | None = None, - response_class: ResponseType | None = None, + opt: Mapping[str, Any] | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, @@ -249,12 +249,12 @@ def __init__( include_in_schema: bool | EmptyType = Empty, operation_class: type[Operation] = Operation, operation_id: str | OperationIDCreator | None = None, - raises: list[type[HTTPException]] | None = None, + raises: Sequence[type[HTTPException]] | None = None, response_description: str | None = None, - responses: dict[int, ResponseSpec] | None = None, - security: list[SecurityRequirement] | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, - tags: list[str] | None = None, + tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -381,7 +381,7 @@ class head(HTTPRouteHandler): def __init__( self, - path: str | list[str] | None = None, + path: str | None | Sequence[str] = None, *, after_request: AfterRequestHookHandler | None = None, after_response: AfterResponseHookHandler | None = None, @@ -394,12 +394,12 @@ def __init__( dto: type[AbstractDTO] | None | EmptyType = Empty, etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, - guards: list[Guard] | None = None, + guards: Sequence[Guard] | None = None, media_type: MediaType | str | None = None, - middleware: list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, - opt: dict[str, Any] | None = None, - response_class: ResponseType | None = None, + opt: Mapping[str, Any] | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, signature_namespace: Mapping[str, Any] | None = None, @@ -413,13 +413,13 @@ def __init__( include_in_schema: bool | EmptyType = Empty, operation_class: type[Operation] = Operation, operation_id: str | OperationIDCreator | None = None, - raises: list[type[HTTPException]] | None = None, + raises: Sequence[type[HTTPException]] | None = None, response_description: str | None = None, - responses: dict[int, ResponseSpec] | None = None, + responses: Mapping[int, ResponseSpec] | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, - security: list[SecurityRequirement] | None = None, + security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, - tags: list[str] | None = None, + tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -563,7 +563,7 @@ class patch(HTTPRouteHandler): def __init__( self, - path: str | list[str] | None = None, + path: str | None | Sequence[str] = None, *, after_request: AfterRequestHookHandler | None = None, after_response: AfterResponseHookHandler | None = None, @@ -576,12 +576,12 @@ def __init__( dto: type[AbstractDTO] | None | EmptyType = Empty, etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, - guards: list[Guard] | None = None, + guards: Sequence[Guard] | None = None, media_type: MediaType | str | None = None, - middleware: list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, - opt: dict[str, Any] | None = None, - response_class: ResponseType | None = None, + opt: Mapping[str, Any] | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, @@ -596,12 +596,12 @@ def __init__( include_in_schema: bool | EmptyType = Empty, operation_class: type[Operation] = Operation, operation_id: str | OperationIDCreator | None = None, - raises: list[type[HTTPException]] | None = None, + raises: Sequence[type[HTTPException]] | None = None, response_description: str | None = None, - responses: dict[int, ResponseSpec] | None = None, - security: list[SecurityRequirement] | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, - tags: list[str] | None = None, + tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -727,7 +727,7 @@ class post(HTTPRouteHandler): def __init__( self, - path: str | list[str] | None = None, + path: str | None | Sequence[str] = None, *, after_request: AfterRequestHookHandler | None = None, after_response: AfterResponseHookHandler | None = None, @@ -740,12 +740,12 @@ def __init__( dto: type[AbstractDTO] | None | EmptyType = Empty, etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, - guards: list[Guard] | None = None, + guards: Sequence[Guard] | None = None, media_type: MediaType | str | None = None, - middleware: list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, - opt: dict[str, Any] | None = None, - response_class: ResponseType | None = None, + opt: Mapping[str, Any] | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, @@ -760,12 +760,12 @@ def __init__( include_in_schema: bool | EmptyType = Empty, operation_class: type[Operation] = Operation, operation_id: str | OperationIDCreator | None = None, - raises: list[type[HTTPException]] | None = None, + raises: Sequence[type[HTTPException]] | None = None, response_description: str | None = None, - responses: dict[int, ResponseSpec] | None = None, - security: list[SecurityRequirement] | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, - tags: list[str] | None = None, + tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -891,7 +891,7 @@ class put(HTTPRouteHandler): def __init__( self, - path: str | list[str] | None = None, + path: str | None | Sequence[str] = None, *, after_request: AfterRequestHookHandler | None = None, after_response: AfterResponseHookHandler | None = None, @@ -904,12 +904,12 @@ def __init__( dto: type[AbstractDTO] | None | EmptyType = Empty, etag: ETag | None = None, exception_handlers: ExceptionHandlersMap | None = None, - guards: list[Guard] | None = None, + guards: Sequence[Guard] | None = None, media_type: MediaType | str | None = None, - middleware: list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, name: str | None = None, - opt: dict[str, Any] | None = None, - response_class: ResponseType | None = None, + opt: Mapping[str, Any] | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, @@ -924,12 +924,12 @@ def __init__( include_in_schema: bool | EmptyType = Empty, operation_class: type[Operation] = Operation, operation_id: str | OperationIDCreator | None = None, - raises: list[type[HTTPException]] | None = None, + raises: Sequence[type[HTTPException]] | None = None, response_description: str | None = None, - responses: dict[int, ResponseSpec] | None = None, - security: list[SecurityRequirement] | None = None, + responses: Mapping[int, ResponseSpec] | None = None, + security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, - tags: list[str] | None = None, + tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: diff --git a/litestar/router.py b/litestar/router.py index 772b12a620..d65f39d47b 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -23,6 +23,7 @@ from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.spec import SecurityRequirement + from litestar.response import Response from litestar.routes import BaseRoute from litestar.types import ( AfterRequestHookHandler, @@ -34,7 +35,6 @@ Middleware, ParametersMap, ResponseCookies, - ResponseType, RouteHandlerMapItem, RouteHandlerType, TypeEncodersMap, @@ -95,7 +95,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, opt: Mapping[str, Any] | None = None, parameters: ParametersMap | None = None, - response_class: ResponseType | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, diff --git a/litestar/testing/helpers.py b/litestar/testing/helpers.py index 1af5f7a5de..5ac59af7c8 100644 --- a/litestar/testing/helpers.py +++ b/litestar/testing/helpers.py @@ -12,14 +12,14 @@ if TYPE_CHECKING: from contextlib import AbstractAsyncContextManager - from litestar import Request, WebSocket + from litestar import Request, Response, WebSocket from litestar.config.allowed_hosts import AllowedHostsConfig from litestar.config.app import ExperimentalFeatures from litestar.config.compression import CompressionConfig from litestar.config.cors import CORSConfig from litestar.config.csrf import CSRFConfig from litestar.config.response_cache import ResponseCacheConfig - from litestar.datastructures import CacheControlHeader, ETag, ResponseHeader, State + from litestar.datastructures import CacheControlHeader, ETag, State from litestar.dto import AbstractDTO from litestar.events import BaseEventEmitterBackend, EventListener from litestar.logging.config import BaseLoggingConfig @@ -47,7 +47,7 @@ OnAppInitHandler, ParametersMap, ResponseCookies, - ResponseType, + ResponseHeaders, TypeEncodersMap, ) @@ -92,9 +92,9 @@ def create_test_client( pdb_on_exception: bool | None = None, request_class: type[Request] | None = None, response_cache_config: ResponseCacheConfig | None = None, - response_class: ResponseType | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, - response_headers: Sequence[ResponseHeader] | None = None, + response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, root_path: str = "", security: Sequence[SecurityRequirement] | None = None, @@ -347,9 +347,9 @@ def create_async_test_client( raise_server_exceptions: bool = True, request_class: type[Request] | None = None, response_cache_config: ResponseCacheConfig | None = None, - response_class: ResponseType | None = None, + response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, - response_headers: Sequence[ResponseHeader] | None = None, + response_headers: ResponseHeaders | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, root_path: str = "", security: Sequence[SecurityRequirement] | None = None, diff --git a/litestar/types/__init__.py b/litestar/types/__init__.py index 6eea3f0ff1..90e319277c 100644 --- a/litestar/types/__init__.py +++ b/litestar/types/__init__.py @@ -74,13 +74,7 @@ from .empty import Empty, EmptyType from .file_types import FileInfo, FileSystemProtocol from .helper_types import AnyIOBackend, MaybePartial, OptionalSequence, SSEData, StreamType, SyncOrAsyncUnion -from .internal_types import ( - ControllerRouterHandler, - ReservedKwargs, - ResponseType, - RouteHandlerMapItem, - RouteHandlerType, -) +from .internal_types import ControllerRouterHandler, ReservedKwargs, RouteHandlerMapItem, RouteHandlerType from .protocols import DataclassProtocol, Logger from .serialization import DataContainerType, LitestarEncodableType @@ -146,7 +140,6 @@ "ReservedKwargs", "ResponseCookies", "ResponseHeaders", - "ResponseType", "RouteHandlerMapItem", "RouteHandlerType", "Scope", diff --git a/litestar/types/internal_types.py b/litestar/types/internal_types.py index 499de1b6fc..d473c22677 100644 --- a/litestar/types/internal_types.py +++ b/litestar/types/internal_types.py @@ -9,7 +9,6 @@ "PathParameterDefinition", "PathParameterDefinition", "ReservedKwargs", - "ResponseType", "RouteHandlerMapItem", "RouteHandlerType", ) @@ -22,7 +21,6 @@ from litestar.handlers.asgi_handlers import ASGIRouteHandler from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.handlers.websocket_handlers import WebsocketRouteHandler - from litestar.response import Response from litestar.router import Router from litestar.template import TemplateConfig from litestar.template.config import EngineType @@ -30,7 +28,6 @@ ReservedKwargs: TypeAlias = Literal["request", "socket", "headers", "query", "cookies", "state", "data"] RouteHandlerType: TypeAlias = "HTTPRouteHandler | WebsocketRouteHandler | ASGIRouteHandler" -ResponseType: TypeAlias = "type[Response]" ControllerRouterHandler: TypeAlias = "type[Controller] | RouteHandlerType | Router | Callable[..., Any]" RouteHandlerMapItem: TypeAlias = 'dict[Method | Literal["websocket", "asgi"], RouteHandlerType]' TemplateConfigType: TypeAlias = "TemplateConfig[EngineType]" diff --git a/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py b/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py index 231da33b97..73ed586b50 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py +++ b/tests/unit/test_handlers/test_http_handlers/test_kwarg_handling.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Type import pytest from hypothesis import given @@ -9,7 +9,6 @@ from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.handlers.http_handlers._utils import get_default_status_code from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT -from litestar.types import ResponseType from litestar.utils import normalize_path @@ -30,7 +29,7 @@ def test_route_handler_kwarg_handling( http_method: Any, media_type: MediaType, include_in_schema: bool, - response_class: Optional[ResponseType], + response_class: Optional[Type[Response]], response_headers: Any, status_code: Any, path: Any, From e6eb9f2023f8af8612be23b3d6e12a962d77b8cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sun, 4 Feb 2024 11:22:28 +0100 Subject: [PATCH 13/14] feat(DI): Support externally typed classes as dependency providers (#3066) * Support injecting externally typed classes --- docs/examples/plugins/di_plugin.py | 31 +++++++++ docs/usage/plugins.rst | 25 +++++-- litestar/app.py | 14 +++- litestar/contrib/pydantic/__init__.py | 15 ++++- .../contrib/pydantic/pydantic_di_plugin.py | 26 ++++++++ litestar/handlers/base.py | 66 +++++++++++++------ litestar/plugins/__init__.py | 2 + litestar/plugins/base.py | 29 +++++++- litestar/plugins/core.py | 31 +++++++++ pyproject.toml | 1 + .../test_injection_of_classes.py | 33 ++++++++++ tests/examples/test_plugins/test_di_plugin.py | 10 +++ .../test_attrs/test_inject_attrs_class.py | 20 ++++++ .../test_pydantic/test_inject_pydantic.py | 22 +++++++ .../test_base_handlers/test_resolution.py | 15 +++++ tests/unit/test_plugins/test_base.py | 18 ++++- 16 files changed, 326 insertions(+), 32 deletions(-) create mode 100644 docs/examples/plugins/di_plugin.py create mode 100644 litestar/contrib/pydantic/pydantic_di_plugin.py create mode 100644 litestar/plugins/core.py create mode 100644 tests/examples/test_plugins/test_di_plugin.py create mode 100644 tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py create mode 100644 tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py diff --git a/docs/examples/plugins/di_plugin.py b/docs/examples/plugins/di_plugin.py new file mode 100644 index 0000000000..35625d4163 --- /dev/null +++ b/docs/examples/plugins/di_plugin.py @@ -0,0 +1,31 @@ +from inspect import Parameter, Signature +from typing import Any, Dict, Tuple + +from litestar import Litestar, get +from litestar.di import Provide +from litestar.plugins import DIPlugin + + +class MyBaseType: + def __init__(self, param): + self.param = param + + +class MyDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return issubclass(type_, MyBaseType) + + def get_typed_init(self, type_: Any) -> Tuple[Signature, Dict[str, Any]]: + signature = Signature([Parameter(name="param", kind=Parameter.POSITIONAL_OR_KEYWORD)]) + annotations = {"param": str} + return signature, annotations + + +@get("/", dependencies={"injected": Provide(MyBaseType, sync_to_thread=False)}) +async def handler(injected: MyBaseType) -> str: + return injected.param + + +app = Litestar(route_handlers=[handler], plugins=[MyDIPlugin()]) + +# run: /?param=hello diff --git a/docs/usage/plugins.rst b/docs/usage/plugins.rst index 8c4b64ce0a..4911b8da23 100644 --- a/docs/usage/plugins.rst +++ b/docs/usage/plugins.rst @@ -19,7 +19,7 @@ that can interact with the data that is used to instantiate the application inst the contract for plugins that extend serialization functionality of the application. InitPluginProtocol -~~~~~~~~~~~~~~~~~~ +------------------ ``InitPluginProtocol`` defines an interface that allows for customization of the application's initialization process. Init plugins can define dependencies, add route handlers, configure middleware, and much more! @@ -37,7 +37,7 @@ they are provided in the ``plugins`` argument of the :class:`app ` instance is then returned. SerializationPluginProtocol -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +--------------------------- The SerializationPluginProtocol defines a contract for plugins that provide serialization functionality for data types that are otherwise unsupported by the framework. @@ -79,7 +79,7 @@ the plugin, and doesn't otherwise have a ``dto`` or ``return_dto`` defined, the that annotation. Example -------- ++++++++ The following example shows the actual implementation of the ``SerializationPluginProtocol`` for `SQLAlchemy `_ models that is is provided in ``advanced_alchemy``. @@ -106,3 +106,20 @@ subtypes are not created for the same model. If the annotation is not in the ``_type_dto_map`` dictionary, the method creates a new DTO type for the annotation, adds it to the ``_type_dto_map`` dictionary, and returns it. + + +DIPlugin +-------- + +:class:`~litestar.plugins.DIPlugin` can be used to extend Litestar's dependency +injection by providing information about injectable types. + +Its main purpose it to facilitate the injection of callables with unknown signatures, +for example Pydantic's ``BaseModel`` classes; These are not supported natively since, +while they are callables, their type information is not contained within their callable +signature (their :func:`__init__` method). + + +.. literalinclude:: /examples/plugins/di_plugin.py + :language: python + :caption: Dynamically generating signature information for a custom type diff --git a/litestar/app.py b/litestar/app.py index 4f9f002fd1..745514a164 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -488,18 +488,30 @@ def serialization_plugins(self) -> list[SerializationPluginProtocol]: @staticmethod def _get_default_plugins(plugins: list[PluginProtocol]) -> list[PluginProtocol]: + from litestar.plugins.core import MsgspecDIPlugin + + plugins.append(MsgspecDIPlugin()) + with suppress(MissingDependencyException): - from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin + from litestar.contrib.pydantic import ( + PydanticDIPlugin, + PydanticInitPlugin, + PydanticPlugin, + PydanticSchemaPlugin, + ) pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins) pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins) pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins) + pydantic_serialization_plugin_found = any(isinstance(plugin, PydanticDIPlugin) for plugin in plugins) if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found: plugins.append(PydanticPlugin()) elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found: plugins.append(PydanticSchemaPlugin()) elif not pydantic_plugin_found and not pydantic_init_plugin_found: plugins.append(PydanticInitPlugin()) + if not pydantic_plugin_found and not pydantic_serialization_plugin_found: + plugins.append(PydanticDIPlugin()) with suppress(MissingDependencyException): from litestar.contrib.attrs import AttrsSchemaPlugin diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index 122a710539..9bab707c31 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -4,6 +4,7 @@ from litestar.plugins import InitPluginProtocol +from .pydantic_di_plugin import PydanticDIPlugin from .pydantic_dto_factory import PydanticDTO from .pydantic_init_plugin import PydanticInitPlugin from .pydantic_schema_plugin import PydanticSchemaPlugin @@ -14,7 +15,13 @@ from litestar.config.app import AppConfig -__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin", "PydanticPlugin") +__all__ = ( + "PydanticDTO", + "PydanticInitPlugin", + "PydanticSchemaPlugin", + "PydanticPlugin", + "PydanticDIPlugin", +) def _model_dump(model: BaseModel | BaseModelV1, *, by_alias: bool = False) -> dict[str, Any]: @@ -53,6 +60,10 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. """ app_config.plugins.extend( - [PydanticInitPlugin(prefer_alias=self.prefer_alias), PydanticSchemaPlugin(prefer_alias=self.prefer_alias)] + [ + PydanticInitPlugin(prefer_alias=self.prefer_alias), + PydanticSchemaPlugin(prefer_alias=self.prefer_alias), + PydanticDIPlugin(), + ] ) return app_config diff --git a/litestar/contrib/pydantic/pydantic_di_plugin.py b/litestar/contrib/pydantic/pydantic_di_plugin.py new file mode 100644 index 0000000000..2096fd4ab6 --- /dev/null +++ b/litestar/contrib/pydantic/pydantic_di_plugin.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import inspect +from inspect import Signature +from typing import Any + +from litestar.contrib.pydantic.utils import is_pydantic_model_class +from litestar.plugins import DIPlugin + + +class PydanticDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return is_pydantic_model_class(type_) + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + try: + model_fields = dict(type_.model_fields) + except AttributeError: + model_fields = {k: model_field.field_info for k, model_field in type_.__fields__.items()} + + parameters = [ + inspect.Parameter(name=field_name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=Any) + for field_name in model_fields + ] + type_hints = {field_name: Any for field_name in model_fields} + return Signature(parameters), type_hints diff --git a/litestar/handlers/base.py b/litestar/handlers/base.py index da2c53f3b9..aa02b56a4f 100644 --- a/litestar/handlers/base.py +++ b/litestar/handlers/base.py @@ -9,6 +9,7 @@ from litestar.di import Provide from litestar.dto import DTOData from litestar.exceptions import ImproperlyConfiguredException +from litestar.plugins import DIPlugin, PluginRegistry from litestar.serialization import default_deserializer, default_serializer from litestar.types import ( Dependencies, @@ -339,37 +340,60 @@ def resolve_guards(self) -> list[Guard]: return self._resolved_guards + def _get_plugin_registry(self) -> PluginRegistry | None: + from litestar.app import Litestar + + root_owner = self.ownership_layers[0] + if isinstance(root_owner, Litestar): + return root_owner.plugins + return None + def resolve_dependencies(self) -> dict[str, Provide]: """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope.""" + plugin_registry = self._get_plugin_registry() if self._resolved_dependencies is Empty: self._resolved_dependencies = {} - for layer in self.ownership_layers: for key, provider in (layer.dependencies or {}).items(): - if not isinstance(provider, Provide): - provider = Provide(provider) - - self._validate_dependency_is_unique( - dependencies=self._resolved_dependencies, key=key, provider=provider + self._resolved_dependencies[key] = self._resolve_dependency( + key=key, provider=provider, plugin_registry=plugin_registry ) - if not getattr(provider, "parsed_signature", None): - provider.parsed_fn_signature = ParsedSignature.from_fn( - unwrap_partial(provider.dependency), self.resolve_signature_namespace() - ) - - if not getattr(provider, "signature_model", None): - provider.signature_model = SignatureModel.create( - dependency_name_set=self.dependency_name_set, - fn=provider.dependency, - parsed_signature=provider.parsed_fn_signature, - data_dto=self.resolve_data_dto(), - type_decoders=self.resolve_type_decoders(), - ) - - self._resolved_dependencies[key] = provider return self._resolved_dependencies + def _resolve_dependency( + self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None + ) -> Provide: + if not isinstance(provider, Provide): + provider = Provide(provider) + + if self._resolved_dependencies is not Empty: # pragma: no cover + self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider) + + if not getattr(provider, "parsed_fn_signature", None): + dependency = unwrap_partial(provider.dependency) + plugin: DIPlugin | None = None + if plugin_registry: + plugin = next( + (p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)), + None, + ) + if plugin: + signature, init_type_hints = plugin.get_typed_init(dependency) + provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints) + else: + provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace()) + + if not getattr(provider, "signature_model", None): + provider.signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=provider.dependency, + parsed_signature=provider.parsed_fn_signature, + data_dto=self.resolve_data_dto(), + type_decoders=self.resolve_type_decoders(), + ) + return provider + def resolve_middleware(self) -> list[Middleware]: """Build the middleware stack for the RouteHandler and return it. diff --git a/litestar/plugins/__init__.py b/litestar/plugins/__init__.py index 6f71b78b4a..f09310436d 100644 --- a/litestar/plugins/__init__.py +++ b/litestar/plugins/__init__.py @@ -1,6 +1,7 @@ from litestar.plugins.base import ( CLIPlugin, CLIPluginProtocol, + DIPlugin, InitPluginProtocol, OpenAPISchemaPlugin, OpenAPISchemaPluginProtocol, @@ -11,6 +12,7 @@ __all__ = ( "SerializationPluginProtocol", + "DIPlugin", "CLIPlugin", "InitPluginProtocol", "OpenAPISchemaPluginProtocol", diff --git a/litestar/plugins/base.py b/litestar/plugins/base.py index a6b83537b8..afc571efe7 100644 --- a/litestar/plugins/base.py +++ b/litestar/plugins/base.py @@ -1,9 +1,12 @@ from __future__ import annotations +import abc from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Iterator, Protocol, TypeVar, Union, cast, runtime_checkable if TYPE_CHECKING: + from inspect import Signature + from click import Group from litestar._openapi.schema_generation import SchemaCreator @@ -23,6 +26,7 @@ "CLIPlugin", "CLIPluginProtocol", "PluginRegistry", + "DIPlugin", ) @@ -154,6 +158,26 @@ def create_dto_for_type(self, field_definition: FieldDefinition) -> type[Abstrac raise NotImplementedError() +class DIPlugin(abc.ABC): + """Extend dependency injection""" + + @abc.abstractmethod + def has_typed_init(self, type_: Any) -> bool: + """Return ``True`` if ``type_`` has type information available for its + :func:`__init__` method that cannot be extracted from this method's type + annotations (e.g. a Pydantic BaseModel subclass), and + :meth:`DIPlugin.get_typed_init` supports extraction of these annotations. + """ + ... + + @abc.abstractmethod + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + r"""Return signature and type information about the ``type_``\ s :func:`__init__` + method. + """ + ... + + @runtime_checkable class OpenAPISchemaPluginProtocol(Protocol): """Plugin protocol to extend the support of OpenAPI schema generation for non-library types.""" @@ -241,6 +265,7 @@ def is_constrained_field(field_definition: FieldDefinition) -> bool: OpenAPISchemaPluginProtocol, ReceiveRoutePlugin, SerializationPluginProtocol, + DIPlugin, ] PluginT = TypeVar("PluginT", bound=PluginProtocol) @@ -250,9 +275,10 @@ class PluginRegistry: __slots__ = { "init": "Plugins that implement the InitPluginProtocol", "openapi": "Plugins that implement the OpenAPISchemaPluginProtocol", - "receive_route": "ReceiveRoutePlugin types", + "receive_route": "ReceiveRoutePlugin instances", "serialization": "Plugins that implement the SerializationPluginProtocol", "cli": "Plugins that implement the CLIPluginProtocol", + "di": "DIPlugin instances", "_plugins_by_type": None, "_plugins": None, "_get_plugins_of_type": None, @@ -266,6 +292,7 @@ def __init__(self, plugins: list[PluginProtocol]) -> None: self.receive_route = tuple(p for p in plugins if isinstance(p, ReceiveRoutePlugin)) self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPluginProtocol)) self.cli = tuple(p for p in plugins if isinstance(p, CLIPluginProtocol)) + self.di = tuple(p for p in plugins if isinstance(p, DIPlugin)) def get(self, type_: type[PluginT] | str) -> PluginT: """Return the registered plugin of ``type_``. diff --git a/litestar/plugins/core.py b/litestar/plugins/core.py new file mode 100644 index 0000000000..d25d6d661b --- /dev/null +++ b/litestar/plugins/core.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import inspect +from inspect import Signature +from typing import Any + +import msgspec + +from litestar.plugins import DIPlugin + +__all__ = ("MsgspecDIPlugin",) + + +class MsgspecDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return type(type_) is type(msgspec.Struct) # noqa: E721 + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + parameters = [] + type_hints = {} + for field_info in msgspec.structs.fields(type_): + type_hints[field_info.name] = field_info.type + parameters.append( + inspect.Parameter( + name=field_info.name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.type, + default=field_info.default, + ) + ) + return inspect.Signature(parameters), type_hints diff --git a/pyproject.toml b/pyproject.toml index 655feceff1..736a90e55d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ markers = [ "server_integration: Test integration with ASGI server" ] xfail_strict = true +testpaths = ["tests", "docs/examples/testing"] [tool.mypy] packages = ["litestar", "tests"] diff --git a/tests/e2e/test_dependency_injection/test_injection_of_classes.py b/tests/e2e/test_dependency_injection/test_injection_of_classes.py index 795c4f657e..6091015b0e 100644 --- a/tests/e2e/test_dependency_injection/test_injection_of_classes.py +++ b/tests/e2e/test_dependency_injection/test_injection_of_classes.py @@ -1,3 +1,7 @@ +from dataclasses import dataclass + +import msgspec + from litestar import Controller, get from litestar.di import Provide from litestar.testing import create_test_client @@ -37,3 +41,32 @@ def test_function(self, container: HandlerDependency) -> str: with create_test_client(MyController) as client: response = client.get(f"/test/{path_param_value}?query_param={query_param_value}") assert response.text == "15" + + +def test_inject_dataclass() -> None: + @dataclass + class Foo: + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} + + +def test_inject_msgspec_struct() -> None: + class Foo(msgspec.Struct): + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} diff --git a/tests/examples/test_plugins/test_di_plugin.py b/tests/examples/test_plugins/test_di_plugin.py new file mode 100644 index 0000000000..c42786b068 --- /dev/null +++ b/tests/examples/test_plugins/test_di_plugin.py @@ -0,0 +1,10 @@ +from docs.examples.plugins.di_plugin import app + +from litestar.testing import TestClient + + +def test_di_plugin_example() -> None: + with TestClient(app) as client: + res = client.get("/?param=hello") + assert res.status_code == 200 + assert res.text == "hello" diff --git a/tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py b/tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py new file mode 100644 index 0000000000..bc6f526b68 --- /dev/null +++ b/tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py @@ -0,0 +1,20 @@ +from attrs import define + +from litestar import get +from litestar.di import Provide +from litestar.testing import create_test_client + + +def test_inject_attrs_class() -> None: + @define + class Foo: + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} diff --git a/tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py b/tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py new file mode 100644 index 0000000000..ed7a53267c --- /dev/null +++ b/tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py @@ -0,0 +1,22 @@ +import pydantic as pydantic_v2 +import pytest +from pydantic import v1 as pydantic_v1 + +from litestar import get +from litestar.di import Provide +from litestar.testing import create_test_client + + +@pytest.mark.parametrize("base_model", [pydantic_v1.BaseModel, pydantic_v2.BaseModel]) +def test_inject_pydantic_model(base_model: type) -> None: + class Foo(base_model): # type: ignore[misc] + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} diff --git a/tests/unit/test_handlers/test_base_handlers/test_resolution.py b/tests/unit/test_handlers/test_base_handlers/test_resolution.py index 079c4a2cdc..12809610a4 100644 --- a/tests/unit/test_handlers/test_base_handlers/test_resolution.py +++ b/tests/unit/test_handlers/test_base_handlers/test_resolution.py @@ -52,3 +52,18 @@ async def handler(self) -> None: "controller": Provide(controller_dependency), "handler": Provide(handler_dependency), } + + +def test_resolve_dependencies_cached() -> None: + dependency = Provide(function_factory()) + + @get(dependencies={"foo": dependency}) + async def handler() -> None: + pass + + @get(dependencies={"foo": dependency}) + async def handler_2() -> None: + pass + + assert handler.resolve_dependencies() is handler.resolve_dependencies() + assert handler_2.resolve_dependencies() is handler_2.resolve_dependencies() diff --git a/tests/unit/test_plugins/test_base.py b/tests/unit/test_plugins/test_base.py index f1d146f1d1..8598eb5c64 100644 --- a/tests/unit/test_plugins/test_base.py +++ b/tests/unit/test_plugins/test_base.py @@ -8,9 +8,10 @@ from litestar import Litestar, MediaType, get from litestar.constants import UNDEFINED_SENTINELS from litestar.contrib.attrs import AttrsSchemaPlugin -from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin +from litestar.contrib.pydantic import PydanticDIPlugin, PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin from litestar.contrib.sqlalchemy.plugins import SQLAlchemySerializationPlugin from litestar.plugins import CLIPluginProtocol, InitPluginProtocol, OpenAPISchemaPlugin, PluginRegistry +from litestar.plugins.core import MsgspecDIPlugin from litestar.testing import create_test_client from litestar.typing import FieldDefinition @@ -121,6 +122,17 @@ def test_app_get_default_plugins( any_pydantic = bool(init_plugin) or bool(schema_plugin) default_plugins = Litestar._get_default_plugins(plugins) # type: ignore[arg-type] if not any_pydantic: - assert {type(p) for p in default_plugins} == {PydanticPlugin, AttrsSchemaPlugin} + assert {type(p) for p in default_plugins} == { + PydanticPlugin, + AttrsSchemaPlugin, + PydanticDIPlugin, + MsgspecDIPlugin, + } else: - assert {type(p) for p in default_plugins} == {PydanticInitPlugin, PydanticSchemaPlugin, AttrsSchemaPlugin} + assert {type(p) for p in default_plugins} == { + PydanticInitPlugin, + PydanticSchemaPlugin, + AttrsSchemaPlugin, + PydanticDIPlugin, + MsgspecDIPlugin, + } From 1e037d3bbf2c60cba05117f670d55a4951c6463e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Mon, 5 Feb 2024 19:47:22 +0100 Subject: [PATCH 14/14] test: Add missing `module_loader` tests (#3073) Fix module loader coverage --- litestar/utils/module_loader.py | 25 +++++++++------------ tests/unit/test_utils/test_module_loader.py | 16 +++++++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/litestar/utils/module_loader.py b/litestar/utils/module_loader.py index d466221a0c..09dbf9f8d6 100644 --- a/litestar/utils/module_loader.py +++ b/litestar/utils/module_loader.py @@ -2,7 +2,7 @@ from __future__ import annotations -import platform +import os.path import sys from importlib import import_module from importlib.util import find_spec @@ -17,8 +17,6 @@ "module_to_os_path", ) -PYTHON_38 = sys.version_info < (3, 9, 0) - def module_to_os_path(dotted_path: str = "app") -> Path: """Find Module to OS Path. @@ -27,7 +25,7 @@ def module_to_os_path(dotted_path: str = "app") -> Path: specified by `dotted_path`. Args: - dotted_path (str, optional): The path to the module. Defaults to "app". + dotted_path: The path to the module. Defaults to "app". Raises: TypeError: The module could not be found. @@ -36,15 +34,12 @@ def module_to_os_path(dotted_path: str = "app") -> Path: Path: The path to the module. """ try: - src = find_spec(dotted_path) + if (src := find_spec(dotted_path)) is None: # pragma: no cover + raise TypeError(f"Couldn't find the path for {dotted_path}") except ModuleNotFoundError as e: - msg = "Couldn't find the path for %s" - raise TypeError(msg, dotted_path) from e - path_separator = "\\" if platform.system() == "Windows" else "/" - if PYTHON_38: - suffix = f"{path_separator}__init__.py" - return Path(str(src.origin)[: (-1 * len(suffix))] if src.origin.endswith(suffix) else src.origin) # type: ignore[arg-type, union-attr] - return Path(str(src.origin).removesuffix(f"{path_separator}__init__.py")) # type: ignore + raise TypeError(f"Couldn't find the path for {dotted_path}") from e + + return Path(str(src.origin).rsplit(os.path.sep + "__init__.py", maxsplit=1)[0]) def import_string(dotted_path: str) -> Any: @@ -54,7 +49,7 @@ def import_string(dotted_path: str) -> Any: last name in the path. Raise ImportError if the import failed. Args: - dotted_path (str): The path of the module to import. + dotted_path: The path of the module to import. Raises: ImportError: Could not import the module. @@ -72,8 +67,8 @@ def _cached_import(module_path: str, class_name: str) -> Any: """Import and cache a class from a module. Args: - module_path (str): dotted path to module. - class_name (str): Class or function name. + module_path: dotted path to module. + class_name: Class or function name. Returns: object: The imported class or function diff --git a/tests/unit/test_utils/test_module_loader.py b/tests/unit/test_utils/test_module_loader.py index eee27bfe25..b76315980f 100644 --- a/tests/unit/test_utils/test_module_loader.py +++ b/tests/unit/test_utils/test_module_loader.py @@ -1,4 +1,7 @@ +from pathlib import Path + import pytest +from _pytest.monkeypatch import MonkeyPatch from litestar.config.compression import CompressionConfig from litestar.utils.module_loader import import_string, module_to_os_path @@ -21,3 +24,16 @@ def test_module_path() -> None: with pytest.raises(TypeError): _ = module_to_os_path("litestar.config.compression.Config") _ = module_to_os_path("litestar.config.compression.extra.module") + + +def test_import_non_existing_attribute_raises() -> None: + with pytest.raises(ImportError): + import_string("litestar.app.some_random_string") + + +def test_import_string_cached(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: + tmp_path.joinpath("testmodule.py").write_text("x = 'foo'") + monkeypatch.chdir(tmp_path) + monkeypatch.syspath_prepend(tmp_path) + + assert import_string("testmodule.x") == "foo"