Skip to content

Commit

Permalink
Do further validation on sequence arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Sep 10, 2024
1 parent 636c8fa commit af239d3
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
28 changes: 21 additions & 7 deletions tripy/tests/frontend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,28 @@ def test_sync_arg_type_invalid(self):
):
x, y, z = sync_arg_types(3.0, 3, 4)

def test_seq_arg_invalid(self):
with helper.raises(
tp.TripyException,
match=r"Encountered non-number of type str in sequence: hello",
):
_ = func([1, 2, "hello"])

def test_nested_seq_inconsistent_len(self):
with helper.raises(
tp.TripyException,
match=r"Expected a sequence of length 3 but got length 4: \[7, 8, 9, 10\]",
):
_ = func([[1, 2, 3], [4, 5, 6], [7, 8, 9, 10]])

def test_nested_seq_inconsistent_types(self):
with helper.raises(
tp.TripyException,
match=r"Expected a sequence but got str: hi",
):
_ = func([[1, 2, 3], [4, 5, 6], "hi"])

def test_invalid_argument_type_not_converted(self):
a = np.array([1, 2, 3])
b = func(np.array([1, 2, 3]))
assert (a == b).all()

def test_invalid_argument_type_in_sequence_not_converted(self):
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
c, d = func([np.array([1, 2, 3]), np.array([4, 5, 6])])
assert (c == a).all()
assert (d == b).all()
47 changes: 39 additions & 8 deletions tripy/tripy/frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,50 @@ def find_sync_target_dtype(arg_name):
return None

def convert_nontensor_arg(arg, list_index=None):
def is_valid_sequence(seq_arg: Sequence) -> bool:
from tripy.utils import Result

def is_valid_sequence(seq_arg: Sequence) -> Result:
if len(seq_arg) == 0:
return True
if isinstance(seq_arg[0], Sequence):
return is_valid_sequence(seq_arg[0])
return isinstance(seq_arg[0], numbers.Number)
return Result.ok()
# If one is a sequence, all must be sequences of the same length. Do not accept strings.
if isinstance(seq_arg[0], Sequence) and not isinstance(seq_arg[0], str):
target_len = len(seq_arg[0])
for inner_arg in seq_arg[1:]:
if not isinstance(inner_arg, Sequence) or isinstance(inner_arg, str):
return Result.err(
[f"Expected a sequence but got {type(inner_arg).__qualname__}: {inner_arg}"]
)
if len(inner_arg) != target_len:
return Result.err(
[
f"Expected a sequence of length {target_len} but got length {len(inner_arg)}: {inner_arg}"
]
)
valid_inner = is_valid_sequence(inner_arg)
if not valid_inner:
return valid_inner
return Result.ok()
# Otherwise check for numbers.
for inner_arg in seq_arg:
if not isinstance(inner_arg, numbers.Number):
return Result.err(
[
f"Encountered non-number of type {type(inner_arg).__qualname__} in sequence: {inner_arg}"
]
)
return Result.ok()

# simply do not convert in these cases and let the registry give an error instead
if not isinstance(arg, numbers.Number) and not (
isinstance(arg, Sequence) and is_valid_sequence(arg)
):
if not isinstance(arg, numbers.Number) and not isinstance(arg, Sequence):
return arg

if isinstance(arg, Sequence):
valid_sequence = is_valid_sequence(arg)
if not valid_sequence:
raise_error(
f"Encountered invalid sequence argument: {arg}", details=valid_sequence.error_details
)

cast_dtype = find_sync_target_dtype(name)
return add_column_info_for_non_tensor(
arg, index, is_kwarg=name in kwargs, dtype=cast_dtype, list_index=list_index
Expand Down

0 comments on commit af239d3

Please sign in to comment.