From 22985b05466f602652c117ecb727eddea1547ef9 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 24 Jul 2024 17:10:24 +0200 Subject: [PATCH] adjust tests --- tests/test_fft.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/test_fft.py b/tests/test_fft.py index 293f77a..f52986c 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -5,6 +5,7 @@ from math import prod import jax.numpy as jnp +import numpy as np import pytest from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils @@ -14,6 +15,7 @@ from numpy.testing import assert_allclose import jaxdecomp +from jaxdecomp._src import PENCILS, SLAB_XY, SLAB_YZ # Initialize cuDecomp initialize_distributed() @@ -37,13 +39,21 @@ def create_spmd_array(global_shape, pdims): key=jax.random.PRNGKey(rank)) # Remap to the global array from the local slice devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices, axis_names=('y', 'z')) + mesh = Mesh(devices.T, axis_names=('z', 'y')) global_array = multihost_utils.host_local_array_to_global_array( local_array, mesh, P('z', 'y')) return global_array, mesh +def print_array(array): + print(f"shape {array.shape} rank {rank}") + for z in range(array.shape[0]): + for y in range(array.shape[1]): + for x in range(array.shape[2]): + print(f"[{z},{y},{x}] {array[z,y,x]}") + + pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100 pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100 @@ -61,6 +71,13 @@ def test_fft(pdims, global_shape): print("*" * 80) print(f"Testing with pdims {pdims} and global shape {global_shape}") + if pdims[0] == 1: + penciltype = SLAB_XY + elif pdims[1] == 1: + penciltype = SLAB_YZ + else: + penciltype = PENCILS + print(f"Decomposition type {penciltype}") global_array, mesh = create_spmd_array(global_shape, pdims) @@ -82,14 +99,21 @@ def test_fft(pdims, global_shape): assert_allclose( gathered_array.imag, gathered_rec_array.imag, rtol=1e-7, atol=1e-7) + print(f"Reconstruction check OK!") + # Check the forward FFT - transpose_back = [1, 2, 0] + if penciltype == SLAB_YZ: + transpose_back = [2, 0, 1] + else: + transpose_back = [1, 2, 0] jax_karray_transposed = jax_karray.transpose(transpose_back) assert_allclose( gathered_karray.real, jax_karray_transposed.real, rtol=1e-7, atol=1e-7) assert_allclose( gathered_karray.imag, jax_karray_transposed.imag, rtol=1e-7, atol=1e-7) + print(f"FFT with transpose check OK!") + # Cartesian product tests @pytest.mark.parametrize("pdims",