diff --git a/tripy/tests/integration/test_arange.py b/tripy/tests/integration/test_arange.py index 31951e7df..e5a03ec72 100644 --- a/tripy/tests/integration/test_arange.py +++ b/tripy/tests/integration/test_arange.py @@ -12,34 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import cupy as cp import numpy as np diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index 20af5aa8e..beabf8c4c 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -300,7 +300,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": ) def arange( start: Union[numbers.Number, "tripy.Tensor"], - stop: Union[numbers.Number, "tripy.Tensor"], + stop: Union[numbers.Number, "tripy.Tensor"] = None, step: Union[numbers.Number, "tripy.Tensor"] = 1, dtype: "tripy.dtype" = datatype.float32, ) -> "tripy.Tensor": @@ -317,6 +317,9 @@ def arange( Returns: A tensor of shape :math:`[\frac{\text{stop}-\text{start}}{\text{step}}]`. + Note: + If only single argument is passed to arange, argument will be interpreted as stop and start will be set to 0. + .. code-block:: python :linenos: :caption: Example @@ -336,6 +339,10 @@ def arange( from tripy.common.datatype import int32, int64 from tripy.frontend.tensor import Tensor + if stop is None: + stop = start + start = 0 + if dtype == int64: raise_error("Known issue with i64. Arange currently does not work with int64 inputs.") @@ -361,39 +368,3 @@ def arange( output = iota((size,), 0, dtype) * full((size,), step, dtype) + full((size,), start, dtype) return output - - -@export.public_api(document_under="operations/initializers") -@constraints.dtype_info( - dtype_variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"], - }, - dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, -) -def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor": - r""" - Returns a 1D tensor containing a sequence of numbers in the half-open interval - :math:`[0, \text{stop})` incrementing by 1. - - - Args: - stop: The exclusive upper bound of the values to generate. - dtype: The desired datatype of the tensor. - - Returns: - A tensor of shape :math:`[\text{stop}]`. - - .. code-block:: python - :linenos: - :caption: Example - - output = tp.arange(5) - - assert (cp.from_dlpack(output).get() == np.arange(5, dtype=np.float32)).all() - """ - from tripy.common.datatype import int64 - from tripy.frontend.tensor import Tensor - - if dtype == int64: - raise_error("Known issue with i64. Arange currently does not work with int64 inputs. Issue #116") - return arange(0, stop, dtype=dtype)