Skip to content

Commit

Permalink
feat(DI): Support externally typed classes as dependency providers (#…
Browse files Browse the repository at this point in the history
…3066)

* Support injecting externally typed classes
  • Loading branch information
provinzkraut committed Feb 6, 2024
1 parent 1966c4d commit e6eb9f2
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 32 deletions.
31 changes: 31 additions & 0 deletions docs/examples/plugins/di_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from inspect import Parameter, Signature
from typing import Any, Dict, Tuple

from litestar import Litestar, get
from litestar.di import Provide
from litestar.plugins import DIPlugin


class MyBaseType:
def __init__(self, param):
self.param = param


class MyDIPlugin(DIPlugin):
def has_typed_init(self, type_: Any) -> bool:
return issubclass(type_, MyBaseType)

def get_typed_init(self, type_: Any) -> Tuple[Signature, Dict[str, Any]]:
signature = Signature([Parameter(name="param", kind=Parameter.POSITIONAL_OR_KEYWORD)])
annotations = {"param": str}
return signature, annotations


@get("/", dependencies={"injected": Provide(MyBaseType, sync_to_thread=False)})
async def handler(injected: MyBaseType) -> str:
return injected.param


app = Litestar(route_handlers=[handler], plugins=[MyDIPlugin()])

# run: /?param=hello
25 changes: 21 additions & 4 deletions docs/usage/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ that can interact with the data that is used to instantiate the application inst
the contract for plugins that extend serialization functionality of the application.

InitPluginProtocol
~~~~~~~~~~~~~~~~~~
------------------

``InitPluginProtocol`` defines an interface that allows for customization of the application's initialization process.
Init plugins can define dependencies, add route handlers, configure middleware, and much more!
Expand All @@ -37,7 +37,7 @@ they are provided in the ``plugins`` argument of the :class:`app <litestar.app.L
authors should make it clear in their documentation if their plugin should be invoked before or after other plugins.

Example
-------
+++++++

The following example shows a simple plugin that adds a route handler, and a dependency to the application.

Expand All @@ -54,7 +54,7 @@ is provided by the ``get_name()`` function, and ``route_handlers`` is updated to
function. The modified :class:`AppConfig <litestar.config.app.AppConfig>` instance is then returned.

SerializationPluginProtocol
~~~~~~~~~~~~~~~~~~~~~~~~~~~
---------------------------

The SerializationPluginProtocol defines a contract for plugins that provide serialization functionality for data types
that are otherwise unsupported by the framework.
Expand All @@ -79,7 +79,7 @@ the plugin, and doesn't otherwise have a ``dto`` or ``return_dto`` defined, the
that annotation.

Example
-------
+++++++

The following example shows the actual implementation of the ``SerializationPluginProtocol`` for
`SQLAlchemy <https://www.sqlalchemy.org/>`_ models that is is provided in ``advanced_alchemy``.
Expand All @@ -106,3 +106,20 @@ subtypes are not created for the same model.

If the annotation is not in the ``_type_dto_map`` dictionary, the method creates a new DTO type for the annotation,
adds it to the ``_type_dto_map`` dictionary, and returns it.


DIPlugin
--------

:class:`~litestar.plugins.DIPlugin` can be used to extend Litestar's dependency
injection by providing information about injectable types.

Its main purpose it to facilitate the injection of callables with unknown signatures,
for example Pydantic's ``BaseModel`` classes; These are not supported natively since,
while they are callables, their type information is not contained within their callable
signature (their :func:`__init__` method).


.. literalinclude:: /examples/plugins/di_plugin.py
:language: python
:caption: Dynamically generating signature information for a custom type
14 changes: 13 additions & 1 deletion litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,18 +488,30 @@ def serialization_plugins(self) -> list[SerializationPluginProtocol]:

@staticmethod
def _get_default_plugins(plugins: list[PluginProtocol]) -> list[PluginProtocol]:
from litestar.plugins.core import MsgspecDIPlugin

plugins.append(MsgspecDIPlugin())

with suppress(MissingDependencyException):
from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin
from litestar.contrib.pydantic import (
PydanticDIPlugin,
PydanticInitPlugin,
PydanticPlugin,
PydanticSchemaPlugin,
)

pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins)
pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins)
pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins)
pydantic_serialization_plugin_found = any(isinstance(plugin, PydanticDIPlugin) for plugin in plugins)
if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found:
plugins.append(PydanticPlugin())
elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found:
plugins.append(PydanticSchemaPlugin())
elif not pydantic_plugin_found and not pydantic_init_plugin_found:
plugins.append(PydanticInitPlugin())
if not pydantic_plugin_found and not pydantic_serialization_plugin_found:
plugins.append(PydanticDIPlugin())
with suppress(MissingDependencyException):
from litestar.contrib.attrs import AttrsSchemaPlugin

