Skip to content

Commit

Permalink
refactor(utils): Remove Ref (#2624)
Browse files Browse the repository at this point in the history
* Remove Ref usage from route handlers
* Remove Ref usage from Provide
* Remove Ref usage from compression middleware
* Remove Ref
---------

Signed-off-by: Janek Nouvertné <[email protected]>
  • Loading branch information
provinzkraut authored Nov 5, 2023
1 parent 6d0c86f commit 1ad1997
Show file tree
Hide file tree
Showing 22 changed files with 76 additions and 88 deletions.
2 changes: 1 addition & 1 deletion litestar/_openapi/path_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_description_for_handler(route_handler: HTTPRouteHandler, use_handler_doc
"""
handler_description = route_handler.description
if handler_description is None and use_handler_docstrings:
fn = unwrap_partial(route_handler.fn.value)
fn = unwrap_partial(route_handler.fn)
return cleandoc(fn.__doc__) if fn.__doc__ else None
return handler_description

Expand Down
2 changes: 1 addition & 1 deletion litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def url_for_static_asset(self, name: str, file_path: str) -> str:
if handler_index is None:
raise NoRouteMatchFoundException(f"Static handler {name} can not be found")

handler_fn = cast("AnyCallable", handler_index["handler"].fn.value)
handler_fn = cast("AnyCallable", handler_index["handler"].fn)
if not isinstance(handler_fn, StaticFiles):
raise NoRouteMatchFoundException(f"Handler with name {name} is not a static files handler")

Expand Down
2 changes: 1 addition & 1 deletion litestar/cli/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def routes_command(app: Litestar) -> None: # pragma: no cover
f"[blue]{handler.name or handler.handler_name}[/blue]",
]

if inspect.iscoroutinefunction(unwrap_partial(handler.fn.value)):
if inspect.iscoroutinefunction(unwrap_partial(handler.fn)):
handler_info.append("[magenta]async[/magenta]")
else:
handler_info.append("[yellow]sync[/yellow]")
Expand Down
6 changes: 4 additions & 2 deletions litestar/controller.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import types
from collections import defaultdict
from copy import deepcopy
from functools import partial
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast

Expand Down Expand Up @@ -209,7 +209,9 @@ def get_route_handlers(self) -> list[BaseRouteHandler]:
self_handlers.sort(key=attrgetter("handler_id"))
for self_handler in self_handlers:
route_handler = deepcopy(self_handler)
route_handler.fn.value = partial(route_handler.fn.value, self)
# at the point we get a reference to the handler function, it's unbound, so
# we replace it with a regular bound method here
route_handler._fn = types.MethodType(route_handler._fn, self)
route_handler.owner = self
route_handlers.append(route_handler)

Expand Down
12 changes: 6 additions & 6 deletions litestar/di.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from litestar.exceptions import ImproperlyConfiguredException
from litestar.types import Empty
from litestar.utils import Ref, ensure_async_callable
from litestar.utils import ensure_async_callable
from litestar.utils.predicates import is_async_callable, is_sync_or_async_generator
from litestar.utils.warnings import (
warn_implicit_sync_to_thread,
Expand Down Expand Up @@ -35,7 +35,7 @@ class Provide:
)

signature_model: type[SignatureModel]
dependency: Ref[AnyCallable]
dependency: AnyCallable

def __init__(
self,
Expand Down Expand Up @@ -64,10 +64,10 @@ def __init__(
warn_implicit_sync_to_thread(dependency, stacklevel=3)

if sync_to_thread and has_sync_callable:
self.dependency = Ref["AnyCallable"](ensure_async_callable(dependency)) # pyright: ignore
self.dependency = ensure_async_callable(dependency) # pyright: ignore
self.has_sync_callable = False
else:
self.dependency = Ref["AnyCallable"](dependency) # pyright: ignore
self.dependency = dependency # pyright: ignore
self.has_sync_callable = has_sync_callable

self.sync_to_thread = bool(sync_to_thread)
Expand All @@ -81,9 +81,9 @@ async def __call__(self, **kwargs: Any) -> Any:
return self.value

if self.has_sync_callable:
value = self.dependency.value(**kwargs)
value = self.dependency(**kwargs)
else:
value = await self.dependency.value(**kwargs)
value = await self.dependency(**kwargs)

if self.use_cache:
self.value = value
Expand Down
2 changes: 1 addition & 1 deletion litestar/handlers/asgi_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _validate_handler_function(self) -> None:
raise ImproperlyConfiguredException(
"ASGI handler functions should define 'scope', 'send' and 'receive' arguments"
)
if not is_async_callable(self.fn.value):
if not is_async_callable(self.fn):
raise ImproperlyConfiguredException("Functions decorated with 'asgi' must be async functions")


Expand Down
19 changes: 9 additions & 10 deletions litestar/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
Empty,
ExceptionHandlersMap,
Guard,
MaybePartial,
Middleware,
TypeDecodersSequence,
TypeEncodersMap,
)
from litestar.typing import FieldDefinition
from litestar.utils import Ref, ensure_async_callable, get_name, normalize_path
from litestar.utils import ensure_async_callable, get_name, normalize_path
from litestar.utils.helpers import unwrap_partial
from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace

Expand Down Expand Up @@ -155,7 +154,7 @@ def __init__(

def __call__(self, fn: AsyncAnyCallable) -> Self:
"""Replace a function with itself."""
self._fn = Ref["MaybePartial[AsyncAnyCallable]"](fn)
self._fn = fn
return self

@property
Expand Down Expand Up @@ -194,15 +193,15 @@ def signature_model(self) -> type[SignatureModel]:
if self._signature_model is Empty:
self._signature_model = SignatureModel.create(
dependency_name_set=self.dependency_name_set,
fn=cast("AnyCallable", self.fn.value),
fn=cast("AnyCallable", self.fn),
parsed_signature=self.parsed_fn_signature,
data_dto=self.resolve_data_dto(),
type_decoders=self.resolve_type_decoders(),
)
return cast("type[SignatureModel]", self._signature_model)

@property
def fn(self) -> Ref[MaybePartial[AsyncAnyCallable]]:
def fn(self) -> AsyncAnyCallable:
"""Get the handler function.
Raises:
Expand All @@ -226,7 +225,7 @@ def parsed_fn_signature(self) -> ParsedSignature:
"""
if self._parsed_fn_signature is Empty:
self._parsed_fn_signature = ParsedSignature.from_fn(
unwrap_partial(self.fn.value), self.resolve_signature_namespace()
unwrap_partial(self.fn), self.resolve_signature_namespace()
)

return cast("ParsedSignature", self._parsed_fn_signature)
Expand All @@ -253,7 +252,7 @@ def handler_name(self) -> str:
Returns:
Name of the handler function
"""
return get_name(unwrap_partial(self.fn.value))
return get_name(unwrap_partial(self.fn))

@property
def dependency_name_set(self) -> set[str]:
Expand Down Expand Up @@ -358,9 +357,9 @@ def resolve_dependencies(self) -> dict[str, Provide]:
if not getattr(provider, "signature_model", None):
provider.signature_model = SignatureModel.create(
dependency_name_set=self.dependency_name_set,
fn=provider.dependency.value,
fn=provider.dependency,
parsed_signature=ParsedSignature.from_fn(
unwrap_partial(provider.dependency.value), self.resolve_signature_namespace()
unwrap_partial(provider.dependency), self.resolve_signature_namespace()
),
data_dto=self.resolve_data_dto(),
type_decoders=self.resolve_type_decoders(),
Expand Down Expand Up @@ -537,7 +536,7 @@ def __str__(self) -> str:
A string
"""
target: type[AsyncAnyCallable] | AsyncAnyCallable # pyright: ignore
target = unwrap_partial(self.fn.value)
target = unwrap_partial(self.fn)
if not hasattr(target, "__qualname__"):
target = type(target)
return f"{target.__module__}.{target.__qualname__}"
4 changes: 2 additions & 2 deletions litestar/handlers/http_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,10 @@ def on_registration(self, app: Litestar) -> None:
super().on_registration(app)
self.resolve_after_response()
self.resolve_include_in_schema()
self.has_sync_callable = not is_async_callable(self.fn.value)
self.has_sync_callable = not is_async_callable(self.fn)

if self.has_sync_callable and self.sync_to_thread:
self.fn.value = ensure_async_callable(self.fn.value)
self._fn = ensure_async_callable(self.fn)
self.has_sync_callable = False

def _validate_handler_function(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litestar/handlers/websocket_handlers/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def signature_model(self) -> type[SignatureModel]:
if self._signature_model is Empty:
self._signature_model = SignatureModel.create(
dependency_name_set=self.dependency_name_set,
fn=cast("AnyCallable", self.fn.value),
fn=cast("AnyCallable", self.fn),
parsed_signature=self.parsed_fn_signature,
type_decoders=self.resolve_type_decoders(),
)
Expand Down
2 changes: 1 addition & 1 deletion litestar/handlers/websocket_handlers/route_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _validate_handler_function(self) -> None:
if param in self.parsed_fn_signature.parameters:
raise ImproperlyConfiguredException(f"The {param} kwarg is not supported with websocket handlers")

if not is_async_callable(self.fn.value):
if not is_async_callable(self.fn):
raise ImproperlyConfiguredException("Functions decorated with 'websocket' must be async functions")


Expand Down
32 changes: 17 additions & 15 deletions litestar/middleware/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from gzip import GzipFile
from io import BytesIO
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal

from litestar.constants import SCOPE_STATE_IS_CACHED, SCOPE_STATE_RESPONSE_COMPRESSED
from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import CompressionEncoding, ScopeType
from litestar.exceptions import MissingDependencyException
from litestar.middleware.base import AbstractMiddleware
from litestar.utils import Ref, get_litestar_scope_state, set_litestar_scope_state
from litestar.utils import get_litestar_scope_state, set_litestar_scope_state

__all__ = ("CompressionFacade", "CompressionMiddleware")

Expand Down Expand Up @@ -173,8 +173,8 @@ def create_compression_send_wrapper(
bytes_buffer = BytesIO()
facade = CompressionFacade(buffer=bytes_buffer, compression_encoding=compression_encoding, config=self.config)

initial_message = Ref[Optional["HTTPResponseStartEvent"]](None)
started = Ref[bool](False)
initial_message: HTTPResponseStartEvent | None = None
started = False

_own_encoding = compression_encoding.encode("latin-1")

Expand All @@ -184,24 +184,26 @@ async def send_wrapper(message: Message) -> None:
Args:
message (Message): An ASGI Message.
"""
nonlocal started
nonlocal initial_message

if message["type"] == "http.response.start":
initial_message.value = message
initial_message = message
return

if initial_message.value and get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED):
await send(initial_message.value)
if initial_message and get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED):
await send(initial_message)
await send(message)
return

if initial_message.value and message["type"] == "http.response.body":
if initial_message and message["type"] == "http.response.body":
body = message["body"]
more_body = message.get("more_body")

if not started.value:
started.value = True
if not started:
started = True
if more_body:
headers = MutableScopeHeaders(initial_message.value)
headers = MutableScopeHeaders(initial_message)
headers["Content-Encoding"] = compression_encoding
headers.extend_header_value("vary", "Accept-Encoding")
del headers["Content-Length"]
Expand All @@ -212,26 +214,26 @@ async def send_wrapper(message: Message) -> None:
message["body"] = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate()
await send(initial_message.value)
await send(initial_message)
await send(message)

elif len(body) >= self.config.minimum_size:
facade.write(body)
facade.close()
body = bytes_buffer.getvalue()

headers = MutableScopeHeaders(initial_message.value)
headers = MutableScopeHeaders(initial_message)
headers["Content-Encoding"] = compression_encoding
headers["Content-Length"] = str(len(body))
headers.extend_header_value("vary", "Accept-Encoding")
message["body"] = body
set_litestar_scope_state(scope, SCOPE_STATE_RESPONSE_COMPRESSED, True)

await send(initial_message.value)
await send(initial_message)
await send(message)

else:
await send(initial_message.value)
await send(initial_message)
await send(message)

else:
Expand Down
2 changes: 1 addition & 1 deletion litestar/routes/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any](scope=scope, receive=receive)
await self.route_handler.authorize_connection(connection=connection)

