-
Notifications
You must be signed in to change notification settings - Fork 58
/
benchmark_kungfu_torch.py
135 lines (110 loc) · 4.16 KB
/
benchmark_kungfu_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Forked from https://github.com/horovod/horovod/blob/master/examples/pytorch_synthetic_benchmark.py
import argparse
import timeit
import kungfu.torch as kf
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from torchvision import models
# Benchmark settings
parser = argparse.ArgumentParser(
description='PyTorch Synthetic Benchmark',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--fp16-allreduce',
action='store_true',
default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--model',
type=str,
default='resnet50',
help='model to benchmark')
parser.add_argument('--batch-size',
type=int,
default=32,
help='input batch size')
parser.add_argument(
'--num-warmup-batches',
type=int,
default=10,
help='number of warm-up batches that don\'t count towards benchmark')
parser.add_argument('--num-batches-per-iter',
type=int,
default=10,
help='number of batches per benchmark iteration')
parser.add_argument('--num-iters',
type=int,
default=10,
help='number of benchmark iterations')
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--use-adasum',
action='store_true',
default=False,
help='use adasum algorithm to do reduction')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
torch.cuda.set_device(kf.get_cuda_index())
cudnn.benchmark = True
# Set up standard model.
model = getattr(models, args.model)()
# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = kf.current_cluster_size() if not args.use_adasum else 1
if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and kf.nccl_built():
lr_scaler = kf.current_local_size()
optimizer = optim.SGD(model.parameters(), lr=0.01 * lr_scaler)
# KungFu: wrap optimizer with SynchronousSGDOptimizer.
optimizer = kf.optimizers.SynchronousSGDOptimizer(
optimizer,
named_parameters=model.named_parameters(),
# compression=compression,
# op=kf.Adasum if args.use_adasum else kf.Average,
)
# TODO: broadcast parameters & optimizer state.
# kf.broadcast_parameters(model.state_dict(), root_rank=0)
# kf.broadcast_optimizer_state(optimizer, root_rank=0)
# Set up fixed fake data
data = torch.randn(args.batch_size, 3, 224, 224)
target = torch.LongTensor(args.batch_size).random_() % 1000
if args.cuda:
data, target = data.cuda(), target.cuda()
def benchmark_step():
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
def log(s, nl=True):
if kf.current_rank() != 0:
return
print(s, end='\n' if nl else '')
log('Model: %s' % args.model)
log('Batch size: %d' % args.batch_size)
device = 'GPU' if args.cuda else 'CPU'
log('Number of %ss: %d' % (device, kf.current_cluster_size()))
# Warm-up
log('Running warmup...')
timeit.timeit(benchmark_step, number=args.num_warmup_batches)
# Benchmark
log('Running benchmark...')
img_secs = []
for x in range(args.num_iters):
time = timeit.timeit(benchmark_step, number=args.num_batches_per_iter)
img_sec = args.batch_size * args.num_batches_per_iter / time
log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device))
img_secs.append(img_sec)
# Results
img_sec_mean = np.mean(img_secs)
img_sec_conf = 1.96 * np.std(img_secs)
log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf))
log('Total img/sec on %d %s(s): %.1f +-%.1f' %
(kf.current_cluster_size(), device, kf.current_cluster_size() *
img_sec_mean, kf.current_cluster_size() * img_sec_conf))