diff --git a/ivy/functional/frontends/paddle/fft.py b/ivy/functional/frontends/paddle/fft.py index eb9d9dec4c825..84dc998b03fd2 100644 --- a/ivy/functional/frontends/paddle/fft.py +++ b/ivy/functional/frontends/paddle/fft.py @@ -11,8 +11,9 @@ "paddle", ) @to_ivy_arrays_and_back -def fft(x, n=None, axis=-1.0, norm="backward", name=None): - ret = ivy.fft(ivy.astype(x, "complex128"), axis, norm=norm, n=n) +def fft2(x, s=None, axes=(- 2, - 1), norm='backward', name=None): + x_comp = x.astype("complex128") + ret = ivy.fft2(x_comp, s=s, dim=axes, norm=norm) return ivy.astype(ret, x.dtype) @@ -44,3 +45,11 @@ def fftshift(x, axes=None, name=None): roll = ivy.roll(x, shifts, axis=axes) return roll +@with_supported_dtypes( + {"2.5.0 and below": ("complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def fft2(x, n=None, axes=(- 2, - 1), norm='backward', name=None): + ret = ivy.fft2(ivy.astype(x, "complex128"), axes, norm=norm, n=n) + return ivy.astype(ret, x.dtype) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py index 0fa652f27d4f3..36c7d6649c0d5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py @@ -67,3 +67,39 @@ def test_paddle_fttshift(dtype_x_axis, frontend, test_flags, fn_tree, on_device) x=x[0], axes=axes, ) +@handle_frontend_test( + fn_tree="paddle.fft.fft2", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + min_num_dims=2, + min_dim_size=2, + valid_axis=True, + force_int_axis=True, + ), + n=st.one_of( + st.integers(min_value=2, max_value=10), + st.just(None), + ), + norm=st.sampled_from(["backward", "ortho", "forward"]), +) +def test_paddle_fft2( + dtype_x_axis, + n, + norm, + frontend, + test_flags, + fn_tree, +): + input_dtypes, x, axes = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + x=x[0], + n=n, + axes=axes, + norm=norm, + )