From 5f0e9df977282bb65aef2bf517c825bf5c34fc7d Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 11 Sep 2024 17:16:32 -0700 Subject: [PATCH 1/4] Adds implementation of cumsum --- tripy/docs/conf.py | 4 +- tripy/tests/frontend/ops/test_cumsum.py | 24 ++++++ tripy/tests/integration/test_cumsum.py | 40 ++++++++++ tripy/tripy/frontend/ops/cumsum.py | 100 ++++++++++++++++++++++++ 4 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 tripy/tests/frontend/ops/test_cumsum.py create mode 100644 tripy/tests/integration/test_cumsum.py create mode 100644 tripy/tripy/frontend/ops/cumsum.py diff --git a/tripy/docs/conf.py b/tripy/docs/conf.py index 2e87a5ed..376962dc 100644 --- a/tripy/docs/conf.py +++ b/tripy/docs/conf.py @@ -237,14 +237,14 @@ def process_docstring(app, what, name, obj, options, lines): if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None): add_text_index = re.search(r":param \w+: ", block).span()[1] blocks[index] = ( - f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}" + f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}" ) if TYPE_VERIFICATION[unqual_name].return_dtype is not None and re.search(r":returns:", block): add_text_index = re.search(r":returns:", block).span()[1] + 1 # Add dtype constraint to start of returns description. blocks[index] = ( - f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}" + f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}" ) seen_classes.add(name) diff --git a/tripy/tests/frontend/ops/test_cumsum.py b/tripy/tests/frontend/ops/test_cumsum.py new file mode 100644 index 00000000..06ca5cd9 --- /dev/null +++ b/tripy/tests/frontend/ops/test_cumsum.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import helper +import tripy as tp + + +class TestCumsum: + def test_invalid_dim_fails(self): + a = tp.ones((2, 2)) + with helper.raises(tp.TripyException, "Dimension argument is out of bounds."): + tp.cumsum(a, dim=4) diff --git a/tripy/tests/integration/test_cumsum.py b/tripy/tests/integration/test_cumsum.py new file mode 100644 index 00000000..c8f8bbb7 --- /dev/null +++ b/tripy/tests/integration/test_cumsum.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import tripy as tp + + +class TestCumsum: + @pytest.mark.parametrize( + "data,dim,expected", + [ + ([0, 1, 2, 3], 0, [0, 1, 3, 6]), + # Negative dim: + ([[2, 3], [4, 5]], -2, [[2, 3], [6, 8]]), + # Non-innermost dim: + ([[2, 3], [4, 5]], 0, [[2, 3], [6, 8]]), + # >2D (can potentially find transposition bugs) + ([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 0, [[[1, 2], [3, 4]], [[6, 8], [10, 12]]]), + ], + ) + def test_cumsum(self, data, dim, expected): + inp = tp.Tensor(data, dtype=tp.float32) + + out = tp.cumsum(inp, dim=dim) + + expected = tp.Tensor(expected, dtype=tp.float32) + assert tp.allclose(out, expected) + assert out.shape == expected.shape diff --git a/tripy/tripy/frontend/ops/cumsum.py b/tripy/tripy/frontend/ops/cumsum.py new file mode 100644 index 00000000..f0a76719 --- /dev/null +++ b/tripy/tripy/frontend/ops/cumsum.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tripy import constraints, export + +from tripy.frontend import utils as frontend_utils + + +@export.public_api(document_under="operations/functions") +@constraints.dtype_info( + dtype_variables={ + "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], + }, + dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, +) +@frontend_utils.process_dim +def cumsum(input: "tripy.Tensor", dim: int) -> "tripy.Tensor": + """ + Computes the cumulative sum of elements in the input along the dimension ``dim``. + + Args: + input: The input tensor. + dim: The dimension along which to compute the cumulative sum. + + Returns: + A tensor of the same shape as the input. + + .. code-block:: python + :linenos: + :caption: 1D tensor + + input = tp.arange(4, 0, step=-1, dtype=tp.int32) + output = tp.cumsum(input, dim=0) + + assert cp.array_equal(cp.cumsum(cp.from_dlpack(input)), cp.from_dlpack(output)) + + .. code-block:: python + :linenos: + :caption: 2D tensor + + input = tp.reshape(tp.arange(9, 0, step=-1, dtype=tp.int32), (3, 3)) + output = tp.cumsum(input, dim=0) + + assert cp.array_equal(cp.cumsum(cp.from_dlpack(input), axis=0), cp.from_dlpack(output)) + """ + # Consider: + # + # a = [3, 2, 1] + # + # then, we can implement cumsum as: + # + # out = a @ [[1, 1, 1] + # [0, 1, 1] + # [0, 0, 1]] + # + # which will yield: + # + # out = [3, 3 + 2, 3 + 2 + 1] + # + # In the general case where `a` is an N-dimensional tensor, we simply transpose + # the dimension of interest to the innermost position and then carry out the + # GEMM described above, then tranpose the output back. + from tripy.frontend.trace.ops.permute import permute + from tripy.frontend.ops.tensor_initializers import triu, ones + + # For the examples in the comments that follow, assume the input shape is (3, 5, 7) and + # we are applying cumsum over dim=1 (the dimension of length 5). + + # Swap dim to innermost position: (3, 5, 7) -> (3, 7, 5) + move_to_innermost_perm = list(range(input.rank)) + del move_to_innermost_perm[dim] + move_to_innermost_perm.append(dim) + transposed = permute(input, move_to_innermost_perm) + + # GEMM with square upper triangular matrix: (3, 7, 5) @ (5, 5) -> (3, 7, 5) + + # TODO: We should replace this with: + # shape = transposed.shape[-1:] * 2 + # once the relevant shape inference bugs are fixed. + shape = (transposed.shape[input.rank - 1], transposed.shape[input.rank - 1]) + out = transposed @ triu(ones(shape=shape, dtype=transposed.dtype)) + + # Swap innermost position back to `dim`: (3, 7, 5) -> (3, 5, 7) + reset_dim_perm = list(range(input.rank)) + del reset_dim_perm[-1] + reset_dim_perm.insert(dim, input.rank - 1) + out = permute(out, reset_dim_perm) + + return out From 1b708b385bc1f76f911d1757b3fd53b8e821c48a Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 11 Sep 2024 17:26:23 -0700 Subject: [PATCH 2/4] Removes manual checks for INT64 data types from various ops --- .../spec_verification/object_builders.py | 69 ++++++++++--------- tripy/tripy/frontend/ops/allclose.py | 4 -- tripy/tripy/frontend/ops/cumsum.py | 2 +- .../tripy/frontend/ops/tensor_initializers.py | 16 ----- tripy/tripy/frontend/trace/ops/flip.py | 4 -- tripy/tripy/frontend/trace/ops/gather.py | 4 -- tripy/tripy/frontend/trace/ops/iota.py | 8 --- tripy/tripy/frontend/trace/ops/matmul.py | 4 -- tripy/tripy/frontend/trace/ops/reduce.py | 16 ----- .../frontend/trace/ops/unary_elementwise.py | 5 -- 10 files changed, 39 insertions(+), 93 deletions(-) diff --git a/tripy/tests/spec_verification/object_builders.py b/tripy/tests/spec_verification/object_builders.py index 4b289721..6d86d737 100644 --- a/tripy/tests/spec_verification/object_builders.py +++ b/tripy/tests/spec_verification/object_builders.py @@ -81,43 +81,19 @@ def default_builder(init, dtype, namespace): All other types do not have defaults and must be passed to the verifier using default_constraints_all. """ default_constraints_all = { - "__rtruediv__": {"self": 1}, - "__rsub__": {"self": 1}, + "__getitem__": {"index": 2}, + "__matmul__": {"self": tp.ones((2, 3))}, "__radd__": {"self": 1}, - "__rpow__": {"self": 1}, "__rmul__": {"self": 1}, - "softmax": {"dim": 1}, - "concatenate": {"dim": 0}, - "expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))}, - "full": {"shape": tp.Tensor([3]), "value": 1}, - "full_like": {"value": 1}, - "flip": {"dim": 1}, - "gather": {"dim": 0, "index": tp.Tensor([1])}, - "iota": {"shape": tp.Tensor([4])}, - "__matmul__": {"self": tp.ones((2, 3))}, - "transpose": {"dim0": 0, "dim1": 1}, - "permute": {"perm": [1, 0]}, - "quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, - "dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, - "sum": {"dim": 0}, + "__rpow__": {"self": 1}, + "__rsub__": {"self": 1}, + "__rtruediv__": {"self": 1}, "all": {"dim": 0}, "any": {"dim": 0}, - "max": {"dim": 0}, - "prod": {"dim": 0}, - "mean": {"dim": 0}, - "var": {"dim": 0}, + "arange": {"start": 0, "stop": 5}, "argmax": {"dim": 0}, "argmin": {"dim": 0}, - "reshape": {"shape": tp.Tensor([6])}, - "squeeze": {"input": tp.ones((3, 1)), "dims": (1)}, - "__getitem__": {"index": 2}, - "split": {"indices_or_sections": 2}, - "unsqueeze": {"dim": 1}, - "masked_fill": {"value": 1}, - "ones": {"shape": tp.Tensor([3, 2])}, - "zeros": {"shape": tp.Tensor([3, 2])}, - "arange": {"start": 0, "stop": 5}, - "repeat": {"repeats": 2, "dim": 0}, + "concatenate": {"dim": 0}, "convolution": { "input": tp.ones((1, 3, 5, 5)), "weight": tp.ones((1, 3, 3, 3)), @@ -127,6 +103,31 @@ def default_builder(init, dtype, namespace): "lhs_dilation": [1, 1], "rhs_dilation": [1, 1], }, + "cumsum": {"dim": 0}, + "dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, + "expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))}, + "flip": {"dim": 1}, + "full_like": {"value": 1}, + "full": {"shape": tp.Tensor([3]), "value": 1}, + "gather": {"dim": 0, "index": tp.Tensor([1])}, + "iota": {"shape": tp.Tensor([4])}, + "masked_fill": {"value": 1}, + "max": {"dim": 0}, + "mean": {"dim": 0}, + "ones": {"shape": tp.Tensor([3, 2])}, + "permute": {"perm": [1, 0]}, + "prod": {"dim": 0}, + "quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, + "repeat": {"repeats": 2, "dim": 0}, + "reshape": {"shape": tp.Tensor([6])}, + "softmax": {"dim": 1}, + "split": {"indices_or_sections": 2}, + "squeeze": {"input": tp.ones((3, 1)), "dims": (1)}, + "sum": {"dim": 0}, + "transpose": {"dim0": 0, "dim1": 1}, + "unsqueeze": {"dim": 1}, + "var": {"dim": 0}, + "zeros": {"shape": tp.Tensor([3, 2])}, } @@ -137,24 +138,30 @@ def create_obj(func_obj, func_name, param_name, param_dtype, namespace): param_dict = func_sig.parameters param_type_annot = param_dict[param_name] init = None + # Check if there is a value in default_constraints_all for func_name and param_name and use it. default_constraints = default_constraints_all.get(func_name, None) if default_constraints != None: other_constraint = default_constraints.get(param_name, None) if other_constraint is not None: init = other_constraint + # If parameter had a default then use it otherwise skip. if init is None and param_type_annot.default is not param_type_annot.empty: # Checking if not equal to None since default can be 0 or similar. if param_type_annot.default != None: init = param_type_annot.default + param_type = param_type_annot.annotation while get_origin(param_type) in [Union, Optional]: param_type = get_args(param_type)[0] # ForwardRef refers to any case where type hint is a string. if isinstance(param_type, ForwardRef): param_type = param_type.__forward_arg__ + create_obj_func = find_func.get(param_type, default_builder) if create_obj_func: namespace[param_name] = create_obj_func(init, param_dtype, namespace) return namespace[param_name] + + assert False, f"Could not create parameter: {param_name}" diff --git a/tripy/tripy/frontend/ops/allclose.py b/tripy/tripy/frontend/ops/allclose.py index 1fd1f6e9..bbe6d855 100644 --- a/tripy/tripy/frontend/ops/allclose.py +++ b/tripy/tripy/frontend/ops/allclose.py @@ -57,10 +57,6 @@ def allclose(a: "tripy.Tensor", b: "tripy.Tensor", rtol: float = 1e-05, atol: fl """ from tripy.frontend.trace.ops.unary_elementwise import abs from tripy.frontend.trace.ops.reduce import all - from tripy.common.datatype import int64 - - if a.dtype == int64: - raise_error("Known issue with i64. Allclose currently does not work with int64 inputs. Issue #116") compare = abs(a - b) <= (atol + rtol * abs(b)) return bool(all(compare)) diff --git a/tripy/tripy/frontend/ops/cumsum.py b/tripy/tripy/frontend/ops/cumsum.py index f0a76719..6b017b2a 100644 --- a/tripy/tripy/frontend/ops/cumsum.py +++ b/tripy/tripy/frontend/ops/cumsum.py @@ -20,7 +20,7 @@ @export.public_api(document_under="operations/functions") @constraints.dtype_info( dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int8", "int32", "int64", "bool"], + "T1": ["float32", "float16", "bfloat16", "float8", "int32"], }, dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, ) diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index d148b4f1..d3386b52 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -218,10 +218,6 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.tril(cp.from_dlpack(input).get(), -1)) """ - from tripy.common.datatype import int64 - - if tensor.dtype == int64: - raise_error("Known issue with i64. Tril currently does not work with int64 inputs.") tri_mask = (iota_like(tensor, -2, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) >= iota_like( tensor, -1, datatype.int32 ) @@ -280,10 +276,6 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.triu(cp.from_dlpack(input).get(), -1)) """ - from tripy.common.datatype import int64 - - if tensor.dtype == int64: - raise_error("Known issue with i64. Triu currently does not work with int64 inputs.") tri_mask = (iota_like(tensor, -2, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) <= iota_like( tensor, -1, datatype.int32 ) @@ -330,10 +322,6 @@ def arange( assert tp.allclose(output, tp.Tensor(np.arange(2.3, 0.8, -0.2, dtype=np.float32))) """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Arange currently does not work with int64 inputs.") if step == 0: raise_error("Step in arange cannot be 0.", []) @@ -377,8 +365,4 @@ def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tr assert (cp.from_dlpack(output).get() == np.arange(5, dtype=np.float32)).all() """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Arange currently does not work with int64 inputs. Issue #116") return arange(0, stop, dtype=dtype) diff --git a/tripy/tripy/frontend/trace/ops/flip.py b/tripy/tripy/frontend/trace/ops/flip.py index 3ea162a6..d0947744 100644 --- a/tripy/tripy/frontend/trace/ops/flip.py +++ b/tripy/tripy/frontend/trace/ops/flip.py @@ -79,10 +79,6 @@ def flip(input: "tripy.Tensor", dims: Optional[Union[int, Sequence[int]]] = None output = tp.flip(input, dims=-1) assert cp.array_equal(cp.from_dlpack(output), cp.array([[4, 3, 2, 1, 0], [9, 8, 7, 6, 5]])) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Flip currently does not work with int64 inputs. Issue #116") rank = input.rank if dims is None: dims = [d for d in range(rank)] diff --git a/tripy/tripy/frontend/trace/ops/gather.py b/tripy/tripy/frontend/trace/ops/gather.py index 63b99e73..59f55470 100644 --- a/tripy/tripy/frontend/trace/ops/gather.py +++ b/tripy/tripy/frontend/trace/ops/gather.py @@ -129,8 +129,4 @@ def gather(input: "tripy.Tensor", dim: int, index: "tripy.Tensor") -> "tripy.Ten assert np.array_equal(cp.from_dlpack(output).get(), np.take(cp.from_dlpack(data).get(), cp.from_dlpack(indices).get(), axis=1)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Gather currently does not work with int64 inputs. Issue #116") return Gather.build([input, index], dim) diff --git a/tripy/tripy/frontend/trace/ops/iota.py b/tripy/tripy/frontend/trace/ops/iota.py index d0d981c7..7ba0ad77 100644 --- a/tripy/tripy/frontend/trace/ops/iota.py +++ b/tripy/tripy/frontend/trace/ops/iota.py @@ -115,10 +115,6 @@ def iota( assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)) """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Iota currently does not work with int64 inputs. Issue #116") output_rank = len(shape) if isinstance(shape, Sequence) else None return iota_impl(shape, dim, dtype, output_rank) @@ -154,8 +150,4 @@ def iota_like(input: "tripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtyp assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)) """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Iota_like currently does not work with int64 inputs. Issue #116") return iota_impl(input.shape, dim, utils.default(dtype, input.dtype), input.rank) diff --git a/tripy/tripy/frontend/trace/ops/matmul.py b/tripy/tripy/frontend/trace/ops/matmul.py index 3d019ba6..6dd7b0a8 100644 --- a/tripy/tripy/frontend/trace/ops/matmul.py +++ b/tripy/tripy/frontend/trace/ops/matmul.py @@ -237,8 +237,4 @@ def __matmul__(self: "tripy.Tensor", other: "tripy.Tensor") -> "tripy.Tensor": output = a @ b assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(a).get() @ cp.from_dlpack(b).get()) """ - from tripy.common.datatype import int64 - - if other.dtype == int64: - raise_error("Known issue with i64. __matmul__ currently does not work with int64 inputs. Issue #116") return MatrixMultiplication.build([self, other]) diff --git a/tripy/tripy/frontend/trace/ops/reduce.py b/tripy/tripy/frontend/trace/ops/reduce.py index c3f6d52c..cfc33d72 100644 --- a/tripy/tripy/frontend/trace/ops/reduce.py +++ b/tripy/tripy/frontend/trace/ops/reduce.py @@ -167,10 +167,6 @@ def sum( assert np.array_equal(cp.from_dlpack(output).get(), np.sum(np.arange(6, dtype=np.float32).reshape((2, 3)), 0)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Sum currently does not work with int64 inputs. Issue #116") return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim) @@ -274,10 +270,6 @@ def max( assert np.array_equal(cp.from_dlpack(output).get(), np.max(np.arange(6, dtype=np.float32).reshape((2, 3)), 0)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Max currently does not work with int64 inputs. Issue #116") return _reduce_impl(input, Reduce.Kind.MAX, dim, keepdim) @@ -313,10 +305,6 @@ def prod( assert np.array_equal(cp.from_dlpack(output).get(), np.prod(np.arange(6, dtype=np.float32).reshape((2, 3)), 0)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Prod currently does not work with int64 inputs. Issue #116") return _reduce_impl(input, Reduce.Kind.MUL, dim, keepdim) @@ -374,10 +362,6 @@ def mean( assert np.array_equal(cp.from_dlpack(output).get(), np.mean(np.arange(6, dtype=np.float32).reshape((2, 3)), axis=1, keepdims=True)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Mean currently does not work with int64 inputs. Issue #116") return mean_impl(input, dim, keepdim) diff --git a/tripy/tripy/frontend/trace/ops/unary_elementwise.py b/tripy/tripy/frontend/trace/ops/unary_elementwise.py index 8aad7c59..25a2041c 100644 --- a/tripy/tripy/frontend/trace/ops/unary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/unary_elementwise.py @@ -287,9 +287,4 @@ def abs(input: "tripy.Tensor") -> "tripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1, 2], dtype=np.float32)) """ - from tripy.frontend import Tensor - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Abs currently does not work with int64 inputs. Issue #116") return UnaryElementwise.build([input], UnaryElementwise.Kind.ABS) From 6da16dcc13233ffe8cfb8e50849215b4cc53f0e8 Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 11 Sep 2024 17:50:37 -0700 Subject: [PATCH 3/4] Cleans up display of type constraints and exceptions --- tripy/docs/README.md | 2 +- tripy/docs/conf.py | 32 +++++++++++++++++--------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tripy/docs/README.md b/tripy/docs/README.md index a2a55435..2fac8d95 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -7,7 +7,7 @@ This directory includes all the source files for the public API documentation. You can build the documentation locally in the development container by running: ```bash python3 docs/generate_rsts.py -sphinx-build build/doc_sources build/docs -c docs/ -j auto -W +sphinx-build build/doc_sources build/docs -c docs/ -j 4 -W ``` To view the documentation, you can open `build/docs/index.html` in a browser. diff --git a/tripy/docs/conf.py b/tripy/docs/conf.py index 376962dc..e48f1f3d 100644 --- a/tripy/docs/conf.py +++ b/tripy/docs/conf.py @@ -195,14 +195,19 @@ def process_docstring(app, what, name, obj, options, lines): if unqual_name in TYPE_VERIFICATION: add_text_index = -1 for index, block in enumerate(blocks): + + def insert_block(text): + nonlocal index + + blocks.insert(index, text) + index += 1 + if re.search(r".. code-block::", block): type_dict = TYPE_VERIFICATION[unqual_name].dtypes - blocks.insert(index, "Type Constraints:") - index += 1 + insert_block("TYPE CONSTRAINTS:") # Add the dtype constraint name and the dtypes that correlate. for type_name, dt in type_dict.items(): - blocks.insert( - index, + insert_block( f" - **{type_name}**: :class:`" + "`, :class:`".join( sorted( @@ -215,20 +220,17 @@ def process_docstring(app, what, name, obj, options, lines): ) + "`", ) - index += 1 - blocks.insert(index, "\n") - if TYPE_VERIFICATION[unqual_name].dtype_exceptions != []: + insert_block("\n") + + if TYPE_VERIFICATION[unqual_name].dtype_exceptions: # Add the dtype exceptions. - index += 1 - blocks.insert(index, "**Unsupported Type Combinations**:") - dtype_exception_text = [] + insert_block("UNSUPPORTED TYPE COMBINATIONS:") for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions: - dtype_exception_text.append( - ", ".join([f"{key}: :class:`{val}`" for key, val in exception_dict.items()]) + insert_block( + " - " + + ", ".join([f"**{key}**\ =\ :class:`{val}`" for key, val in exception_dict.items()]), ) - dtype_exception_text = "; ".join(dtype_exception_text) + "\n" - index += 1 - blocks.insert(index, dtype_exception_text) + insert_block("\n") break if re.search(r":param \w+: ", block): From 251102e4c5577c14c87d211aca8f1a9ee279494c Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Mon, 16 Sep 2024 09:58:45 -0700 Subject: [PATCH 4/4] Add scalar shape class to Tripy (#202) This PR adds `ScalarShape` to Tripy which is encodes a value that is sliced out of a `Shape` tensor. --- tripy/tests/frontend/test_shape.py | 24 ++++++++ tripy/tripy/frontend/shape.py | 57 ++++++++++++++++++- tripy/tripy/frontend/trace/ops/base.py | 33 +++++++---- .../frontend/trace/ops/binary_elementwise.py | 9 ++- tripy/tripy/frontend/trace/ops/cast.py | 4 +- tripy/tripy/frontend/trace/ops/expand.py | 4 +- tripy/tripy/frontend/trace/ops/matmul.py | 2 +- tripy/tripy/frontend/trace/ops/reshape.py | 4 +- tripy/tripy/frontend/trace/ops/shape.py | 2 +- tripy/tripy/frontend/trace/ops/slice.py | 4 +- tripy/tripy/frontend/trace/ops/utils.py | 6 +- tripy/tripy/frontend/trace/ops/where.py | 4 +- 12 files changed, 122 insertions(+), 31 deletions(-) diff --git a/tripy/tests/frontend/test_shape.py b/tripy/tests/frontend/test_shape.py index e9728c95..c852426a 100644 --- a/tripy/tests/frontend/test_shape.py +++ b/tripy/tests/frontend/test_shape.py @@ -40,6 +40,30 @@ def other_values(request): return request.param +class TestShapeScalar: + @pytest.mark.parametrize("value", [1, tp.Tensor(1), np.array(2)]) + def test_scalar_shape(self, value): + s = tp.ShapeScalar(values) + + assert isinstance(s, tp.ShapeScalar) + assert s.trace_tensor.producer.inputs == [] + + def test_scalar_slice(self): + a = tp.iota((3, 3)) + assert isinstance(a.shape[0], tp.ShapeScalar) + + s = a.shape[0] * a.shape[1] + b = tp.reshape(a, tp.reshape(s, (1,))) + assert tp.allclose(tp.flatten(a), b) + + def test_scalar_scalar_op(self): + a = tp.iota((3, 4)) + s1 = a.shape[0] + s2 = a.shape[1] + s = s1 + s2 + assert isinstance(s, tp.ShapeScalar) + + class TestShape: def test_shape(self, values): s = tp.Shape(values) diff --git a/tripy/tripy/frontend/shape.py b/tripy/tripy/frontend/shape.py index 7650f704..86201ad9 100644 --- a/tripy/tripy/frontend/shape.py +++ b/tripy/tripy/frontend/shape.py @@ -24,6 +24,62 @@ import tripy.frontend.utils as frontend_utils +@export.public_api() +class ShapeScalar(Tensor): + """ + Scalar shape is a tensor used to represent a scalar value extracted from a shape tensor. + ShapeScalars are scalars (rank 0) of non-negative integer (using int32 as the datatype). + """ + + def __init__( + self, + data: Union[Sequence, Tensor, "np.ndarray", "cp.ndarray", "torch.Tensor", "jnp.ndarray"], + name: Optional[str] = None, + ) -> None: + r""" + Args: + data: The value of the ShapeScalar, which should be a scalar integer. + name: An optional name + """ + + from tripy.common.exception import raise_error + + if isinstance(data, Tensor): + # these fields can be None in the case of an uninitialized tensor (like Tensor(None)) + if data.trace_tensor.rank is not None and data.trace_tensor.rank != 0: + raise_error( + f"Scalar shape tensors must be of rank 0, but input tensor is rank {data.rank}", details=[data] + ) + if data.dtype is not None and data.dtype != int32: + raise_error( + f"Scalar shape tensor must have int32 member, but input tensor has data type {data.dtype}", + details=[data], + ) + + # the shape of data should correspond to the given rank + super().__init__(data=None, dtype=int32, name=name, device=data.device) + # share the underlying data + self.trace_tensor = data.trace_tensor + self.stack_info = data.stack_info + else: + shape = data.shape if hasattr(data, "shape") else utils.get_shape(data) + device = data.device if hasattr(data, "device") else None + if len(shape) != 0: + raise_error( + f"Tensors used to represent scalar shapes must be of rank 0, but given shape {shape} has rank {len(shape)}." + ) + super().__init__(data=data, dtype=int32, name=name, device=device) + + def __repr__(self) -> str: + # denote the representation as a shape rather than a tensor + tensor_repr = super().__repr__() + assert tensor_repr[:6] == "tensor" + return "shape_scalar" + tensor_repr[6:] + + def __str__(self) -> str: + return "shape_scalar" + "(" + ", ".join(map(str, self.tolist())) + ")" + + @export.public_api() class Shape(Tensor): """ @@ -47,7 +103,6 @@ def __init__( r""" Args: data: The value of the shape, which should be a 1D array of integers (the dimensions). - num_dims: The number of dimensions in the shape (its rank), which should correspond to the number of elements in data name: An optional name """ diff --git a/tripy/tripy/frontend/trace/ops/base.py b/tripy/tripy/frontend/trace/ops/base.py index eabd89a6..d4b4ce33 100644 --- a/tripy/tripy/frontend/trace/ops/base.py +++ b/tripy/tripy/frontend/trace/ops/base.py @@ -67,7 +67,7 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[ """ from tripy.common.exception import raise_error - from tripy.frontend.shape import Shape + from tripy.frontend.shape import Shape, ShapeScalar from tripy.frontend.tensor import Tensor # NOTE: If you change the stack depth where the tensors are constructed, update STACK_DEPTH_OF_BUILD in @@ -87,13 +87,22 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[ raise_error( f"Error processing shape inputs in operator {cls.__name__}{custom_err}\n(Shape input indices: {shape_arg_msg}.)" ) - # for shape outputs, we infer the length - if len(res.value) != 0: - inferred_lengths = op.infer_len() - for idx in res.value: - outputs[idx] = Shape(outputs[idx]) - if inferred_lengths[idx] is not None: - out_trace_tensors[idx].shape = [inferred_lengths[idx]] + + shape = res.value.get("shape") + if shape is not None: + # for shape outputs, we infer the length + if len(shape) != 0: + inferred_lengths = op.infer_len() + + for idx in shape: + outputs[idx] = Shape(outputs[idx]) + if inferred_lengths[idx] is not None: + out_trace_tensors[idx].shape = [inferred_lengths[idx]] + + scalar_shape = res.value.get("scalar") + if scalar_shape is not None: + for idx in scalar_shape: + outputs[idx] = ShapeScalar(outputs[idx]) if num_outputs == 1: return outputs[0] @@ -101,8 +110,8 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[ def infer_shape_output_idxs(self, inputs: List["Tensor"]) -> Result: """ - Given the operator's inputs, this method returns a `Result` containing a list of the operator's output indices - that should be wrapped in `tp.Shape`. + Given the operator's inputs, this method returns a `Result` containing a dict of the operator's output indices + that should be wrapped in `tp.Shape` or `tp.ShapeScalar`. By default, this will wrap all the outputs in `tp.Shape` if all the inputs are `tp.Shape`s and not wrap any otherwise, treating it as an error if the inputs are inconsistent. @@ -126,9 +135,9 @@ def infer_shape_output_idxs(self, inputs: List["Tensor"]) -> Result: if any(map(is_shape, inputs)): if all(map(is_shape, inputs)): - return Result.ok(list(range(len(self.outputs)))) + return Result.ok({"shape": list(range(len(self.outputs))), "scalar": []}) return Result.err(["Either all inputs must be tp.Shape or all must be tp.Tensor."]) - return Result.ok([]) + return Result.ok({}) def infer_len(self) -> List[Optional[int]]: """ diff --git a/tripy/tripy/frontend/trace/ops/binary_elementwise.py b/tripy/tripy/frontend/trace/ops/binary_elementwise.py index 5a00ba78..f58a1cc1 100644 --- a/tripy/tripy/frontend/trace/ops/binary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/binary_elementwise.py @@ -48,7 +48,7 @@ def __str__(self): def infer_shape_output_idxs(self, inputs): # permit one input to be a shape but require the output to be a shape - from tripy.frontend.shape import Shape + from tripy.frontend.shape import Shape, ShapeScalar from tripy.utils import Result if any(map(lambda t: isinstance(t, Shape), inputs)): @@ -66,9 +66,12 @@ def infer_shape_output_idxs(self, inputs): f"The following inputs have invalid ranks: {invalid_indices_message}", ] ) - return Result.ok([0]) + return Result.ok({"shape": [0]}) + elif all(map(lambda t: isinstance(t, ShapeScalar), inputs)): + # Binary operation on ShapeScalar should yield another ShapeScalar. + return Result.ok({"scalar": [0]}) else: - return Result.ok([]) + return Result.ok({}) def infer_len(self): # For the shape case, the result will be broadcast to the max of the input shapes diff --git a/tripy/tripy/frontend/trace/ops/cast.py b/tripy/tripy/frontend/trace/ops/cast.py index 690575b7..f435c2e4 100644 --- a/tripy/tripy/frontend/trace/ops/cast.py +++ b/tripy/tripy/frontend/trace/ops/cast.py @@ -33,8 +33,8 @@ def infer_shape_output_idxs(self, inputs): if isinstance(inputs[0], Shape): # Only still a valid shape if it remains int32 if self.dtype == int32: - return Result.ok([0]) - return Result.ok([]) + return Result.ok({"shape": [0]}) + return Result.ok({}) infer_len = InferLenPolicies.infer_same_as_first_input diff --git a/tripy/tripy/frontend/trace/ops/expand.py b/tripy/tripy/frontend/trace/ops/expand.py index 5956e8a9..66c8b982 100644 --- a/tripy/tripy/frontend/trace/ops/expand.py +++ b/tripy/tripy/frontend/trace/ops/expand.py @@ -39,8 +39,8 @@ def infer_shape_output_idxs(self, inputs) -> Result: # wrap if the first input is a shape and the output is rank-1 if isinstance(inputs[0], Shape) and self.output_rank == 1: - return Result.ok([0]) - return Result.ok([]) + return Result.ok({"shape": [0]}) + return Result.ok({}) def infer_len(self): if self.output_len is not None: diff --git a/tripy/tripy/frontend/trace/ops/matmul.py b/tripy/tripy/frontend/trace/ops/matmul.py index 6dd7b0a8..e040654c 100644 --- a/tripy/tripy/frontend/trace/ops/matmul.py +++ b/tripy/tripy/frontend/trace/ops/matmul.py @@ -38,7 +38,7 @@ def infer_shape_output_idxs(self, inputs): if (isinstance(inputs[0], Shape) and isinstance(inputs[1], Shape)) or ( not isinstance(inputs[0], Shape) and not isinstance(inputs[1], Shape) ): - return Result.ok([]) + return Result.ok({}) return Result.err(None) def infer_rank(self): diff --git a/tripy/tripy/frontend/trace/ops/reshape.py b/tripy/tripy/frontend/trace/ops/reshape.py index f86c2613..5b111439 100644 --- a/tripy/tripy/frontend/trace/ops/reshape.py +++ b/tripy/tripy/frontend/trace/ops/reshape.py @@ -47,8 +47,8 @@ def infer_shape_output_idxs(self, inputs): # Only wrap the reshaped output if the result is rank 1, otherwise don't wrap if isinstance(inputs[0], Shape) and self.output_rank == 1: - return Result.ok([0]) - return Result.ok([]) + return Result.ok({"shape": [0]}) + return Result.ok({}) def infer_rank(self): if self.output_rank is None: diff --git a/tripy/tripy/frontend/trace/ops/shape.py b/tripy/tripy/frontend/trace/ops/shape.py index 4835ed77..c4498391 100644 --- a/tripy/tripy/frontend/trace/ops/shape.py +++ b/tripy/tripy/frontend/trace/ops/shape.py @@ -28,7 +28,7 @@ class Shape(BaseTraceOp): # always return a shape def infer_shape_output_idxs(self, inputs) -> Result: - return Result.ok([0]) + return Result.ok({"shape": [0]}) def infer_len(self): return [self.inputs[0].rank] diff --git a/tripy/tripy/frontend/trace/ops/slice.py b/tripy/tripy/frontend/trace/ops/slice.py index b0766f7a..e0e78093 100644 --- a/tripy/tripy/frontend/trace/ops/slice.py +++ b/tripy/tripy/frontend/trace/ops/slice.py @@ -208,7 +208,7 @@ def __getitem__(self: "tripy.Tensor", index: Union[slice, int, Tuple[int], "trip assert np.array_equal(cp.from_dlpack(output).get(), np.arange(10)[8:2:-1]) """ - from tripy.frontend.shape import Shape + from tripy.frontend.shape import ShapeScalar, Shape from tripy.frontend.tensor import Tensor from tripy.frontend.trace.ops.flip import flip from tripy.frontend.trace.ops.reshape import reshape, squeeze @@ -297,7 +297,7 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]: if squeeze_dims: out = squeeze(out, make_tuple(squeeze_dims)) - return out + return ShapeScalar(out) if isinstance(self, Shape) and out.rank == 0 else out # Conveniently converts the inputs to tensors. The decorator also fills in column info for the converted tensors. diff --git a/tripy/tripy/frontend/trace/ops/utils.py b/tripy/tripy/frontend/trace/ops/utils.py index 998a6e17..e11ce998 100644 --- a/tripy/tripy/frontend/trace/ops/utils.py +++ b/tripy/tripy/frontend/trace/ops/utils.py @@ -61,14 +61,14 @@ def infer_from_first_input_only(self, inputs): from tripy.frontend.shape import Shape if isinstance(inputs[0], Shape): - return Result.ok(list(range(len(self.outputs)))) - return Result.ok([]) + return Result.ok({"shape": list(range(len(self.outputs)))}) + return Result.ok({}) def never_return_shape(self, inputs): """ Accepts shapes but the result is always no shape indices """ - return Result.ok([]) + return Result.ok({}) ## diff --git a/tripy/tripy/frontend/trace/ops/where.py b/tripy/tripy/frontend/trace/ops/where.py index 31634727..edd8e7c8 100644 --- a/tripy/tripy/frontend/trace/ops/where.py +++ b/tripy/tripy/frontend/trace/ops/where.py @@ -40,9 +40,9 @@ def infer_shape_output_idxs(self, inputs): f" the Boolean input must be rank 1, but given rank {inputs[0].rank}", ] ) - return Result.ok([0]) + return Result.ok({"shape": [0]}) elif not isinstance(inputs[1], Shape) and not isinstance(inputs[2], Shape): - return Result.ok([]) + return Result.ok({}) else: return Result.err( [