From e0f222d7517b73c51904e9ffac201730eb57098b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 5 Sep 2024 17:27:20 -0400 Subject: [PATCH] Do not perform implicit conversions of external datatypes in convert_inputs_to_tensors decorator --- tripy/examples/nanogpt/weight_loader.py | 5 ++- tripy/tests/frontend/test_shape.py | 29 +++++++++++----- tripy/tests/frontend/test_utils.py | 15 ++++++++ .../trace/ops/test_binary_elementwise.py | 4 --- tripy/tests/integration/test_functional.py | 5 +-- tripy/tests/integration/test_groupnorm.py | 10 +++--- tripy/tests/integration/test_layernorm.py | 12 +++---- tripy/tests/integration/test_linear.py | 2 +- tripy/tripy/frontend/shape.py | 34 +++++++++++++------ tripy/tripy/frontend/utils.py | 21 ++++++++++-- 10 files changed, 92 insertions(+), 45 deletions(-) diff --git a/tripy/examples/nanogpt/weight_loader.py b/tripy/examples/nanogpt/weight_loader.py index 19018348b..b770ec125 100644 --- a/tripy/examples/nanogpt/weight_loader.py +++ b/tripy/examples/nanogpt/weight_loader.py @@ -1,4 +1,3 @@ - # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -50,7 +49,7 @@ def load_weights_from_hf(model, model_type, dtype): weight = hf_state_dict[key].t().contiguous() if "ln" not in key: weight = weight.to(torch_dtype) - param = tp.Parameter(weight) + param = tp.Parameter(tp.Tensor(weight)) tripy_state_dict[key] = param model.load_from_state_dict(tripy_state_dict) @@ -110,7 +109,7 @@ def get_submodule(module, attr_name): if "ln" not in key: weight = weight.to(torch_dtype) - param = tp.Parameter(weight.contiguous()) + param = tp.Parameter(tp.Tensor(weight.contiguous())) tripy_state_dict[key] = param model.load_from_state_dict(tripy_state_dict) diff --git a/tripy/tests/frontend/test_shape.py b/tripy/tests/frontend/test_shape.py index 5537b242a..eb6589562 100644 --- a/tripy/tests/frontend/test_shape.py +++ b/tripy/tests/frontend/test_shape.py @@ -29,11 +29,10 @@ def values(request): @pytest.fixture( - params=[[4, 5], tp.Tensor([4, 5], dtype=tp.int32), np.array([4, 5], dtype=np.int32)], + params=[[4, 5], tp.Tensor([4, 5], dtype=tp.int32)], ids=[ "python_list", "tripy_tensor", - "numpy_array", ], ) def other_values(request): @@ -95,7 +94,7 @@ def test_plus_override(self, values, other_values): appended = [4, 5] s = tp.Shape(values) - # conversion is implicit except for tp.Tensor + # conversion must be explicit for tp.Tensor lhs_shape = other_values if not isinstance(other_values, tp.Tensor) else tp.Shape(other_values) new_shape = s + lhs_shape assert isinstance(new_shape, tp.Shape) @@ -308,7 +307,7 @@ def test_right_addition(self, other_values): appended = [4, 5] s = tp.Shape(values) - # conversion is implicit except for tp.Tensor + # conversion must be explicit for tp.Tensor rhs_shape = other_values if not isinstance(other_values, tp.Tensor) else tp.Shape(other_values) new_shape = rhs_shape + s @@ -418,19 +417,29 @@ def test_invalid_input_rank_tensor(self): with raises(tp.TripyException, match="Shape tensors must be of rank 1, but input tensor is rank 2"): _ = tp.Shape(tp.ones((3, 2), dtype=tp.int32)) + def test_invalid_mul_sequence(self, values): + s = tp.Shape(values) + with raises(tp.TripyException, match="Attempting to multiply a Tripy Shape by a sequence, which is undefined"): + _ = s * values + def test_invalid_mul_rank(self, values): s = tp.Shape(values) + t = tp.Tensor(values) with raises( tp.TripyException, match="Attempting to multiply a Tripy Shape by a tensor of rank >= 1, which is undefined" ): - _ = s * values + _ = s * t def test_invalid_plus_type(self, values): s = tp.Shape(values) t = tp.Tensor(values, dtype=tp.int32) with raises( tp.TripyException, - match="Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly", + match=( + "Invalid types for addition with a Tripy Shape: " + "Implicit conversions are done only for sequences of Python ints. " + "Consider calling tp.Shape for an explicit conversion." + ), ): s + t @@ -439,7 +448,11 @@ def test_invalid_right_addition_type(self, values): t = tp.Tensor(values, dtype=tp.int32) with raises( tp.TripyException, - match="Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly", + match=( + "Invalid types for addition with a Tripy Shape: " + "Implicit conversions are done only for sequences of Python ints. " + "Consider calling tp.Shape for an explicit conversion." + ), ): t + s @@ -469,8 +482,6 @@ def test_unary_elementwise_fails_at_run_time(self, values): def test_shape_equality(self, other_values): a = tp.Shape([4, 5]) - if isinstance(other_values, np.ndarray): - pytest.skip("numpy array cannot be implicitly cast to Shape type") eq = a == other_values assert isinstance(eq, bool) assert eq diff --git a/tripy/tests/frontend/test_utils.py b/tripy/tests/frontend/test_utils.py index 2d666626d..00669c095 100644 --- a/tripy/tests/frontend/test_utils.py +++ b/tripy/tests/frontend/test_utils.py @@ -16,6 +16,7 @@ # import cupy as cp +import numpy as np import tripy as tp from tripy.frontend.utils import convert_inputs_to_tensors @@ -291,3 +292,17 @@ def test_sync_arg_type_invalid(self): match=r"At least one of the arguments: \('a', 'b', 'c'\) must be a \`tripy.Tensor\`.", ): x, y, z = sync_arg_types(3.0, 3, 4) + + def test_invalid_argument_type(self): + with helper.raises( + tp.TripyException, + match=r"convert_inputs_to_tensors decorator supports conversion only for Python numbers or sequences thereof", + ): + _ = func(np.array([1, 2, 3])) + + def test_invalid_argument_type_in_sequence(self): + with helper.raises( + tp.TripyException, + match=r"convert_inputs_to_tensors decorator supports conversion only for Python numbers or sequences thereof", + ): + _ = func([np.array([1, 2, 3]), np.array([4, 5, 6])]) diff --git a/tripy/tests/frontend/trace/ops/test_binary_elementwise.py b/tripy/tests/frontend/trace/ops/test_binary_elementwise.py index 1d65dadae..d594ce027 100644 --- a/tripy/tests/frontend/trace/ops/test_binary_elementwise.py +++ b/tripy/tests/frontend/trace/ops/test_binary_elementwise.py @@ -59,10 +59,8 @@ class TestBinaryElementwise: "lhs, rhs, left_side_is_non_tensor", [ (tp.Tensor([1.0]), tp.Tensor([2.0]), False), - (tp.Tensor([1.0]), np.array([2.0], dtype=np.float32), False), # shape of (0,) is broadcastable with (1,) (tp.Tensor([], dtype=tp.float32), tp.Tensor([1.0], dtype=tp.float32), False), - (np.array([1.0], dtype=np.float32), tp.Tensor([2.0]), True), (tp.Tensor([1.0]), 2.0, False), (1.0, tp.Tensor([2.0]), True), ], @@ -86,8 +84,6 @@ def test_op_funcs(self, kind, func, lhs, rhs, left_side_is_non_tensor): "lhs, rhs, expected_rank", [ (tp.Tensor([1.0]), tp.Tensor([2.0]), 1), - (tp.Tensor([1.0]), np.array([2.0], dtype=np.float32), 1), - (np.array([1.0], dtype=np.float32), tp.Tensor([2.0]), 1), (tp.Tensor([1.0]), 2.0, 1), (1.0, tp.Tensor([2.0]), 1), (tp.ones((2, 3)), 2.0, 2), diff --git a/tripy/tests/integration/test_functional.py b/tripy/tests/integration/test_functional.py index 1b71aa507..e35fa3253 100644 --- a/tripy/tests/integration/test_functional.py +++ b/tripy/tests/integration/test_functional.py @@ -131,6 +131,7 @@ def test_multiple_copy_2(self): assert out.tolist() == [1, 2] assert out.device.kind == "gpu" + class TestConversionToTripyType: @pytest.mark.parametrize( "reverse_direction", @@ -157,8 +158,8 @@ def test_element_wise_prod(self, reverse_direction, input0, input1): if isinstance(input1, torch.Tensor): input1 = input1.to("cuda") if reverse_direction: - out = input1 * a + out = tp.Tensor(input1) * a input0, input1 = input1, input0 else: - out = a * input1 + out = a * tp.Tensor(input1) assert cp.array_equal(cp.from_dlpack(out), cp.array(input0) * cp.array(input1)) diff --git a/tripy/tests/integration/test_groupnorm.py b/tripy/tests/integration/test_groupnorm.py index a1353dd3e..1f47bcb54 100644 --- a/tripy/tests/integration/test_groupnorm.py +++ b/tripy/tests/integration/test_groupnorm.py @@ -22,10 +22,8 @@ import tripy as tp from tripy.common.exception import TripyException -DTYPES = [ - (torch.float16, tp.float16), - (torch.float32, tp.float32) -] +DTYPES = [(torch.float16, tp.float16), (torch.float32, tp.float32)] + class TestGroupNorm: @@ -49,8 +47,8 @@ def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups dtype=tp_dtype, ) - tp_groupnorm.weight = tp.Parameter(groupnorm.weight.detach()) - tp_groupnorm.bias = tp.Parameter(groupnorm.bias.detach()) + tp_groupnorm.weight = tp.Parameter(tp.Tensor(groupnorm.weight.detach())) + tp_groupnorm.bias = tp.Parameter(tp.Tensor(groupnorm.bias.detach())) input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype) tp_input = tp.Tensor(input, dtype=tp_dtype) diff --git a/tripy/tests/integration/test_layernorm.py b/tripy/tests/integration/test_layernorm.py index 82f4389d0..c6aa5b497 100644 --- a/tripy/tests/integration/test_layernorm.py +++ b/tripy/tests/integration/test_layernorm.py @@ -23,10 +23,8 @@ import tripy as tp from tests import helper -DTYPES = [ - (torch.float16, tp.float16), - (torch.float32, tp.float32) -] +DTYPES = [(torch.float16, tp.float16), (torch.float32, tp.float32)] + class TestLayerNorm: @@ -48,8 +46,8 @@ def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized ) # use Tripy's parameters - tp_layernorm.weight = tp.Parameter(layernorm.weight.detach()) - tp_layernorm.bias = tp.Parameter(layernorm.bias.detach()) + tp_layernorm.weight = tp.Parameter(tp.Tensor(layernorm.weight.detach())) + tp_layernorm.bias = tp.Parameter(tp.Tensor(layernorm.bias.detach())) input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype) tp_input = tp.Tensor(input, dtype=tp_dtype) @@ -65,7 +63,7 @@ def test_layernorm_improper_dimensions(self): tp_layernorm = tp.LayerNorm( normalized_shape=[2, 2], ) - x = tp.ones((5,5,5)) + x = tp.ones((5, 5, 5)) with helper.raises( tp.TripyException, match=re.escape("size of operand dimension 1 (5) is not compatible with size of result dimension 1 (2)"), diff --git a/tripy/tests/integration/test_linear.py b/tripy/tests/integration/test_linear.py index d55940b5d..25a3f5a16 100644 --- a/tripy/tests/integration/test_linear.py +++ b/tripy/tests/integration/test_linear.py @@ -115,7 +115,7 @@ def test_quant_linear(self, use_input_scale, quant_dtype, weight_quant_dim): ids=["block-wise", "per-tensor", "per-channel-0", "per-channel-1"], ) def test_quant_linear_int4_weight_only(self, weight_quant_dim, scale): - scale = tp.Parameter(scale) + scale = tp.Parameter(tp.Tensor(scale)) linear = tp.Linear(4, 8, quant_dtype=tp.int4, weight_quant_dim=weight_quant_dim) linear.weight_scale = scale diff --git a/tripy/tripy/frontend/shape.py b/tripy/tripy/frontend/shape.py index e803f5fdf..44e43739d 100644 --- a/tripy/tripy/frontend/shape.py +++ b/tripy/tripy/frontend/shape.py @@ -155,25 +155,30 @@ def __str__(self) -> str: # addition for shapes is concatenation, not tensor addition + def _validate_add_argument(self, other): + if isinstance(other, Shape): + return + if not isinstance(other, Sequence) or (len(other) != 0 and not isinstance(other[0], int)): + raise_error( + "Invalid types for addition with a Tripy Shape: " + "Implicit conversions are done only for sequences of Python ints. " + "Consider calling tp.Shape for an explicit conversion.", + details=[other], + ) + def __add__(self, other): from tripy.frontend.trace.ops.concatenate import concatenate - if not isinstance(other, Shape) and isinstance(other, Tensor): - raise_error( - "Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly" - ) - elif not isinstance(other, Shape): + self._validate_add_argument(other) + if not isinstance(other, Shape): other = Shape(other) return concatenate([self, other], 0) def __radd__(self, other): from tripy.frontend.trace.ops.concatenate import concatenate - if not isinstance(other, Shape) and isinstance(other, Tensor): - raise_error( - "Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly" - ) - elif not isinstance(other, Shape): + self._validate_add_argument(other) + if not isinstance(other, Shape): other = Shape(other) return concatenate([other, self], 0) @@ -191,6 +196,15 @@ def __mul__(self, other): # Only defined with a scalar argument if not isinstance(other, Tensor): + # note: Python does not accept floats as arguments for list multiplication either + if isinstance(other, Sequence): + raise_error("Attempting to multiply a Tripy Shape by a sequence, which is undefined", details=[other]) + if not isinstance(other, int): + raise_error( + "Attempting to multiply an invalid datatype with a Tripy Shape. " + "Implicit conversions are done only for Python ints. Consider calling tp.Shape explicitly.", + details=[other], + ) other = Tensor(other, dtype=int32) if other.rank >= 1: raise_error( diff --git a/tripy/tripy/frontend/utils.py b/tripy/tripy/frontend/utils.py index ea2748232..acd8b63ef 100644 --- a/tripy/tripy/frontend/utils.py +++ b/tripy/tripy/frontend/utils.py @@ -25,7 +25,7 @@ from tripy.frontend.trace.ops import BaseTraceOp -# Decorator to preprocess inputs of a function and convert numpy, python types to tripy tensors. +# Decorator to preprocess inputs of a function and convert Python numbers to tripy tensors. def convert_inputs_to_tensors( sync_arg_types: Optional[List[Tuple[str]]] = None, exclude: Optional[List[str]] = None, @@ -34,8 +34,9 @@ def convert_inputs_to_tensors( skip_num_stack_entries: int = 0, ): """ - Decorator that converts all arguments to `Tensor`s before passing them along - to the decorated function. + Decorator that converts all arguments to `Tensor`s or `Shape`s before passing them along + to the decorated function. Converts only Python numbers or lists of Python numbers; + inputs like `numpy` arrays should be handled manually. Args: sync_arg_types: A list of tuples of strings indicating the parameter indices for parameters @@ -193,6 +194,20 @@ def find_sync_target_dtype(arg_name): return None def convert_nontensor_arg(arg, list_index=None): + def is_valid_sequence(seq_arg: Sequence) -> bool: + if len(seq_arg) == 0: + return True + if isinstance(seq_arg[0], Sequence): + return is_valid_sequence(seq_arg[0]) + return isinstance(seq_arg[0], (int, float)) + + if not isinstance(arg, (int, float)) and not (isinstance(arg, Sequence) and is_valid_sequence(arg)): + raise_error( + "convert_inputs_to_tensors decorator supports conversion only" + f" for Python numbers or sequences thereof. Given argument of type {type(arg)}", + [arg], + ) + cast_dtype = find_sync_target_dtype(name) return add_column_info_for_non_tensor( arg, index, is_kwarg=name in kwargs, dtype=cast_dtype, list_index=list_index