From c76d61c9081edad502d80f31d2b2db3c6a90a2ed Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:33:30 +0200 Subject: [PATCH] expand tests ffftn --- heat/fft/fft.py | 11 +++++++++++ heat/fft/tests/test_fft.py | 19 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 05dd2d831d..e827c5c0de 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -117,6 +117,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: local_x = x.larray except AttributeError: raise TypeError("x must be a DNDarray, is {}".format(type(x))) + original_split = x.split # sanitize kwargs @@ -126,6 +127,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: if repeated_axes: raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") s = kwargs.get("s", None) + s = sanitize_axis(x.gshape, s) norm = kwargs.get("norm", None) # non-distributed DNDarray @@ -142,6 +144,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: for i, axis in enumerate(axes): output_shape[axis] = s[i] else: + axes = tuple(range(x.ndim)) s = tuple(output_shape[axis] for axis in axes) output_shape = tuple(output_shape) @@ -170,6 +173,14 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: ) x = x.transpose(transpose_axes) + # original split is 0 and fft is along axis 0 + if x.ndim == 1: + _ = x.resplit(axis=None) + result = __fftn_op(_, fftn_op, **kwargs) + del _ + result.resplit_(axis=0) + return result + # redistribute x from axis 0 to 1 _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 9f7de3de4e..afb00c402a 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -62,7 +62,7 @@ def test_fft(self): def test_ifft(self): # 1D non-distributed - x = ht.random.randn(6) + x = ht.random.randn(6, dtype=ht.float64) x_fft = ht.fft.fft(x) y = ht.fft.ifft(x_fft) self.assertIsInstance(y, ht.DNDarray) @@ -88,7 +88,22 @@ def test_irfft2(self): pass def test_fftn(self): - pass + # 1D non-distributed + x = ht.random.randn(6) + y = ht.fft.fftn(x) + np_y = np.fft.fftn(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assert_array_equal(y, np_y) + + # 1D distributed + x = ht.random.randn(6, split=0) + y = ht.fft.fftn(x) + np_y = np.fft.fftn(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) def test_ifftn(self): pass