Skip to content

Commit

Permalink
chore: simplify unwrap_partial (#2590)
Browse files Browse the repository at this point in the history
* simplify unwrap_partial
* fix async_partial

---------

Signed-off-by: Janek Nouvertné <[email protected]>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
provinzkraut and sourcery-ai[bot] authored Nov 4, 2023
1 parent 580b76c commit 0039e99
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 27 deletions.
8 changes: 4 additions & 4 deletions litestar/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from urllib.parse import quote

Expand Down Expand Up @@ -81,10 +82,9 @@ def unwrap_partial(value: MaybePartial[T]) -> T:
Returns:
Callable
"""
output: Any = value.func if hasattr(value, "func") else value # pyright: ignore
while hasattr(output, "func"):
output = output.func
return cast("T", output)
from litestar.utils.sync import async_partial

return cast("T", value.func if isinstance(value, (partial, async_partial)) else value)


def filter_cookies(local_cookies: Iterable[Cookie], layered_cookies: Iterable[Cookie]) -> list[Cookie]:
Expand Down
5 changes: 2 additions & 3 deletions litestar/utils/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import defaultdict, deque
from collections.abc import Iterable as CollectionsIterable
from dataclasses import is_dataclass
from functools import partial
from inspect import isasyncgenfunction, isclass, isgeneratorfunction
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -39,6 +38,7 @@
from litestar.constants import UNDEFINED_SENTINELS
from litestar.types import Empty
from litestar.types.builtin_types import NoneType, UnionTypes
from litestar.utils.helpers import unwrap_partial
from litestar.utils.typing import get_origin_or_inner_type

if TYPE_CHECKING:
Expand Down Expand Up @@ -86,8 +86,7 @@ class instances with ``async def __call__()`` defined.
Returns:
Bool determining if type of ``value`` is an awaitable.
"""
while isinstance(value, partial):
value = value.func # type: ignore[unreachable]
value = unwrap_partial(value)

return iscoroutinefunction(value) or (
callable(value) and iscoroutinefunction(value.__call__) # type: ignore[operator]
Expand Down
26 changes: 8 additions & 18 deletions litestar/utils/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def __init__(self, fn: Callable[P, T]) -> None:
self._parsed_signature: ParsedSignature | EmptyType = Empty
self.is_method = ismethod(fn) or (callable(fn) and ismethod(fn.__call__)) # type: ignore
self.num_expected_args = len(getfullargspec(fn).args) - (1 if self.is_method else 0)
self.ref = Ref[Callable[..., Awaitable[T]]](
fn if is_async_callable(fn) else async_partial(fn) # pyright: ignore
)
self.ref = Ref[Callable[..., Awaitable[T]]](fn if is_async_callable(fn) else async_partial(fn)) # type: ignore

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""Proxy the wrapped function's call method.
Expand Down Expand Up @@ -84,25 +82,17 @@ def set_parsed_signature(self, namespace: dict[str, Any]) -> None:
self._parsed_signature = ParsedSignature.from_fn(unwrap_partial(self.ref.value), namespace)


def async_partial(fn: Callable) -> Callable:
class async_partial: # noqa: N801
"""Wrap a given sync function making it async.
In difference to the :func:`asyncio.run_sync` function, it allows for passing kwargs.
Args:
fn: A sync callable to wrap.
Returns:
A wrapper
In difference to the :func:`anyio.run_sync` function, it allows for passing kwargs.
"""

async def wrapper(*args: Any, **kwargs: Any) -> Any:
applied_kwarg = partial(fn, **kwargs)
return await run_sync(applied_kwarg, *args)
def __init__(self, fn: Callable[P, T]) -> None: # pyright: ignore
self.func = fn

# this allows us to unwrap the partial later, so it's an important "hack".
wrapper.func = fn # type: ignore
return wrapper
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # pyright: ignore
applied_kwarg = partial(self.func, **kwargs)
return await run_sync(applied_kwarg, *args) # pyright: ignore


class AsyncIteratorWrapper(Generic[T]):
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from functools import partial
from typing import Any, Generic
from typing import Any, Generic, TypeVar

import pytest
from typing_extensions import TypeVar

from litestar.utils.helpers import get_name, unique_name_for_scope, unwrap_partial

Expand Down

0 comments on commit 0039e99

Please sign in to comment.