Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce multi-node training setup #26

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ graphs
sweeps
test_*.sh
.vscode
*slurm*

### Python ###
# Byte-compiled / optimized / DLL files
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD)

### Added

- Support for distributed training using DDP (DistributedDataParallel) and SLURM on multi-node multi-GPU setups
[/#26](https://github.com/mllam/neural-lam/pull/26)
@sadamov

- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests.
[/#38](https://github.com/mllam/neural-lam/pull/38)
@SimonKamuk
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ A few of the key ones are outlined below:
* `--ar_steps`: Number of time steps to unroll for when making predictions and computing the loss

Checkpoints of trained models are stored in the `saved_models` directory.

The implemented models are:

### Graph-LAM
Expand Down Expand Up @@ -172,6 +173,14 @@ python train_model.py --model hi_lam_parallel --graph hierarchical ...

Checkpoint files for our models trained on the MEPS data are available upon request.

### High Performance Computing

The training script can be run on a cluster with multiple GPU-nodes. Neural LAM is set up to use PyTorch Lightning's `DDP` backend for distributed training.
Currently, only the SLURM (Simple Linux Utility for Resource Management) scheduler is supported.
To run on a cluster, consider the following example script: `docs/examples/submit_slurm_job.sh`.
This script must first be adapted to the specific requirements of the cluster and then submitted with `sbatch`.
If SLURM is not available in the current environment, the script is run locally.

## Evaluate Models
Evaluation is also done using `train_model.py`, but using the `--eval` option.
Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data.
Expand Down
17 changes: 17 additions & 0 deletions docs/examples/submit_slurm_job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash -l
#SBATCH --job-name=Neural-LAM
#SBATCH --time=24:00:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres:gpu=4
#SBATCH --partition=normal
#SBATCH --mem=444G
#SBATCH --no-requeue
#SBATCH --exclusive
#SBATCH --output=lightning_logs/neurallam_out_%j.log
#SBATCH --error=lightning_logs/neurallam_err_%j.log

# Load necessary modules or activate environment, for example:
conda activate neural-lam

srun -ul python train_model.py --val_interval 5 --epochs 20 --n_workers 8 --batch_size 12 --model hi_lam --graph hierarchical
28 changes: 21 additions & 7 deletions train_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
import json
import os
import random
import time
from argparse import ArgumentParser
Expand Down Expand Up @@ -259,13 +260,24 @@ def main(input_args=None):
)

# Instantiate model + trainer
if args.eval:
use_distributed_sampler = False
else:
use_distributed_sampler = True
sadamov marked this conversation as resolved.
Show resolved Hide resolved

devices = 1
num_nodes = 1
if torch.cuda.is_available():
device_name = "cuda"
torch.set_float32_matmul_precision(
"high"
) # Allows using Tensor Cores on A100s
accelerator = "cuda"
if "SLURM_JOB_ID" in os.environ and not args.eval:
devices = int(
os.environ.get("SLURM_GPUS_PER_NODE", torch.cuda.device_count())
)
num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1))
# Allows using Tensor Cores on A100s
torch.set_float32_matmul_precision("high")
else:
device_name = "cpu"
accelerator = "cpu"

# Load model parameters Use new args for model
model_class = MODELS[args.model]
Expand All @@ -291,8 +303,10 @@ def main(input_args=None):
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
accelerator=device_name,
accelerator=accelerator,
devices=devices,
num_nodes=num_nodes,
use_distributed_sampler=use_distributed_sampler,
logger=logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
Expand Down
Loading