Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Oct 6, 2024
2 parents d921300 + 3c9797a commit bfec8fb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,16 @@ $ pip install --user .

#### IDRIS [Jean Zay](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-hw-eng.html) HPE SGI 8600 supercomputer

As of April. 2024, the following works:
As of September. 2024, the following works:

You need to load modules in that order exactly.
```bash
# Load NVHPC 23.9 because it has cuda 12.2
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Installing mpi4py
CFLAGS=-noswitcherror pip install mpi4py
# Installing jax
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed
pip install .
```

Expand All @@ -141,7 +139,7 @@ export CRAY_ACCEL_TARGET=nvidia80
# Installing mpi4py
MPICC="cc -target-accel=nvidia80 -shared" CC=nvc CFLAGS="-noswitcherror" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
# Installing jax
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install .
Expand Down
20 changes: 20 additions & 0 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def create_spmd_array(global_shape, pdims):
return global_array, mesh


def print_array(array):
print(f"shape {array.shape} rank {rank}")
for z in range(array.shape[0]):
for y in range(array.shape[1]):
for x in range(array.shape[2]):
print(f"[{z},{y},{x}] {array[z,y,x]}")


pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100

Expand Down Expand Up @@ -122,6 +130,18 @@ def test_fft(pdims, global_shape, local_transpose):
# Temporary solution because I need to find a way to retrigger the jit compile if the config changes
jax.clear_caches()

# Check the forward FFT
if penciltype == SLAB_YZ:
transpose_back = [2, 0, 1]
else:
transpose_back = [1, 2, 0]
jax_karray_transposed = jax_karray.transpose(transpose_back)
assert_allclose(gathered_ - 7, atol=1e-7)
assert_allclose(
gathered_karray.imag, jax_karray_transposed.imag, rtol=1e-7, atol=1e-7)

print(f"FFT with transpose check OK!")


# Cartesian product tests
@pytest.mark.parametrize(
Expand Down

0 comments on commit bfec8fb

Please sign in to comment.