Skip to content

Commit

Permalink
Also validate seq inputs in the function registry
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Sep 11, 2024
1 parent af239d3 commit b5f78de
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
24 changes: 24 additions & 0 deletions tripy/tests/test_function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,30 @@ def func(n: Union[int, float]) -> int:
):
registry["test"](["hi"])

def test_error_inconsistent_sequence(self, registry):
@registry("test")
def func(n: Sequence[int]) -> int:
return sum(n)

with helper.raises(
TripyException,
match=dedent(
rf"""
Could not find an implementation for function: 'test'.
Candidate overloads were:
--> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m
\|
[0-9]+ \| def func\(n: Sequence\[int\]\) \-> int:
[0-9]+ \| \.\.\.
\|\s
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[Union\[int, str\]\]'\.
"""
).strip(),
):
registry["test"]([1, 2, "a"])

def test_error_not_sequence(self, registry):
@registry("test")
def func(n: Sequence[Sequence[int]]) -> int:
Expand Down
9 changes: 6 additions & 3 deletions tripy/tripy/function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def render_arg_type(arg: Any) -> str:
if isinstance(arg, List):
if len(arg) == 0:
return "List"
return f"List[{render_arg_type(arg[0])}]"
# catch inconsistencies this way
arg_types = {render_arg_type(member) for member in arg}
if len(arg_types) == 1:
return f"List[{list(arg_types)[0]}]"
return f"List[Union[{', '.join(arg_types)}]]"
if isinstance(arg, Tuple):
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"
return type(arg).__qualname__
Expand Down Expand Up @@ -142,8 +146,7 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool:
seq_arg = get_args(annotation)
if seq_arg and len(arg) > 0:
assert len(seq_arg) == 1
# We could check every member of the arg but this would result in much more iteration, especially if nested
return matches_type(name, seq_arg[0], arg[0])
return all(map(lambda member: matches_type(name, seq_arg[0], member), arg))
return True

# Forward references can be used for recursive type definitions. Warning: Has the potential for infinite looping if there is no base case!
Expand Down

0 comments on commit b5f78de

Please sign in to comment.