Skip to content

Commit

Permalink
Added IFFT to stateful class for stateful layers
Browse files Browse the repository at this point in the history
  • Loading branch information
arshPratap committed Sep 1, 2023
1 parent 2c0bc53 commit 85f84b6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
55 changes: 55 additions & 0 deletions ivy/stateful/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,3 +2224,58 @@ def _forward(self, x):
The input array as it is.
"""
return x


class IFFT(Module):
def __init__(
self,
dim,
/,
*,
norm="backward",
n=None,
out=None,
device=None,
dtype=None,
):
"""
Class for applying IFFT to input.
Parameters
----------
dim : int
Dimension along which to take the IFFT.
norm : str
Normalization mode. Default: 'backward'
n : int
Size of the IFFT. Default: None
out : int
Size of the output. Default: None
"""
self._dim = dim
self._norm = norm
self._n = n
self._out = out
Module.__init__(self, device=device, dtype=dtype)

def _forward(self, inputs):
"""
Forward pass of the layer.
Parameters
----------
inputs : array
Input array to take the IFFT of.
Returns
-------
array
The output array of the layer.
"""
return ivy.ifft(
inputs,
self._dim,
norm=self._norm,
n=self._n,
out=self._out,
)
39 changes: 39 additions & 0 deletions ivy_tests/test_ivy/test_stateful/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,45 @@ def test_identity_layer(
)


# IFFT
@handle_method(
method_tree="IFFT.__call__",
x_and_ifft=exp_layers_tests._x_and_ifft(),
)
def test_ifft_layer(
*,
x_and_ifft,
test_gradients,
on_device,
class_name,
method_name,
ground_truth_backend,
init_flags,
method_flags,
backend_fw,
):
dtype, x, dim, norm, n = x_and_ifft
helpers.test_method(
ground_truth_backend=ground_truth_backend,
backend_to_test=backend_fw,
init_flags=init_flags,
method_flags=method_flags,
init_all_as_kwargs_np={
"dim": dim,
"norm": norm,
"n": n,
"device": on_device,
"dtype": dtype[0],
},
method_input_dtypes=dtype,
method_all_as_kwargs_np={"inputs": x[0]},
class_name=class_name,
method_name=method_name,
test_gradients=test_gradients,
on_device=on_device,
)


# linear
@handle_method(
method_tree="Linear.__call__",
Expand Down

0 comments on commit 85f84b6

Please sign in to comment.