Skip to content

Commit

Permalink
Do not perform implicit conversions of external datatypes in convert_…
Browse files Browse the repository at this point in the history
…inputs_to_tensors decorator
  • Loading branch information
slyubomirsky committed Sep 5, 2024
1 parent 8599c65 commit e0f222d
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 45 deletions.
5 changes: 2 additions & 3 deletions tripy/examples/nanogpt/weight_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
Expand Down Expand Up @@ -50,7 +49,7 @@ def load_weights_from_hf(model, model_type, dtype):
weight = hf_state_dict[key].t().contiguous()
if "ln" not in key:
weight = weight.to(torch_dtype)
param = tp.Parameter(weight)
param = tp.Parameter(tp.Tensor(weight))
tripy_state_dict[key] = param

model.load_from_state_dict(tripy_state_dict)
Expand Down Expand Up @@ -110,7 +109,7 @@ def get_submodule(module, attr_name):

if "ln" not in key:
weight = weight.to(torch_dtype)
param = tp.Parameter(weight.contiguous())
param = tp.Parameter(tp.Tensor(weight.contiguous()))
tripy_state_dict[key] = param

model.load_from_state_dict(tripy_state_dict)
Expand Down
29 changes: 20 additions & 9 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ def values(request):


@pytest.fixture(
params=[[4, 5], tp.Tensor([4, 5], dtype=tp.int32), np.array([4, 5], dtype=np.int32)],
params=[[4, 5], tp.Tensor([4, 5], dtype=tp.int32)],
ids=[
"python_list",
"tripy_tensor",
"numpy_array",
],
)
def other_values(request):
Expand Down Expand Up @@ -95,7 +94,7 @@ def test_plus_override(self, values, other_values):
appended = [4, 5]
s = tp.Shape(values)

# conversion is implicit except for tp.Tensor
# conversion must be explicit for tp.Tensor
lhs_shape = other_values if not isinstance(other_values, tp.Tensor) else tp.Shape(other_values)
new_shape = s + lhs_shape
assert isinstance(new_shape, tp.Shape)
Expand Down Expand Up @@ -308,7 +307,7 @@ def test_right_addition(self, other_values):
appended = [4, 5]
s = tp.Shape(values)

# conversion is implicit except for tp.Tensor
# conversion must be explicit for tp.Tensor
rhs_shape = other_values if not isinstance(other_values, tp.Tensor) else tp.Shape(other_values)

new_shape = rhs_shape + s
Expand Down Expand Up @@ -418,19 +417,29 @@ def test_invalid_input_rank_tensor(self):
with raises(tp.TripyException, match="Shape tensors must be of rank 1, but input tensor is rank 2"):
_ = tp.Shape(tp.ones((3, 2), dtype=tp.int32))

def test_invalid_mul_sequence(self, values):
s = tp.Shape(values)
with raises(tp.TripyException, match="Attempting to multiply a Tripy Shape by a sequence, which is undefined"):
_ = s * values

def test_invalid_mul_rank(self, values):
s = tp.Shape(values)
t = tp.Tensor(values)
with raises(
tp.TripyException, match="Attempting to multiply a Tripy Shape by a tensor of rank >= 1, which is undefined"
):
_ = s * values
_ = s * t

def test_invalid_plus_type(self, values):
s = tp.Shape(values)
t = tp.Tensor(values, dtype=tp.int32)
with raises(
tp.TripyException,
match="Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly",
match=(
"Invalid types for addition with a Tripy Shape: "
"Implicit conversions are done only for sequences of Python ints. "
"Consider calling tp.Shape for an explicit conversion."
),
):
s + t

Expand All @@ -439,7 +448,11 @@ def test_invalid_right_addition_type(self, values):
t = tp.Tensor(values, dtype=tp.int32)
with raises(
tp.TripyException,
match="Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly",
match=(
"Invalid types for addition with a Tripy Shape: "
"Implicit conversions are done only for sequences of Python ints. "
"Consider calling tp.Shape for an explicit conversion."
),
):
t + s

Expand Down Expand Up @@ -469,8 +482,6 @@ def test_unary_elementwise_fails_at_run_time(self, values):

def test_shape_equality(self, other_values):
a = tp.Shape([4, 5])
if isinstance(other_values, np.ndarray):
pytest.skip("numpy array cannot be implicitly cast to Shape type")
eq = a == other_values
assert isinstance(eq, bool)
assert eq
Expand Down
15 changes: 15 additions & 0 deletions tripy/tests/frontend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import cupy as cp
import numpy as np

import tripy as tp
from tripy.frontend.utils import convert_inputs_to_tensors
Expand Down Expand Up @@ -291,3 +292,17 @@ def test_sync_arg_type_invalid(self):
match=r"At least one of the arguments: \('a', 'b', 'c'\) must be a \`tripy.Tensor\`.",
):
x, y, z = sync_arg_types(3.0, 3, 4)

