Skip to content

Commit

Permalink
feat: hfft2 function and tests added to paddle frontend. (#23520)
Browse files Browse the repository at this point in the history
Co-authored-by: nathzi1505 <[email protected]>
  • Loading branch information
Dhruv-Varshney-developer and p3jitnath authored Sep 15, 2023
1 parent 4b53291 commit 3ad8ff4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
23 changes: 22 additions & 1 deletion ivy/functional/frontends/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ def hfft(x, n=None, axis=-1, norm="backward", name=None):
return ivy.real(result)


@with_supported_dtypes(
{"2.5.1 and below": "complex64"},
"paddle",
)
@to_ivy_arrays_and_back
def hfft2(x, s=None, axis=(-2, -1), norm="backward"):
# check if the input tensor x is a hermitian complex
if not ivy.allclose(ivy.conj(ivy.matrix_transpose(x)), x):
raise ValueError("Input tensor x must be Hermitian complex.")

fft_result = ivy.fft2(x, s=s, dim=axis, norm=norm)

# Depending on the norm, apply scaling and normalization
if norm == "forward":
fft_result /= ivy.sqrt(ivy.prod(ivy.shape(fft_result)))
elif norm == "ortho":
fft_result /= ivy.sqrt(ivy.prod(ivy.shape(x)))

return ivy.real(fft_result) # Return the real part of the result


@with_supported_dtypes(
{"2.5.1 and below": ("complex64", "complex128")},
"paddle",
Expand Down Expand Up @@ -130,4 +151,4 @@ def rfftfreq(n, d=1.0, dtype=None, name=None):
val = 1.0 / (n * d)
pos_max = n // 2 + 1
indices = ivy.arange(0, pos_max, dtype=dtype)
return indices * val
return indices * val
47 changes: 45 additions & 2 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from hypothesis import strategies as st
from hypothesis import given, strategies as st

# local
import ivy_tests.test_ivy.helpers as helpers
Expand Down Expand Up @@ -107,6 +107,49 @@ def test_paddle_hfft(
)


@given(
s=st.one_of(
st.none(), st.tuples(st.integers(min_value=1), st.integers(min_value=1))
),
axis=st.one_of(st.none(), st.tuples(st.integers(min_value=-2, max_value=-1))),
shape=st.lists(st.integers(min_value=1, max_value=10), min_size=2, max_size=2).map(
tuple
),
)
@handle_frontend_test(
fn_tree="paddle.fft.hfft2",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("complex64"),
),
)
def test_paddle_hfft2(
dtype_x_axis,
s,
axis,
norm,
frontend,
backend_fw,
test_flags,
fn_tree,
shape,
):
input_dtypes, x, axis = dtype_x_axis
x = x.reshape(shape) # reshape x to the generated shape

for norm in ["backward", "forward", "ortho"]:
helpers.test_frontend_function(
input_dtypes=input_dtypes,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
x=x,
s=s,
axis=axis,
norm=norm,
)


@handle_frontend_test(
fn_tree="paddle.fft.ifft",
dtype_x_axis=helpers.dtype_values_axis(
Expand Down Expand Up @@ -241,4 +284,4 @@ def test_paddle_rfftfreq(
test_values=True,
n=n,
d=d,
)
)

0 comments on commit 3ad8ff4

Please sign in to comment.