diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index c11c0b0e31c83..944e07da8e156 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -29,7 +29,7 @@ def fftshift(x, axes=None, name=None): return roll -@with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") +@with_unsupported_dtypes({"2.5.2 and below": ("float16", "bfloat16")}, "paddle") @to_ivy_arrays_and_back def rfft(a, n=None, axis=-1, norm=None): if norm is None: @@ -43,4 +43,5 @@ def rfft(a, n=None, axis=-1, norm=None): a_new = ivy.astype(a, "complex64") else: a_new = a - return ivy.fft(a_new, axis, norm=norm, n=n)[:n//2 + 1] + fft_fun = ivy.fft + return fft_fun(a_new, axis, norm=norm, n=n)[:n//2 + 1]