This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using JAXPM, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes.
-
Single-GPU Particle Mesh Simulation
- Introduction to basic PM simulations on a single GPU.
- Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting.
-
Advanced Particle Mesh Simulation on a Single GPU
- Explore using diffrax solvers in the ODE step.
- Explores second order Lagrangian Perturbation Theory (LPT) simulations.
- Introduces weighted density field projections
-
Multi-GPU Particle Mesh Simulation with Halo Exchange
- Extends PM simulation to multi-GPU setups with halo exchange.
- Uses sharding and device mesh configurations to manage distributed data across GPUs.
-
Multi-GPU Particle Mesh Simulation with Advanced Solvers
- Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations.
- Highlights performance, memory considerations, and solver impact on simulation quality.
-
Multi-Host Particle Mesh Simulation
- Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations.
- Guides through job submission, device initialization, and retrieving results across nodes.
Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment.
- JAXPM (included in the installation commands within notebooks)
- Diffrax for ODE solvers
- JAX with CUDA support for multi-GPU or TPU setups
- SLURM for job scheduling on clusters (if running multi-host setups)
Note: These notebooks are tested on the Jean Zay supercomputer and may require configuration changes for different HPC clusters.