diff --git a/ivy/functional/frontends/paddle/fft.py b/ivy/functional/frontends/paddle/fft.py index 91b0f243feed8..813fa31f81232 100644 --- a/ivy/functional/frontends/paddle/fft.py +++ b/ivy/functional/frontends/paddle/fft.py @@ -152,3 +152,29 @@ def rfftfreq(n, d=1.0, dtype=None, name=None): pos_max = n // 2 + 1 indices = ivy.arange(0, pos_max, dtype=dtype) return indices * val + + +@with_supported_dtypes( + {"2.5.1 and below": ("complex64", "complex128")}, + "paddle", +) +@to_ivy_arrays_and_back +def rfftn(x, s=None, axes=None, norm="backward", name=None): + """Compute the N-dimensional discrete Fourier Transform over any number of axes in + an M-dimensional real array by means of the Fast Fourier Transform (FFT).""" + if s is None: + s = x.shape + + # Apply rfft along the last axis + rfft_result = ivy.rfftn(x, s=s, axes=axes, norm=norm, out=name) + + if axes is None: + # If axes is not specified, transform all axes except the last one. + axes = tuple(range(x.ndim - 1)) + + # Apply fft on the specified axes for N-dimensional FFT + fftn_result = rfft_result + for axis in axes: + fftn_result = ivy.fft(fftn_result, axis=axis, norm=norm) + + return fftn_result diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py index e5a37ba5c904e..3cf67e5580d51 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py @@ -285,3 +285,52 @@ def test_paddle_rfftfreq( n=n, d=d, ) + + +@handle_frontend_test( + fn_tree="paddle.fft.rfftn", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + s=st.one_of( + st.tuples( + st.integers(min_value=2, max_value=10), + st.integers(min_value=2, max_value=10), + ), + st.just(None), + ), + axes=st.one_of( + st.lists( + st.integers(min_value=0, max_value=2), min_size=1, max_size=3, unique=True + ), + st.just(None), + ), + norm=st.sampled_from(["backward", "ortho", "forward"]), +) +def test_paddle_rfftn( + dtype_x_axis, + s, + axes, + norm, + frontend, + backend_fw, + test_flags, + fn_tree, +): + input_dtypes, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + x=x[0], + s=s, + axes=axes, + norm=norm, + )