Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tripy] Separate converting shape args from converting tensor args #217

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tripy/docs/post0_developer_guides/how-to-add-new-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ import tripy.frontend.utils as frontend_utils
# If we needed to provide any special autodoc options, we could use the `autodoc_options` parameter.
@export.public_api(document_under="tensor_operations")

# The `convert_inputs_to_tensors` decorator converts function arguments to Tensors.
# This is what makes it possible for the user to use Python numbers in Tripy functions (e.g. `tensor + 1`)
# In this case, we want `shape` to turn into a `tripy.Shape` instead of a regular `Tensor`.
@frontend_utils.convert_inputs_to_tensors(shape_argument=["shape"], exclude=["dim", "dtype"])
# The `convert_shape_inputs` decorator converts the specified function arguments into `tripy.Shape`s,
# which would allow for using Python numbers and sequences. The `convert_inputs_to_tensors` decorator more generally converts
# function arguments into Tripy tensors and is also commonly used in the codebase.
@frontend_utils.convert_shape_inputs(["shape"])
def theta(shape: Tuple[int], dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
# For any public facing interfaces, we have documentation requirements which you can read
# about in the 'Docs README' (linked below). The docstring we've implemented here
Expand Down
126 changes: 70 additions & 56 deletions tripy/tests/frontend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,111 +19,119 @@
import numpy as np

import tripy as tp
from tripy.frontend.utils import convert_inputs_to_tensors
from tripy.frontend.utils import convert_inputs_to_tensors, convert_shape_inputs
from tests import helper

# Putting underscores at the beginning and end of the names to get around the check
# for magic methods. We would not want to see this outside of tests.


@convert_inputs_to_tensors()
def func(a):
def __func_test_basic__(a):
return a


@convert_inputs_to_tensors()
def multi_input(a, b, c):
def __func_test_multi_input__(a, b, c):
return a, b, c


@convert_inputs_to_tensors(sync_arg_types=[("a", "b", "c")])
def sync_arg_types(a, b, c):
def __func_test_sync_arg_types__(a, b, c):
return a, b, c


@convert_inputs_to_tensors()
def variadic_positional_args(*args):
def __func_test_variadic_positional_args__(*args):
return args


@convert_inputs_to_tensors()
def arg_before_variadic_positional_args(x, *args):
def __func_test_arg_before_variadic_positional_args__(x, *args):
return (x,) + args


@convert_inputs_to_tensors()
def kwarg_after_variadic_positional_args(*args, y):
def __func_test_kwarg_after_variadic_positional_args__(*args, y):
return args + (y,)


@convert_inputs_to_tensors(unpack_argument=["xs"])
def convert_list_input(xs):
def __func_test_convert_list_input__(xs):
return xs


@convert_inputs_to_tensors(sync_arg_types=[("xs",)], unpack_argument=["xs"])
def sync_within_list(xs):
def __func_test_sync_within_list__(xs):
return xs


@convert_inputs_to_tensors(sync_arg_types=[("x", "ys")], unpack_argument=["ys"])
def sync_single_type_to_list(x, ys):
def __func_test_sync_single_type_to_list__(x, ys):
return x, ys


@convert_inputs_to_tensors(sync_arg_types=[("xs", "y")], unpack_argument=["xs"])
def sync_list_type_to_single(xs, y):
def __func_test_sync_list_type_to_single__(xs, y):
return xs, y


@convert_inputs_to_tensors(sync_arg_types=[("xs", "ys")], unpack_argument=["xs", "ys"])
def sync_list_types(xs, ys):
def __func_test_sync_list_types__(xs, ys):
return xs, ys


@convert_inputs_to_tensors(shape_argument=["s"])
@convert_shape_inputs(["s"])
def convert_shape(s):
return s


@convert_shape_inputs(["s"])
def ignore_not_named(s, t):
return (s, t)


class TestConvertInputsToTensors:
def test_args(self):
assert isinstance(func(0), tp.Tensor)
assert isinstance(__func_test_basic__(0), tp.Tensor)

def test_kwargs(self):
assert isinstance(func(a=0), tp.Tensor)
assert isinstance(__func_test_basic__(a=0), tp.Tensor)

def test_convert_list_into_tensor(self):
t1 = func([1, 2, 3])
t1 = __func_test_basic__([1, 2, 3])
assert isinstance(t1, tp.Tensor)
assert t1.shape == [3]

t2 = func([[1, 2], [3, 4]])
t2 = __func_test_basic__([[1, 2], [3, 4]])
assert t2.shape == [2, 2]

def test_convert_list_input(self):
xs = convert_list_input([1.0, 2.0, 3.0, 4.0])
xs = __func_test_convert_list_input__([1.0, 2.0, 3.0, 4.0])
assert len(xs) == 4
for x in xs:
assert isinstance(x, tp.Tensor)
assert not convert_list_input([])
assert not __func_test_convert_list_input__([])

def test_convert_tuple_input(self):
xs = convert_list_input((1.0, 2.0))
xs = __func_test_convert_list_input__((1.0, 2.0))
assert isinstance(xs, tuple)
assert len(xs) == 2
assert isinstance(xs[0], tp.Tensor)
assert isinstance(xs[1], tp.Tensor)

def test_variadic_positional_args(self):
x, y = variadic_positional_args(1.0, 2.0)
x, y = __func_test_variadic_positional_args__(1.0, 2.0)
assert isinstance(x, tp.Tensor)
assert isinstance(y, tp.Tensor)

def test_arg_before_variadic_positional_args(self):
x, y = arg_before_variadic_positional_args(1.0, 2.0)
x, y = __func_test_arg_before_variadic_positional_args__(1.0, 2.0)
assert isinstance(x, tp.Tensor)
assert isinstance(y, tp.Tensor)

def test_kwarg_after_variadic_positional_args(self):
x, y = kwarg_after_variadic_positional_args(1.0, y=2.0)
x, y = __func_test_kwarg_after_variadic_positional_args__(1.0, y=2.0)
assert isinstance(x, tp.Tensor)
assert isinstance(y, tp.Tensor)

Expand Down Expand Up @@ -171,29 +179,35 @@ def test_convert_shape_unsqueeze_tensors(self):
assert isinstance(s, tp.Shape)
assert cp.from_dlpack(s).get().tolist() == [1, 2]

def test_convert_only_specified_argument_to_shape(self):
t1 = tp.Tensor([1, 2, 3], dtype=tp.int32)
s, t2 = ignore_not_named(t1, [4, 5, 6])
assert isinstance(s, tp.Shape)
assert t2 == [4, 5, 6]

# When we convert arguments to tensors, we should preserve the column range
# of the original non-Tensor argument.
def test_includes_column_range_for_non_tensors(self):
tensor = func(3.0)
tensor = __func_test_basic__(3.0)

# Column offset of the `3.0` above.
assert tensor.stack_info[tensor.stack_info.get_first_user_frame_index()].column_range == (22, 25)
assert tensor.stack_info[tensor.stack_info.get_first_user_frame_index()].column_range == (37, 40)

def test_includes_column_range_for_non_tensors_multiple_inputs(self):
a, b, c = multi_input(1, 2.0, 3)
a, b, c = __func_test_multi_input__(1, 2.0, 3)

# Column offsets of the arguments above.
assert a.stack_info[a.stack_info.get_first_user_frame_index()].column_range == (30, 31)
assert b.stack_info[b.stack_info.get_first_user_frame_index()].column_range == (33, 36)
assert c.stack_info[c.stack_info.get_first_user_frame_index()].column_range == (38, 39)
assert a.stack_info[a.stack_info.get_first_user_frame_index()].column_range == (44, 45)
assert b.stack_info[b.stack_info.get_first_user_frame_index()].column_range == (47, 50)
assert c.stack_info[c.stack_info.get_first_user_frame_index()].column_range == (52, 53)

def test_includes_column_range_for_non_tensors_multiple_inputs_with_kwargs(self):
a, b, c = multi_input(1, b=2.0, c=3)
a, b, c = __func_test_multi_input__(1, b=2.0, c=3)

# Column offsets of the arguments above.
assert a.stack_info[a.stack_info.get_first_user_frame_index()].column_range == (30, 31)
assert b.stack_info[b.stack_info.get_first_user_frame_index()].column_range == (33, 38)
assert c.stack_info[c.stack_info.get_first_user_frame_index()].column_range == (40, 43)
assert a.stack_info[a.stack_info.get_first_user_frame_index()].column_range == (44, 45)
assert b.stack_info[b.stack_info.get_first_user_frame_index()].column_range == (47, 52)
assert c.stack_info[c.stack_info.get_first_user_frame_index()].column_range == (54, 57)

def test_includes_column_range_for_non_tensors_for_magic_methods(self):
c = tp.ones((2, 3)) + 3
Expand All @@ -212,33 +226,33 @@ def test_includes_column_range_for_non_tensors_for_magic_methods_with_kwargs(sel
assert stack_info[stack_info.get_first_user_frame_index()].column_range == (36, 43)

def test_includes_column_range_for_list_elements(self):
xs = convert_list_input([1.0, 2.0])
assert xs[0].stack_info[xs[0].stack_info.get_first_user_frame_index()].column_range == (33, 36)
assert xs[1].stack_info[xs[1].stack_info.get_first_user_frame_index()].column_range == (38, 41)
xs = __func_test_convert_list_input__([1.0, 2.0])
assert xs[0].stack_info[xs[0].stack_info.get_first_user_frame_index()].column_range == (47, 50)
assert xs[1].stack_info[xs[1].stack_info.get_first_user_frame_index()].column_range == (52, 55)

def test_includes_column_range_for_tuple_elements(self):
xs = convert_list_input((1.0, 2.0))
assert xs[0].stack_info[xs[0].stack_info.get_first_user_frame_index()].column_range == (33, 36)
assert xs[1].stack_info[xs[1].stack_info.get_first_user_frame_index()].column_range == (38, 41)
xs = __func_test_convert_list_input__((1.0, 2.0))
assert xs[0].stack_info[xs[0].stack_info.get_first_user_frame_index()].column_range == (47, 50)
assert xs[1].stack_info[xs[1].stack_info.get_first_user_frame_index()].column_range == (52, 55)

def test_sync_arg_type_includes_non_tensor_column_range(self):
x, y, z = sync_arg_types(tp.Tensor(3.0, dtype=tp.float16), 3, 4.0)
x, y, z = __func_test_sync_arg_types__(tp.Tensor(3.0, dtype=tp.float16), 3, 4.0)

assert y.dtype == tp.float16
assert z.dtype == tp.float16
assert y.stack_info[y.stack_info.get_first_user_frame_index()].column_range == (67, 68)
assert z.stack_info[z.stack_info.get_first_user_frame_index()].column_range == (70, 73)
assert y.stack_info[y.stack_info.get_first_user_frame_index()].column_range == (81, 82)
assert z.stack_info[z.stack_info.get_first_user_frame_index()].column_range == (84, 87)

def test_sync_arg_type_includes_non_tensor_column_range_with_kwargs(self):
x, y, z = sync_arg_types(tp.Tensor(3.0, dtype=tp.float16), b=3, c=4.0)
x, y, z = __func_test_sync_arg_types__(tp.Tensor(3.0, dtype=tp.float16), b=3, c=4.0)

assert y.dtype == tp.float16
assert z.dtype == tp.float16
assert y.stack_info[y.stack_info.get_first_user_frame_index()].column_range == (67, 70)
assert z.stack_info[z.stack_info.get_first_user_frame_index()].column_range == (72, 77)
assert y.stack_info[y.stack_info.get_first_user_frame_index()].column_range == (81, 84)
assert z.stack_info[z.stack_info.get_first_user_frame_index()].column_range == (86, 91)

def test_sync_arg_type_not_applied_to_tensors(self):
x, y, z = sync_arg_types(
x, y, z = __func_test_sync_arg_types__(
tp.Tensor(3.0),
tp.Tensor(3, dtype=tp.int32),
tp.Tensor(4, dtype=tp.float16),
Expand All @@ -249,36 +263,36 @@ def test_sync_arg_type_not_applied_to_tensors(self):
assert z.dtype == tp.float16

def test_sync_arg_type_within_list(self):
xs = sync_within_list([1.0, tp.Tensor(3, dtype=tp.float16), 5])
xs = __func_test_sync_within_list__([1.0, tp.Tensor(3, dtype=tp.float16), 5])

assert xs[0].dtype == tp.float16
assert xs[1].dtype == tp.float16
assert xs[2].dtype == tp.float16

def test_sync_single_arg_type_to_list(self):
_, ys = sync_single_type_to_list(tp.Tensor(5, dtype=tp.int32), [2.0, 3.0, 4.0])
_, ys = __func_test_sync_single_type_to_list__(tp.Tensor(5, dtype=tp.int32), [2.0, 3.0, 4.0])

assert ys[0].dtype == tp.int32
assert ys[1].dtype == tp.int32
assert ys[2].dtype == tp.int32

def test_sync_list_arg_type_to_single_arg(self):
xs, y = sync_list_type_to_single([1.0, tp.Tensor(5, dtype=tp.int32), 4.0], 1.0)
xs, y = __func_test_sync_list_type_to_single__([1.0, tp.Tensor(5, dtype=tp.int32), 4.0], 1.0)

assert xs[0].dtype == tp.int32
assert xs[2].dtype == tp.int32
assert y.dtype == tp.int32

def test_sync_list_arg_types(self):
xs, ys = sync_list_types([1.0, 2.0, 3.0], [3, 4, tp.Tensor(6, dtype=tp.int32)])
xs, ys = __func_test_sync_list_types__([1.0, 2.0, 3.0], [3, 4, tp.Tensor(6, dtype=tp.int32)])

for x in xs:
assert x.dtype == tp.int32
for y in ys:
assert y.dtype == tp.int32

def test_sync_arg_type_list_not_applied_to_tensors(self):
xs = sync_within_list(
xs = __func_test_sync_within_list__(
[tp.Tensor(1.0, dtype=tp.int32), tp.Tensor(3, dtype=tp.float16), tp.Tensor(5, dtype=tp.float32)]
)

Expand All @@ -291,30 +305,30 @@ def test_sync_arg_type_invalid(self):
tp.TripyException,
match=r"At least one of the arguments: \('a', 'b', 'c'\) must be a \`tripy.Tensor\`.",
):
x, y, z = sync_arg_types(3.0, 3, 4)
x, y, z = __func_test_sync_arg_types__(3.0, 3, 4)

def test_seq_arg_invalid(self):
with helper.raises(
tp.TripyException,
match=r"Encountered non-number of type str in sequence: hello",
):
_ = func([1, 2, "hello"])
_ = __func_test_basic__([1, 2, "hello"])

def test_nested_seq_inconsistent_len(self):
with helper.raises(
tp.TripyException,
match=r"Expected a sequence of length 3 but got length 4: \[7, 8, 9, 10\]",
):
_ = func([[1, 2, 3], [4, 5, 6], [7, 8, 9, 10]])
_ = __func_test_basic__([[1, 2, 3], [4, 5, 6], [7, 8, 9, 10]])

def test_nested_seq_inconsistent_types(self):
with helper.raises(
tp.TripyException,
match=r"Expected a sequence but got str: hi",
):
_ = func([[1, 2, 3], [4, 5, 6], "hi"])
_ = __func_test_basic__([[1, 2, 3], [4, 5, 6], "hi"])

def test_invalid_argument_type_not_converted(self):
a = np.array([1, 2, 3])
b = func(np.array([1, 2, 3]))
b = __func_test_basic__(np.array([1, 2, 3]))
assert (a == b).all()
9 changes: 9 additions & 0 deletions tripy/tests/frontend/trace/ops/test_binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
(Comparison.Kind.GREATER, lambda a, b: a > b),
]

# Ops that do not automatically convert data types
_NON_CONVERTING_OPS = {BinaryElementwise.Kind.MAXIMUM, BinaryElementwise.Kind.MINIMUM}

# Ops that are flipped instead of calling a right-side version.
_FLIP_OPS = {}
for key, val in {
Expand Down Expand Up @@ -68,6 +71,12 @@ class TestBinaryElementwise:
)
@pytest.mark.parametrize("kind, func", _BINARY_OPS + _COMPARISON_OPS)
def test_op_funcs(self, kind, func, lhs, rhs, left_side_is_non_tensor):
if kind in _NON_CONVERTING_OPS:
if not isinstance(lhs, tp.Tensor):
lhs = tp.Tensor(lhs)
if not isinstance(rhs, tp.Tensor):
rhs = tp.Tensor(rhs)

out = func(lhs, rhs)
assert isinstance(out, tp.Tensor)
assert isinstance(out.trace_tensor.producer, BinaryElementwise)
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def apply_color(inp, color):


def _make_stack_info_message(stack_info: "utils.StackInfo", enable_color: bool = True) -> Optional[str]:
from tripy.frontend.utils import convert_inputs_to_tensors
from tripy.frontend.utils import convert_inputs_to_tensors, convert_shape_inputs

EXCLUDE_FUNCTIONS = [convert_inputs_to_tensors]
EXCLUDE_FUNCTIONS = [convert_inputs_to_tensors, convert_shape_inputs]

def should_exclude(frame):
for func in EXCLUDE_FUNCTIONS:
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/ops/tensor_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


@export.public_api(document_under="operations/initializers")
@frontend_utils.convert_inputs_to_tensors(shape_argument=["shape"], exclude=["dtype"])
@frontend_utils.convert_shape_inputs(["shape"])
@constraints.dtype_info(
dtype_variables={
"T1": ["int32"],
Expand Down Expand Up @@ -64,7 +64,7 @@ def ones(


@export.public_api(document_under="operations/initializers")
@frontend_utils.convert_inputs_to_tensors(shape_argument=["shape"], exclude=["dtype"])
@frontend_utils.convert_shape_inputs(["shape"])
@constraints.dtype_info(
dtype_variables={
"T1": ["int32"],
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def __mul__(self, other):
# so we should clamp the argument
if other.rank == 0:
other = reshape(other, (1,))
other = Shape(maximum(other, 0)) + [len(self)]
other = Shape(maximum(other, Tensor(0, int32))) + [len(self)]

unsqueezed = unsqueeze(self, 0)
tiled = expand(unsqueezed, other)
Expand All @@ -291,7 +291,7 @@ def __mul__(self, other):
def __rmul__(self, other):
return self.__mul__(other)

@frontend_utils.convert_inputs_to_tensors(shape_argument=["other"])
@frontend_utils.convert_shape_inputs(["other"])
def __eq__(self, other):
from tripy.frontend.trace.ops.reduce import all

Expand Down
Loading