Skip to content

Commit

Permalink
fix: Allow to use */** arguments with non-standard names
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenay committed Aug 27, 2024
1 parent abb0e1e commit 97a1d97
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
12 changes: 9 additions & 3 deletions fast_depends/core/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def build_call_model(
custom_fields: Dict[str, CustomField] = {}
positional_args: List[str] = []
keyword_args: List[str] = []
var_positional_arg: Optional[str] = None
var_keyword_arg: Optional[str] = None

for param_name, param in typed_params.parameters.items():
dep: Optional[Depends] = None
Expand Down Expand Up @@ -117,10 +119,12 @@ def build_call_model(
annotation = param.annotation

default: Any
if param_name == "args":
if param.kind == inspect.Parameter.VAR_POSITIONAL:
default = ()
elif param_name == "kwargs":
var_positional_arg = param_name
elif param.kind == inspect.Parameter.VAR_KEYWORD:
default = {}
var_keyword_arg = param_name
elif param.default is inspect.Parameter.empty:
default = Ellipsis
else:
Expand Down Expand Up @@ -180,7 +184,7 @@ def build_call_model(
else:
if param.kind is param.KEYWORD_ONLY:
keyword_args.append(param_name)
elif param_name not in ("args", "kwargs"):
elif param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
positional_args.append(param_name)

func_model = create_model( # type: ignore[call-overload]
Expand Down Expand Up @@ -210,6 +214,8 @@ def build_call_model(
custom_fields=custom_fields,
positional_args=positional_args,
keyword_args=keyword_args,
var_positional_arg=var_positional_arg,
var_keyword_arg=var_keyword_arg,
extra_dependencies=[
build_call_model(
d.dependency,
Expand Down
25 changes: 17 additions & 8 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class CallModel(Generic[P, T]):
custom_fields: Dict[str, CustomField]
keyword_args: Tuple[str, ...]
positional_args: Tuple[str, ...]
var_positional_arg: Optional[str]
var_keyword_arg: Optional[str]

# Dependencies and custom fields
use_cache: bool
Expand All @@ -82,6 +84,8 @@ class CallModel(Generic[P, T]):
"alias_arguments",
"keyword_args",
"positional_args",
"var_positional_arg",
"var_keyword_arg",
"dependencies",
"extra_dependencies",
"sorted_dependencies",
Expand Down Expand Up @@ -152,6 +156,8 @@ def __init__(
extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None,
keyword_args: Optional[List[str]] = None,
positional_args: Optional[List[str]] = None,
var_positional_arg: Optional[str] = None,
var_keyword_arg: Optional[str] = None,
custom_fields: Optional[Dict[str, CustomField]] = None,
):
self.call = call
Expand All @@ -164,6 +170,8 @@ def __init__(

self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())
self.var_positional_arg = var_positional_arg
self.var_keyword_arg = var_keyword_arg
self.response_model = response_model
self.use_cache = use_cache
self.cast = cast
Expand Down Expand Up @@ -241,8 +249,8 @@ def _solve(
if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty:
kw[arg] = v

if "kwargs" in self.alias_arguments:
kw["kwargs"] = kwargs
if self.var_keyword_arg is not None:
kw[self.var_keyword_arg] = kwargs
else:
kw.update(kwargs)

Expand All @@ -253,8 +261,8 @@ def _solve(
break

keyword_args: Iterable[str]
if has_args := "args" in self.alias_arguments:
kw["args"] = args
if self.var_positional_arg is not None:
kw[self.var_positional_arg] = args
keyword_args = self.keyword_args

else:
Expand All @@ -281,21 +289,22 @@ def _solve(
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in keyword_args
}
kwargs_.update(getattr(casted_model, "kwargs", {}))
if self.var_keyword_arg:
kwargs_.update(getattr(casted_model, self.var_keyword_arg, {}))

if has_args:
if self.var_positional_arg is not None:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
args_.extend(getattr(casted_model, self.var_positional_arg, ()))
else:
args_ = ()

else:
kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args}

if has_args:
if self.var_positional_arg is not None:
args_ = tuple(map(solved_kw.get, self.positional_args))
else:
args_ = ()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,33 @@ def extra_func(n): ...

assert set(model.params.keys()) == {"a", "b"}
assert set(model.flat_params.keys()) == {"a", "b", "c", "m", "n"}


def test_args_kwargs_params():
def func1(m): ...

def func2(c, b=Depends(func1), d=CustomField()): # noqa: B008
...

def func3(b): ...

def default_var_names(a, *args, b, m=Depends(func2), k=Depends(func3), **kwargs):
return a, args, b, kwargs

def custom_var_names(a, *args_, b, m=Depends(func2), k=Depends(func3), **kwargs_):
return a, args_, b, kwargs_

def extra_func(n): ...

model1 = build_call_model(default_var_names, extra_dependencies=(Depends(extra_func),))

assert set(model1.params.keys()) == {"a", "args", "b", "kwargs"}
assert set(model1.flat_params.keys()) == {"a", "args", "b", "kwargs", "c", "m", "n"}

model2 = build_call_model(custom_var_names, extra_dependencies=(Depends(extra_func),))

assert set(model2.params.keys()) == {"a", "args_", "b", "kwargs_"}
assert set(model2.flat_params.keys()) == {"a", "args_", "b", "kwargs_", "c", "m", "n"}

assert default_var_names(1, *('a'), b=2, **{'kw': 'kw'}) == (1, ('a',), 2, {'kw': 'kw'})
assert custom_var_names(1, *('a'), b=2, **{'kw': 'kw'}) == (1, ('a',), 2, {'kw': 'kw'})

0 comments on commit 97a1d97

Please sign in to comment.