-
Notifications
You must be signed in to change notification settings - Fork 55
/
mnmc_ddp_slurm.py
161 lines (137 loc) · 4.97 KB
/
mnmc_ddp_slurm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
(MNMC) Multiple Nodes Multi-GPU Cards Training
with DistributedDataParallel and Slurm
Try to compare with [mnmc_ddp_launch.py & mnmc_ddp_mp.py] and find out the differences.
"""
import os
import subprocess
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.nn.parallel import DistributedDataParallel as DDP
BATCH_SIZE = 256
EPOCHS = 5
def setup_distributed(backend="nccl", port=None):
"""Initialize slurm distributed training environment. (from mmcv)"""
proc_id = int(os.environ["SLURM_PROCID"])
ntasks = int(os.environ["SLURM_NTASKS"])
node_list = os.environ["SLURM_NODELIST"]
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
# specify master port
if port is not None:
os.environ["MASTER_PORT"] = str(port)
elif "MASTER_PORT" in os.environ:
pass # use MASTER_PORT in the environment variable
else:
os.environ["MASTER_PORT"] = "29500"
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = addr
os.environ["WORLD_SIZE"] = str(ntasks)
os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
os.environ["RANK"] = str(proc_id)
dist.init_process_group(backend=backend)
if __name__ == "__main__":
# 0. set up distributed device
setup_distributed()
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")
# 1. define network
net = torchvision.models.resnet18(pretrained=False, num_classes=10)
net = net.cuda(local_rank)
# DistributedDataParallel
net = DDP(net, device_ids=[local_rank], output_device=local_rank)
# 2. define dataloader
trainset = torchvision.datasets.CIFAR10(
root="./data",
train=True,
download=False,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
),
)
# DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(
trainset,
shuffle=True,
)
train_loader = torch.utils.data.DataLoader(
trainset,
batch_size=BATCH_SIZE,
num_workers=4,
pin_memory=True,
sampler=train_sampler,
)
# 3. define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
net.parameters(),
lr=0.01 * dist.get_world_size(),
momentum=0.9,
weight_decay=0.0001,
nesterov=True,
)
if rank == 0:
print(" ======= Training ======= \n")
# 4. start to train
net.train()
for ep in range(1, EPOCHS + 1):
train_loss = correct = total = 0
# set sampler
train_loader.sampler.set_epoch(ep)
for idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.cuda(local_rank), targets.to(local_rank)
outputs = net(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
total += targets.size(0)
correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()
if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == len(train_loader)):
print(
" == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(
idx + 1,
len(train_loader),
ep,
EPOCHS,
train_loss / (idx + 1),
100.0 * correct / total,
)
)
if rank == 0:
print("\n ======= Training Finished ======= \n")
"""
usage:
>>> srun --help
example:
>>> srun --partition=openai -n8 --gres=gpu:8 --ntasks-per-node=8 --job-name=slrum_test \
python -u mnmc_ddp_slurm.py
======= Training =======
[init] == local rank: 1, global rank: 1 ==
[init] == local rank: 7, global rank: 7 ==
[init] == local rank: 4, global rank: 4 ==
[init] == local rank: 2, global rank: 2 ==
[init] == local rank: 0, global rank: 0 ==
[init] == local rank: 5, global rank: 5 ==
[init] == local rank: 6, global rank: 6 ==
[init] == local rank: 3, global rank: 3 ==
== step: [ 25/25] [0/5] | loss: 1.934 | acc: 29.152%
== step: [ 25/25] [1/5] | loss: 1.546 | acc: 42.976%
== step: [ 25/25] [2/5] | loss: 1.418 | acc: 48.064%
== step: [ 25/25] [3/5] | loss: 1.322 | acc: 51.728%
== step: [ 25/25] [4/5] | loss: 1.219 | acc: 55.920%
======= Training Finished =======
"""