Skip to content

Commit

Permalink
Allow arange to accept tensors; add floordiv operation; allow fill op…
Browse files Browse the repository at this point in the history
…eration to take fill value as tensor
  • Loading branch information
parthchadha committed Sep 6, 2024
1 parent 9df5c16 commit 50b1f24
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 65 deletions.
28 changes: 0 additions & 28 deletions tripy/tests/integration/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cupy as cp
import numpy as np
Expand Down
45 changes: 8 additions & 37 deletions tripy/tripy/frontend/ops/tensor_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
)
def arange(
start: Union[numbers.Number, "tripy.Tensor"],
stop: Union[numbers.Number, "tripy.Tensor"],
stop: Union[numbers.Number, "tripy.Tensor"] = None,
step: Union[numbers.Number, "tripy.Tensor"] = 1,
dtype: "tripy.dtype" = datatype.float32,
) -> "tripy.Tensor":
Expand All @@ -317,6 +317,9 @@ def arange(
Returns:
A tensor of shape :math:`[\frac{\text{stop}-\text{start}}{\text{step}}]`.
Note:
If only single argument is passed to arange, argument will be interpreted as stop and start will be set to 0.
.. code-block:: python
:linenos:
:caption: Example
Expand All @@ -336,6 +339,10 @@ def arange(
from tripy.common.datatype import int32, int64
from tripy.frontend.tensor import Tensor

if stop is None:
stop = start
start = 0

if dtype == int64:
raise_error("Known issue with i64. Arange currently does not work with int64 inputs.")

Expand All @@ -361,39 +368,3 @@ def arange(

output = iota((size,), 0, dtype) * full((size,), step, dtype) + full((size,), start, dtype)
return output


@export.public_api(document_under="operations/initializers")
@constraints.dtype_info(
dtype_variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
},
dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
)
def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor":
r"""
Returns a 1D tensor containing a sequence of numbers in the half-open interval
:math:`[0, \text{stop})` incrementing by 1.
Args:
stop: The exclusive upper bound of the values to generate.
dtype: The desired datatype of the tensor.
Returns:
A tensor of shape :math:`[\text{stop}]`.
.. code-block:: python
:linenos:
:caption: Example
output = tp.arange(5)
assert (cp.from_dlpack(output).get() == np.arange(5, dtype=np.float32)).all()
"""
from tripy.common.datatype import int64
from tripy.frontend.tensor import Tensor

if dtype == int64:
raise_error("Known issue with i64. Arange currently does not work with int64 inputs. Issue #116")
return arange(0, stop, dtype=dtype)

0 comments on commit 50b1f24

Please sign in to comment.