Expand Down
15 changes: 13 additions & 2 deletions litestar/contrib/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from litestar.plugins import InitPluginProtocol

from .pydantic_di_plugin import PydanticDIPlugin
from .pydantic_dto_factory import PydanticDTO
from .pydantic_init_plugin import PydanticInitPlugin
from .pydantic_schema_plugin import PydanticSchemaPlugin
Expand All @@ -14,7 +15,13 @@

from litestar.config.app import AppConfig

__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin", "PydanticPlugin")
__all__ = (
"PydanticDTO",
"PydanticInitPlugin",
"PydanticSchemaPlugin",
"PydanticPlugin",
"PydanticDIPlugin",
)


def _model_dump(model: BaseModel | BaseModelV1, *, by_alias: bool = False) -> dict[str, Any]:
Expand Down Expand Up @@ -53,6 +60,10 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
"""
app_config.plugins.extend(
[PydanticInitPlugin(prefer_alias=self.prefer_alias), PydanticSchemaPlugin(prefer_alias=self.prefer_alias)]
[
PydanticInitPlugin(prefer_alias=self.prefer_alias),
PydanticSchemaPlugin(prefer_alias=self.prefer_alias),
PydanticDIPlugin(),
]
)
return app_config
26 changes: 26 additions & 0 deletions litestar/contrib/pydantic/pydantic_di_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import inspect
from inspect import Signature
from typing import Any

from litestar.contrib.pydantic.utils import is_pydantic_model_class
from litestar.plugins import DIPlugin


class PydanticDIPlugin(DIPlugin):
def has_typed_init(self, type_: Any) -> bool:
return is_pydantic_model_class(type_)

def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]:
try:
model_fields = dict(type_.model_fields)
except AttributeError:
model_fields = {k: model_field.field_info for k, model_field in type_.__fields__.items()}

parameters = [
inspect.Parameter(name=field_name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=Any)
for field_name in model_fields
]
type_hints = {field_name: Any for field_name in model_fields}
return Signature(parameters), type_hints
66 changes: 45 additions & 21 deletions litestar/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from litestar.di import Provide
from litestar.dto import DTOData
from litestar.exceptions import ImproperlyConfiguredException
from litestar.plugins import DIPlugin, PluginRegistry
from litestar.serialization import default_deserializer, default_serializer
from litestar.types import (
Dependencies,
Expand Down Expand Up @@ -339,37 +340,60 @@ def resolve_guards(self) -> list[Guard]:

return self._resolved_guards

def _get_plugin_registry(self) -> PluginRegistry | None:
from litestar.app import Litestar

root_owner = self.ownership_layers[0]
if isinstance(root_owner, Litestar):
return root_owner.plugins
return None

def resolve_dependencies(self) -> dict[str, Provide]:
"""Return all dependencies correlating to handler function's kwargs that exist in the handler's scope."""
plugin_registry = self._get_plugin_registry()
if self._resolved_dependencies is Empty:
self._resolved_dependencies = {}

for layer in self.ownership_layers:
for key, provider in (layer.dependencies or {}).items():
if not isinstance(provider, Provide):
provider = Provide(provider)

self._validate_dependency_is_unique(
dependencies=self._resolved_dependencies, key=key, provider=provider
self._resolved_dependencies[key] = self._resolve_dependency(
key=key, provider=provider, plugin_registry=plugin_registry
)

if not getattr(provider, "parsed_signature", None):
provider.parsed_fn_signature = ParsedSignature.from_fn(
unwrap_partial(provider.dependency), self.resolve_signature_namespace()
)

if not getattr(provider, "signature_model", None):
provider.signature_model = SignatureModel.create(
dependency_name_set=self.dependency_name_set,
fn=provider.dependency,
parsed_signature=provider.parsed_fn_signature,
data_dto=self.resolve_data_dto(),
type_decoders=self.resolve_type_decoders(),
)

self._resolved_dependencies[key] = provider
return self._resolved_dependencies

def _resolve_dependency(
self, key: str, provider: Provide | AnyCallable, plugin_registry: PluginRegistry | None
) -> Provide:
if not isinstance(provider, Provide):
provider = Provide(provider)

if self._resolved_dependencies is not Empty: # pragma: no cover
self._validate_dependency_is_unique(dependencies=self._resolved_dependencies, key=key, provider=provider)

if not getattr(provider, "parsed_fn_signature", None):
dependency = unwrap_partial(provider.dependency)
plugin: DIPlugin | None = None
if plugin_registry:
plugin = next(
(p for p in plugin_registry.di if isinstance(p, DIPlugin) and p.has_typed_init(dependency)),
None,
)
if plugin:
signature, init_type_hints = plugin.get_typed_init(dependency)
provider.parsed_fn_signature = ParsedSignature.from_signature(signature, init_type_hints)
else:
provider.parsed_fn_signature = ParsedSignature.from_fn(dependency, self.resolve_signature_namespace())

if not getattr(provider, "signature_model", None):
provider.signature_model = SignatureModel.create(
dependency_name_set=self.dependency_name_set,
fn=provider.dependency,
parsed_signature=provider.parsed_fn_signature,
data_dto=self.resolve_data_dto(),
type_decoders=self.resolve_type_decoders(),
)
return provider

def resolve_middleware(self) -> list[Middleware]:
"""Build the middleware stack for the RouteHandler and return it.
Expand Down
2 changes: 2 additions & 0 deletions litestar/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from litestar.plugins.base import (
CLIPlugin,
CLIPluginProtocol,
DIPlugin,
InitPluginProtocol,
OpenAPISchemaPlugin,
OpenAPISchemaPluginProtocol,
Expand All @@ -11,6 +12,7 @@

__all__ = (
"SerializationPluginProtocol",
"DIPlugin",
"CLIPlugin",
"InitPluginProtocol",
"OpenAPISchemaPluginProtocol",
Expand Down
29 changes: 28 additions & 1 deletion litestar/plugins/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import abc
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Iterator, Protocol, TypeVar, Union, cast, runtime_checkable

if TYPE_CHECKING:
from inspect import Signature

from click import Group

from litestar._openapi.schema_generation import SchemaCreator
Expand All @@ -23,6 +26,7 @@
"CLIPlugin",
"CLIPluginProtocol",
"PluginRegistry",
"DIPlugin",
)


Expand Down Expand Up @@ -154,6 +158,26 @@ def create_dto_for_type(self, field_definition: FieldDefinition) -> type[Abstrac
raise NotImplementedError()


class DIPlugin(abc.ABC):
"""Extend dependency injection"""

@abc.abstractmethod
def has_typed_init(self, type_: Any) -> bool:
"""Return ``True`` if ``type_`` has type information available for its
:func:`__init__` method that cannot be extracted from this method's type
annotations (e.g. a Pydantic BaseModel subclass), and
:meth:`DIPlugin.get_typed_init` supports extraction of these annotations.
"""
...

@abc.abstractmethod
def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]:
r"""Return signature and type information about the ``type_``\ s :func:`__init__`
method.
"""
...


@runtime_checkable
class OpenAPISchemaPluginProtocol(Protocol):
"""Plugin protocol to extend the support of OpenAPI schema generation for non-library types."""
Expand Down Expand Up @@ -241,6 +265,7 @@ def is_constrained_field(field_definition: FieldDefinition) -> bool:
OpenAPISchemaPluginProtocol,
ReceiveRoutePlugin,
SerializationPluginProtocol,
DIPlugin,
]

