diff --git a/tripy/examples/nanogpt/weight_loader.py b/tripy/examples/nanogpt/weight_loader.py index 19018348..c1db8575 100644 --- a/tripy/examples/nanogpt/weight_loader.py +++ b/tripy/examples/nanogpt/weight_loader.py @@ -1,4 +1,3 @@ - # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 diff --git a/tripy/tests/backend/mlir/test_compiler.py b/tripy/tests/backend/mlir/test_compiler.py index 1a120ac2..f04ab1a2 100644 --- a/tripy/tests/backend/mlir/test_compiler.py +++ b/tripy/tests/backend/mlir/test_compiler.py @@ -53,8 +53,8 @@ def test_reason_context(self): with FlatIRTensor.context(["This is the first level of context"]): with FlatIRTensor.context(["This is the second level of context"]): # We need to emit an error from one of the internally created `FlatIRTensor`s to see the context - a = tp.ones(1) - b = tp.ones(1) + a = tp.ones((1,)) + b = tp.ones((1,)) trace = Trace([a + b]) flat_ir = trace.to_flat_ir() producer = flat_ir.outputs[0].producer.inputs[0] diff --git a/tripy/tests/frontend/test_shape.py b/tripy/tests/frontend/test_shape.py index d55c95a0..5f01fb34 100644 --- a/tripy/tests/frontend/test_shape.py +++ b/tripy/tests/frontend/test_shape.py @@ -29,11 +29,10 @@ def values(request): @pytest.fixture( - params=[[4, 5], tp.Tensor([4, 5], dtype=tp.int32), np.array([4, 5], dtype=np.int32)], + params=[[4, 5], tp.Tensor([4, 5], dtype=tp.int32)], ids=[ "python_list", "tripy_tensor", - "numpy_array", ], ) def other_values(request): @@ -95,7 +94,7 @@ def test_plus_override(self, values, other_values): appended = [4, 5] s = tp.Shape(values) - # conversion is implicit except for tp.Tensor + # conversion must be explicit for tp.Tensor lhs_shape = other_values if not isinstance(other_values, tp.Tensor) else tp.Shape(other_values) new_shape = s + lhs_shape assert isinstance(new_shape, tp.Shape) @@ -308,7 +307,7 @@ def test_right_addition(self, other_values): appended = [4, 5] s = tp.Shape(values) - # conversion is implicit except for tp.Tensor + # conversion must be explicit for tp.Tensor rhs_shape = other_values if not isinstance(other_values, tp.Tensor) else tp.Shape(other_values) new_shape = rhs_shape + s @@ -418,19 +417,29 @@ def test_invalid_input_rank_tensor(self): with raises(tp.TripyException, match="Shape tensors must be of rank 1, but input tensor is rank 2"): _ = tp.Shape(tp.ones((3, 2), dtype=tp.int32)) + def test_invalid_mul_sequence(self, values): + s = tp.Shape(values) + with raises(tp.TripyException, match="Attempting to multiply a Tripy Shape by a sequence, which is undefined"): + _ = s * values + def test_invalid_mul_rank(self, values): s = tp.Shape(values) + t = tp.Tensor(values) with raises( tp.TripyException, match="Attempting to multiply a Tripy Shape by a tensor of rank >= 1, which is undefined" ): - _ = s * values + _ = s * t def test_invalid_plus_type(self, values): s = tp.Shape(values) t = tp.Tensor(values, dtype=tp.int32) with raises( tp.TripyException, - match="Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly", + match=( + "Invalid types for addition with a Tripy Shape." + r"\s*Implicit conversions are done only for sequences of Python ints. " + "Consider calling tp.Shape for an explicit conversion." + ), ): s + t @@ -439,7 +448,11 @@ def test_invalid_right_addition_type(self, values): t = tp.Tensor(values, dtype=tp.int32) with raises( tp.TripyException, - match="Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly", + match=( + "Invalid types for addition with a Tripy Shape." + r"\s*Implicit conversions are done only for sequences of Python ints. " + "Consider calling tp.Shape for an explicit conversion." + ), ): t + s @@ -469,8 +482,6 @@ def test_unary_elementwise_fails_at_run_time(self, values): def test_shape_equality(self, other_values): a = tp.Shape([4, 5]) - if isinstance(other_values, np.ndarray): - pytest.skip("numpy array cannot be implicitly cast to Shape type") eq = a == other_values assert isinstance(eq, bool) assert eq diff --git a/tripy/tests/frontend/test_utils.py b/tripy/tests/frontend/test_utils.py index 90459a01..19f1ec67 100644 --- a/tripy/tests/frontend/test_utils.py +++ b/tripy/tests/frontend/test_utils.py @@ -16,6 +16,7 @@ # import cupy as cp +import numpy as np import tripy as tp from tripy.frontend.utils import convert_inputs_to_tensors @@ -291,3 +292,29 @@ def test_sync_arg_type_invalid(self): match=r"At least one of the arguments: \('a', 'b', 'c'\) must be a \`tripy.Tensor\`.", ): x, y, z = sync_arg_types(3.0, 3, 4) + + def test_seq_arg_invalid(self): + with helper.raises( + tp.TripyException, + match=r"Encountered non-number of type str in sequence: hello", + ): + _ = func([1, 2, "hello"]) + + def test_nested_seq_inconsistent_len(self): + with helper.raises( + tp.TripyException, + match=r"Expected a sequence of length 3 but got length 4: \[7, 8, 9, 10\]", + ): + _ = func([[1, 2, 3], [4, 5, 6], [7, 8, 9, 10]]) + + def test_nested_seq_inconsistent_types(self): + with helper.raises( + tp.TripyException, + match=r"Expected a sequence but got str: hi", + ): + _ = func([[1, 2, 3], [4, 5, 6], "hi"]) + + def test_invalid_argument_type_not_converted(self): + a = np.array([1, 2, 3]) + b = func(np.array([1, 2, 3])) + assert (a == b).all() diff --git a/tripy/tests/frontend/trace/ops/test_binary_elementwise.py b/tripy/tests/frontend/trace/ops/test_binary_elementwise.py index b714ea6c..4deb3cae 100644 --- a/tripy/tests/frontend/trace/ops/test_binary_elementwise.py +++ b/tripy/tests/frontend/trace/ops/test_binary_elementwise.py @@ -59,10 +59,8 @@ class TestBinaryElementwise: "lhs, rhs, left_side_is_non_tensor", [ (tp.Tensor([1.0]), tp.Tensor([2.0]), False), - (tp.Tensor([1.0]), np.array([2.0], dtype=np.float32), False), # shape of (0,) is broadcastable with (1,) (tp.Tensor([], dtype=tp.float32), tp.Tensor([1.0], dtype=tp.float32), False), - (np.array([1.0], dtype=np.float32), tp.Tensor([2.0]), True), (tp.Tensor([1.0]), 2.0, False), (1.0, tp.Tensor([2.0]), True), ], @@ -86,8 +84,6 @@ def test_op_funcs(self, kind, func, lhs, rhs, left_side_is_non_tensor): "lhs, rhs, expected_rank", [ (tp.Tensor([1.0]), tp.Tensor([2.0]), 1), - (tp.Tensor([1.0]), np.array([2.0], dtype=np.float32), 1), - (np.array([1.0], dtype=np.float32), tp.Tensor([2.0]), 1), (tp.Tensor([1.0]), 2.0, 1), (1.0, tp.Tensor([2.0]), 1), (tp.ones((2, 3)), 2.0, 2), diff --git a/tripy/tests/integration/test_functional.py b/tripy/tests/integration/test_functional.py index b3986827..cab3151c 100644 --- a/tripy/tests/integration/test_functional.py +++ b/tripy/tests/integration/test_functional.py @@ -130,36 +130,3 @@ def test_multiple_copy_2(self): out = tp.copy(a, tp.device("gpu")) assert out.tolist() == [1, 2] assert out.device.kind == "gpu" - - -class TestConversionToTripyType: - @pytest.mark.parametrize( - "reverse_direction", - [False, True], - ) - @pytest.mark.parametrize( - "input0", - [cp.ones((2, 3), dtype=cp.float32), cp.ones((3,), dtype=np.float32)], - ) - @pytest.mark.parametrize( - "input1", - [ - [ - 4.0, - ], - (5.0,), - cp.array([4.0], dtype=cp.float32), - cp.ones((1, 3), dtype=cp.float32), - torch.Tensor([[4.0]]), - ], - ) - def test_element_wise_prod(self, reverse_direction, input0, input1): - a = tp.Tensor(input0) - if isinstance(input1, torch.Tensor): - input1 = input1.to("cuda") - if reverse_direction: - out = input1 * a - input0, input1 = input1, input0 - else: - out = a * input1 - assert cp.array_equal(cp.from_dlpack(out), cp.array(input0) * cp.array(input1)) diff --git a/tripy/tests/test_function_registry.py b/tripy/tests/test_function_registry.py index b0648d3e..574b2d5d 100644 --- a/tripy/tests/test_function_registry.py +++ b/tripy/tests/test_function_registry.py @@ -17,10 +17,12 @@ import inspect from textwrap import dedent -from typing import Any, Dict, List +from typing import Any, Dict, List, Sequence, Union import pytest +from tests import helper + from tripy import TripyException from tripy.function_registry import AnnotationInfo, FunctionRegistry @@ -61,31 +63,29 @@ def test_overloading_string_annotations(self, int_float_registry): assert int_float_registry["transform"](1.0) == 0.0 def test_error_on_missing_overload(self, int_float_registry): - with pytest.raises( + # Note presence of ANSI color codes. Also note that the last | has a space after it + with helper.raises( TripyException, match=dedent( rf""" - Could not find an implementation for function: 'transform'. - Note: Argument types were: \[str\]. + Could not find an implementation for function: 'transform'\. Candidate overloads were: - --> {__file__}:[0-9]+ - | - [0-9]+ | \@registry\("transform"\) - [0-9]+ | def transform_int\(a: "int"\): - [0-9]+ | return a \+ 1 - | + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mtransform_int\(\)\x1b\[0m + \| + [0-9]+ \| def transform_int\(a: \"int\"\): + [0-9]+ \| \.\.\. + \|\s - Not a valid overload because: For parameter: 'a', expected an instance of type: 'int' but got argument of type: 'str'. + Not a valid overload because: For parameter: 'a', expected an instance of type: 'int' but got argument of type: 'str'\. - --> {__file__}:[0-9]+ - | - [0-9]+ | \@registry\("transform"\) - [0-9]+ | def transform_float\(a: "float"\): - [0-9]+ | return a \- 1 - | + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mtransform_float\(\)\x1b\[0m + \| + [0-9]+ \| def transform_float\(a: "float"\): + [0-9]+ \| \.\.\. + \|\s - Not a valid overload because: For parameter: 'a', expected an instance of type: 'float' but got argument of type: 'str'. + Not a valid overload because: For parameter: 'a', expected an instance of type: 'float' but got argument of type: 'str'\. """ ).strip(), ): @@ -96,22 +96,20 @@ def test_error_when_kwargs_wrong(self, registry): def func(a: int, b: int, c: int): return a + b + c - with pytest.raises( + with helper.raises( TripyException, match=dedent( rf""" Could not find an implementation for function: 'test'. - Note: Argument types were: \[int, b=int, c=float\]. Candidate overloads were: - --> {__file__}:[0-9]+ - | - [0-9]+ | \@registry\("test"\) - [0-9]+ | def func\(a: int, b: int, c: int\): - [0-9]+ | return a \+ b \+ c - | + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(a: int, b: int, c: int\): + [0-9]+ \| \.\.\. + \|\s - Not a valid overload because: For parameter: 'c', expected an instance of type: 'int' but got argument of type: 'float'. + Not a valid overload because: For parameter: 'c', expected an instance of type: 'int' but got argument of type: 'float'\. """ ).strip(), ): @@ -122,7 +120,7 @@ def test_invalid_string_annotation_fails_gracefully(self, registry): def func(a: "not_a_real_type"): pass - with pytest.raises( + with helper.raises( NameError, match="Error while evaluating type annotation: 'not_a_real_type' for parameter: 'a' of function: 'func'." "\nNote: Error was: name 'not_a_real_type' is not defined", @@ -162,27 +160,24 @@ def func(a: int): def func(b: int): return b - 1 - with pytest.raises( + with helper.raises( TripyException, match=dedent( rf""" Ambiguous overload for function: 'test'. - Note: Argument types were: \[int\]. Candidate overloads were: - --> {__file__}:[0-9]+ - | - [0-9]+ | \@registry\("test"\) - [0-9]+ | def func\(a: int\): - [0-9]+ | return a \+ 1 - | - - --> {__file__}:[0-9]+ - | - [0-9]+ | \@registry\("test"\) - [0-9]+ | def func\(b: int\): - [0-9]+ | return b \- 1 - | + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(a: int\): + [0-9]+ \| \.\.\. + \|\s + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(b: int\): + [0-9]+ \| \.\.\. + \|\s """ ).strip(), ): @@ -193,7 +188,7 @@ def test_missing_arguments_gives_sane_error(self, registry): def func(a: int, b: int): return a + b - with pytest.raises(TripyException, match="Some required arguments were not provided: \['a'\]"): + with helper.raises(TripyException, match="Some required arguments were not provided: \['a'\]"): registry["test"](b=0) def test_func_overload_caches_signature(self, registry): @@ -270,3 +265,255 @@ def func(**kwargs: Dict[str, Any]): return sum(kwargs.values()) assert registry["test"](a=1.0, b=2.0, c=3.0) == 6.0 + + def test_sequence_check(self, registry): + @registry("test") + def func(int_seq: Sequence[int]) -> int: + return sum(int_seq) + + assert registry["test"]([1, 2, 3]) == 6 + # empty should work too + assert registry["test"]([]) == 0 + + def test_sequence_no_arg_check(self, registry): + @registry("test") + def func(seq: Sequence) -> int: + return len(seq) + + assert registry["test"]([1, 2, 3]) == 3 + assert registry["test"](["a", "b"]) == 2 + assert registry["test"]([]) == 0 + + def test_union_check(self, registry): + @registry("test") + def func(n: Union[int, Sequence[int]]) -> int: + if isinstance(n, int): + return n + return sum(n) + + assert registry["test"]([1, 2, 3]) == 6 + assert registry["test"](6) == 6 + + def test_nested_sequence_check(self, registry): + @registry("test") + def func(n: Sequence[Sequence[int]]) -> int: + if n and n[0]: + return n[0][0] + return -1 + + assert registry["test"]([[1, 2, 3], [4, 5, 6]]) == 1 + + def test_nested_union_and_sequence_check(self, registry): + @registry("test") + def func(n: Sequence[Union[int, Sequence[int]]]) -> int: + if len(n) == 0: + return 0 + if isinstance(n[0], Sequence): + return len(n) * len(n[0]) + return len(n) + + assert registry["test"]([]) == 0 + assert registry["test"]([1, 2, 3]) == 3 + assert registry["test"]([[1, 2], [3, 4], [5, 6]]) == 6 + + def test_number_array(self, registry): + @registry("test") + def func(n: "tripy.types.NestedNumberSequence"): + return n + + assert registry["test"](1) == 1 + assert registry["test"](2.0) == 2.0 + assert registry["test"]([1, 2, 3]) == [1, 2, 3] + assert registry["test"]([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) == [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + + def test_error_sequence(self, registry): + @registry("test") + def func(n: Sequence[int]) -> int: + return sum(n) + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'\. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Sequence\[int\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[float\]'\. + """ + ).strip(), + ): + registry["test"]([1.0, 2.0, 3.0]) + + def test_error_union(self, registry): + @registry("test") + def func(n: Union[int, float]) -> int: + return 0 + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Union\[int, float\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[int, float\]' but got argument of type: 'List\[str\]'\. + """ + ).strip(), + ): + registry["test"](["hi"]) + + def test_error_inconsistent_sequence(self, registry): + @registry("test") + def func(n: Sequence[int]) -> int: + return sum(n) + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Sequence\[int\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\. + """ + ).strip(), + ): + registry["test"]([1, 2, "a"]) + + def test_error_not_sequence(self, registry): + @registry("test") + def func(n: Sequence[Sequence[int]]) -> int: + return sum(sum(n)) + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Sequence\[Sequence\[int\]\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[int\]'\. + """ + ).strip(), + ): + registry["test"]([1, 2, 3]) + + def test_error_nested_sequence(self, registry): + @registry("test") + def func(n: Sequence[Sequence[int]]) -> int: + return sum(sum(n)) + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Sequence\[Sequence\[int\]\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\. + """ + ).strip(), + ): + registry["test"]([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + def test_error_nested_union_and_sequence(self, registry): + @registry("test") + def func(n: Sequence[Union[int, float]]) -> int: + return sum(n) + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: Sequence\[Union\[int, float\]\]\) \-> int: + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Union\[int, float\]\]' but got argument of type: 'List\[str\]'\. + """ + ).strip(), + ): + registry["test"](["a", "b", "c"]) + + def test_error_number_array_not_sequence(self, registry): + @registry("test") + def func(n: "tripy.types.NestedNumberSequence"): + return n + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: \"tripy\.types\.NestedNumberSequence\"\): + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'str'\. + """ + ).strip(), + ): + registry["test"]("hi") + + def test_error_number_array_not_sequence_of_numbers(self, registry): + @registry("test") + def func(n: "tripy.types.NestedNumberSequence"): + return n + + with helper.raises( + TripyException, + match=dedent( + rf""" + Could not find an implementation for function: 'test'\. + Candidate overloads were: + + --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m + \| + [0-9]+ \| def func\(n: \"tripy\.types\.NestedNumberSequence\"\): + [0-9]+ \| \.\.\. + \|\s + + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'List\[List\[str\]\]' + """ + ).strip(), + ): + registry["test"]([["a"], ["b"], ["c"]]) diff --git a/tripy/tripy/frontend/module/parameter.py b/tripy/tripy/frontend/module/parameter.py index 5ddb99b6..c872f6ce 100644 --- a/tripy/tripy/frontend/module/parameter.py +++ b/tripy/tripy/frontend/module/parameter.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Sequence +from typing import Any, Sequence import tripy.frontend.utils as frontend_utils from tripy import export, utils @@ -31,10 +31,12 @@ class Parameter(Tensor): """ @frontend_utils.convert_inputs_to_tensors() - def __init__(self, tensor: "tripy.Tensor") -> None: + def __init__(self, tensor: Any) -> None: """ Args: - tensor: The tensor value for this parameter. + tensor: + The tensor value for this parameter. If provided as an external data format (e.g., a Numpy array), + it will be converted into a Tripy Tensor. .. code-block:: python :linenos: @@ -45,7 +47,11 @@ def __init__(self, tensor: "tripy.Tensor") -> None: assert isinstance(parameter, tp.Parameter) assert isinstance(parameter, tp.Tensor) """ - self.__dict__ = tensor.__dict__ + t = tensor + # for convenience, this will convert other dlpack-supporting representations too + if not isinstance(t, Tensor): + t = Tensor(t) + self.__dict__ = t.__dict__ def _is_compatible_helper(self, original_shape, other_shape, original_dtype, other_dtype) -> Result: if original_shape != other_shape: diff --git a/tripy/tripy/frontend/shape.py b/tripy/tripy/frontend/shape.py index 618ddeef..898c24fd 100644 --- a/tripy/tripy/frontend/shape.py +++ b/tripy/tripy/frontend/shape.py @@ -155,25 +155,32 @@ def __str__(self) -> str: # addition for shapes is concatenation, not tensor addition + def _validate_add_argument(self, other): + if isinstance(other, Shape): + return + if not isinstance(other, Sequence) or (len(other) != 0 and not isinstance(other[0], int)): + raise_error( + "Invalid types for addition with a Tripy Shape.", + details=[ + "Implicit conversions are done only for sequences of Python ints. ", + "Consider calling tp.Shape for an explicit conversion. ", + f"Note: argument was {other}.", + ], + ) + def __add__(self, other): from tripy.frontend.trace.ops.concatenate import concatenate - if not isinstance(other, Shape) and isinstance(other, Tensor): - raise_error( - "Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly" - ) - elif not isinstance(other, Shape): + self._validate_add_argument(other) + if not isinstance(other, Shape): other = Shape(other) return concatenate([self, other], 0) def __radd__(self, other): from tripy.frontend.trace.ops.concatenate import concatenate - if not isinstance(other, Shape) and isinstance(other, Tensor): - raise_error( - "Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly" - ) - elif not isinstance(other, Shape): + self._validate_add_argument(other) + if not isinstance(other, Shape): other = Shape(other) return concatenate([other, self], 0) @@ -191,10 +198,26 @@ def __mul__(self, other): # Only defined with a scalar argument if not isinstance(other, Tensor): + # note: Python does not accept floats as arguments for list multiplication either + if isinstance(other, Sequence): + raise_error( + "Attempting to multiply a Tripy Shape by a sequence, which is undefined", + details=[f"Note: argument was {other}."], + ) + if not isinstance(other, int): + raise_error( + "Invalid types for multplication with a Tripy Shape.", + details=[ + "Implicit conversions are done only for Python ints. ", + "Consider calling tp.Shape for an explicit conversion. ", + f"Note: argument was: {other}.", + ], + ) other = Tensor(other, dtype=int32) if other.rank >= 1: raise_error( - "Attempting to multiply a Tripy Shape by a tensor of rank >= 1, which is undefined", details=[other] + "Attempting to multiply a Tripy Shape by a tensor of rank >= 1, which is undefined", + details=[f"Note: argument was {other}."], ) # note: in Python, if a list is multiplied by a negative number, this is the same as multiplying by 0, # so we should clamp the argument diff --git a/tripy/tripy/frontend/trace/ops/binary_elementwise.py b/tripy/tripy/frontend/trace/ops/binary_elementwise.py index 884b6acc..8f07d37a 100644 --- a/tripy/tripy/frontend/trace/ops/binary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/binary_elementwise.py @@ -183,7 +183,10 @@ def to_flat_ir(self, inputs, outputs): dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"}, function_name="__radd__", ) -def __add__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __add__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs an elementwise sum. @@ -214,7 +217,10 @@ def __add__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64"]}, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, ) -def __sub__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __sub__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs an elementwise subtraction. @@ -245,7 +251,7 @@ def __sub__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "float8", "int32", "int64"]}, dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"}, ) -def __rsub__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __rsub__(self: "tripy.types.NestedNumberSequence", other: "tripy.types.TensorLike") -> "tripy.Tensor": """ Performs an elementwise subtraction. @@ -276,7 +282,10 @@ def __rsub__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tripy. dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8"]}, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, ) -def __pow__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __pow__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs an elementwise exponentiation. @@ -345,7 +354,10 @@ def __rpow__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tripy. dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"}, function_name="__rmul__", ) -def __mul__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __mul__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs an elementwise multiplication. @@ -376,7 +388,10 @@ def __mul__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, ) -def __truediv__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __truediv__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs an elementwise division. @@ -407,7 +422,7 @@ def __truediv__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", A dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"}, ) -def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __rtruediv__(self: numbers.Number, other: "tripy.types.TensorLike") -> "tripy.Tensor": """ Performs an elementwise division. @@ -438,7 +453,7 @@ def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tr dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"]}, dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"}, ) -def maximum(lhs: Union["tripy.Tensor", Any], rhs: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def maximum(lhs: "tripy.types.TensorLike", rhs: "tripy.types.TensorLike") -> "tripy.Tensor": """ Performs an elementwise maximum. @@ -469,7 +484,7 @@ def maximum(lhs: Union["tripy.Tensor", Any], rhs: Union["tripy.Tensor", Any]) -> dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"]}, dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"}, ) -def minimum(lhs: Union["tripy.Tensor", Any], rhs: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def minimum(lhs: "tripy.types.TensorLike", rhs: "tripy.types.TensorLike") -> "tripy.Tensor": """ Performs an elementwise minimum. @@ -503,7 +518,10 @@ def minimum(lhs: Union["tripy.Tensor", Any], rhs: Union["tripy.Tensor", Any]) -> }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __lt__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __lt__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs a 'less than' comparison. @@ -537,7 +555,10 @@ def __lt__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __le__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __le__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs a 'less than or equal' comparison. @@ -571,7 +592,10 @@ def __le__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __eq__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __eq__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs an 'equal' comparison. @@ -605,7 +629,7 @@ def __eq__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __ne__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __ne__(self: "tripy.types.TensorLike", other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": """ Performs a 'not equal' comparison. @@ -639,7 +663,10 @@ def __ne__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __ge__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __ge__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs a 'greater than or equal' comparison. @@ -673,7 +700,10 @@ def __ge__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __gt__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __gt__( + self: "tripy.types.TensorLike", + other: "tripy.types.TensorLike", +) -> "tripy.Tensor": """ Performs a 'greater than' comparison. diff --git a/tripy/tripy/frontend/trace/ops/dequantize.py b/tripy/tripy/frontend/trace/ops/dequantize.py index b5b74240..7d78aa00 100644 --- a/tripy/tripy/frontend/trace/ops/dequantize.py +++ b/tripy/tripy/frontend/trace/ops/dequantize.py @@ -111,7 +111,7 @@ def to_flat_ir(self, inputs, outputs): @frontend_utils.convert_inputs_to_tensors(exclude=["dtype", "dim"]) def dequantize( input: "tripy.Tensor", - scale: Union["tripy.Tensor", Any], + scale: "tripy.types.TensorLike", dtype: datatype.dtype, dim: Union[int, Any] = None, ) -> "tripy.Tensor": diff --git a/tripy/tripy/frontend/trace/ops/quantize.py b/tripy/tripy/frontend/trace/ops/quantize.py index d1a57e14..af309756 100644 --- a/tripy/tripy/frontend/trace/ops/quantize.py +++ b/tripy/tripy/frontend/trace/ops/quantize.py @@ -135,7 +135,7 @@ def to_flat_ir(self, inputs, outputs): ) def quantize( input: "tripy.Tensor", - scale: Union["tripy.Tensor", Any], + scale: "tripy.types.TensorLike", dtype: datatype.dtype, dim: Union[int, Any] = None, ) -> "tripy.Tensor": diff --git a/tripy/tripy/frontend/utils.py b/tripy/tripy/frontend/utils.py index 0f242715..9de46020 100644 --- a/tripy/tripy/frontend/utils.py +++ b/tripy/tripy/frontend/utils.py @@ -17,6 +17,7 @@ import functools from collections import deque +import numbers from typing import List, Optional, Sequence, Tuple, Union from tripy import utils @@ -25,7 +26,7 @@ from tripy.frontend.trace.ops import BaseTraceOp -# Decorator to preprocess inputs of a function and convert numpy, python types to tripy tensors. +# Decorator to preprocess inputs of a function and convert Python numbers to tripy tensors. def convert_inputs_to_tensors( sync_arg_types: Optional[List[Tuple[str]]] = None, exclude: Optional[List[str]] = None, @@ -34,8 +35,9 @@ def convert_inputs_to_tensors( skip_num_stack_entries: int = 0, ): """ - Decorator that converts all arguments to `Tensor`s before passing them along - to the decorated function. + Decorator that converts all arguments to `Tensor`s or `Shape`s before passing them along + to the decorated function. Converts only Python numbers or lists of Python numbers; + inputs like `numpy` arrays should be handled manually. Args: sync_arg_types: A list of tuples of strings indicating the parameter indices for parameters @@ -193,6 +195,50 @@ def find_sync_target_dtype(arg_name): return None def convert_nontensor_arg(arg, list_index=None): + from tripy.utils import Result + + def is_valid_sequence(seq_arg: Sequence) -> Result: + if len(seq_arg) == 0: + return Result.ok() + # If one is a sequence, all must be sequences of the same length. Do not accept strings. + if isinstance(seq_arg[0], Sequence) and not isinstance(seq_arg[0], str): + target_len = len(seq_arg[0]) + for inner_arg in seq_arg[1:]: + if not isinstance(inner_arg, Sequence) or isinstance(inner_arg, str): + return Result.err( + [f"Expected a sequence but got {type(inner_arg).__qualname__}: {inner_arg}"] + ) + if len(inner_arg) != target_len: + return Result.err( + [ + f"Expected a sequence of length {target_len} but got length {len(inner_arg)}: {inner_arg}" + ] + ) + valid_inner = is_valid_sequence(inner_arg) + if not valid_inner: + return valid_inner + return Result.ok() + # Otherwise check for numbers. + for inner_arg in seq_arg: + if not isinstance(inner_arg, numbers.Number): + return Result.err( + [ + f"Encountered non-number of type {type(inner_arg).__qualname__} in sequence: {inner_arg}" + ] + ) + return Result.ok() + + # simply do not convert in these cases and let the registry give an error instead + if not isinstance(arg, numbers.Number) and not isinstance(arg, Sequence): + return arg + + if isinstance(arg, Sequence): + valid_sequence = is_valid_sequence(arg) + if not valid_sequence: + raise_error( + f"Encountered invalid sequence argument: {arg}", details=valid_sequence.error_details + ) + cast_dtype = find_sync_target_dtype(name) return add_column_info_for_non_tensor( arg, index, is_kwarg=name in kwargs, dtype=cast_dtype, list_index=list_index diff --git a/tripy/tripy/function_registry.py b/tripy/tripy/function_registry.py index 1c3c8214..df2ca569 100644 --- a/tripy/tripy/function_registry.py +++ b/tripy/tripy/function_registry.py @@ -19,7 +19,7 @@ import inspect from collections import OrderedDict, defaultdict from textwrap import dedent -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from dataclasses import dataclass @@ -98,7 +98,30 @@ def _get_annotations(self): def matches_arg_types(self, args, kwargs) -> "Result": from tripy.utils.result import Result - def matches_type(name: str, annotation: type, arg: Any): + def sanitize_name(annotation): + # typing module annotations are likely to be better when pretty-printed due to including subscripts + return annotation if annotation.__module__ == "typing" else annotation.__qualname__ + + def render_arg_type(arg: Any) -> str: + # it is more useful to report more detailed types for sequences/tuples in error messages + from typing import List, Tuple + + if isinstance(arg, List): + if len(arg) == 0: + return "List" + # catch inconsistencies this way + arg_types = {render_arg_type(member) for member in arg} + if len(arg_types) == 1: + return f"List[{list(arg_types)[0]}]" + return f"List[Union[{', '.join(arg_types)}]]" + if isinstance(arg, Tuple): + return f"Tuple[{', '.join(map(render_arg_type, arg))}]" + return type(arg).__qualname__ + + def matches_type(name: str, annotation: type, arg: Any) -> bool: + from collections.abc import Sequence as ABCSequence + from typing import ForwardRef, get_args, get_origin, Sequence, Union + # In cases where a type is not available at the time of function definition, the type # annotation may be provided as a string. Since we need the actual type, we just # eval it here. @@ -111,12 +134,32 @@ def matches_type(name: str, annotation: type, arg: Any): f"\nNote: Error was: {e}" ) + # can add more cases, prioritizing the common ones + if get_origin(annotation) is Union: + return any(map(lambda type_arg: matches_type(name, type_arg, arg), get_args(annotation))) + + # note: get_origin for typing.Sequence normalizes it into collections.abc.Sequence, see spec for get_origin + if get_origin(annotation) is ABCSequence: + # in the context of Tripy, it does not make sense to consider strings as sequences + if not isinstance(arg, Sequence) or isinstance(arg, str): + return False + seq_arg = get_args(annotation) + if seq_arg and len(arg) > 0: + assert len(seq_arg) == 1 + return all(map(lambda member: matches_type(name, seq_arg[0], member), arg)) + return True + + # Forward references can be used for recursive type definitions. Warning: Has the potential for infinite looping if there is no base case! + if isinstance(annotation, ForwardRef): + # need this import in case the annotation references tripy + import tripy + + return matches_type(name, eval(annotation.__forward_arg__), arg) + try: return isinstance(arg, annotation) except TypeError: - # When the type annotation includes a subscripted generic (e.g. Tuple[int]), isinstance - # does not work. We could introduce more advanced type checking using `typing.get_origin` and `typing.get_args` - # but for now, we don't support overloading on generic types. + # When the type annotation includes a subscripted generic that we do not handle above, isinstance does not work return True annotations = self._get_annotations() @@ -133,9 +176,9 @@ def matches_type(name: str, annotation: type, arg: Any): if not matches_type(name, annotation.type_info, arg): return Result.err( [ - f"For parameter: '{name}', expected an instance of type: " - f"'{annotation.type_info.__qualname__}' but got argument of type: '{type(arg).__qualname__}'." - ], + f"For parameter: '{name}', expected an instance of type: '{sanitize_name(annotation.type_info)}' " + f"but got argument of type: '{render_arg_type(arg)}'." + ] ) for name, arg in kwargs.items(): @@ -144,9 +187,9 @@ def matches_type(name: str, annotation: type, arg: Any): if not matches_type(name, typ, arg): return Result.err( [ - f"For parameter: '{name}', expected an instance of type: " - f"'{typ.__qualname__}' but got argument of type: '{type(arg).__qualname__}'." - ], + f"For parameter: '{name}', expected an instance of type: '{sanitize_name(typ)}' " + f"but got argument of type: '{render_arg_type(arg)}'." + ] ) elif not any(annotation.kind == inspect.Parameter.VAR_KEYWORD for annotation in annotations.values()): # We can only validate the names of arguments if the function does not accept variadic kwargs @@ -227,7 +270,7 @@ def raise_overload_error(msg, candidate_overloads, mismatch_reasons=None, extra_ raise_error( f"{msg} for function: '{key}'.", details=[ - f"Note: Argument types were: [{', '.join(arg_type_strs)}].\nCandidate overloads were:\n\n", + f"Candidate overloads were:\n\n", *overloads_error, extra_info, ], diff --git a/tripy/tripy/types.py b/tripy/tripy/types.py new file mode 100644 index 00000000..466a1b69 --- /dev/null +++ b/tripy/tripy/types.py @@ -0,0 +1,51 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Special type annotations used in Tripy. +""" + +import numbers +import sys +from typing import Union, Sequence + +from tripy import export + +export.public_api()(sys.modules[__name__]) + +NestedNumberSequence = export.public_api( + document_under="types.rst", + autodoc_options=[":no-index:"], + module=sys.modules[__name__], + symbol="NestedNumberSequence", +)(Union[numbers.Number, Sequence["tripy.types.NestedNumberSequence"]]) + +NestedNumberSequence.__doc__ = """ +Denotes the recursive type annotation for sequences of Python numbers, possibly nested to an arbitrary depth. +Tripy often automatically converts these sequences to `tp.Tensor`. +""" + +TensorLike = export.public_api( + document_under="types.rst", + autodoc_options=[":no-index:"], + module=sys.modules[__name__], + symbol="TensorLike", +)(Union["tripy.Tensor", "tripy.types.NestedNumberSequence"]) + +TensorLike.__doc__ = """ +Type annotation for a parameter that is either a Tripy `Tensor` or a Python sequence that can be automatically converted into one. +"""