diff --git a/tripy/docs/README.md b/tripy/docs/README.md index a2a55435b..2fac8d95a 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -7,7 +7,7 @@ This directory includes all the source files for the public API documentation. You can build the documentation locally in the development container by running: ```bash python3 docs/generate_rsts.py -sphinx-build build/doc_sources build/docs -c docs/ -j auto -W +sphinx-build build/doc_sources build/docs -c docs/ -j 4 -W ``` To view the documentation, you can open `build/docs/index.html` in a browser. diff --git a/tripy/docs/conf.py b/tripy/docs/conf.py index 2e87a5ed4..e48f1f3d9 100644 --- a/tripy/docs/conf.py +++ b/tripy/docs/conf.py @@ -195,14 +195,19 @@ def process_docstring(app, what, name, obj, options, lines): if unqual_name in TYPE_VERIFICATION: add_text_index = -1 for index, block in enumerate(blocks): + + def insert_block(text): + nonlocal index + + blocks.insert(index, text) + index += 1 + if re.search(r".. code-block::", block): type_dict = TYPE_VERIFICATION[unqual_name].dtypes - blocks.insert(index, "Type Constraints:") - index += 1 + insert_block("TYPE CONSTRAINTS:") # Add the dtype constraint name and the dtypes that correlate. for type_name, dt in type_dict.items(): - blocks.insert( - index, + insert_block( f" - **{type_name}**: :class:`" + "`, :class:`".join( sorted( @@ -215,20 +220,17 @@ def process_docstring(app, what, name, obj, options, lines): ) + "`", ) - index += 1 - blocks.insert(index, "\n") - if TYPE_VERIFICATION[unqual_name].dtype_exceptions != []: + insert_block("\n") + + if TYPE_VERIFICATION[unqual_name].dtype_exceptions: # Add the dtype exceptions. - index += 1 - blocks.insert(index, "**Unsupported Type Combinations**:") - dtype_exception_text = [] + insert_block("UNSUPPORTED TYPE COMBINATIONS:") for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions: - dtype_exception_text.append( - ", ".join([f"{key}: :class:`{val}`" for key, val in exception_dict.items()]) + insert_block( + " - " + + ", ".join([f"**{key}**\ =\ :class:`{val}`" for key, val in exception_dict.items()]), ) - dtype_exception_text = "; ".join(dtype_exception_text) + "\n" - index += 1 - blocks.insert(index, dtype_exception_text) + insert_block("\n") break if re.search(r":param \w+: ", block): @@ -237,14 +239,14 @@ def process_docstring(app, what, name, obj, options, lines): if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None): add_text_index = re.search(r":param \w+: ", block).span()[1] blocks[index] = ( - f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}" + f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}" ) if TYPE_VERIFICATION[unqual_name].return_dtype is not None and re.search(r":returns:", block): add_text_index = re.search(r":returns:", block).span()[1] + 1 # Add dtype constraint to start of returns description. blocks[index] = ( - f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}" + f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}" ) seen_classes.add(name) diff --git a/tripy/tests/frontend/ops/test_cumsum.py b/tripy/tests/frontend/ops/test_cumsum.py new file mode 100644 index 000000000..06ca5cd91 --- /dev/null +++ b/tripy/tests/frontend/ops/test_cumsum.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import helper +import tripy as tp + + +class TestCumsum: + def test_invalid_dim_fails(self): + a = tp.ones((2, 2)) + with helper.raises(tp.TripyException, "Dimension argument is out of bounds."): + tp.cumsum(a, dim=4) diff --git a/tripy/tests/integration/test_cumsum.py b/tripy/tests/integration/test_cumsum.py new file mode 100644 index 000000000..c8f8bbb7e --- /dev/null +++ b/tripy/tests/integration/test_cumsum.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import tripy as tp + + +class TestCumsum: + @pytest.mark.parametrize( + "data,dim,expected", + [ + ([0, 1, 2, 3], 0, [0, 1, 3, 6]), + # Negative dim: + ([[2, 3], [4, 5]], -2, [[2, 3], [6, 8]]), + # Non-innermost dim: + ([[2, 3], [4, 5]], 0, [[2, 3], [6, 8]]), + # >2D (can potentially find transposition bugs) + ([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 0, [[[1, 2], [3, 4]], [[6, 8], [10, 12]]]), + ], + ) + def test_cumsum(self, data, dim, expected): + inp = tp.Tensor(data, dtype=tp.float32) + + out = tp.cumsum(inp, dim=dim) + + expected = tp.Tensor(expected, dtype=tp.float32) + assert tp.allclose(out, expected) + assert out.shape == expected.shape diff --git a/tripy/tests/spec_verification/object_builders.py b/tripy/tests/spec_verification/object_builders.py index 4b289721a..6d86d737d 100644 --- a/tripy/tests/spec_verification/object_builders.py +++ b/tripy/tests/spec_verification/object_builders.py @@ -81,43 +81,19 @@ def default_builder(init, dtype, namespace): All other types do not have defaults and must be passed to the verifier using default_constraints_all. """ default_constraints_all = { - "__rtruediv__": {"self": 1}, - "__rsub__": {"self": 1}, + "__getitem__": {"index": 2}, + "__matmul__": {"self": tp.ones((2, 3))}, "__radd__": {"self": 1}, - "__rpow__": {"self": 1}, "__rmul__": {"self": 1}, - "softmax": {"dim": 1}, - "concatenate": {"dim": 0}, - "expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))}, - "full": {"shape": tp.Tensor([3]), "value": 1}, - "full_like": {"value": 1}, - "flip": {"dim": 1}, - "gather": {"dim": 0, "index": tp.Tensor([1])}, - "iota": {"shape": tp.Tensor([4])}, - "__matmul__": {"self": tp.ones((2, 3))}, - "transpose": {"dim0": 0, "dim1": 1}, - "permute": {"perm": [1, 0]}, - "quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, - "dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, - "sum": {"dim": 0}, + "__rpow__": {"self": 1}, + "__rsub__": {"self": 1}, + "__rtruediv__": {"self": 1}, "all": {"dim": 0}, "any": {"dim": 0}, - "max": {"dim": 0}, - "prod": {"dim": 0}, - "mean": {"dim": 0}, - "var": {"dim": 0}, + "arange": {"start": 0, "stop": 5}, "argmax": {"dim": 0}, "argmin": {"dim": 0}, - "reshape": {"shape": tp.Tensor([6])}, - "squeeze": {"input": tp.ones((3, 1)), "dims": (1)}, - "__getitem__": {"index": 2}, - "split": {"indices_or_sections": 2}, - "unsqueeze": {"dim": 1}, - "masked_fill": {"value": 1}, - "ones": {"shape": tp.Tensor([3, 2])}, - "zeros": {"shape": tp.Tensor([3, 2])}, - "arange": {"start": 0, "stop": 5}, - "repeat": {"repeats": 2, "dim": 0}, + "concatenate": {"dim": 0}, "convolution": { "input": tp.ones((1, 3, 5, 5)), "weight": tp.ones((1, 3, 3, 3)), @@ -127,6 +103,31 @@ def default_builder(init, dtype, namespace): "lhs_dilation": [1, 1], "rhs_dilation": [1, 1], }, + "cumsum": {"dim": 0}, + "dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, + "expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))}, + "flip": {"dim": 1}, + "full_like": {"value": 1}, + "full": {"shape": tp.Tensor([3]), "value": 1}, + "gather": {"dim": 0, "index": tp.Tensor([1])}, + "iota": {"shape": tp.Tensor([4])}, + "masked_fill": {"value": 1}, + "max": {"dim": 0}, + "mean": {"dim": 0}, + "ones": {"shape": tp.Tensor([3, 2])}, + "permute": {"perm": [1, 0]}, + "prod": {"dim": 0}, + "quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, + "repeat": {"repeats": 2, "dim": 0}, + "reshape": {"shape": tp.Tensor([6])}, + "softmax": {"dim": 1}, + "split": {"indices_or_sections": 2}, + "squeeze": {"input": tp.ones((3, 1)), "dims": (1)}, + "sum": {"dim": 0}, + "transpose": {"dim0": 0, "dim1": 1}, + "unsqueeze": {"dim": 1}, + "var": {"dim": 0}, + "zeros": {"shape": tp.Tensor([3, 2])}, } @@ -137,24 +138,30 @@ def create_obj(func_obj, func_name, param_name, param_dtype, namespace): param_dict = func_sig.parameters param_type_annot = param_dict[param_name] init = None + # Check if there is a value in default_constraints_all for func_name and param_name and use it. default_constraints = default_constraints_all.get(func_name, None) if default_constraints != None: other_constraint = default_constraints.get(param_name, None) if other_constraint is not None: init = other_constraint + # If parameter had a default then use it otherwise skip. if init is None and param_type_annot.default is not param_type_annot.empty: # Checking if not equal to None since default can be 0 or similar. if param_type_annot.default != None: init = param_type_annot.default + param_type = param_type_annot.annotation while get_origin(param_type) in [Union, Optional]: param_type = get_args(param_type)[0] # ForwardRef refers to any case where type hint is a string. if isinstance(param_type, ForwardRef): param_type = param_type.__forward_arg__ + create_obj_func = find_func.get(param_type, default_builder) if create_obj_func: namespace[param_name] = create_obj_func(init, param_dtype, namespace) return namespace[param_name] + + assert False, f"Could not create parameter: {param_name}" diff --git a/tripy/tripy/flat_ir/ops/__init__.py b/tripy/tripy/flat_ir/ops/__init__.py index 940ab5cdd..719132bc4 100644 --- a/tripy/tripy/flat_ir/ops/__init__.py +++ b/tripy/tripy/flat_ir/ops/__init__.py @@ -31,6 +31,7 @@ from tripy.flat_ir.ops.dot import DotOp from tripy.flat_ir.ops.exponential import ExpOp from tripy.flat_ir.ops.flip import FlipOp +from tripy.flat_ir.ops.floor import FloorOp from tripy.flat_ir.ops.gather import DynamicGatherOp from tripy.flat_ir.ops.get_dimension_size import GetDimensionSizeOp from tripy.flat_ir.ops.iota import DynamicIotaOp diff --git a/tripy/tripy/flat_ir/ops/floor.py b/tripy/tripy/flat_ir/ops/floor.py new file mode 100644 index 000000000..26a0c3890 --- /dev/null +++ b/tripy/tripy/flat_ir/ops/floor.py @@ -0,0 +1,28 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass + +from mlir_tensorrt.compiler.dialects import stablehlo + +from tripy.flat_ir.ops.base import BaseFlatIROp + + +@dataclass(repr=False) +class FloorOp(BaseFlatIROp): + def to_mlir(self, operands): + return [stablehlo.floor(*operands)] diff --git a/tripy/tripy/frontend/ops/allclose.py b/tripy/tripy/frontend/ops/allclose.py index 1fd1f6e99..bbe6d8557 100644 --- a/tripy/tripy/frontend/ops/allclose.py +++ b/tripy/tripy/frontend/ops/allclose.py @@ -57,10 +57,6 @@ def allclose(a: "tripy.Tensor", b: "tripy.Tensor", rtol: float = 1e-05, atol: fl """ from tripy.frontend.trace.ops.unary_elementwise import abs from tripy.frontend.trace.ops.reduce import all - from tripy.common.datatype import int64 - - if a.dtype == int64: - raise_error("Known issue with i64. Allclose currently does not work with int64 inputs. Issue #116") compare = abs(a - b) <= (atol + rtol * abs(b)) return bool(all(compare)) diff --git a/tripy/tripy/frontend/ops/cumsum.py b/tripy/tripy/frontend/ops/cumsum.py new file mode 100644 index 000000000..6b017b2ae --- /dev/null +++ b/tripy/tripy/frontend/ops/cumsum.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tripy import constraints, export + +from tripy.frontend import utils as frontend_utils + + +@export.public_api(document_under="operations/functions") +@constraints.dtype_info( + dtype_variables={ + "T1": ["float32", "float16", "bfloat16", "float8", "int32"], + }, + dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, +) +@frontend_utils.process_dim +def cumsum(input: "tripy.Tensor", dim: int) -> "tripy.Tensor": + """ + Computes the cumulative sum of elements in the input along the dimension ``dim``. + + Args: + input: The input tensor. + dim: The dimension along which to compute the cumulative sum. + + Returns: + A tensor of the same shape as the input. + + .. code-block:: python + :linenos: + :caption: 1D tensor + + input = tp.arange(4, 0, step=-1, dtype=tp.int32) + output = tp.cumsum(input, dim=0) + + assert cp.array_equal(cp.cumsum(cp.from_dlpack(input)), cp.from_dlpack(output)) + + .. code-block:: python + :linenos: + :caption: 2D tensor + + input = tp.reshape(tp.arange(9, 0, step=-1, dtype=tp.int32), (3, 3)) + output = tp.cumsum(input, dim=0) + + assert cp.array_equal(cp.cumsum(cp.from_dlpack(input), axis=0), cp.from_dlpack(output)) + """ + # Consider: + # + # a = [3, 2, 1] + # + # then, we can implement cumsum as: + # + # out = a @ [[1, 1, 1] + # [0, 1, 1] + # [0, 0, 1]] + # + # which will yield: + # + # out = [3, 3 + 2, 3 + 2 + 1] + # + # In the general case where `a` is an N-dimensional tensor, we simply transpose + # the dimension of interest to the innermost position and then carry out the + # GEMM described above, then tranpose the output back. + from tripy.frontend.trace.ops.permute import permute + from tripy.frontend.ops.tensor_initializers import triu, ones + + # For the examples in the comments that follow, assume the input shape is (3, 5, 7) and + # we are applying cumsum over dim=1 (the dimension of length 5). + + # Swap dim to innermost position: (3, 5, 7) -> (3, 7, 5) + move_to_innermost_perm = list(range(input.rank)) + del move_to_innermost_perm[dim] + move_to_innermost_perm.append(dim) + transposed = permute(input, move_to_innermost_perm) + + # GEMM with square upper triangular matrix: (3, 7, 5) @ (5, 5) -> (3, 7, 5) + + # TODO: We should replace this with: + # shape = transposed.shape[-1:] * 2 + # once the relevant shape inference bugs are fixed. + shape = (transposed.shape[input.rank - 1], transposed.shape[input.rank - 1]) + out = transposed @ triu(ones(shape=shape, dtype=transposed.dtype)) + + # Swap innermost position back to `dim`: (3, 7, 5) -> (3, 5, 7) + reset_dim_perm = list(range(input.rank)) + del reset_dim_perm[-1] + reset_dim_perm.insert(dim, input.rank - 1) + out = permute(out, reset_dim_perm) + + return out diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index 3e6f84662..c47c8c6fb 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -217,10 +217,6 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.tril(cp.from_dlpack(input).get(), -1)) """ - from tripy.common.datatype import int64 - - if tensor.dtype == int64: - raise_error("Known issue with i64. Tril currently does not work with int64 inputs.") tri_mask = (iota_like(tensor, -2, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) >= iota_like( tensor, -1, datatype.int32 ) @@ -279,10 +275,6 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.triu(cp.from_dlpack(input).get(), -1)) """ - from tripy.common.datatype import int64 - - if tensor.dtype == int64: - raise_error("Known issue with i64. Triu currently does not work with int64 inputs.") tri_mask = (iota_like(tensor, -2, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) <= iota_like( tensor, -1, datatype.int32 ) @@ -332,11 +324,8 @@ def arange( assert tp.allclose(output, tp.Tensor(np.arange(2.3, 0.8, -0.2, dtype=np.float32))) """ - from tripy.common.datatype import int32, int64 - from tripy.frontend.tensor import Tensor - - if dtype == int64: - raise_error("Known issue with i64. Arange currently does not work with int64 inputs.") + from tripy.frontend import Tensor + from tripy.common.datatype import int32 if isinstance(step, numbers.Number) and step == 0: raise_error("Step in arange cannot be 0.", []) @@ -390,8 +379,4 @@ def arange(stop: Union[numbers.Number, "tripy.Tensor"], dtype: "tripy.dtype" = d assert (cp.from_dlpack(output).get() == np.arange(5, dtype=np.float32)).all() """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Arange currently does not work with int64 inputs. Issue #116") return arange(0, stop, dtype=dtype) diff --git a/tripy/tripy/frontend/trace/ops/binary_elementwise.py b/tripy/tripy/frontend/trace/ops/binary_elementwise.py index 0033491b9..4911fd9a3 100644 --- a/tripy/tripy/frontend/trace/ops/binary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/binary_elementwise.py @@ -34,6 +34,7 @@ class Kind: POW = " ** " MUL = " * " DIV = " / " + FLOOR_DIV = " // " MAXIMUM = "maximum" MINIMUM = "minimum" @@ -126,19 +127,35 @@ def broadcast_inputs(self, inputs, outputs): return inputs def to_flat_ir(self, inputs, outputs): - from tripy.flat_ir.ops import AddOp, DivideOp, MaxOp, MinOp, MulOp, PowOp, SubtractOp + from tripy.flat_ir.ops import AddOp, DivideOp, FloorOp, MaxOp, MinOp, MulOp, PowOp, SubtractOp + from tripy.flat_ir.tensor import FlatIRTensor inputs = self.broadcast_inputs(inputs, outputs) - OpType = { - BinaryElementwise.Kind.SUM: AddOp, - BinaryElementwise.Kind.POW: PowOp, - BinaryElementwise.Kind.MUL: MulOp, - BinaryElementwise.Kind.SUB: SubtractOp, - BinaryElementwise.Kind.DIV: DivideOp, - BinaryElementwise.Kind.MAXIMUM: MaxOp, - BinaryElementwise.Kind.MINIMUM: MinOp, - }[self.kind] - OpType.build(inputs, outputs) + + if self.kind == BinaryElementwise.Kind.FLOOR_DIV: + # First apply DivideOp + divide_out = FlatIRTensor.build( + shape=outputs[0].shape, + rank=outputs[0].rank, + dtype=outputs[0].dtype, + device=outputs[0].device, + reason_details=["Intermediate output of division operator for FLOOR_DIV operation."], + ) + DivideOp.build(inputs, [divide_out]) + # Then apply FloorOp to the result of the division + FloorOp.build([divide_out], outputs) + else: + OpType = { + BinaryElementwise.Kind.SUM: AddOp, + BinaryElementwise.Kind.POW: PowOp, + BinaryElementwise.Kind.MUL: MulOp, + BinaryElementwise.Kind.SUB: SubtractOp, + BinaryElementwise.Kind.DIV: DivideOp, + BinaryElementwise.Kind.MAXIMUM: MaxOp, + BinaryElementwise.Kind.MINIMUM: MinOp, + BinaryElementwise.Kind.FLOOR_DIV: DivideOp, + }[self.kind] + OpType.build(inputs, outputs) @dataclass(repr=False) @@ -435,8 +452,8 @@ def __floordiv__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Performs an elementwise floor division. Args: - self: Tensor to be divided by other. - other: The tensor by which to divide this one. + self: Tensor to be floor-divided by other. + other: The tensor by which to floor-divide this one. It should be broadcast-compatible. Returns: @@ -456,6 +473,8 @@ def __floordiv__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", from tripy.common.datatype import int32 return cast(cast(BinaryElementwise.build([self, other], BinaryElementwise.Kind.DIV), int32), self.dtype) + # Use the below code when https://github.com/NVIDIA/TensorRT-Incubator/issues/208 is fixed + # return BinaryElementwise.build([self, other], BinaryElementwise.Kind.FLOOR_DIV) @TENSOR_METHOD_REGISTRY("__rfloordiv__") @@ -469,8 +488,8 @@ def __rfloordiv__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", Performs an elementwise floor division. Args: - self: Tensor to be divided by other. - other: The tensor to be divided by this one. + self: Tensor to be floor-divided by other. + other: The tensor to be floor-divided by this one. It should be broadcast-compatible. Returns: @@ -490,6 +509,8 @@ def __rfloordiv__(self: Union["tripy.Tensor", Any], other: Union["tripy.Tensor", from tripy.common.datatype import int32 return cast(cast(BinaryElementwise.build([other, self], BinaryElementwise.Kind.DIV), int32), self.dtype) + # Use the below code when https://github.com/NVIDIA/TensorRT-Incubator/issues/208 is fixed + # return BinaryElementwise.build([other, self], BinaryElementwise.Kind.FLOOR_DIV) @export.public_api(document_under="operations/functions") diff --git a/tripy/tripy/frontend/trace/ops/flip.py b/tripy/tripy/frontend/trace/ops/flip.py index 3ea162a6b..d09477444 100644 --- a/tripy/tripy/frontend/trace/ops/flip.py +++ b/tripy/tripy/frontend/trace/ops/flip.py @@ -79,10 +79,6 @@ def flip(input: "tripy.Tensor", dims: Optional[Union[int, Sequence[int]]] = None output = tp.flip(input, dims=-1) assert cp.array_equal(cp.from_dlpack(output), cp.array([[4, 3, 2, 1, 0], [9, 8, 7, 6, 5]])) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Flip currently does not work with int64 inputs. Issue #116") rank = input.rank if dims is None: dims = [d for d in range(rank)] diff --git a/tripy/tripy/frontend/trace/ops/gather.py b/tripy/tripy/frontend/trace/ops/gather.py index 63b99e730..59f55470c 100644 --- a/tripy/tripy/frontend/trace/ops/gather.py +++ b/tripy/tripy/frontend/trace/ops/gather.py @@ -129,8 +129,4 @@ def gather(input: "tripy.Tensor", dim: int, index: "tripy.Tensor") -> "tripy.Ten assert np.array_equal(cp.from_dlpack(output).get(), np.take(cp.from_dlpack(data).get(), cp.from_dlpack(indices).get(), axis=1)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Gather currently does not work with int64 inputs. Issue #116") return Gather.build([input, index], dim) diff --git a/tripy/tripy/frontend/trace/ops/iota.py b/tripy/tripy/frontend/trace/ops/iota.py index d0d981c77..7ba0ad779 100644 --- a/tripy/tripy/frontend/trace/ops/iota.py +++ b/tripy/tripy/frontend/trace/ops/iota.py @@ -115,10 +115,6 @@ def iota( assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)) """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Iota currently does not work with int64 inputs. Issue #116") output_rank = len(shape) if isinstance(shape, Sequence) else None return iota_impl(shape, dim, dtype, output_rank) @@ -154,8 +150,4 @@ def iota_like(input: "tripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtyp assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)) """ - from tripy.common.datatype import int64 - - if dtype == int64: - raise_error("Known issue with i64. Iota_like currently does not work with int64 inputs. Issue #116") return iota_impl(input.shape, dim, utils.default(dtype, input.dtype), input.rank) diff --git a/tripy/tripy/frontend/trace/ops/matmul.py b/tripy/tripy/frontend/trace/ops/matmul.py index 2cc2b4ae4..e040654c1 100644 --- a/tripy/tripy/frontend/trace/ops/matmul.py +++ b/tripy/tripy/frontend/trace/ops/matmul.py @@ -237,8 +237,4 @@ def __matmul__(self: "tripy.Tensor", other: "tripy.Tensor") -> "tripy.Tensor": output = a @ b assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(a).get() @ cp.from_dlpack(b).get()) """ - from tripy.common.datatype import int64 - - if other.dtype == int64: - raise_error("Known issue with i64. __matmul__ currently does not work with int64 inputs. Issue #116") return MatrixMultiplication.build([self, other]) diff --git a/tripy/tripy/frontend/trace/ops/reduce.py b/tripy/tripy/frontend/trace/ops/reduce.py index c3f6d52c6..cfc33d722 100644 --- a/tripy/tripy/frontend/trace/ops/reduce.py +++ b/tripy/tripy/frontend/trace/ops/reduce.py @@ -167,10 +167,6 @@ def sum( assert np.array_equal(cp.from_dlpack(output).get(), np.sum(np.arange(6, dtype=np.float32).reshape((2, 3)), 0)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Sum currently does not work with int64 inputs. Issue #116") return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim) @@ -274,10 +270,6 @@ def max( assert np.array_equal(cp.from_dlpack(output).get(), np.max(np.arange(6, dtype=np.float32).reshape((2, 3)), 0)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Max currently does not work with int64 inputs. Issue #116") return _reduce_impl(input, Reduce.Kind.MAX, dim, keepdim) @@ -313,10 +305,6 @@ def prod( assert np.array_equal(cp.from_dlpack(output).get(), np.prod(np.arange(6, dtype=np.float32).reshape((2, 3)), 0)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Prod currently does not work with int64 inputs. Issue #116") return _reduce_impl(input, Reduce.Kind.MUL, dim, keepdim) @@ -374,10 +362,6 @@ def mean( assert np.array_equal(cp.from_dlpack(output).get(), np.mean(np.arange(6, dtype=np.float32).reshape((2, 3)), axis=1, keepdims=True)) """ - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Mean currently does not work with int64 inputs. Issue #116") return mean_impl(input, dim, keepdim) diff --git a/tripy/tripy/frontend/trace/ops/unary_elementwise.py b/tripy/tripy/frontend/trace/ops/unary_elementwise.py index 8aad7c598..25a2041c5 100644 --- a/tripy/tripy/frontend/trace/ops/unary_elementwise.py +++ b/tripy/tripy/frontend/trace/ops/unary_elementwise.py @@ -287,9 +287,4 @@ def abs(input: "tripy.Tensor") -> "tripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1, 2], dtype=np.float32)) """ - from tripy.frontend import Tensor - from tripy.common.datatype import int64 - - if input.dtype == int64: - raise_error("Known issue with i64. Abs currently does not work with int64 inputs. Issue #116") return UnaryElementwise.build([input], UnaryElementwise.Kind.ABS)