Skip to content

Commit

Permalink
feat: annotated_types with CustomField support (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik authored Aug 11, 2024
1 parent 3d7e6f8 commit e87f39a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 34 deletions.
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.4.7"
__version__ = "2.4.8"
35 changes: 21 additions & 14 deletions fast_depends/core/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ def build_call_model(
elif get_origin(param.annotation) is Annotated:
annotated_args = get_args(param.annotation)
type_annotation = annotated_args[0]
custom_annotations = [
arg for arg in annotated_args[1:] if isinstance(arg, CUSTOM_ANNOTATIONS)
]

custom_annotations = []
regular_annotations = []
for arg in annotated_args[1:]:
if isinstance(arg, CUSTOM_ANNOTATIONS):
custom_annotations.append(arg)
else:
regular_annotations.append(arg)

assert (
len(custom_annotations) <= 1
Expand All @@ -102,7 +107,10 @@ def build_call_model(
else: # pragma: no cover
raise AssertionError("unreachable")

annotation = type_annotation
if regular_annotations:
annotation = param.annotation
else:
annotation = type_annotation
else:
annotation = param.annotation
else:
Expand All @@ -113,23 +121,22 @@ def build_call_model(
default = ()
elif param_name == "kwargs":
default = {}
elif param.default is inspect.Parameter.empty:
default = Ellipsis
else:
default = param.default

if isinstance(default, Depends):
assert (
not dep
), "You can not use `Depends` with `Annotated` and default both"
dep = default
dep, default = default, Ellipsis

elif isinstance(default, CustomField):
assert (
not custom
), "You can not use `CustomField` with `Annotated` and default both"
custom = default

elif default is inspect.Parameter.empty:
class_fields[param_name] = (annotation, ...)
custom, default = default, Ellipsis

else:
class_fields[param_name] = (annotation, default)
Expand All @@ -147,7 +154,7 @@ def build_call_model(
)

if dep.cast is True:
class_fields[param_name] = (annotation, ...)
class_fields[param_name] = (annotation, Ellipsis)

keyword_args.append(param_name)

Expand All @@ -163,7 +170,7 @@ def build_call_model(
annotation = Any

if custom.required:
class_fields[param_name] = (annotation, ...)
class_fields[param_name] = (annotation, default)

else:
class_fields[param_name] = class_fields.get(param_name, (Optional[annotation], None))
Expand All @@ -184,10 +191,10 @@ def build_call_model(

response_model: Optional[Type[ResponseModel[T]]] = None
if cast and return_annotation and return_annotation is not inspect.Parameter.empty:
response_model = create_model(
response_model = create_model( # type: ignore[call-overload]
"ResponseModel",
__config__=get_config_base(pydantic_config), # type: ignore[assignment]
response=(return_annotation, ...),
__config__=get_config_base(pydantic_config),
response=(return_annotation, Ellipsis),
)

return CallModel(
Expand Down
38 changes: 19 additions & 19 deletions tests/async/test_depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def another_dep_func(b: int, a: int = 3) -> dict: # pragma: no cover

@inject
async def some_func(
b: int, c=Depends(dep_func), d=Depends(another_dep_func)
b: int, c=Depends(dep_func), d=Depends(another_dep_func)
) -> int: # pragma: no cover
assert c is None
return b
Expand Down Expand Up @@ -108,17 +108,17 @@ async def dep_func(a):

@inject
async def some_func(
a: int,
b: int,
c: "Annotated[int, Depends(dep_func)]",
a: int,
b: int,
c: "Annotated[int, Depends(dep_func)]",
) -> float:
assert isinstance(c, int)
return a + b + c

@inject
async def another_func(
a: int,
c: "Annotated[int, Depends(dep_func)]",
a: int,
c: "Annotated[int, Depends(dep_func)]",
):
return a + c

Expand All @@ -133,17 +133,17 @@ async def adep_func(a):

@inject
async def some_func(
a: int,
b: int,
c: Annotated["float", Depends(adep_func)],
a: int,
b: int,
c: Annotated["float", Depends(adep_func)],
) -> float:
assert isinstance(c, float)
return a + b + c

@inject
async def another_func(
a: int,
c: Annotated["float", Depends(adep_func)],
a: int,
c: Annotated["float", Depends(adep_func)],
):
return a + c

Expand Down Expand Up @@ -184,8 +184,8 @@ async def dep_func(a=Depends(nested_dep_func, use_cache=False)):

@inject
async def some_func(
a=Depends(dep_func, use_cache=False),
b=Depends(nested_dep_func, use_cache=False),
a=Depends(dep_func, use_cache=False),
b=Depends(nested_dep_func, use_cache=False),
):
assert a is b
return a + b
Expand Down Expand Up @@ -361,9 +361,9 @@ async def get_logger() -> logging.Logger:

@inject
async def some_func(
b,
a: A = Depends(dep, cast=False),
logger: logging.Logger = Depends(get_logger, cast=False),
b,
a: A = Depends(dep, cast=False),
logger: logging.Logger = Depends(get_logger, cast=False),
):
assert a.a == 1
assert logger
Expand All @@ -386,9 +386,9 @@ async def get_logger() -> logging.Logger:

@inject(cast=False)
async def some_func(
b: str,
a: A = Depends(dep),
logger: logging.Logger = Depends(get_logger),
b: str,
a: A = Depends(dep),
logger: logging.Logger = Depends(get_logger),
) -> str:
assert a.a == 1
assert logger
Expand Down
16 changes: 16 additions & 0 deletions tests/library/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import anyio
import pydantic
import pytest
from annotated_types import Ge
from typing_extensions import Annotated

from fast_depends import Depends, inject
from fast_depends.library import CustomField
from tests.marks import pydanticV2


class Header(CustomField):
Expand Down Expand Up @@ -128,6 +130,20 @@ def sync_catch(key: Annotated[int, Header()]):
assert sync_catch(headers={"key": "1"}) == 1


@pydanticV2
def test_annotated_header_with_meta():
@inject
def sync_catch(key: Annotated[int, Header(), Ge(3)] = 3): # noqa: B008
return key

with pytest.raises(pydantic.ValidationError):
sync_catch(headers={"key": "2"})

assert sync_catch(headers={"key": "4"}) == 4

assert sync_catch(headers={}) == 3


def test_header_required():
@inject
def sync_catch(key2=Header()): # pragma: no cover # noqa: B008
Expand Down

0 comments on commit e87f39a

Please sign in to comment.