-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #212 from aws-samples/cpu_ddp
add DDP CPU example
- Loading branch information
Showing
3 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/bin/bash | ||
#SBATCH --job-name=cpu-ddp | ||
#SBATCH --exclusive | ||
#SBATCH --wait-all-nodes=1 | ||
#SBATCH --nodes 2 | ||
#SBATCH --cpus-per-task=4 | ||
#SBATCH --output=logs/%x_%j.out # logfile for stdout/stderr | ||
|
||
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) | ||
nodes_array=($nodes) | ||
head_node=${nodes_array[0]} | ||
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) | ||
|
||
echo Node IP: $head_node_ip | ||
export LOGLEVEL=INFO | ||
|
||
srun /opt/conda/envs/pytorch/bin/torchrun \ | ||
--nnodes 2 \ | ||
--nproc_per_node 4 \ | ||
--rdzv_id $RANDOM \ | ||
--rdzv_backend c10d \ | ||
--rdzv_endpoint $head_node_ip:29500 \ | ||
ddp.py 50 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# PyTorch DDP on CPU <!-- omit in toc --> | ||
|
||
This test case is intended to provide a simple distributed training example on CPU using [PyTorch DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html). | ||
|
||
## 1. Preparation | ||
|
||
This guide assumes that you have the following: | ||
|
||
* A functional Slurm cluster on AWS, whose compute instances are based on DeepLearning AMI. | ||
* An FSx for Lustre filesystem mounted on `/fsx`. | ||
|
||
We recommend that you setup a Slurm cluster using the templates in the architectures [directory](../../1.architectures). | ||
|
||
|
||
## 2. Submit training job | ||
|
||
Submit DDP training job with: | ||
|
||
```bash | ||
sbatch 1.train.sbatch | ||
``` | ||
|
||
Output of the training job can be found in `logs` directory: | ||
|
||
```bash | ||
# cat logs/cpu-ddp_xxx.out | ||
Node IP: 10.1.96.108 | ||
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified. | ||
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] | ||
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] ***************************************** | ||
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | ||
[2024-03-12 08:22:45,549] torch.distributed.run: [WARNING] ***************************************** | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] Starting elastic_operator with launch configs: | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] entrypoint : ddp.py | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] min_nodes : 2 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] max_nodes : 2 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] nproc_per_node : 4 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] run_id : 5982 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] rdzv_backend : c10d | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] rdzv_endpoint : 10.1.96.108:29500 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] rdzv_configs : {'timeout': 900} | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] max_restarts : 0 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] monitor_interval : 5 | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] log_dir : None | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] metrics_cfg : {} | ||
[2024-03-12 08:22:45,549] torch.distributed.launcher.api: [INFO] | ||
[2024-03-12 08:22:45,552] torch.distributed.elastic.agent.server.local_elastic_agent: [INFO] log directory set to: /tmp/torchelastic_9g50nxjq/5982_tflt1tcd | ||
[2024-03-12 08:22:45,552] torch.distributed.elastic.agent.server.api: [INFO] [default] starting workers for entrypoint: python | ||
... | ||
[RANK 3] Epoch 49 | Batchsize: 32 | Steps: 8 | ||
[RANK 5] Epoch 49 | Batchsize: 32 | Steps: 8 | ||
[RANK 4] Epoch 49 | Batchsize: 32 | Steps: 8 | ||
[2024-03-12 08:22:56,574] torch.distributed.elastic.agent.server.api: [INFO] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish. | ||
[2024-03-12 08:22:56,574] torch.distributed.elastic.agent.server.api: [INFO] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish | ||
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish. | ||
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish | ||
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] Done waiting for other agents. Elapsed: 0.0010929107666015625 seconds | ||
[2024-03-12 08:22:56,575] torch.distributed.elastic.agent.server.api: [INFO] Done waiting for other agents. Elapsed: 0.0005395412445068359 seconds | ||
``` | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Modified version of https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
import torch.multiprocessing as mp | ||
from torch.utils.data.distributed import DistributedSampler | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.distributed import init_process_group, destroy_process_group | ||
import os | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
class MyTrainDataset(Dataset): | ||
def __init__(self, size): | ||
self.size = size | ||
self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)] | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def __getitem__(self, index): | ||
return self.data[index] | ||
|
||
def ddp_setup(): | ||
init_process_group(backend="gloo") | ||
|
||
class Trainer: | ||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
train_data: DataLoader, | ||
optimizer: torch.optim.Optimizer, | ||
save_every: int, | ||
snapshot_path: str, | ||
) -> None: | ||
self.model = model | ||
self.rank = os.environ["RANK"] | ||
self.train_data = train_data | ||
self.optimizer = optimizer | ||
self.save_every = save_every | ||
self.epochs_run = 0 | ||
self.snapshot_path = snapshot_path | ||
if os.path.exists(snapshot_path): | ||
print("Loading snapshot") | ||
self._load_snapshot(snapshot_path) | ||
|
||
self.model = DDP(self.model) | ||
|
||
def _load_snapshot(self, snapshot_path): | ||
snapshot = torch.load(snapshot_path) | ||
self.model.load_state_dict(snapshot["MODEL_STATE"]) | ||
self.epochs_run = snapshot["EPOCHS_RUN"] | ||
print(f"Resuming training from snapshot at Epoch {self.epochs_run}") | ||
|
||
def _run_batch(self, source, targets): | ||
self.optimizer.zero_grad() | ||
output = self.model(source) | ||
loss = F.cross_entropy(output, targets) | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
def _run_epoch(self, epoch): | ||
b_sz = len(next(iter(self.train_data))[0]) | ||
print(f"[RANK {self.rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}") | ||
self.train_data.sampler.set_epoch(epoch) | ||
for source, targets in self.train_data: | ||
source = source | ||
targets = targets | ||
self._run_batch(source, targets) | ||
|
||
def _save_snapshot(self, epoch): | ||
snapshot = { | ||
"MODEL_STATE": self.model.module.state_dict(), | ||
"EPOCHS_RUN": epoch, | ||
} | ||
torch.save(snapshot, self.snapshot_path) | ||
print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}") | ||
|
||
def train(self, max_epochs: int): | ||
for epoch in range(self.epochs_run, max_epochs): | ||
self._run_epoch(epoch) | ||
if epoch % self.save_every == 0: | ||
self._save_snapshot(epoch) | ||
|
||
|
||
def load_train_objs(): | ||
train_set = MyTrainDataset(2048) # load your dataset | ||
model = torch.nn.Linear(20, 1) # load your model | ||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | ||
return train_set, model, optimizer | ||
|
||
|
||
def prepare_dataloader(dataset: Dataset, batch_size: int): | ||
return DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
pin_memory=True, | ||
shuffle=False, | ||
sampler=DistributedSampler(dataset) | ||
) | ||
|
||
|
||
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): | ||
ddp_setup() | ||
dataset, model, optimizer = load_train_objs() | ||
train_data = prepare_dataloader(dataset, batch_size) | ||
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) | ||
trainer.train(total_epochs) | ||
destroy_process_group() | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
parser = argparse.ArgumentParser(description='simple distributed training job') | ||
parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') | ||
parser.add_argument('save_every', type=int, help='How often to save a snapshot') | ||
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') | ||
args = parser.parse_args() | ||
|
||
main(args.save_every, args.total_epochs, args.batch_size) |