From 9b1df32780dd2b84e87f918f4b8bd0db4fae3fd9 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 10 Sep 2024 23:27:57 -0400 Subject: [PATCH] Move registry for typing --- tripy/tests/test_function_registry.py | 14 +++---- .../frontend/trace/ops/binary_elementwise.py | 42 ++++++++++++------- tripy/tripy/frontend/trace/ops/dequantize.py | 2 +- tripy/tripy/frontend/trace/ops/quantize.py | 2 +- tripy/tripy/{common => }/types.py | 23 +++++----- 5 files changed, 49 insertions(+), 34 deletions(-) rename tripy/tripy/{common => }/types.py (58%) diff --git a/tripy/tests/test_function_registry.py b/tripy/tests/test_function_registry.py index 102daf31..5f86a9c1 100644 --- a/tripy/tests/test_function_registry.py +++ b/tripy/tests/test_function_registry.py @@ -318,7 +318,7 @@ def func(n: Sequence[Union[int, Sequence[int]]]) -> int: def test_tensor_literal(self, registry): @registry("test") - def func(n: "tripy.TensorLiteral.sig"): + def func(n: "tripy.types.tensor_literal"): return n assert registry["test"](1) == 1 @@ -472,7 +472,7 @@ def func(n: Sequence[Union[int, float]]) -> int: def test_error_tensor_literal_not_sequence(self, registry): @registry("test") - def func(n: "tripy.TensorLiteral.sig"): + def func(n: "tripy.types.tensor_literal"): return n with helper.raises( @@ -484,11 +484,11 @@ def func(n: "tripy.TensorLiteral.sig"): --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m \| - [0-9]+ \| def func\(n: \"tripy\.TensorLiteral\.sig\"\): + [0-9]+ \| def func\(n: \"tripy\.types\.tensor_literal\"\): [0-9]+ \| \.\.\. \|\s - Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.TensorLiteral\.sig'\)\]\]' but got argument of type: 'str'\. + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.tensor_literal'\)\]\]' but got argument of type: 'str'\. """ ).strip(), ): @@ -496,7 +496,7 @@ def func(n: "tripy.TensorLiteral.sig"): def test_error_tensor_literal_not_sequence_of_numbers(self, registry): @registry("test") - def func(n: "tripy.TensorLiteral.sig"): + def func(n: "tripy.types.tensor_literal"): return n with helper.raises( @@ -508,11 +508,11 @@ def func(n: "tripy.TensorLiteral.sig"): --> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m \| - [0-9]+ \| def func\(n: \"tripy\.TensorLiteral\.sig\"\): + [0-9]+ \| def func\(n: \"tripy\.types\.tensor_literal\"\): [0-9]+ \| \.\.\. \|\s - Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.TensorLiteral\.sig'\)\]\]' but got argument of type: 'List\[List\[str\]\]' + Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.tensor_literal'\)\]\]' but got argument of type: 'List\[List\[str\]\]' """ ).strip(), ): diff --git a/tripy/tripy/frontend/trace/ops/binary_elementwise.py b/tripy/tripy/frontend/trace/ops/binary_elementwise.py index 0b6a6f19..b26c91cc 100644 --- a/tripy/tripy/frontend/trace/ops/binary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/binary_elementwise.py @@ -184,7 +184,8 @@ def to_flat_ir(self, inputs, outputs): function_name="__radd__", ) def __add__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs an elementwise sum. @@ -217,7 +218,8 @@ def __add__( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, ) def __sub__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs an elementwise subtraction. @@ -250,7 +252,7 @@ def __sub__( dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"}, ) def __rsub__( - self: "tripy.TensorLiteral.sig", other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: "tripy.types.tensor_literal", other: Union["tripy.Tensor", "tripy.types.tensor_literal"] ) -> "tripy.Tensor": """ Performs an elementwise subtraction. @@ -283,7 +285,8 @@ def __rsub__( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, ) def __pow__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs an elementwise exponentiation. @@ -354,7 +357,8 @@ def __rpow__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tripy. function_name="__rmul__", ) def __mul__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs an elementwise multiplication. @@ -387,7 +391,8 @@ def __mul__( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, ) def __truediv__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs an elementwise division. @@ -419,7 +424,7 @@ def __truediv__( dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]}, dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"}, ) -def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]) -> "tripy.Tensor": +def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", "tripy.types.tensor_literal"]) -> "tripy.Tensor": """ Performs an elementwise division. @@ -451,7 +456,7 @@ def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", "tripy.Tenso dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"}, ) def maximum( - lhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], rhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + lhs: Union["tripy.Tensor", "tripy.types.tensor_literal"], rhs: Union["tripy.Tensor", "tripy.types.tensor_literal"] ) -> "tripy.Tensor": """ Performs an elementwise maximum. @@ -484,7 +489,7 @@ def maximum( dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"}, ) def minimum( - lhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], rhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + lhs: Union["tripy.Tensor", "tripy.types.tensor_literal"], rhs: Union["tripy.Tensor", "tripy.types.tensor_literal"] ) -> "tripy.Tensor": """ Performs an elementwise minimum. @@ -520,7 +525,8 @@ def minimum( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) def __lt__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs a 'less than' comparison. @@ -556,7 +562,8 @@ def __lt__( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) def __le__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs a 'less than or equal' comparison. @@ -592,7 +599,8 @@ def __le__( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) def __eq__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs an 'equal' comparison. @@ -627,7 +635,9 @@ def __eq__( }, dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) -def __ne__(self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor": +def __ne__( + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], other: Union["tripy.Tensor", Any] +) -> "tripy.Tensor": """ Performs a 'not equal' comparison. @@ -662,7 +672,8 @@ def __ne__(self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union[ dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) def __ge__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs a 'greater than or equal' comparison. @@ -698,7 +709,8 @@ def __ge__( dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"}, ) def __gt__( - self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"] + self: Union["tripy.Tensor", "tripy.types.tensor_literal"], + other: Union["tripy.Tensor", "tripy.types.tensor_literal"], ) -> "tripy.Tensor": """ Performs a 'greater than' comparison. diff --git a/tripy/tripy/frontend/trace/ops/dequantize.py b/tripy/tripy/frontend/trace/ops/dequantize.py index a2f3d946..008004bd 100644 --- a/tripy/tripy/frontend/trace/ops/dequantize.py +++ b/tripy/tripy/frontend/trace/ops/dequantize.py @@ -111,7 +111,7 @@ def to_flat_ir(self, inputs, outputs): @frontend_utils.convert_inputs_to_tensors(exclude=["dtype", "dim"]) def dequantize( input: "tripy.Tensor", - scale: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], + scale: Union["tripy.Tensor", "tripy.types.tensor_literal"], dtype: datatype.dtype, dim: Union[int, Any] = None, ) -> "tripy.Tensor": diff --git a/tripy/tripy/frontend/trace/ops/quantize.py b/tripy/tripy/frontend/trace/ops/quantize.py index e0bc210e..8855dcab 100644 --- a/tripy/tripy/frontend/trace/ops/quantize.py +++ b/tripy/tripy/frontend/trace/ops/quantize.py @@ -135,7 +135,7 @@ def to_flat_ir(self, inputs, outputs): ) def quantize( input: "tripy.Tensor", - scale: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], + scale: Union["tripy.Tensor", "tripy.types.tensor_literal"], dtype: datatype.dtype, dim: Union[int, Any] = None, ) -> "tripy.Tensor": diff --git a/tripy/tripy/common/types.py b/tripy/tripy/types.py similarity index 58% rename from tripy/tripy/common/types.py rename to tripy/tripy/types.py index bfc19044..827fde9b 100644 --- a/tripy/tripy/common/types.py +++ b/tripy/tripy/types.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,17 +16,20 @@ # import numbers +import sys from typing import Union, Sequence from tripy import export +export.public_api()(sys.modules[__name__]) -@export.public_api(document_under="types.rst") -class TensorLiteral: - """ - The `sig` member of this class can be used as a type annotation for tensor literals. - A tensor literal can be a Python number or a sequence of tensor literals - (i.e., a sequence of numbers of any depth). - """ - - sig = Union[numbers.Number, Sequence["tripy.TensorLiteral.sig"]] +tensor_literal = export.public_api( + document_under="types.rst", + module=sys.modules[__name__], + symbol="tensor_literal", +)(Union[numbers.Number, Sequence["tripy.types.tensor_literal"]]) +""" +Denotes the recursive type annotation for tensor literals. +A tensor literal can be a Python number or a sequence of tensor literals +(i.e., a sequence of numbers of any depth). +"""