Skip to content

Commit

Permalink
shaffle private imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Apr 17, 2023
1 parent e8c9565 commit d77e375
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 99 deletions.
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain Fastapi Dependency Injection System"""

__version__ = "1.0.2"
__version__ = "1.0.3"
99 changes: 12 additions & 87 deletions fast_depends/injector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import asyncio
import functools
import inspect
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from contextlib import AsyncExitStack, ExitStack
from copy import deepcopy
from typing import (
Any,
AsyncGenerator,
Callable,
ContextManager,
Dict,
List,
Mapping,
Expand All @@ -18,14 +12,22 @@
cast,
)

import anyio
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import ModelField

from fast_depends.construct import get_dependant
from fast_depends.model import Dependant
from fast_depends.types import AnyCallable, AnyDict, P
from fast_depends.types import AnyCallable, AnyDict
from fast_depends.utils import (
run_async,
is_coroutine_callable,
is_gen_callable,
is_async_gen_callable,
solve_generator_async,
solve_generator_sync,
)


T = TypeVar("T")

Expand Down Expand Up @@ -100,7 +102,7 @@ async def solve_dependencies_async(
)
else:
solved, sub_errors = use_sub_dependant.cast_response(
await run_async(dependant=use_sub_dependant, values=sub_values)
await run_async(use_sub_dependant.call, **sub_values)
)

if sub_errors:
Expand Down Expand Up @@ -206,83 +208,6 @@ def solve_dependencies_sync(
return params, errors, dependency_cache


def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)


def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)


def is_coroutine_callable(call: AnyCallable) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(call)


async def solve_generator_async(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call): # pragma: no branch
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)


def solve_generator_sync(
*, call: Callable[..., Any], stack: ExitStack, sub_values: Dict[str, Any]
) -> Any:
cm = contextmanager(call)(**sub_values)
return stack.enter_context(cm)


async def run_async(*, dependant: Dependant, values: AnyDict) -> Any:
assert dependant.call is not None, "dependant.call must be a function"
if asyncio.iscoroutinefunction(dependant.call):
return await dependant.call(**values)
else:
return await run_in_threadpool(dependant.call, **values)


async def run_in_threadpool(
func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if kwargs: # pragma: no cover
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)


@asynccontextmanager
async def contextmanager_in_threadpool(
cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
exit_limiter = anyio.CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, None, limiter=exit_limiter
)
)
if not ok: # pragma: no branch
raise e
else:
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)


def params_to_args(
required_params: Sequence[ModelField],
received_params: Mapping[str, Any],
Expand Down
2 changes: 2 additions & 0 deletions fast_depends/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, List, Optional, Tuple

from pydantic import create_model
from pydantic.error_wrappers import ErrorList
from pydantic.fields import ModelField

Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
self.path = path
# Save the cache key at creation to optimize performance
self.cache_key = (self.call,)
self.error_model = create_model(getattr(call, "__name__", str(call)))

def cast_response(self, response: Any) -> Tuple[Optional[Any], Optional[ErrorList]]:
if self.return_field is None:
Expand Down
21 changes: 12 additions & 9 deletions fast_depends/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ def Depends( # noqa: N802
return model.Depends(dependency=dependency, use_cache=use_cache)


def wrap_dependant(dependant: model.Dependant) -> model.Dependant:
return dependant


def inject(
func: Callable[P, T],
*,
dependency_overrides_provider: Optional[Any] = dependency_provider,
wrap_dependant: Callable[[model.Dependant], model.Dependant] = wrap_dependant,
) -> Callable[P, T]:
dependant = get_dependant(call=func, path=func.__name__)
error_model = create_model(func.__name__)

dependant = wrap_dependant(dependant)

if is_coroutine_callable(func) is True:
f = async_typed_wrapper
Expand All @@ -42,7 +48,6 @@ def inject(
partial(
f,
dependant=dependant,
error_model=error_model,
dependency_overrides_provider=dependency_overrides_provider,
)
)
Expand All @@ -51,7 +56,6 @@ def inject(
async def async_typed_wrapper(
*args: P.args,
dependant: model.Dependant,
error_model: BaseModel,
dependency_overrides_provider: Optional[Any],
**kwargs: P.kwargs,
) -> Any:
Expand All @@ -66,22 +70,21 @@ async def async_typed_wrapper(
)

if errors:
raise ValidationError(errors, error_model)
raise ValidationError(errors, dependant.error_model)

v, errors = dependant.cast_response(
await run_async(dependant=dependant, values=solved_result)
await run_async(dependant.call, **solved_result)
)

if errors:
raise ValidationError(errors, error_model)
raise ValidationError(errors, dependant.error_model)

return v


def sync_typed_wrapper(
*args: P.args,
dependant: model.Dependant,
error_model: BaseModel,
dependency_overrides_provider: Optional[Any],
**kwargs: P.kwargs,
) -> Any:
Expand All @@ -96,11 +99,11 @@ def sync_typed_wrapper(
)

if errors:
raise ValidationError(errors, error_model)
raise ValidationError(errors, dependant.error_model)

v, errors = dependant.cast_response(dependant.call(**solved_result))

if errors:
raise ValidationError(errors, error_model)
raise ValidationError(errors, dependant.error_model)

return v
100 changes: 98 additions & 2 deletions fast_depends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
from typing import Sequence
import asyncio
import functools
import inspect
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
from typing import (
Any,
AsyncGenerator,
Callable,
ContextManager,
Dict,
Sequence,
TypeVar,
)

from fast_depends.types import AnyDict, P
import anyio

from fast_depends.types import AnyDict, P, AnyCallable


T = TypeVar("T")


def args_to_kwargs(
Expand All @@ -12,3 +29,82 @@ def args_to_kwargs(
unused = filter(lambda x: x not in kwargs, arguments)

return dict((*zip(unused, args), *kwargs.items()))



async def run_async(func: AnyCallable, *args: Any, **kwargs: AnyDict) -> Any:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return await run_in_threadpool(func, *args, **kwargs)


async def run_in_threadpool(
func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if kwargs: # pragma: no cover
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)



def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)


def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)


def is_coroutine_callable(call: AnyCallable) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
call = getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(call)


async def solve_generator_async(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call): # pragma: no branch
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)


def solve_generator_sync(
*, call: Callable[..., Any], stack: ExitStack, sub_values: Dict[str, Any]
) -> Any:
cm = contextmanager(call)(**sub_values)
return stack.enter_context(cm)


@asynccontextmanager
async def contextmanager_in_threadpool(
cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
exit_limiter = anyio.CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, None, limiter=exit_limiter
)
)
if not ok: # pragma: no branch
raise e
else:
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)

0 comments on commit d77e375

Please sign in to comment.