await self.route_handler.fn.value(scope=scope, receive=receive, send=send)
await self.route_handler.fn(scope=scope, receive=receive, send=send)
8 changes: 4 additions & 4 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ async def _get_response_data(
if cleanup_group:
async with cleanup_group:
data = (
route_handler.fn.value(**parsed_kwargs)
route_handler.fn(**parsed_kwargs)
if route_handler.has_sync_callable
else await route_handler.fn.value(**parsed_kwargs)
else await route_handler.fn(**parsed_kwargs)
)
elif route_handler.has_sync_callable:
data = route_handler.fn.value(**parsed_kwargs)
data = route_handler.fn(**parsed_kwargs)
else:
data = await route_handler.fn.value(**parsed_kwargs)
data = await route_handler.fn(**parsed_kwargs)

return data, cleanup_group

Expand Down
4 changes: 2 additions & 2 deletions litestar/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N

if cleanup_group:
async with cleanup_group:
await self.route_handler.fn.value(**parsed_kwargs)
await self.route_handler.fn(**parsed_kwargs)
await cleanup_group.cleanup()
else:
await self.route_handler.fn.value(**parsed_kwargs)
await self.route_handler.fn(**parsed_kwargs)
3 changes: 1 addition & 2 deletions litestar/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from litestar.utils.deprecation import deprecated, warn_deprecation

from .helpers import Ref, get_enum_string_value, get_name, unique_name_for_scope, url_quote
from .helpers import get_enum_string_value, get_name, unique_name_for_scope, url_quote
from .path import join_paths, normalize_path
from .predicates import (
is_annotated_type,
Expand Down Expand Up @@ -33,7 +33,6 @@
__all__ = (
"ensure_async_callable",
"AsyncIteratorWrapper",
"Ref",
"delete_litestar_scope_state",
"deprecated",
"find_index",
Expand Down
14 changes: 1 addition & 13 deletions litestar/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
from urllib.parse import quote

from litestar.utils.typing import get_origin_or_inner_type
Expand All @@ -14,7 +13,6 @@
from litestar.types import MaybePartial

__all__ = (
"Ref",
"get_enum_string_value",
"get_name",
"unwrap_partial",
Expand Down Expand Up @@ -60,16 +58,6 @@ def get_enum_string_value(value: Enum | str) -> str:
return value.value if isinstance(value, Enum) else value # type:ignore


@dataclass
class Ref(Generic[T]):
"""A helper class that encapsulates a value."""

__slots__ = ("value",)

value: T
"""The value wrapped by the ref."""


def unwrap_partial(value: MaybePartial[T]) -> T:
"""Unwraps a partial, returning the underlying callable.
Expand Down
Loading

0 comments on commit 1ad1997

Please sign in to comment.