diff --git a/tests/test_qlsi.py b/tests/test_qlsi.py index a8f507d..48b5e06 100644 --- a/tests/test_qlsi.py +++ b/tests/test_qlsi.py @@ -127,6 +127,11 @@ def test_qlsi_unwrap_phase_2d_3d(): def test_qlsi_rotate_2d_3d(hologram): + """ + Ensure the old 2d and new 3d rotation is identical. + Note that the hologram is used only as an example input image, + and it is not the correct data type for QLSI. + """ data_2d = hologram data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) @@ -155,6 +160,11 @@ def test_qlsi_rotate_2d_3d(hologram): def test_qlsi_pad_2d_3d(hologram): + """ + Ensure the old 2d and new 3d padding is identical. + Note that the hologram is used only as an example input image, + and it is not the correct data type for QLSI. + """ data_2d = hologram data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) @@ -171,22 +181,31 @@ def test_qlsi_pad_2d_3d(hologram): def test_fxy_complex_mul(hologram): + """ + Ensure the old 2d and new 3d complex multiplication is identical. + Note that the hologram is used only as an example input image, + and it is not the correct data type for QLSI. + """ data_2d = hologram data_3d, _ = qpretrieve.data_array_layout._convert_2d_to_3d(data_2d) assert np.array_equal(data_2d, data_3d[0]) # 2d - fx_2d = data_2d.reshape(-1, 1) - fy_2d = data_2d.reshape(1, -1) + fx_2d = np.fft.fftfreq(data_2d.shape[0]).reshape(-1, 1) + fy_2d = np.fft.fftfreq(data_2d.shape[1]).reshape(1, -1) fxy_2d = -2 * np.pi * 1j * (fx_2d + 1j * fy_2d) + assert fxy_2d.shape == (64, 64) fxy_2d[0, 0] = 1 # 3d - fx_3d = data_3d.reshape(data_3d.shape[0], -1, 1) - fy_3d = data_3d.reshape(data_3d.shape[0], 1, -1) + fx_3d = np.fft.fftfreq(data_3d.shape[-2]).reshape(-1, 1) + fy_3d = np.fft.fftfreq(data_3d.shape[-1]).reshape(1, -1) fxy_3d = -2 * np.pi * 1j * (fx_3d + 1j * fy_3d) + fxy_3d = np.repeat(fxy_3d[np.newaxis, :, :], + repeats=data_3d.shape[0], axis=0) + assert fxy_3d.shape == (1, 64, 64) fxy_3d[:, 0, 0] = 1 - assert np.array_equal(fx_2d, fx_3d[0]) + assert np.array_equal(fx_2d, fx_3d) assert np.array_equal(fxy_2d, fxy_3d[0])