Skip to content

Commit

Permalink
Move registry for typing
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Sep 11, 2024
1 parent 9a389f5 commit 9b1df32
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 34 deletions.
14 changes: 7 additions & 7 deletions tripy/tests/test_function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def func(n: Sequence[Union[int, Sequence[int]]]) -> int:

def test_tensor_literal(self, registry):
@registry("test")
def func(n: "tripy.TensorLiteral.sig"):
def func(n: "tripy.types.tensor_literal"):
return n

assert registry["test"](1) == 1
Expand Down Expand Up @@ -472,7 +472,7 @@ def func(n: Sequence[Union[int, float]]) -> int:

def test_error_tensor_literal_not_sequence(self, registry):
@registry("test")
def func(n: "tripy.TensorLiteral.sig"):
def func(n: "tripy.types.tensor_literal"):
return n

with helper.raises(
Expand All @@ -484,19 +484,19 @@ def func(n: "tripy.TensorLiteral.sig"):
--> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m
\|
[0-9]+ \| def func\(n: \"tripy\.TensorLiteral\.sig\"\):
[0-9]+ \| def func\(n: \"tripy\.types\.tensor_literal\"\):
[0-9]+ \| \.\.\.
\|\s
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.TensorLiteral\.sig'\)\]\]' but got argument of type: 'str'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.tensor_literal'\)\]\]' but got argument of type: 'str'\.
"""
).strip(),
):
registry["test"]("hi")

def test_error_tensor_literal_not_sequence_of_numbers(self, registry):
@registry("test")
def func(n: "tripy.TensorLiteral.sig"):
def func(n: "tripy.types.tensor_literal"):
return n

with helper.raises(
Expand All @@ -508,11 +508,11 @@ def func(n: "tripy.TensorLiteral.sig"):
--> \x1b\[38;5;3m{__file__}\x1b\[0m:[0-9]+ in \x1b\[38;5;6mfunc\(\)\x1b\[0m
\|
[0-9]+ \| def func\(n: \"tripy\.TensorLiteral\.sig\"\):
[0-9]+ \| def func\(n: \"tripy\.types\.tensor_literal\"\):
[0-9]+ \| \.\.\.
\|\s
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.TensorLiteral\.sig'\)\]\]' but got argument of type: 'List\[List\[str\]\]'
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.tensor_literal'\)\]\]' but got argument of type: 'List\[List\[str\]\]'
"""
).strip(),
):
Expand Down
42 changes: 27 additions & 15 deletions tripy/tripy/frontend/trace/ops/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def to_flat_ir(self, inputs, outputs):
function_name="__radd__",
)
def __add__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs an elementwise sum.
Expand Down Expand Up @@ -217,7 +218,8 @@ def __add__(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"},
)
def __sub__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs an elementwise subtraction.
Expand Down Expand Up @@ -250,7 +252,7 @@ def __sub__(
dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"},
)
def __rsub__(
self: "tripy.TensorLiteral.sig", other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: "tripy.types.tensor_literal", other: Union["tripy.Tensor", "tripy.types.tensor_literal"]
) -> "tripy.Tensor":
"""
Performs an elementwise subtraction.
Expand Down Expand Up @@ -283,7 +285,8 @@ def __rsub__(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"},
)
def __pow__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs an elementwise exponentiation.
Expand Down Expand Up @@ -354,7 +357,8 @@ def __rpow__(self: numbers.Number, other: Union["tripy.Tensor", Any]) -> "tripy.
function_name="__rmul__",
)
def __mul__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs an elementwise multiplication.
Expand Down Expand Up @@ -387,7 +391,8 @@ def __mul__(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"},
)
def __truediv__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs an elementwise division.
Expand Down Expand Up @@ -419,7 +424,7 @@ def __truediv__(
dtype_variables={"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64"]},
dtype_constraints={"other": "T1", constraints.RETURN_VALUE: "T1"},
)
def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]) -> "tripy.Tensor":
def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", "tripy.types.tensor_literal"]) -> "tripy.Tensor":
"""
Performs an elementwise division.
Expand Down Expand Up @@ -451,7 +456,7 @@ def __rtruediv__(self: numbers.Number, other: Union["tripy.Tensor", "tripy.Tenso
dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"},
)
def maximum(
lhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], rhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
lhs: Union["tripy.Tensor", "tripy.types.tensor_literal"], rhs: Union["tripy.Tensor", "tripy.types.tensor_literal"]
) -> "tripy.Tensor":
"""
Performs an elementwise maximum.
Expand Down Expand Up @@ -484,7 +489,7 @@ def maximum(
dtype_constraints={"lhs": "T1", "rhs": "T1", constraints.RETURN_VALUE: "T1"},
)
def minimum(
lhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], rhs: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
lhs: Union["tripy.Tensor", "tripy.types.tensor_literal"], rhs: Union["tripy.Tensor", "tripy.types.tensor_literal"]
) -> "tripy.Tensor":
"""
Performs an elementwise minimum.
Expand Down Expand Up @@ -520,7 +525,8 @@ def minimum(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"},
)
def __lt__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs a 'less than' comparison.
Expand Down Expand Up @@ -556,7 +562,8 @@ def __lt__(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"},
)
def __le__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs a 'less than or equal' comparison.
Expand Down Expand Up @@ -592,7 +599,8 @@ def __le__(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"},
)
def __eq__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs an 'equal' comparison.
Expand Down Expand Up @@ -627,7 +635,9 @@ def __eq__(
},
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"},
)
def __ne__(self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", Any]) -> "tripy.Tensor":
def __ne__(
self: Union["tripy.Tensor", "tripy.types.tensor_literal"], other: Union["tripy.Tensor", Any]
) -> "tripy.Tensor":
"""
Performs a 'not equal' comparison.
Expand Down Expand Up @@ -662,7 +672,8 @@ def __ne__(self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union[
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"},
)
def __ge__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs a 'greater than or equal' comparison.
Expand Down Expand Up @@ -698,7 +709,8 @@ def __ge__(
dtype_constraints={"self": "T1", "other": "T1", constraints.RETURN_VALUE: "T2"},
)
def __gt__(
self: Union["tripy.Tensor", "tripy.TensorLiteral.sig"], other: Union["tripy.Tensor", "tripy.TensorLiteral.sig"]
self: Union["tripy.Tensor", "tripy.types.tensor_literal"],
other: Union["tripy.Tensor", "tripy.types.tensor_literal"],
) -> "tripy.Tensor":
"""
Performs a 'greater than' comparison.
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def to_flat_ir(self, inputs, outputs):
@frontend_utils.convert_inputs_to_tensors(exclude=["dtype", "dim"])
def dequantize(
input: "tripy.Tensor",
scale: Union["tripy.Tensor", "tripy.TensorLiteral.sig"],
scale: Union["tripy.Tensor", "tripy.types.tensor_literal"],
dtype: datatype.dtype,
dim: Union[int, Any] = None,
) -> "tripy.Tensor":
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def to_flat_ir(self, inputs, outputs):
)
def quantize(
input: "tripy.Tensor",
scale: Union["tripy.Tensor", "tripy.TensorLiteral.sig"],
scale: Union["tripy.Tensor", "tripy.types.tensor_literal"],
dtype: datatype.dtype,
dim: Union[int, Any] = None,
) -> "tripy.Tensor":
Expand Down
23 changes: 13 additions & 10 deletions tripy/tripy/common/types.py → tripy/tripy/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,17 +16,20 @@
#

import numbers
import sys
from typing import Union, Sequence

from tripy import export

export.public_api()(sys.modules[__name__])

@export.public_api(document_under="types.rst")
class TensorLiteral:
"""
The `sig` member of this class can be used as a type annotation for tensor literals.
A tensor literal can be a Python number or a sequence of tensor literals
(i.e., a sequence of numbers of any depth).
"""

sig = Union[numbers.Number, Sequence["tripy.TensorLiteral.sig"]]
tensor_literal = export.public_api(
document_under="types.rst",
module=sys.modules[__name__],
symbol="tensor_literal",
)(Union[numbers.Number, Sequence["tripy.types.tensor_literal"]])
"""
Denotes the recursive type annotation for tensor literals.
A tensor literal can be a Python number or a sequence of tensor literals
(i.e., a sequence of numbers of any depth).
"""

0 comments on commit 9b1df32

Please sign in to comment.