diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 316b3a9..1f5896d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ If you already cloned the repository and you know that you need to deep dive in the code, here are some guidelines to set up your environment. -### Virtual environment with `venv` +## Virtual environment with `venv` You can create a virtual environment in a directory using Python's `venv` module: @@ -12,7 +12,7 @@ python -m venv venv That will create a directory `./venv/` with the Python binaries and then you will be able to install packages for that isolated environment. -### Activate the environment +## Activate the environment Activate the new environment with: @@ -20,12 +20,13 @@ Activate the new environment with: source ./venv/bin/activate ``` -Make sure you have the latest pip version on your virtual environment to +Make sure you have the latest pip version on your virtual environment to + ```bash python -m pip install --upgrade pip ``` -### pip +## pip After activating the environment as described above: @@ -35,7 +36,7 @@ pip install -e ."[dev]" It will install all the dependencies and your local FastDepends in your local environment. -#### Using your local FastDepends +### Using your local FastDepends If you create a Python file that imports and uses FastDepends, and run it with the Python from your local environment, it will use your local FastDepends source code. @@ -43,9 +44,10 @@ And if you update that local FastDepends source code, as it is installed with `- That way, you don't have to "install" your local version to be able to test every change. -### Tests +## Tests + +### Pytests -#### Pytests To run tests with your current FastDepends application and Python environment use: ```bash @@ -55,30 +57,3 @@ bash ./scripts/test.sh # with coverage output bash ./scripts/test-cov.sh ``` - -#### Hatch - -If you are using **hatch** use following environments to run tests: - -##### **TEST** - -Run tests at all python 3.8-3.12 versions. - -All python versions should be avalilable at your system. - -```bash -# Run test at all python 3.8-3.12 versions -hatch run test:run -``` - -##### **TEST-LAST** - -Run tests at python 3.12 version. - -```bash -# Run tests at python 3.12 -hatch run test-last:run - -# Run all tests at python 3.12 and show coverage -hatch run test-last:cov -``` \ No newline at end of file diff --git a/docs/docs/alternatives.md b/docs/docs/alternatives.md index e3f3728..c1e510a 100644 --- a/docs/docs/alternatives.md +++ b/docs/docs/alternatives.md @@ -19,7 +19,6 @@ Key features: * Composable: decoupled internal APIs give you the flixibility to customize wiring, execution and binding. * Performant: `di` can execute dependencies in parallel and cache results ins scopes. - ## [Dependency Injector](https://python-dependency-injector.etc-labs.org) Dependency Injector is a dependency injection framework for Python. @@ -37,4 +36,4 @@ Key features: * Asynchronous * Typing * Perfomance -* Maturity \ No newline at end of file +* Maturity diff --git a/docs/docs/contributing.md b/docs/docs/contributing.md index 316b3a9..1f5896d 100644 --- a/docs/docs/contributing.md +++ b/docs/docs/contributing.md @@ -2,7 +2,7 @@ If you already cloned the repository and you know that you need to deep dive in the code, here are some guidelines to set up your environment. -### Virtual environment with `venv` +## Virtual environment with `venv` You can create a virtual environment in a directory using Python's `venv` module: @@ -12,7 +12,7 @@ python -m venv venv That will create a directory `./venv/` with the Python binaries and then you will be able to install packages for that isolated environment. -### Activate the environment +## Activate the environment Activate the new environment with: @@ -20,12 +20,13 @@ Activate the new environment with: source ./venv/bin/activate ``` -Make sure you have the latest pip version on your virtual environment to +Make sure you have the latest pip version on your virtual environment to + ```bash python -m pip install --upgrade pip ``` -### pip +## pip After activating the environment as described above: @@ -35,7 +36,7 @@ pip install -e ."[dev]" It will install all the dependencies and your local FastDepends in your local environment. -#### Using your local FastDepends +### Using your local FastDepends If you create a Python file that imports and uses FastDepends, and run it with the Python from your local environment, it will use your local FastDepends source code. @@ -43,9 +44,10 @@ And if you update that local FastDepends source code, as it is installed with `- That way, you don't have to "install" your local version to be able to test every change. -### Tests +## Tests + +### Pytests -#### Pytests To run tests with your current FastDepends application and Python environment use: ```bash @@ -55,30 +57,3 @@ bash ./scripts/test.sh # with coverage output bash ./scripts/test-cov.sh ``` - -#### Hatch - -If you are using **hatch** use following environments to run tests: - -##### **TEST** - -Run tests at all python 3.8-3.12 versions. - -All python versions should be avalilable at your system. - -```bash -# Run test at all python 3.8-3.12 versions -hatch run test:run -``` - -##### **TEST-LAST** - -Run tests at python 3.12 version. - -```bash -# Run tests at python 3.12 -hatch run test-last:run - -# Run all tests at python 3.12 and show coverage -hatch run test-last:cov -``` \ No newline at end of file diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 5c85b3a..c43efe5 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.2.8" +__version__ = "2.3.0" diff --git a/fast_depends/_compat.py b/fast_depends/_compat.py index dd0fdaf..cc8f44d 100644 --- a/fast_depends/_compat.py +++ b/fast_depends/_compat.py @@ -1,3 +1,5 @@ +import sys +from importlib.metadata import version as get_version from typing import Any, Dict, Optional, Type from pydantic import BaseModel, create_model @@ -12,6 +14,7 @@ "get_config_base", "get_model_fields", "ConfigDict", + "ExceptionGroup", ) @@ -55,3 +58,14 @@ class CreateBaseModel(BaseModel): # type: ignore[no-redef] class Config: arbitrary_types_allowed = True + + +ANYIO_V3 = get_version("anyio").startswith("3.") + +if ANYIO_V3: + from anyio import ExceptionGroup as ExceptionGroup +else: + if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup as ExceptionGroup + else: + ExceptionGroup = ExceptionGroup diff --git a/fast_depends/core/build.py b/fast_depends/core/build.py index 1c8f6b9..57c862e 100644 --- a/fast_depends/core/build.py +++ b/fast_depends/core/build.py @@ -16,7 +16,6 @@ Annotated, ParamSpec, TypeVar, - assert_never, get_args, get_origin, ) @@ -25,7 +24,12 @@ from fast_depends.core.model import CallModel, ResponseModel from fast_depends.dependencies import Depends from fast_depends.library import CustomField -from fast_depends.utils import get_typed_signature, is_coroutine_callable +from fast_depends.utils import ( + get_typed_signature, + is_async_gen_callable, + is_coroutine_callable, + is_gen_callable, +) CUSTOM_ANNOTATIONS = (Depends, CustomField) @@ -57,6 +61,12 @@ def build_call_model( ), f"You cannot use async dependency `{name}` at sync main" typed_params, return_annotation = get_typed_signature(call) + if ( + (is_call_generator := is_gen_callable(call) or + is_async_gen_callable(call)) and + (return_args := get_args(return_annotation)) + ): + return_annotation = return_args[0] class_fields: Dict[str, Tuple[Any, Any]] = {} dependencies: Dict[str, "CallModel[..., Any]"] = {} @@ -89,7 +99,7 @@ def build_call_model( elif isinstance(next_custom, CustomField): custom = next_custom else: # pragma: no cover - assert_never() + raise AssertionError("unreachable") annotation = type_annotation else: @@ -185,6 +195,7 @@ def build_call_model( cast=cast, use_cache=use_cache, is_async=is_call_async, + is_generator=is_call_generator, dependencies=dependencies, custom_fields=custom_fields, positional_args=positional_args, diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py index e2274dc..a9ddb76 100644 --- a/fast_depends/core/model.py +++ b/fast_depends/core/model.py @@ -1,4 +1,6 @@ +from collections import namedtuple from contextlib import AsyncExitStack, ExitStack +from functools import partial from inspect import Parameter, unwrap from typing import ( Any, @@ -10,14 +12,16 @@ Iterable, List, Optional, + Sequence, Tuple, Type, Union, ) -from typing_extensions import ParamSpec, TypeVar, assert_never +import anyio +from typing_extensions import ParamSpec, TypeVar -from fast_depends._compat import BaseModel, FieldInfo, get_model_fields +from fast_depends._compat import BaseModel, ExceptionGroup, FieldInfo, get_model_fields from fast_depends.library import CustomField from fast_depends.utils import ( async_map, @@ -33,6 +37,11 @@ T = TypeVar("T") +PriorityPair = namedtuple( + "PriorityPair", ("call", "dependencies_number", "dependencies_names") +) + + class ResponseModel(BaseModel, Generic[T]): response: T @@ -52,6 +61,7 @@ class CallModel(Generic[P, T]): dependencies: Dict[str, "CallModel[..., Any]"] extra_dependencies: Iterable["CallModel[..., Any]"] + sorted_dependencies: Tuple[Tuple["CallModel[..., Any]", int], ...] custom_fields: Dict[str, CustomField] keyword_args: Tuple[str, ...] positional_args: Tuple[str, ...] @@ -72,6 +82,7 @@ class CallModel(Generic[P, T]): "positional_args", "dependencies", "extra_dependencies", + "sorted_dependencies", "custom_fields", "use_cache", "cast", @@ -96,6 +107,38 @@ def flat_params(self) -> Dict[str, FieldInfo]: params.update(d.flat_params) return params + @property + def flat_dependencies( + self, + ) -> Dict[ + Callable[..., Any], + Tuple[ + "CallModel[..., Any]", + Tuple[Callable[..., Any], ...], + ], + ]: + flat: Dict[ + Callable[..., Any], + Tuple[ + "CallModel[..., Any]", + Tuple[Callable[..., Any], ...], + ], + ] = {} + + for i in (*self.dependencies.values(), *self.extra_dependencies): + flat.update( + { + i.call: ( + i, + tuple(j.call for j in i.dependencies.values()), + ) + } + ) + + flat.update(i.flat_dependencies) + + return flat + def __init__( self, /, @@ -108,6 +151,7 @@ def __init__( use_cache: bool = True, cast: bool = True, is_async: bool = False, + is_generator: bool = False, dependencies: Optional[Dict[str, "CallModel[..., Any]"]] = None, extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None, keyword_args: Optional[List[str]] = None, @@ -121,7 +165,7 @@ def __init__( fields: Dict[str, FieldInfo] = get_model_fields(model) self.dependencies = dependencies or {} - self.extra_dependencies = extra_dependencies or [] + self.extra_dependencies = extra_dependencies or () self.custom_fields = custom_fields or {} self.alias_arguments = tuple(f.alias or name for name, f in fields.items()) @@ -135,16 +179,25 @@ def __init__( self.use_cache = use_cache self.cast = cast self.is_async = ( - is_async or is_coroutine_callable(call) or is_async_gen_callable(self.call) + is_async or is_coroutine_callable(call) or is_async_gen_callable(call) + ) + self.is_generator = ( + is_generator or is_gen_callable(call) or is_async_gen_callable(call) ) - self.is_generator = is_gen_callable(self.call) or is_async_gen_callable( - self.call + + sorted_dep: List["CallModel[..., Any]"] = [] + flat = self.flat_dependencies + for calls in flat.values(): + _sort_dep(sorted_dep, calls, flat) + + self.sorted_dependencies = tuple( + (i, len(i.sorted_dependencies)) for i in sorted_dep if i.use_cache ) def _solve( self, /, - *args: P.args, + *args: Tuple[Any, ...], cache_dependencies: Dict[ Union[ Callable[P, T], @@ -164,15 +217,12 @@ def _solve( ], ] ] = None, - **kwargs: P.kwargs, + **kwargs: Dict[str, Any], ) -> Generator[ Tuple[ - Iterable[Any], + Sequence[Any], Dict[str, Any], - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], + Callable[..., Any], ], Any, T, @@ -182,13 +232,14 @@ def _solve( assert self.is_async or not is_coroutine_callable( call ), f"You cannot use async dependency `{self.call_name}` at sync main" + else: call = self.call - if self.use_cache and self.call in cache_dependencies: - return cache_dependencies[self.call] + if self.use_cache and call in cache_dependencies: + return cache_dependencies[call] - kw = {} + kw: Dict[str, Any] = {} for arg in self.keyword_args: if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty: @@ -220,8 +271,7 @@ def _solve( solved_kw: Dict[str, Any] solved_kw = yield (), kw, call - args_: Iterable[Any] - + args_: Sequence[Any] if self.cast: casted_model = self.model(**solved_kw) @@ -229,24 +279,22 @@ def _solve( arg: getattr(casted_model, arg, solved_kw.get(arg)) for arg in keyword_args } - kwargs_.update(getattr(casted_model, "kwargs", solved_kw.get("kwargs", {}))) + kwargs_.update(getattr(casted_model, "kwargs", {})) if has_args: args_ = [ getattr(casted_model, arg, solved_kw.get(arg)) for arg in self.positional_args ] - args_.extend(getattr(casted_model, "args", solved_kw.get("args", ()))) + args_.extend(getattr(casted_model, "args", ())) else: args_ = () else: kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args} - kwargs_.update(solved_kw.get("kwargs", {})) if has_args: - args_ = [solved_kw.get(arg) for arg in self.positional_args] - args_.extend(solved_kw.get("args", ())) + args_ = tuple(map(solved_kw.get, self.positional_args)) else: args_ = () @@ -270,7 +318,7 @@ def _cast_response(self, /, value: Any) -> Any: def solve( self, /, - *args: P.args, + *args: Tuple[Any, ...], stack: ExitStack, cache_dependencies: Dict[ Union[ @@ -292,7 +340,7 @@ def solve( ] ] = None, nested: bool = False, - **kwargs: P.kwargs, + **kwargs: Dict[str, Any], ) -> T: cast_gen = self._solve( *args, @@ -306,6 +354,17 @@ def solve( cached_value: T = e.value return cached_value + # Heat cache and solve extra dependencies + for dep, _ in self.sorted_dependencies: + dep.solve( + stack=stack, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + nested=True, + **kwargs, + ) + + # Always get from cache for dep in self.extra_dependencies: dep.solve( stack=stack, @@ -325,7 +384,10 @@ def solve( ) for custom in self.custom_fields.values(): - kwargs = custom.use(**kwargs) + if custom.field: + custom.use_field(kwargs) + else: + kwargs = custom.use(**kwargs) final_args, final_kwargs, call = cast_gen.send(kwargs) @@ -351,12 +413,12 @@ def solve( else: return map(self._cast_response, value) # type: ignore[no-any-return, call-overload] - assert_never(response) # pragma: no cover + raise AssertionError("unreachable") async def asolve( self, /, - *args: P.args, + *args: Tuple[Any, ...], stack: AsyncExitStack, cache_dependencies: Dict[ Union[ @@ -378,7 +440,7 @@ async def asolve( ] ] = None, nested: bool = False, - **kwargs: P.kwargs, + **kwargs: Dict[str, Any], ) -> T: cast_gen = self._solve( *args, @@ -392,6 +454,31 @@ async def asolve( cached_value: T = e.value return cached_value + # Heat cache and solve extra dependencies + dep_to_solve: List[Callable[..., Awaitable[Any]]] = [] + try: + async with anyio.create_task_group() as tg: + for dep, subdep in self.sorted_dependencies: + solve = partial( + dep.asolve, + stack=stack, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + nested=True, + **kwargs, + ) + if not subdep: + tg.start_soon(solve) + else: + dep_to_solve.append(solve) + except ExceptionGroup as exgr: + for ex in exgr.exceptions: + raise ex from None + + for i in dep_to_solve: + await i() + + # Always get from cache for dep in self.extra_dependencies: await dep.asolve( stack=stack, @@ -410,8 +497,22 @@ async def asolve( **kwargs, ) - for custom in self.custom_fields.values(): - kwargs = await run_async(custom.use, **kwargs) + custom_to_solve: List[CustomField] = [] + + try: + async with anyio.create_task_group() as tg: + for custom in self.custom_fields.values(): + if custom.field: + tg.start_soon(run_async, custom.use_field, kwargs) + else: + custom_to_solve.append(custom) + + except ExceptionGroup as exgr: + for ex in exgr.exceptions: + raise ex from None + + for j in custom_to_solve: + kwargs = await run_async(j.use, **kwargs) final_args, final_kwargs, call = cast_gen.send(kwargs) @@ -436,4 +537,37 @@ async def asolve( else: return async_map(self._cast_response, value) # type: ignore[return-value, arg-type] - assert_never(response) # pragma: no cover + raise AssertionError("unreachable") + + +def _sort_dep( + collector: List["CallModel[..., Any]"], + items: Tuple[ + "CallModel[..., Any]", + Tuple[Callable[..., Any], ...], + ], + flat: Dict[ + Callable[..., Any], + Tuple[ + "CallModel[..., Any]", + Tuple[Callable[..., Any], ...], + ], + ], +) -> None: + model, calls = items + + if model in collector: + return + + if not calls: + position = -1 + + else: + for i in calls: + sub_model, _ = flat[i] + if sub_model not in collector: + _sort_dep(collector, flat[i], flat) + + position = max(collector.index(flat[i][0]) for i in calls) + + collector.insert(position + 1, model) diff --git a/fast_depends/library/model.py b/fast_depends/library/model.py index 13a721f..8b18ea4 100644 --- a/fast_depends/library/model.py +++ b/fast_depends/library/model.py @@ -9,6 +9,13 @@ class CustomField(ABC): cast: bool required: bool + __slots__ = ( + "cast", + "param_name", + "required", + "field", + ) + def __init__( self, *, @@ -18,11 +25,15 @@ def __init__( self.cast = cast self.param_name = None self.required = required + self.field = False def set_param_name(self: Cls, name: str) -> Cls: self.param_name = name return self - def use(self, /, **kwargs: Dict[str, Any]) -> Dict[str, Any]: + def use(self, /, **kwargs: Any) -> Dict[str, Any]: assert self.param_name, "You should specify `param_name` before using" return kwargs + + def use_field(self, kwargs: Dict[str, Any]) -> None: + raise NotImplementedError("You should implement `use_field` method.") diff --git a/pyproject.toml b/pyproject.toml index 39a2a21..0553d04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,34 +86,6 @@ exclude = [ "/docs", ] -[tool.hatch.envs.default] -python = "3.12" -skip-install = false -features = [ - "dev", -] - -[tool.hatch.envs.test] -features = [ - "test", -] - -[tool.hatch.envs.test.scripts] -run = "pytest -q" - -[[tool.hatch.envs.test.matrix]] -python = ["38", "39", "310", "311", "312"] - -[tool.hatch.envs.test-last] -python = "3.12" -features = [ - "test", -] - -[tool.hatch.envs.test-last.scripts] -run = "pytest -q" -cov = "bash ./scripts/test-cov.sh -v" - [tool.mypy] strict = true ignore_missing_imports = true diff --git a/tests/async/test_cast.py b/tests/async/test_cast.py index 2fb958e..f6ac5ae 100644 --- a/tests/async/test_cast.py +++ b/tests/async/test_cast.py @@ -1,10 +1,12 @@ -from typing import Dict, Tuple +from typing import Dict, Iterator, Tuple import pytest +from annotated_types import Ge from pydantic import BaseModel, Field, ValidationError from typing_extensions import Annotated from fast_depends import inject +from tests.marks import pydanticV2 @pytest.mark.anyio @@ -192,16 +194,26 @@ async def simple_func(a: str) -> int: @pytest.mark.anyio -async def test_args_kwargs_without_cast(): - @inject(cast=False) - async def simple_func( - a: int, - *args: Tuple[float, ...], - b: int, - **kwargs: Dict[str, int], - ): - return a, args, b, kwargs +async def test_generator_iterator_type(): + @inject + async def simple_func(a: str) -> Iterator[int]: + for _ in range(2): + yield a - assert (1.0, (2.0, 3), 3.0, {"key": 1.0}) == await simple_func( - 1.0, 2.0, 3, b=3.0, key=1.0 - ) + async for i in simple_func("1"): + assert i == 1 + + +@pytest.mark.anyio +@pydanticV2 +async def test_multi_annotated(): + from pydantic.functional_validators import AfterValidator + + @inject() + async def f(a: Annotated[int, Ge(10), AfterValidator(lambda x: x + 10)]) -> int: + return a + + with pytest.raises(ValidationError): + await f(1) + + assert await f(10) == 20 diff --git a/tests/async/test_depends.py b/tests/async/test_depends.py index dd3a748..c85ccc1 100644 --- a/tests/async/test_depends.py +++ b/tests/async/test_depends.py @@ -108,6 +108,29 @@ async def some_func(a=Depends(dep_func), b=Depends(nested_dep_func)): mock.assert_called_once() +@pytest.mark.anyio +async def test_not_cache(): + mock = Mock() + + async def nested_dep_func(): + mock() + return 1000 + + async def dep_func(a=Depends(nested_dep_func, use_cache=False)): + return a + + @inject + async def some_func( + a=Depends(dep_func, use_cache=False), + b=Depends(nested_dep_func, use_cache=False), + ): + assert a is b + return a + b + + assert await some_func() + assert mock.call_count == 2 + + @pytest.mark.anyio async def test_yield(): mock = Mock() @@ -277,11 +300,9 @@ async def get_logger() -> logging.Logger: async def some_func( b, a: A = Depends(dep, cast=False), - c: str = Depends(lambda: 1, cast=False), logger: logging.Logger = Depends(get_logger, cast=False), ): assert a.a == 1 - assert c == 1 assert logger return b diff --git a/tests/library/test_custom.py b/tests/library/test_custom.py index 5e68c66..043243a 100644 --- a/tests/library/test_custom.py +++ b/tests/library/test_custom.py @@ -1,6 +1,8 @@ import logging -from typing import Any, Callable +from time import monotonic_ns +from typing import Any, Dict +import anyio import pydantic import pytest from typing_extensions import Annotated @@ -10,18 +12,39 @@ class Header(CustomField): - def use(self, /, **kwargs: Callable[..., Any]) -> Callable[..., Any]: + def use(self, /, **kwargs: Any) -> Dict[str, Any]: kwargs = super().use(**kwargs) - if kwargs.get("headers", {}).get(self.param_name): - kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name) + if v := kwargs.get("headers", {}).get(self.param_name): + kwargs[self.param_name] = v return kwargs +class FieldHeader(Header): + def __init__(self, *, cast: bool = True, required: bool = True) -> None: + super().__init__(cast=cast, required=required) + self.field = True + + def use_field(self, kwargs: Any) -> None: + if v := kwargs.get("headers", {}).get(self.param_name): + kwargs[self.param_name] = v + + class AsyncHeader(Header): - async def use(self, **kwargs: Callable[..., Any]) -> Callable[..., Any]: + async def use(self, /, **kwargs: Any) -> Dict[str, Any]: return super().use(**kwargs) +class AsyncFieldHeader(Header): + def __init__(self, *, cast: bool = True, required: bool = True) -> None: + super().__init__(cast=cast, required=required) + self.field = True + + async def use_field(self, kwargs: Any) -> None: + await anyio.sleep(0.1) + if v := kwargs.get("headers", {}).get(self.param_name): + kwargs[self.param_name] = v + + def test_header(): @inject def sync_catch(key: int = Header()): # noqa: B008 @@ -60,10 +83,33 @@ def sync_catch(key: str = Header(), key2: int = Header()): # noqa: B008 @pytest.mark.anyio async def test_async_header_async(): @inject - async def async_catch(key: float = AsyncHeader()): # noqa: B008 - return key + async def async_catch( # noqa: B008 + key: float = AsyncHeader(), key2: int = AsyncHeader() + ): + return key, key2 + + assert (await async_catch(headers={"key": "1", "key2": 1})) == (1.0, 1) - assert (await async_catch(headers={"key": "1"})) == 1.0 + +def test_sync_field_header(): + @inject + def sync_catch(key: float = FieldHeader(), key2: int = FieldHeader()): # noqa: B008 + return key, key2 + + assert sync_catch(headers={"key": "1", "key2": 1}) == (1.0, 1) + + +@pytest.mark.anyio +async def test_async_field_header(): + @inject + async def async_catch( # noqa: B008 + key: float = AsyncFieldHeader(), key2: int = AsyncFieldHeader() + ): + return key, key2 + + start = monotonic_ns() + assert (await async_catch(headers={"key": "1", "key2": 1})) == (1.0, 1) + assert (monotonic_ns() - start) / 10**9 < 0.2 def test_async_header_sync(): diff --git a/tests/marks.py b/tests/marks.py new file mode 100644 index 0000000..7e4dcaa --- /dev/null +++ b/tests/marks.py @@ -0,0 +1,7 @@ +import pytest + +from fast_depends._compat import PYDANTIC_V2 + +pydanticV1 = pytest.mark.skipif(PYDANTIC_V2, reason="requires PydanticV2") # noqa: N816 + +pydanticV2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") # noqa: N816 diff --git a/tests/sync/test_cast.py b/tests/sync/test_cast.py index e2d8364..b50b133 100644 --- a/tests/sync/test_cast.py +++ b/tests/sync/test_cast.py @@ -1,10 +1,12 @@ -from typing import Dict, Tuple +from typing import Dict, Iterator, Tuple import pytest +from annotated_types import Ge from pydantic import BaseModel, Field, ValidationError from typing_extensions import Annotated from fast_depends import inject +from tests.marks import pydanticV2 def test_not_annotated(): @@ -186,16 +188,25 @@ def simple_func(a: str) -> int: assert i == 1 -def test_args_kwargs_without_cast(): - @inject(cast=False) - def simple_func( - a: int, - *args: Tuple[float, ...], - b: int, - **kwargs: Dict[str, int], - ): - return a, args, b, kwargs +def test_generator_iterator_type(): + @inject + def simple_func(a: str) -> Iterator[int]: + for _ in range(2): + yield a - assert (1.0, (2.0, 3), 3.0, {"key": 1.0}) == simple_func( - 1.0, 2.0, 3, b=3.0, key=1.0 - ) + for i in simple_func("1"): + assert i == 1 + + +@pydanticV2 +def test_multi_annotated(): + from pydantic.functional_validators import AfterValidator + + @inject() + def f(a: Annotated[int, Ge(10), AfterValidator(lambda x: x + 10)]) -> int: + return a + + with pytest.raises(ValidationError): + f(1) + + assert f(10) == 20 diff --git a/tests/sync/test_depends.py b/tests/sync/test_depends.py index fe63ac6..f1604b3 100644 --- a/tests/sync/test_depends.py +++ b/tests/sync/test_depends.py @@ -71,7 +71,7 @@ def another_func(a: int, c: D): assert another_func("3") == 6.0 -def test_cash(): +def test_cache(): mock = Mock() def nested_dep_func(): @@ -82,7 +82,10 @@ def dep_func(a=Depends(nested_dep_func)): return a @inject - def some_func(a=Depends(dep_func), b=Depends(nested_dep_func)): + def some_func( + a=Depends(dep_func), + b=Depends(nested_dep_func), + ): assert a is b return a + b @@ -90,6 +93,28 @@ def some_func(a=Depends(dep_func), b=Depends(nested_dep_func)): mock.assert_called_once() +def test_not_cache(): + mock = Mock() + + def nested_dep_func(): + mock() + return 1000 + + def dep_func(a=Depends(nested_dep_func, use_cache=False)): + return a + + @inject + def some_func( + a=Depends(dep_func, use_cache=False), + b=Depends(nested_dep_func, use_cache=False), + ): + assert a is b + return a + b + + some_func() + assert mock.call_count == 2 + + def test_yield(): mock = Mock() @@ -154,11 +179,9 @@ def get_logger() -> logging.Logger: def some_func( b, a: A = Depends(dep, cast=False), - c: str = Depends(lambda: 1, cast=False), logger: logging.Logger = Depends(get_logger, cast=False), ): assert a.a == 1 - assert c == 1 assert logger return b @@ -254,7 +277,7 @@ def dep(a): return a @inject - def func(a=Depends(partial(dep, 10))): # noqa D008 + def func(a=Depends(partial(dep, 10))): return a assert func() == 10 diff --git a/tests/test_overrides.py b/tests/test_overrides.py index 7abdbd7..da3f899 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -67,23 +67,6 @@ def func(d=Depends(base_dep)): assert func() == 1 -def test_override_context_with_yield(provider): - def base_dep(): - yield 1 - - def override_dep(): - yield 2 - - @inject - def func(d=Depends(base_dep)): - return d - - with provider.scope(base_dep, override_dep): - assert func() == 2 - - assert func() == 1 - - def test_sync_by_async_override(provider): def base_dep(): # pragma: no cover return 1