Skip to content

Commit

Permalink
Allow arange to accept tensors; add floordiv operation; allow fill op…
Browse files Browse the repository at this point in the history
…eration to take fill value as tensor
  • Loading branch information
parthchadha committed Sep 6, 2024
1 parent 9df5c16 commit 7947339
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 38 deletions.
28 changes: 0 additions & 28 deletions tripy/tests/integration/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cupy as cp
import numpy as np
Expand Down
4 changes: 1 addition & 3 deletions tripy/tripy/frontend/ops/tensor_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

import math
import numbers
from typing import Optional, Sequence, Union

Expand Down Expand Up @@ -370,7 +369,7 @@ def arange(
},
dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
)
def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor":
def arange(stop: Union[numbers.Number, "tripy.Tensor"], dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor":
r"""
Returns a 1D tensor containing a sequence of numbers in the half-open interval
:math:`[0, \text{stop})` incrementing by 1.
Expand All @@ -392,7 +391,6 @@ def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tr
assert (cp.from_dlpack(output).get() == np.arange(5, dtype=np.float32)).all()
"""
from tripy.common.datatype import int64
from tripy.frontend.tensor import Tensor

if dtype == int64:
raise_error("Known issue with i64. Arange currently does not work with int64 inputs. Issue #116")
Expand Down
27 changes: 20 additions & 7 deletions tripy/tripy/function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import inspect
from collections import OrderedDict, defaultdict
from textwrap import dedent
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, ForwardRef
from typing import get_origin, get_args, Union


from dataclasses import dataclass

Expand Down Expand Up @@ -102,14 +104,23 @@ def matches_type(name: str, annotation: type, arg: Any):
# In cases where a type is not available at the time of function definition, the type
# annotation may be provided as a string. Since we need the actual type, we just
# eval it here.
if isinstance(annotation, str):
import importlib

if isinstance(annotation, (str, ForwardRef)):
if isinstance(annotation, ForwardRef):
annotation = annotation.__forward_arg__
try:
annotation = eval(annotation)
module_name, class_name = annotation.rsplit(".", 1)
module = importlib.import_module(module_name)
resolved_type = getattr(module, class_name)
return matches_type(name, resolved_type, arg)
except Exception as e:
raise NameError(
f"Error while evaluating type annotation: '{annotation}' for parameter: '{name}' of function: '{self.func.__name__}'."
f"\nNote: Error was: {e}"
)
print(f"Warning: Could not resolve type '{annotation}'. Assuming it matches. Error: {e}")
return True

# Handle Union types
if get_origin(annotation) is Union:
return any(matches_type(name, t, arg) for t in get_args(annotation))

try:
return isinstance(arg, annotation)
Expand Down Expand Up @@ -137,6 +148,8 @@ def matches_type(name: str, annotation: type, arg: Any):
f"'{annotation.type_info.__qualname__}' but got argument of type: '{type(arg).__qualname__}'."
],
)
else:
print(f"matched type {annotation.type_info.__qualname__} with arg type of {type(arg).__qualname__}")

for name, arg in kwargs.items():
if name in annotations:
Expand Down

0 comments on commit 7947339

Please sign in to comment.