def test_invalid_argument_type(self):
with helper.raises(
tp.TripyException,
match=r"convert_inputs_to_tensors decorator supports conversion only for Python numbers or sequences thereof",
):
_ = func(np.array([1, 2, 3]))

def test_invalid_argument_type_in_sequence(self):
with helper.raises(
tp.TripyException,
match=r"convert_inputs_to_tensors decorator supports conversion only for Python numbers or sequences thereof",
):
_ = func([np.array([1, 2, 3]), np.array([4, 5, 6])])
4 changes: 0 additions & 4 deletions tripy/tests/frontend/trace/ops/test_binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ class TestBinaryElementwise:
"lhs, rhs, left_side_is_non_tensor",
[
(tp.Tensor([1.0]), tp.Tensor([2.0]), False),
(tp.Tensor([1.0]), np.array([2.0], dtype=np.float32), False),
# shape of (0,) is broadcastable with (1,)
(tp.Tensor([], dtype=tp.float32), tp.Tensor([1.0], dtype=tp.float32), False),
(np.array([1.0], dtype=np.float32), tp.Tensor([2.0]), True),
(tp.Tensor([1.0]), 2.0, False),
(1.0, tp.Tensor([2.0]), True),
],
Expand All @@ -86,8 +84,6 @@ def test_op_funcs(self, kind, func, lhs, rhs, left_side_is_non_tensor):
"lhs, rhs, expected_rank",
[
(tp.Tensor([1.0]), tp.Tensor([2.0]), 1),
(tp.Tensor([1.0]), np.array([2.0], dtype=np.float32), 1),
(np.array([1.0], dtype=np.float32), tp.Tensor([2.0]), 1),
(tp.Tensor([1.0]), 2.0, 1),
(1.0, tp.Tensor([2.0]), 1),
(tp.ones((2, 3)), 2.0, 2),
Expand Down
5 changes: 3 additions & 2 deletions tripy/tests/integration/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_multiple_copy_2(self):
assert out.tolist() == [1, 2]
assert out.device.kind == "gpu"


class TestConversionToTripyType:
@pytest.mark.parametrize(
"reverse_direction",
Expand All @@ -157,8 +158,8 @@ def test_element_wise_prod(self, reverse_direction, input0, input1):
if isinstance(input1, torch.Tensor):
input1 = input1.to("cuda")
if reverse_direction:
out = input1 * a
out = tp.Tensor(input1) * a
input0, input1 = input1, input0
else:
out = a * input1
out = a * tp.Tensor(input1)
assert cp.array_equal(cp.from_dlpack(out), cp.array(input0) * cp.array(input1))
10 changes: 4 additions & 6 deletions tripy/tests/integration/test_groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
import tripy as tp
from tripy.common.exception import TripyException

DTYPES = [
(torch.float16, tp.float16),
(torch.float32, tp.float32)
]
DTYPES = [(torch.float16, tp.float16), (torch.float32, tp.float32)]


class TestGroupNorm:

Expand All @@ -49,8 +47,8 @@ def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups
dtype=tp_dtype,
)

tp_groupnorm.weight = tp.Parameter(groupnorm.weight.detach())
tp_groupnorm.bias = tp.Parameter(groupnorm.bias.detach())
tp_groupnorm.weight = tp.Parameter(tp.Tensor(groupnorm.weight.detach()))
tp_groupnorm.bias = tp.Parameter(tp.Tensor(groupnorm.bias.detach()))

input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype)
tp_input = tp.Tensor(input, dtype=tp_dtype)
Expand Down
12 changes: 5 additions & 7 deletions tripy/tests/integration/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
import tripy as tp
from tests import helper

DTYPES = [
(torch.float16, tp.float16),
(torch.float32, tp.float32)
]
DTYPES = [(torch.float16, tp.float16), (torch.float32, tp.float32)]


class TestLayerNorm:

Expand All @@ -48,8 +46,8 @@ def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized
)

# use Tripy's parameters
tp_layernorm.weight = tp.Parameter(layernorm.weight.detach())
tp_layernorm.bias = tp.Parameter(layernorm.bias.detach())
tp_layernorm.weight = tp.Parameter(tp.Tensor(layernorm.weight.detach()))
tp_layernorm.bias = tp.Parameter(tp.Tensor(layernorm.bias.detach()))