PluginT = TypeVar("PluginT", bound=PluginProtocol)
Expand All @@ -250,9 +275,10 @@ class PluginRegistry:
__slots__ = {
"init": "Plugins that implement the InitPluginProtocol",
"openapi": "Plugins that implement the OpenAPISchemaPluginProtocol",
"receive_route": "ReceiveRoutePlugin types",
"receive_route": "ReceiveRoutePlugin instances",
"serialization": "Plugins that implement the SerializationPluginProtocol",
"cli": "Plugins that implement the CLIPluginProtocol",
"di": "DIPlugin instances",
"_plugins_by_type": None,
"_plugins": None,
"_get_plugins_of_type": None,
Expand All @@ -266,6 +292,7 @@ def __init__(self, plugins: list[PluginProtocol]) -> None:
self.receive_route = tuple(p for p in plugins if isinstance(p, ReceiveRoutePlugin))
self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPluginProtocol))
self.cli = tuple(p for p in plugins if isinstance(p, CLIPluginProtocol))
self.di = tuple(p for p in plugins if isinstance(p, DIPlugin))

def get(self, type_: type[PluginT] | str) -> PluginT:
"""Return the registered plugin of ``type_``.
Expand Down
31 changes: 31 additions & 0 deletions litestar/plugins/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import inspect
from inspect import Signature
from typing import Any

import msgspec

from litestar.plugins import DIPlugin

__all__ = ("MsgspecDIPlugin",)


class MsgspecDIPlugin(DIPlugin):
def has_typed_init(self, type_: Any) -> bool:
return type(type_) is type(msgspec.Struct) # noqa: E721

def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]:
parameters = []
type_hints = {}
for field_info in msgspec.structs.fields(type_):
type_hints[field_info.name] = field_info.type
parameters.append(
inspect.Parameter(
name=field_info.name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=field_info.type,
default=field_info.default,
)
)
return inspect.Signature(parameters), type_hints
Loading

0 comments on commit e6eb9f2

Please sign in to comment.