From 92280e132c97a9fc316ecaa7fb01e1582dbfd195 Mon Sep 17 00:00:00 2001 From: Firi <56671623+firi193@users.noreply.github.com> Date: Fri, 6 Oct 2023 20:03:15 +0300 Subject: [PATCH] feat(Paddle-frontend): Adds paddle.fft.ihfft2 (#25886) Co-authored-by: hirwa-nshuti --- ivy/functional/frontends/paddle/fft.py | 35 +++++++++++++++ .../test_frontends/test_paddle/test_fft.py | 44 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/ivy/functional/frontends/paddle/fft.py b/ivy/functional/frontends/paddle/fft.py index ae868ddf3e6f5..9045ef19bb89c 100644 --- a/ivy/functional/frontends/paddle/fft.py +++ b/ivy/functional/frontends/paddle/fft.py @@ -153,6 +153,41 @@ def ifftshift(x, axes=None, name=None): return roll +@with_supported_dtypes( + { + "2.5.1 and below": ( + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def ihfft2(x, s=None, axes=(-2, -1), norm="backward", name=None): + # check if the input array is two-dimensional and real + if len(ivy.array(x).shape) != 2 or ivy.is_complex_dtype(x): + raise ValueError("input must be a two-dimensional real array") + + # cast the input to the same float64 type so that there are no backend issues + x_ = ivy.astype(x, ivy.float64) + + ihfft2_result = 0 + # Compute the complex conjugate of the 2-dimensional discrete Fourier Transform + if norm == "backward": + ihfft2_result = ivy.conj(ivy.rfftn(x_, s=s, axes=axes, norm="forward")) + if norm == "forward": + ihfft2_result = ivy.conj(ivy.rfftn(x_, s=s, axes=axes, norm="backward")) + if norm == "ortho": + ihfft2_result = ivy.conj(ivy.rfftn(x_, s=s, axes=axes, norm="ortho")) + + if x.dtype == ivy.float32 or x.dtype == ivy.int32 or x.dtype == ivy.int64: + return ivy.astype(ihfft2_result, ivy.complex64) + if x.dtype == ivy.float64: + return ivy.astype(ihfft2_result, ivy.complex128) + + @with_supported_dtypes( {"2.5.1 and below": ("complex64", "complex128")}, "paddle", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py index 269473214f479..c49ea38dd9aae 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py @@ -258,6 +258,50 @@ def test_paddle_ifftshift( ) +@handle_frontend_test( + fn_tree="paddle.fft.ihfft2", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=["float64", "float32", "int64", "int32"], + min_value=-10, + max_value=10, + min_num_dims=2, + max_num_dims=2, + shape=st.tuples( + st.integers(min_value=2, max_value=10), + st.integers(min_value=2, max_value=10), + ), + ), + s=st.one_of( + st.lists(st.integers(min_value=2, max_value=10), min_size=2, max_size=2), + ), + axes=st.just([-2, -1]), + norm=st.sampled_from(["backward", "ortho", "forward"]), +) +def test_paddle_ihfft2( + dtype_x_axis, + s, + axes, + norm, + frontend, + backend_fw, + test_flags, + fn_tree, +): + input_dtypes, x, axis_ = dtype_x_axis + + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + x=x[0], + s=s, + axes=axes, + norm=norm, + ) + + @handle_frontend_test( fn_tree="paddle.fft.irfft", dtype_x_axis=helpers.dtype_values_axis(