From b5f78ded284186e8ec8678c3b979bb8f5daa73c4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 10 Sep 2024 20:04:13 -0400 Subject: [PATCH] Also validate seq inputs in the function registry --- tripy/tests/test_function_registry.py | 24 ++++++++++++++++++++++++ tripy/tripy/function_registry.py | 9 ++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/tripy/tests/test_function_registry.py b/tripy/tests/test_function_registry.py index 6f2613762..181eb3485 100644 --- a/tripy/tests/test_function_registry.py +++ b/tripy/tests/test_function_registry.py @@ -374,6 +374,30 @@ def func(n: Union[int, float]) -> int: ): registry["test"](["hi"]) + def test_error_inconsistent_sequence(self, registry): + @registry("test") + def func(n: Sequence[int]) -> int: + return sum(n) + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Sequence\[int\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[Union\[int, str\]\]'\. + """ + ).strip(), + ): + registry["test"]([1, 2, "a"]) + def test_error_not_sequence(self, registry): @registry("test") def func(n: Sequence[Sequence[int]]) -> int: diff --git a/tripy/tripy/function_registry.py b/tripy/tripy/function_registry.py index 59bf492ef..68dbf50bc 100644 --- a/tripy/tripy/function_registry.py +++ b/tripy/tripy/function_registry.py @@ -109,7 +109,11 @@ def render_arg_type(arg: Any) -> str: if isinstance(arg, List): if len(arg) == 0: return "List" - return f"List[{render_arg_type(arg[0])}]" + # catch inconsistencies this way + arg_types = {render_arg_type(member) for member in arg} + if len(arg_types) == 1: + return f"List[{list(arg_types)[0]}]" + return f"List[Union[{', '.join(arg_types)}]]" if isinstance(arg, Tuple): return f"Tuple[{', '.join(map(render_arg_type, arg))}]" return type(arg).__qualname__ @@ -142,8 +146,7 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool: seq_arg = get_args(annotation) if seq_arg and len(arg) > 0: assert len(seq_arg) == 1 - # We could check every member of the arg but this would result in much more iteration, especially if nested - return matches_type(name, seq_arg[0], arg[0]) + return all(map(lambda member: matches_type(name, seq_arg[0], member), arg)) return True # Forward references can be used for recursive type definitions. Warning: Has the potential for infinite looping if there is no base case!