diff --git a/docs/examples/plugins/di_plugin.py b/docs/examples/plugins/di_plugin.py new file mode 100644 index 0000000000..35625d4163 --- /dev/null +++ b/docs/examples/plugins/di_plugin.py @@ -0,0 +1,31 @@ +from inspect import Parameter, Signature +from typing import Any, Dict, Tuple + +from litestar import Litestar, get +from litestar.di import Provide +from litestar.plugins import DIPlugin + + +class MyBaseType: + def __init__(self, param): + self.param = param + + +class MyDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return issubclass(type_, MyBaseType) + + def get_typed_init(self, type_: Any) -> Tuple[Signature, Dict[str, Any]]: + signature = Signature([Parameter(name="param", kind=Parameter.POSITIONAL_OR_KEYWORD)]) + annotations = {"param": str} + return signature, annotations + + +@get("/", dependencies={"injected": Provide(MyBaseType, sync_to_thread=False)}) +async def handler(injected: MyBaseType) -> str: + return injected.param + + +app = Litestar(route_handlers=[handler], plugins=[MyDIPlugin()]) + +# run: /?param=hello diff --git a/docs/usage/plugins.rst b/docs/usage/plugins.rst index 8c4b64ce0a..4911b8da23 100644 --- a/docs/usage/plugins.rst +++ b/docs/usage/plugins.rst @@ -19,7 +19,7 @@ that can interact with the data that is used to instantiate the application inst the contract for plugins that extend serialization functionality of the application. InitPluginProtocol -~~~~~~~~~~~~~~~~~~ +------------------ ``InitPluginProtocol`` defines an interface that allows for customization of the application's initialization process. Init plugins can define dependencies, add route handlers, configure middleware, and much more! @@ -37,7 +37,7 @@ they are provided in the ``plugins`` argument of the :class:`app ` instance is then returned. SerializationPluginProtocol -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +--------------------------- The SerializationPluginProtocol defines a contract for plugins that provide serialization functionality for data types that are otherwise unsupported by the framework. @@ -79,7 +79,7 @@ the plugin, and doesn't otherwise have a ``dto`` or ``return_dto`` defined, the that annotation. Example -------- ++++++++ The following example shows the actual implementation of the ``SerializationPluginProtocol`` for `SQLAlchemy `_ models that is is provided in ``advanced_alchemy``. @@ -106,3 +106,20 @@ subtypes are not created for the same model. If the annotation is not in the ``_type_dto_map`` dictionary, the method creates a new DTO type for the annotation, adds it to the ``_type_dto_map`` dictionary, and returns it. + + +DIPlugin +-------- + +:class:`~litestar.plugins.DIPlugin` can be used to extend Litestar's dependency +injection by providing information about injectable types. + +Its main purpose it to facilitate the injection of callables with unknown signatures, +for example Pydantic's ``BaseModel`` classes; These are not supported natively since, +while they are callables, their type information is not contained within their callable +signature (their :func:`__init__` method). + + +.. literalinclude:: /examples/plugins/di_plugin.py + :language: python + :caption: Dynamically generating signature information for a custom type diff --git a/litestar/app.py b/litestar/app.py index 4f9f002fd1..745514a164 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -488,18 +488,30 @@ def serialization_plugins(self) -> list[SerializationPluginProtocol]: @staticmethod def _get_default_plugins(plugins: list[PluginProtocol]) -> list[PluginProtocol]: + from litestar.plugins.core import MsgspecDIPlugin + + plugins.append(MsgspecDIPlugin()) + with suppress(MissingDependencyException): - from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin + from litestar.contrib.pydantic import ( + PydanticDIPlugin, + PydanticInitPlugin, + PydanticPlugin, + PydanticSchemaPlugin, + ) pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins) pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins) pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins) + pydantic_serialization_plugin_found = any(isinstance(plugin, PydanticDIPlugin) for plugin in plugins) if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found: plugins.append(PydanticPlugin()) elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found: plugins.append(PydanticSchemaPlugin()) elif not pydantic_plugin_found and not pydantic_init_plugin_found: plugins.append(PydanticInitPlugin()) + if not pydantic_plugin_found and not pydantic_serialization_plugin_found: + plugins.append(PydanticDIPlugin()) with suppress(MissingDependencyException): from litestar.contrib.attrs import AttrsSchemaPlugin diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index 122a710539..9bab707c31 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -4,6 +4,7 @@ from litestar.plugins import InitPluginProtocol +from .pydantic_di_plugin import PydanticDIPlugin from .pydantic_dto_factory import PydanticDTO from .pydantic_init_plugin import PydanticInitPlugin from .pydantic_schema_plugin import PydanticSchemaPlugin @@ -14,7 +15,13 @@ from litestar.config.app import AppConfig -__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin", "PydanticPlugin") +__all__ = ( + "PydanticDTO", + "PydanticInitPlugin", + "PydanticSchemaPlugin", + "PydanticPlugin", + "PydanticDIPlugin", +) def _model_dump(model: BaseModel | BaseModelV1, *, by_alias: bool = False) -> dict[str, Any]: @@ -53,6 +60,10 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. """ app_config.plugins.extend( - [PydanticInitPlugin(prefer_alias=self.prefer_alias), PydanticSchemaPlugin(prefer_alias=self.prefer_alias)] + [ + PydanticInitPlugin(prefer_alias=self.prefer_alias), + PydanticSchemaPlugin(prefer_alias=self.prefer_alias), + PydanticDIPlugin(), + ] ) return app_config diff --git a/litestar/contrib/pydantic/pydantic_di_plugin.py b/litestar/contrib/pydantic/pydantic_di_plugin.py new file mode 100644 index 0000000000..2096fd4ab6 --- /dev/null +++ b/litestar/contrib/pydantic/pydantic_di_plugin.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import inspect +from inspect import Signature +from typing import Any + +from litestar.contrib.pydantic.utils import is_pydantic_model_class +from litestar.plugins import DIPlugin + + +class PydanticDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return is_pydantic_model_class(type_) + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + try: + model_fields = dict(type_.model_fields) + except AttributeError: + model_fields = {k: model_field.field_info for k, model_field in type_.__fields__.items()} + + parameters = [ + inspect.Parameter(name=field_name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=Any) + for field_name in model_fields + ] + type_hints = {field_name: Any for field_name in model_fields} + return Signature(parameters), type_hints diff --git a/litestar/handlers/base.py b/litestar/handlers/base.py index da2c53f3b9..aa02b56a4f 100644 --- a/litestar/handlers/base.py +++ b/litestar/handlers/base.py @@ -9,6 +9,7 @@ from litestar.di import Provide from litestar.dto import DTOData from litestar.exceptions import ImproperlyConfiguredException +from litestar.plugins import DIPlugin, PluginRegistry from litestar.serialization import default_deserializer, default_serializer from litestar.types import ( Dependencies, @@ -339,37 +340,60 @@ def resolve_guards(self) -> list[Guard]: return self._resolved_guards + def _get_plugin_registry(self) -> PluginRegistry | None: + from litestar.app import Litestar + + root_owner = self.ownership_layers[0] + if isinstance(root_owner, Litestar): + return root_owner.plugins + return None + def resolve_dependencies(self) -> dict[str, Provide]: """Return all dependencies correlating to handler function's kwargs that exist in the handler's scope.""" + plugin_registry = self._get_plugin_registry() if self._resolved_dependencies is Empty: self._resolved_dependencies = {} - for layer in self.ownership_layers: for key, provider in (layer.dependencies or {}).items(): - if not isinstance(provider, Provide): - provider = Provide(provider) - - self._validate_dependency_is_unique( - dependencies=self._resolved_dependencies, key=key, provider=provider + self._resolved_dependencies[key] = self._resolve_dependency( + key=key, provider=provider, plugin_registry=plugin_registry ) - if not getattr(provider, "parsed_signature", None): - provider.parsed_fn_signature = ParsedSignature.from_fn( - unwrap_partial(provider.dependency), self.resolve_signature_namespace() - ) - - if not getattr(provider, "signature_model", None): - provider.signature_model = SignatureModel.create( - dependency_name_set=self.dependency_name_set, - fn=provider.dependency, - parsed_signature=provider.parsed_fn_signature, - data_dto=self.resolve_data_dto(), - type_decoders=self.resolve_type_decoders(), - ) - - self._resolved_dependencies[key] = provider return self._resolved_dependencies + def _resolve_dependency( + self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None + ) -> Provide: + if not isinstance(provider, Provide): + provider = Provide(provider) + + if self._resolved_dependencies is not Empty: # pragma: no cover + self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider) + + if not getattr(provider, "parsed_fn_signature", None): + dependency = unwrap_partial(provider.dependency) + plugin: DIPlugin | None = None + if plugin_registry: + plugin = next( + (p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)), + None, + ) + if plugin: + signature, init_type_hints = plugin.get_typed_init(dependency) + provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints) + else: + provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace()) + + if not getattr(provider, "signature_model", None): + provider.signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=provider.dependency, + parsed_signature=provider.parsed_fn_signature, + data_dto=self.resolve_data_dto(), + type_decoders=self.resolve_type_decoders(), + ) + return provider + def resolve_middleware(self) -> list[Middleware]: """Build the middleware stack for the RouteHandler and return it. diff --git a/litestar/plugins/__init__.py b/litestar/plugins/__init__.py index 6f71b78b4a..f09310436d 100644 --- a/litestar/plugins/__init__.py +++ b/litestar/plugins/__init__.py @@ -1,6 +1,7 @@ from litestar.plugins.base import ( CLIPlugin, CLIPluginProtocol, + DIPlugin, InitPluginProtocol, OpenAPISchemaPlugin, OpenAPISchemaPluginProtocol, @@ -11,6 +12,7 @@ __all__ = ( "SerializationPluginProtocol", + "DIPlugin", "CLIPlugin", "InitPluginProtocol", "OpenAPISchemaPluginProtocol", diff --git a/litestar/plugins/base.py b/litestar/plugins/base.py index a6b83537b8..afc571efe7 100644 --- a/litestar/plugins/base.py +++ b/litestar/plugins/base.py @@ -1,9 +1,12 @@ from __future__ import annotations +import abc from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Iterator, Protocol, TypeVar, Union, cast, runtime_checkable if TYPE_CHECKING: + from inspect import Signature + from click import Group from litestar._openapi.schema_generation import SchemaCreator @@ -23,6 +26,7 @@ "CLIPlugin", "CLIPluginProtocol", "PluginRegistry", + "DIPlugin", ) @@ -154,6 +158,26 @@ def create_dto_for_type(self, field_definition: FieldDefinition) -> type[Abstrac raise NotImplementedError() +class DIPlugin(abc.ABC): + """Extend dependency injection""" + + @abc.abstractmethod + def has_typed_init(self, type_: Any) -> bool: + """Return ``True`` if ``type_`` has type information available for its + :func:`__init__` method that cannot be extracted from this method's type + annotations (e.g. a Pydantic BaseModel subclass), and + :meth:`DIPlugin.get_typed_init` supports extraction of these annotations. + """ + ... + + @abc.abstractmethod + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + r"""Return signature and type information about the ``type_``\ s :func:`__init__` + method. + """ + ... + + @runtime_checkable class OpenAPISchemaPluginProtocol(Protocol): """Plugin protocol to extend the support of OpenAPI schema generation for non-library types.""" @@ -241,6 +265,7 @@ def is_constrained_field(field_definition: FieldDefinition) -> bool: OpenAPISchemaPluginProtocol, ReceiveRoutePlugin, SerializationPluginProtocol, + DIPlugin, ] PluginT = TypeVar("PluginT", bound=PluginProtocol) @@ -250,9 +275,10 @@ class PluginRegistry: __slots__ = { "init": "Plugins that implement the InitPluginProtocol", "openapi": "Plugins that implement the OpenAPISchemaPluginProtocol", - "receive_route": "ReceiveRoutePlugin types", + "receive_route": "ReceiveRoutePlugin instances", "serialization": "Plugins that implement the SerializationPluginProtocol", "cli": "Plugins that implement the CLIPluginProtocol", + "di": "DIPlugin instances", "_plugins_by_type": None, "_plugins": None, "_get_plugins_of_type": None, @@ -266,6 +292,7 @@ def __init__(self, plugins: list[PluginProtocol]) -> None: self.receive_route = tuple(p for p in plugins if isinstance(p, ReceiveRoutePlugin)) self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPluginProtocol)) self.cli = tuple(p for p in plugins if isinstance(p, CLIPluginProtocol)) + self.di = tuple(p for p in plugins if isinstance(p, DIPlugin)) def get(self, type_: type[PluginT] | str) -> PluginT: """Return the registered plugin of ``type_``. diff --git a/litestar/plugins/core.py b/litestar/plugins/core.py new file mode 100644 index 0000000000..d25d6d661b --- /dev/null +++ b/litestar/plugins/core.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import inspect +from inspect import Signature +from typing import Any + +import msgspec + +from litestar.plugins import DIPlugin + +__all__ = ("MsgspecDIPlugin",) + + +class MsgspecDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return type(type_) is type(msgspec.Struct) # noqa: E721 + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + parameters = [] + type_hints = {} + for field_info in msgspec.structs.fields(type_): + type_hints[field_info.name] = field_info.type + parameters.append( + inspect.Parameter( + name=field_info.name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.type, + default=field_info.default, + ) + ) + return inspect.Signature(parameters), type_hints diff --git a/pyproject.toml b/pyproject.toml index 655feceff1..736a90e55d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ markers = [ "server_integration: Test integration with ASGI server" ] xfail_strict = true +testpaths = ["tests", "docs/examples/testing"] [tool.mypy] packages = ["litestar", "tests"] diff --git a/tests/e2e/test_dependency_injection/test_injection_of_classes.py b/tests/e2e/test_dependency_injection/test_injection_of_classes.py index 795c4f657e..6091015b0e 100644 --- a/tests/e2e/test_dependency_injection/test_injection_of_classes.py +++ b/tests/e2e/test_dependency_injection/test_injection_of_classes.py @@ -1,3 +1,7 @@ +from dataclasses import dataclass + +import msgspec + from litestar import Controller, get from litestar.di import Provide from litestar.testing import create_test_client @@ -37,3 +41,32 @@ def test_function(self, container: HandlerDependency) -> str: with create_test_client(MyController) as client: response = client.get(f"/test/{path_param_value}?query_param={query_param_value}") assert response.text == "15" + + +def test_inject_dataclass() -> None: + @dataclass + class Foo: + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} + + +def test_inject_msgspec_struct() -> None: + class Foo(msgspec.Struct): + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} diff --git a/tests/examples/test_plugins/test_di_plugin.py b/tests/examples/test_plugins/test_di_plugin.py new file mode 100644 index 0000000000..c42786b068 --- /dev/null +++ b/tests/examples/test_plugins/test_di_plugin.py @@ -0,0 +1,10 @@ +from docs.examples.plugins.di_plugin import app + +from litestar.testing import TestClient + + +def test_di_plugin_example() -> None: + with TestClient(app) as client: + res = client.get("/?param=hello") + assert res.status_code == 200 + assert res.text == "hello" diff --git a/tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py b/tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py new file mode 100644 index 0000000000..bc6f526b68 --- /dev/null +++ b/tests/unit/test_contrib/test_attrs/test_inject_attrs_class.py @@ -0,0 +1,20 @@ +from attrs import define + +from litestar import get +from litestar.di import Provide +from litestar.testing import create_test_client + + +def test_inject_attrs_class() -> None: + @define + class Foo: + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} diff --git a/tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py b/tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py new file mode 100644 index 0000000000..ed7a53267c --- /dev/null +++ b/tests/unit/test_contrib/test_pydantic/test_inject_pydantic.py @@ -0,0 +1,22 @@ +import pydantic as pydantic_v2 +import pytest +from pydantic import v1 as pydantic_v1 + +from litestar import get +from litestar.di import Provide +from litestar.testing import create_test_client + + +@pytest.mark.parametrize("base_model", [pydantic_v1.BaseModel, pydantic_v2.BaseModel]) +def test_inject_pydantic_model(base_model: type) -> None: + class Foo(base_model): # type: ignore[misc] + bar: str + + @get("/", dependencies={"foo": Provide(Foo, sync_to_thread=False)}) + async def handler(foo: Foo) -> Foo: + return foo + + with create_test_client([handler]) as client: + res = client.get("/?bar=baz") + assert res.status_code == 200 + assert res.json() == {"bar": "baz"} diff --git a/tests/unit/test_handlers/test_base_handlers/test_resolution.py b/tests/unit/test_handlers/test_base_handlers/test_resolution.py index 079c4a2cdc..12809610a4 100644 --- a/tests/unit/test_handlers/test_base_handlers/test_resolution.py +++ b/tests/unit/test_handlers/test_base_handlers/test_resolution.py @@ -52,3 +52,18 @@ async def handler(self) -> None: "controller": Provide(controller_dependency), "handler": Provide(handler_dependency), } + + +def test_resolve_dependencies_cached() -> None: + dependency = Provide(function_factory()) + + @get(dependencies={"foo": dependency}) + async def handler() -> None: + pass + + @get(dependencies={"foo": dependency}) + async def handler_2() -> None: + pass + + assert handler.resolve_dependencies() is handler.resolve_dependencies() + assert handler_2.resolve_dependencies() is handler_2.resolve_dependencies() diff --git a/tests/unit/test_plugins/test_base.py b/tests/unit/test_plugins/test_base.py index f1d146f1d1..8598eb5c64 100644 --- a/tests/unit/test_plugins/test_base.py +++ b/tests/unit/test_plugins/test_base.py @@ -8,9 +8,10 @@ from litestar import Litestar, MediaType, get from litestar.constants import UNDEFINED_SENTINELS from litestar.contrib.attrs import AttrsSchemaPlugin -from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin +from litestar.contrib.pydantic import PydanticDIPlugin, PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin from litestar.contrib.sqlalchemy.plugins import SQLAlchemySerializationPlugin from litestar.plugins import CLIPluginProtocol, InitPluginProtocol, OpenAPISchemaPlugin, PluginRegistry +from litestar.plugins.core import MsgspecDIPlugin from litestar.testing import create_test_client from litestar.typing import FieldDefinition @@ -121,6 +122,17 @@ def test_app_get_default_plugins( any_pydantic = bool(init_plugin) or bool(schema_plugin) default_plugins = Litestar._get_default_plugins(plugins) # type: ignore[arg-type] if not any_pydantic: - assert {type(p) for p in default_plugins} == {PydanticPlugin, AttrsSchemaPlugin} + assert {type(p) for p in default_plugins} == { + PydanticPlugin, + AttrsSchemaPlugin, + PydanticDIPlugin, + MsgspecDIPlugin, + } else: - assert {type(p) for p in default_plugins} == {PydanticInitPlugin, PydanticSchemaPlugin, AttrsSchemaPlugin} + assert {type(p) for p in default_plugins} == { + PydanticInitPlugin, + PydanticSchemaPlugin, + AttrsSchemaPlugin, + PydanticDIPlugin, + MsgspecDIPlugin, + }