From fc392a0049137383386b981750259e0b21383844 Mon Sep 17 00:00:00 2001 From: HPatto <139283897+HPatto@users.noreply.github.com> Date: Thu, 12 Oct 2023 21:19:07 +0000 Subject: [PATCH] Implementation of inv function in Paddle. --- ivy/functional/frontends/paddle/linalg.py | 5 +++ .../test_frontends/test_paddle/test_linalg.py | 38 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/ivy/functional/frontends/paddle/linalg.py b/ivy/functional/frontends/paddle/linalg.py index cad91bd96d5d0..992126b86905e 100644 --- a/ivy/functional/frontends/paddle/linalg.py +++ b/ivy/functional/frontends/paddle/linalg.py @@ -99,6 +99,11 @@ def eigvalsh(x, UPLO="L", name=None): return ivy.eigvalsh(x, UPLO=UPLO) +@to_ivy_arrays_and_back +def inv(x, name=None): + return ivy.inv(x) + + @to_ivy_arrays_and_back def lu_unpack(lu_data, lu_pivots, unpack_datas=True, unpack_pivots=True, *, out=None): A = lu_data diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py index f38a9b0d66383..02f940f9680b3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_linalg.py @@ -688,6 +688,44 @@ def test_paddle_eigvalsh( ) +# inv +@handle_frontend_test( + fn_tree="paddle.linalg.inv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100.0, + max_value=100.0, + shape=helpers.ints(min_value=2, max_value=10).map(lambda x: tuple([x, x])), + ).filter( + lambda x: "float16" not in x[0] + and "bfloat16" not in x[0] + and np.linalg.det(np.asarray(x[1][0])) != 0 + ), + test_with_out=st.just(False), +) +def test_paddle_inv( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + rtol=1e-01, + atol=1e-01, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + @handle_frontend_test( fn_tree="paddle.lu_unpack", dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True),