diff --git a/README.md b/README.md index 5ffe8b4..0ee4019 100644 --- a/README.md +++ b/README.md @@ -117,18 +117,16 @@ $ pip install --user . #### IDRIS [Jean Zay](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-hw-eng.html) HPE SGI 8600 supercomputer -As of April. 2024, the following works: +As of September. 2024, the following works: You need to load modules in that order exactly. ```bash # Load NVHPC 23.9 because it has cuda 12.2 module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake -# Installing mpi4py -CFLAGS=-noswitcherror pip install mpi4py # Installing jax -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda12]" # Installing jaxdecomp -export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake +export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed pip install . ``` @@ -141,7 +139,7 @@ export CRAY_ACCEL_TARGET=nvidia80 # Installing mpi4py MPICC="cc -target-accel=nvidia80 -shared" CC=nvc CFLAGS="-noswitcherror" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py # Installing jax -pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda12]" # Installing jaxdecomp export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake pip install . diff --git a/tests/test_fft.py b/tests/test_fft.py index d392f53..c53ee32 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -46,6 +46,14 @@ def create_spmd_array(global_shape, pdims): return global_array, mesh +def print_array(array): + print(f"shape {array.shape} rank {rank}") + for z in range(array.shape[0]): + for y in range(array.shape[1]): + for x in range(array.shape[2]): + print(f"[{z},{y},{x}] {array[z,y,x]}") + + pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100 pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100 @@ -122,6 +130,18 @@ def test_fft(pdims, global_shape, local_transpose): # Temporary solution because I need to find a way to retrigger the jit compile if the config changes jax.clear_caches() + # Check the forward FFT + if penciltype == SLAB_YZ: + transpose_back = [2, 0, 1] + else: + transpose_back = [1, 2, 0] + jax_karray_transposed = jax_karray.transpose(transpose_back) + assert_allclose(gathered_ - 7, atol=1e-7) + assert_allclose( + gathered_karray.imag, jax_karray_transposed.imag, rtol=1e-7, atol=1e-7) + + print(f"FFT with transpose check OK!") + # Cartesian product tests @pytest.mark.parametrize(