input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype)
tp_input = tp.Tensor(input, dtype=tp_dtype)
Expand All @@ -65,7 +63,7 @@ def test_layernorm_improper_dimensions(self):
tp_layernorm = tp.LayerNorm(
normalized_shape=[2, 2],
)
x = tp.ones((5,5,5))
x = tp.ones((5, 5, 5))
with helper.raises(
tp.TripyException,
match=re.escape("size of operand dimension 1 (5) is not compatible with size of result dimension 1 (2)"),
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/integration/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_quant_linear(self, use_input_scale, quant_dtype, weight_quant_dim):
ids=["block-wise", "per-tensor", "per-channel-0", "per-channel-1"],
)
def test_quant_linear_int4_weight_only(self, weight_quant_dim, scale):
scale = tp.Parameter(scale)
scale = tp.Parameter(tp.Tensor(scale))

linear = tp.Linear(4, 8, quant_dtype=tp.int4, weight_quant_dim=weight_quant_dim)
linear.weight_scale = scale
Expand Down
34 changes: 24 additions & 10 deletions tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,30 @@ def __str__(self) -> str:

# addition for shapes is concatenation, not tensor addition

def _validate_add_argument(self, other):
if isinstance(other, Shape):
return
if not isinstance(other, Sequence) or (len(other) != 0 and not isinstance(other[0], int)):
raise_error(
"Invalid types for addition with a Tripy Shape: "
"Implicit conversions are done only for sequences of Python ints. "
"Consider calling tp.Shape for an explicit conversion.",
details=[other],
)

def __add__(self, other):
from tripy.frontend.trace.ops.concatenate import concatenate

if not isinstance(other, Shape) and isinstance(other, Tensor):
raise_error(
"Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly"
)
elif not isinstance(other, Shape):
self._validate_add_argument(other)
if not isinstance(other, Shape):
other = Shape(other)
return concatenate([self, other], 0)

def __radd__(self, other):
from tripy.frontend.trace.ops.concatenate import concatenate

if not isinstance(other, Shape) and isinstance(other, Tensor):
raise_error(
"Attempting to add a Tripy Tensor to a Tripy Shape, which is not allowed. Consider calling tp.Shape explicitly"
)
elif not isinstance(other, Shape):
self._validate_add_argument(other)
if not isinstance(other, Shape):
other = Shape(other)
return concatenate([other, self], 0)

Expand All @@ -191,6 +196,15 @@ def __mul__(self, other):

# Only defined with a scalar argument
if not isinstance(other, Tensor):
# note: Python does not accept floats as arguments for list multiplication either
if isinstance(other, Sequence):
raise_error("Attempting to multiply a Tripy Shape by a sequence, which is undefined", details=[other])
if not isinstance(other, int):
raise_error(
"Attempting to multiply an invalid datatype with a Tripy Shape. "
"Implicit conversions are done only for Python ints. Consider calling tp.Shape explicitly.",
details=[other],
)
other = Tensor(other, dtype=int32)
if other.rank >= 1:
raise_error(
Expand Down
21 changes: 18 additions & 3 deletions tripy/tripy/frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tripy.frontend.trace.ops import BaseTraceOp


# Decorator to preprocess inputs of a function and convert numpy, python types to tripy tensors.
# Decorator to preprocess inputs of a function and convert Python numbers to tripy tensors.
def convert_inputs_to_tensors(
sync_arg_types: Optional[List[Tuple[str]]] = None,
exclude: Optional[List[str]] = None,
Expand All @@ -34,8 +34,9 @@ def convert_inputs_to_tensors(
skip_num_stack_entries: int = 0,
):
"""
Decorator that converts all arguments to `Tensor`s before passing them along
to the decorated function.
Decorator that converts all arguments to `Tensor`s or `Shape`s before passing them along
to the decorated function. Converts only Python numbers or lists of Python numbers;
inputs like `numpy` arrays should be handled manually.
Args:
sync_arg_types: A list of tuples of strings indicating the parameter indices for parameters
Expand Down Expand Up @@ -193,6 +194,20 @@ def find_sync_target_dtype(arg_name):
return None

def convert_nontensor_arg(arg, list_index=None):
def is_valid_sequence(seq_arg: Sequence) -> bool:
if len(seq_arg) == 0:
return True
if isinstance(seq_arg[0], Sequence):
return is_valid_sequence(seq_arg[0])
return isinstance(seq_arg[0], (int, float))

if not isinstance(arg, (int, float)) and not (isinstance(arg, Sequence) and is_valid_sequence(arg)):
raise_error(
"convert_inputs_to_tensors decorator supports conversion only"
f" for Python numbers or sequences thereof. Given argument of type {type(arg)}",
[arg],
)

cast_dtype = find_sync_target_dtype(name)
return add_column_info_for_non_tensor(
arg, index, is_kwarg=name in kwargs, dtype=cast_dtype, list_index=list_index
Expand Down

0 comments on commit e0f222d

Please sign in to comment.