From d5393c3017c3d4fb02ddbcd1f1ea5be49c7f761c Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Thu, 19 Sep 2024 15:21:41 -0700 Subject: [PATCH] Add rank verification, address comments --- tripy/tests/frontend/ops/test_outer.py | 23 +++++++++++++++++++++++ tripy/tripy/frontend/ops/outer.py | 16 +++++++++++----- 2 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 tripy/tests/frontend/ops/test_outer.py diff --git a/tripy/tests/frontend/ops/test_outer.py b/tripy/tests/frontend/ops/test_outer.py new file mode 100644 index 000000000..17d2fd7be --- /dev/null +++ b/tripy/tests/frontend/ops/test_outer.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import helper +import tripy as tp +class TestOuter: + def test_invalid_rank_fails(self): + a = tp.ones((5, 1)) + b = tp.ones((1, 4)) + with helper.raises(tp.TripyException, "Expected input vectors to be 1-d."): + tp.outer(a, b) \ No newline at end of file diff --git a/tripy/tripy/frontend/ops/outer.py b/tripy/tripy/frontend/ops/outer.py index 727ed5fd0..fd620013e 100644 --- a/tripy/tripy/frontend/ops/outer.py +++ b/tripy/tripy/frontend/ops/outer.py @@ -20,15 +20,14 @@ @export.public_api(document_under="operations/functions") -@frontend_utils.convert_inputs_to_tensors(sync_arg_types=[("vec1", "vec2")]) @constraints.dtype_info( dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, - dtype_constraints={"vec1": "T1", "vec2": "T1", "other": "T1", constraints.RETURN_VALUE: "T1"}, + dtype_constraints={"vec1": "T1", "vec2": "T1", constraints.RETURN_VALUE: "T1"}, ) def outer(vec1: "tripy.Tensor", vec2: "tripy.Tensor") -> "tripy.Tensor": r""" - Computes the outer product of 1-d vectors `vec1` and `vec2`, such that the - output dimension is (m x n) if the inputs are of size m and n respectively. + Computes the outer product of 1-d vectors ``vec1`` and ``vec2``, such that the + output dimension is :math:`(m \times n)` if the inputs are of size :math:`m` and :math:`n` respectively. Args: vec1: The first 1d input vector. @@ -50,5 +49,12 @@ def outer(vec1: "tripy.Tensor", vec2: "tripy.Tensor") -> "tripy.Tensor": assert tp.allclose(output, tp.Tensor(torch.outer(t1, t2))) """ from tripy.frontend.trace.ops.unsqueeze import unsqueeze + from tripy.common.exception import raise_error - return unsqueeze(vec1, -1) * unsqueeze(vec2, 0) \ No newline at end of file + if vec1.rank != 1 or vec2.rank != 1: + raise_error( + "Expected input vectors to be 1-d.", + [f"Got vec1.rank={vec1.rank}, ", f"vec2.rank={vec2.rank}"], + ) + + return unsqueeze(vec1, -1) @ unsqueeze(vec2, 0) \ No newline at end of file