From 79473397d5f9ce5ac2737541a987f74d80b918ee Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 6 Sep 2024 09:30:03 -0700 Subject: [PATCH] Allow arange to accept tensors; add floordiv operation; allow fill operation to take fill value as tensor --- tripy/tests/integration/test_arange.py | 28 ------------------- .../tripy/frontend/ops/tensor_initializers.py | 4 +-- tripy/tripy/function_registry.py | 27 +++++++++++++----- 3 files changed, 21 insertions(+), 38 deletions(-) diff --git a/tripy/tests/integration/test_arange.py b/tripy/tests/integration/test_arange.py index 31951e7df..e5a03ec72 100644 --- a/tripy/tests/integration/test_arange.py +++ b/tripy/tests/integration/test_arange.py @@ -12,34 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import cupy as cp import numpy as np diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index 20af5aa8e..c2e756dd9 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -15,7 +15,6 @@ # limitations under the License. # -import math import numbers from typing import Optional, Sequence, Union @@ -370,7 +369,7 @@ def arange( }, dtype_constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, ) -def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor": +def arange(stop: Union[numbers.Number, "tripy.Tensor"], dtype: "tripy.dtype" = datatype.float32) -> "tripy.Tensor": r""" Returns a 1D tensor containing a sequence of numbers in the half-open interval :math:`[0, \text{stop})` incrementing by 1. @@ -392,7 +391,6 @@ def arange(stop: numbers.Number, dtype: "tripy.dtype" = datatype.float32) -> "tr assert (cp.from_dlpack(output).get() == np.arange(5, dtype=np.float32)).all() """ from tripy.common.datatype import int64 - from tripy.frontend.tensor import Tensor if dtype == int64: raise_error("Known issue with i64. Arange currently does not work with int64 inputs. Issue #116") diff --git a/tripy/tripy/function_registry.py b/tripy/tripy/function_registry.py index 6d26f44ca..796d19042 100644 --- a/tripy/tripy/function_registry.py +++ b/tripy/tripy/function_registry.py @@ -19,7 +19,9 @@ import inspect from collections import OrderedDict, defaultdict from textwrap import dedent -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, ForwardRef +from typing import get_origin, get_args, Union + from dataclasses import dataclass @@ -102,14 +104,23 @@ def matches_type(name: str, annotation: type, arg: Any): # In cases where a type is not available at the time of function definition, the type # annotation may be provided as a string. Since we need the actual type, we just # eval it here. - if isinstance(annotation, str): + import importlib + + if isinstance(annotation, (str, ForwardRef)): + if isinstance(annotation, ForwardRef): + annotation = annotation.__forward_arg__ try: - annotation = eval(annotation) + module_name, class_name = annotation.rsplit(".", 1) + module = importlib.import_module(module_name) + resolved_type = getattr(module, class_name) + return matches_type(name, resolved_type, arg) except Exception as e: - raise NameError( - f"Error while evaluating type annotation: '{annotation}' for parameter: '{name}' of function: '{self.func.__name__}'." - f"\nNote: Error was: {e}" - ) + print(f"Warning: Could not resolve type '{annotation}'. Assuming it matches. Error: {e}") + return True + + # Handle Union types + if get_origin(annotation) is Union: + return any(matches_type(name, t, arg) for t in get_args(annotation)) try: return isinstance(arg, annotation) @@ -137,6 +148,8 @@ def matches_type(name: str, annotation: type, arg: Any): f"'{annotation.type_info.__qualname__}' but got argument of type: '{type(arg).__qualname__}'." ], ) + else: + print(f"matched type {annotation.type_info.__qualname__} with arg type of {type(arg).__qualname__}") for name, arg in kwargs.items(): if name in annotations: