diff --git a/.gitignore b/.gitignore index 65e9f6f8..af8bed85 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,6 @@ graphs sweeps test_*.sh .vscode -*slurm* ### Python ### # Byte-compiled / optimized / DLL files diff --git a/CHANGELOG.md b/CHANGELOG.md index f4680c37..25eb344c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added + +- Support for distributed training using DDP (DistributedDataParallel) and SLURM on multi-node multi-GPU setups + [/#26](https://github.com/mllam/neural-lam/pull/26) + @sadamov + - Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests. [/#38](https://github.com/mllam/neural-lam/pull/38) @SimonKamuk diff --git a/README.md b/README.md index 1bdc6602..1e703560 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ A few of the key ones are outlined below: * `--ar_steps`: Number of time steps to unroll for when making predictions and computing the loss Checkpoints of trained models are stored in the `saved_models` directory. + The implemented models are: ### Graph-LAM @@ -172,6 +173,14 @@ python train_model.py --model hi_lam_parallel --graph hierarchical ... Checkpoint files for our models trained on the MEPS data are available upon request. +### High Performance Computing + +The training script can be run on a cluster with multiple GPU-nodes. Neural LAM is set up to use PyTorch Lightning's `DDP` backend for distributed training. +Currently, only the SLURM (Simple Linux Utility for Resource Management) scheduler is supported. +To run on a cluster, consider the following example script: `docs/examples/submit_slurm_job.sh`. +This script must first be adapted to the specific requirements of the cluster and then submitted with `sbatch`. +If SLURM is not available in the current environment, the script is run locally. + ## Evaluate Models Evaluation is also done using `train_model.py`, but using the `--eval` option. Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data. diff --git a/docs/examples/submit_slurm_job.sh b/docs/examples/submit_slurm_job.sh new file mode 100644 index 00000000..941ebbc0 --- /dev/null +++ b/docs/examples/submit_slurm_job.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l +#SBATCH --job-name=Neural-LAM +#SBATCH --time=24:00:00 +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres:gpu=4 +#SBATCH --partition=normal +#SBATCH --mem=444G +#SBATCH --no-requeue +#SBATCH --exclusive +#SBATCH --output=lightning_logs/neurallam_out_%j.log +#SBATCH --error=lightning_logs/neurallam_err_%j.log + +# Load necessary modules or activate environment, for example: +conda activate neural-lam + +srun -ul python train_model.py --val_interval 5 --epochs 20 --n_workers 8 --batch_size 12 --model hi_lam --graph hierarchical diff --git a/train_model.py b/train_model.py index 03863275..69b55740 100644 --- a/train_model.py +++ b/train_model.py @@ -1,5 +1,6 @@ # Standard library import json +import os import random import time from argparse import ArgumentParser @@ -259,13 +260,24 @@ def main(input_args=None): ) # Instantiate model + trainer + if args.eval: + use_distributed_sampler = False + else: + use_distributed_sampler = True + + devices = 1 + num_nodes = 1 if torch.cuda.is_available(): - device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + accelerator = "cuda" + if "SLURM_JOB_ID" in os.environ and not args.eval: + devices = int( + os.environ.get("SLURM_GPUS_PER_NODE", torch.cuda.device_count()) + ) + num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1)) + # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") else: - device_name = "cpu" + accelerator = "cpu" # Load model parameters Use new args for model model_class = MODELS[args.model] @@ -291,8 +303,10 @@ def main(input_args=None): trainer = pl.Trainer( max_epochs=args.epochs, deterministic=True, - strategy="ddp", - accelerator=device_name, + accelerator=accelerator, + devices=devices, + num_nodes=num_nodes, + use_distributed_sampler=use_distributed_sampler, logger=logger, log_every_n_steps=1, callbacks=[checkpoint_callback],