1414from net import Net
1515
1616# Training settings
17- parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
18- parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
19- help = 'input batch size for training (default: 64)' )
20- parser .add_argument ('--test-batch-size' , type = int , default = 1000 , metavar = 'N' ,
21- help = 'input batch size for testing (default: 1000)' )
22- parser .add_argument ('--epochs' , type = int , default = 14 , metavar = 'N' ,
23- help = 'number of epochs to train (default: 14)' )
24- parser .add_argument ('--lr' , type = float , default = 0.01 , metavar = 'LR' ,
25- help = 'learning rate (default: 0.01)' )
26- parser .add_argument ('--momentum' , type = float , default = 0.5 , metavar = 'M' ,
27- help = 'SGD momentum (default: 0.5)' )
28- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
29- help = 'disables CUDA training' )
30- parser .add_argument ('--seed' , type = int , default = 42 , metavar = 'S' ,
31- help = 'random seed (default: 42)' )
32- parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
33- help = 'how many batches to wait before logging training status' )
34- parser .add_argument ('--fp16-allreduce' , action = 'store_true' , default = False ,
35- help = 'use fp16 compression during allreduce' )
36- parser .add_argument ('--use-adasum' , action = 'store_true' , default = False ,
37- help = 'use adasum algorithm to do reduction' )
17+ parser = argparse .ArgumentParser (description = "PyTorch MNIST Example" )
18+ parser .add_argument (
19+ "--batch-size" ,
20+ type = int ,
21+ default = 64 ,
22+ metavar = "N" ,
23+ help = "input batch size for training (default: 64)" ,
24+ )
25+ parser .add_argument (
26+ "--test-batch-size" ,
27+ type = int ,
28+ default = 1000 ,
29+ metavar = "N" ,
30+ help = "input batch size for testing (default: 1000)" ,
31+ )
32+ parser .add_argument (
33+ "--epochs" ,
34+ type = int ,
35+ default = 14 ,
36+ metavar = "N" ,
37+ help = "number of epochs to train (default: 14)" ,
38+ )
39+ parser .add_argument (
40+ "--lr" , type = float , default = 0.01 , metavar = "LR" , help = "learning rate (default: 0.01)"
41+ )
42+ parser .add_argument (
43+ "--momentum" ,
44+ type = float ,
45+ default = 0.5 ,
46+ metavar = "M" ,
47+ help = "SGD momentum (default: 0.5)" ,
48+ )
49+ parser .add_argument (
50+ "--no-cuda" , action = "store_true" , default = False , help = "disables CUDA training"
51+ )
52+ parser .add_argument (
53+ "--seed" , type = int , default = 42 , metavar = "S" , help = "random seed (default: 42)"
54+ )
55+ parser .add_argument (
56+ "--log-interval" ,
57+ type = int ,
58+ default = 10 ,
59+ metavar = "N" ,
60+ help = "how many batches to wait before logging training status" ,
61+ )
62+ parser .add_argument (
63+ "--fp16-allreduce" ,
64+ action = "store_true" ,
65+ default = False ,
66+ help = "use fp16 compression during allreduce" ,
67+ )
68+ parser .add_argument (
69+ "--use-adasum" ,
70+ action = "store_true" ,
71+ default = False ,
72+ help = "use adasum algorithm to do reduction" ,
73+ )
3874
3975
4076def train (model , train_sampler , train_loader , args , optimizer , epoch ):
@@ -53,9 +89,15 @@ def train(model, train_sampler, train_loader, args, optimizer, epoch):
5389 if batch_idx % args .log_interval == 0 :
5490 # Horovod: use train_sampler to determine the number of examples in
5591 # this worker's partition.
56- print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
57- epoch , hvd .size () * batch_idx * len (data ), len (train_loader .dataset ),
58- 100. * batch_idx / len (train_loader ), loss .item ()))
92+ print (
93+ "Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}" .format (
94+ epoch ,
95+ hvd .size () * batch_idx * len (data ),
96+ len (train_loader .dataset ),
97+ 100.0 * batch_idx / len (train_loader ),
98+ loss .item (),
99+ )
100+ )
59101
60102
61103def metric_average (val , name ):
@@ -67,14 +109,14 @@ def metric_average(val, name):
67109
68110def test (model , test_sampler , test_loader , args ):
69111 model .eval ()
70- test_loss = 0.
71- test_accuracy = 0.
112+ test_loss = 0.0
113+ test_accuracy = 0.0
72114 for data , target in test_loader :
73115 if args .cuda :
74116 data , target = data .cuda (), target .cuda ()
75117 output = model (data )
76118 # sum up batch loss
77- test_loss += F .nll_loss (output , target , reduction = ' sum' ).item ()
119+ test_loss += F .nll_loss (output , target , reduction = " sum" ).item ()
78120 # get the index of the max log-probability
79121 pred = output .data .max (1 , keepdim = True )[1 ]
80122 test_accuracy += pred .eq (target .data .view_as (pred )).cpu ().float ().sum ()
@@ -85,13 +127,16 @@ def test(model, test_sampler, test_loader, args):
85127 test_accuracy /= len (test_sampler )
86128
87129 # Horovod: average metric values across workers.
88- test_loss = metric_average (test_loss , ' avg_loss' )
89- test_accuracy = metric_average (test_accuracy , ' avg_accuracy' )
130+ test_loss = metric_average (test_loss , " avg_loss" )
131+ test_accuracy = metric_average (test_accuracy , " avg_accuracy" )
90132
91133 # Horovod: print output only on first rank.
92134 if hvd .rank () == 0 :
93- print ('\n Test set: Average loss: {:.4f}, Accuracy: {:.2f}%\n ' .format (
94- test_loss , 100. * test_accuracy ))
135+ print (
136+ "\n Test set: Average loss: {:.4f}, Accuracy: {:.2f}%\n " .format (
137+ test_loss , 100.0 * test_accuracy
138+ )
139+ )
95140
96141
97142def main ():
@@ -107,44 +152,54 @@ def main():
107152 torch .cuda .set_device (hvd .local_rank ())
108153 torch .cuda .manual_seed (args .seed )
109154
110-
111155 # Horovod: limit # of CPU threads to be used per worker.
112156 torch .set_num_threads (1 )
113157
114- kwargs = {' num_workers' : 1 , ' pin_memory' : True } if args .cuda else {}
158+ kwargs = {" num_workers" : 1 , " pin_memory" : True } if args .cuda else {}
115159 # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
116160 # issues with Infiniband implementations that are not fork-safe
117- if (kwargs .get ('num_workers' , 0 ) > 0 and hasattr (mp , '_supports_context' ) and
118- mp ._supports_context and 'forkserver' in mp .get_all_start_methods ()):
119- kwargs ['multiprocessing_context' ] = 'forkserver'
120-
121- transform = transforms .Compose ([
122- transforms .ToTensor (),
123- transforms .Normalize ((0.1307 ,), (0.3081 ,))
124- ])
161+ if (
162+ kwargs .get ("num_workers" , 0 ) > 0
163+ and hasattr (mp , "_supports_context" )
164+ and mp ._supports_context
165+ and "forkserver" in mp .get_all_start_methods ()
166+ ):
167+ kwargs ["multiprocessing_context" ] = "forkserver"
168+
169+ transform = transforms .Compose (
170+ [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
171+ )
125172
126173 if hvd .rank () != 0 :
127174 # might be downloading mnist data, let rank 0 download first
128175 hvd .barrier ()
129176
130177 # train_dataset = datasets.MNIST('data-%d' % hvd.rank(), train=True, download=True, transform=transform)
131- train_dataset = datasets .MNIST ('./data' , train = True , download = True , transform = transform )
178+ train_dataset = datasets .MNIST (
179+ "./data" , train = True , download = True , transform = transform
180+ )
132181
133182 if hvd .rank () == 0 :
134183 # mnist data is downloaded, indicate other ranks can proceed
135184 hvd .barrier ()
136185
137186 # Horovod: use DistributedSampler to partition the training data.
138- train_sampler = dist .DistributedSampler (train_dataset , num_replicas = hvd .size (), rank = hvd .rank ())
187+ train_sampler = dist .DistributedSampler (
188+ train_dataset , num_replicas = hvd .size (), rank = hvd .rank ()
189+ )
139190 train_loader = torch .utils .data .DataLoader (
140- train_dataset , batch_size = args .batch_size , sampler = train_sampler , ** kwargs )
191+ train_dataset , batch_size = args .batch_size , sampler = train_sampler , ** kwargs
192+ )
141193
142194 # test_dataset = datasets.MNIST('data-%d' % hvd.rank(), train=False, transform=transform)
143- test_dataset = datasets .MNIST (' ./data' , train = False , transform = transform )
195+ test_dataset = datasets .MNIST (" ./data" , train = False , transform = transform )
144196 # Horovod: use DistributedSampler to partition the test data.
145- test_sampler = dist .DistributedSampler (test_dataset , num_replicas = hvd .size (), rank = hvd .rank ())
146- test_loader = torch .utils .data .DataLoader (test_dataset , batch_size = args .test_batch_size ,
147- sampler = test_sampler , ** kwargs )
197+ test_sampler = dist .DistributedSampler (
198+ test_dataset , num_replicas = hvd .size (), rank = hvd .rank ()
199+ )
200+ test_loader = torch .utils .data .DataLoader (
201+ test_dataset , batch_size = args .test_batch_size , sampler = test_sampler , ** kwargs
202+ )
148203
149204 model = Net ()
150205
@@ -159,8 +214,9 @@ def main():
159214 lr_scaler = hvd .local_size ()
160215
161216 # Horovod: scale learning rate by lr_scaler.
162- optimizer = optim .SGD (model .parameters (), lr = args .lr * lr_scaler ,
163- momentum = args .momentum )
217+ optimizer = optim .SGD (
218+ model .parameters (), lr = args .lr * lr_scaler , momentum = args .momentum
219+ )
164220
165221 # Horovod: broadcast parameters & optimizer state.
166222 hvd .broadcast_parameters (model .state_dict (), root_rank = 0 )
@@ -170,12 +226,14 @@ def main():
170226 compression = hvd .Compression .fp16 if args .fp16_allreduce else hvd .Compression .none
171227
172228 # Horovod: wrap optimizer with DistributedOptimizer.
173- optimizer = hvd .DistributedOptimizer (optimizer ,
174- named_parameters = model .named_parameters (),
175- compression = compression ,
176- op = hvd .Adasum if args .use_adasum else hvd .Average )
229+ optimizer = hvd .DistributedOptimizer (
230+ optimizer ,
231+ named_parameters = model .named_parameters (),
232+ compression = compression ,
233+ op = hvd .Adasum if args .use_adasum else hvd .Average ,
234+ )
177235
178- total_time = 0.
236+ total_time = 0.0
179237
180238 for epoch in range (1 , args .epochs + 1 ):
181239 start = time .time ()
@@ -186,6 +244,6 @@ def main():
186244 return hvd .rank (), total_time
187245
188246
189- if __name__ == ' __main__' :
247+ if __name__ == " __main__" :
190248 rk , tt = main ()
191- print (f' [{ rk } ] Total time elapsed: { tt } seconds' )
249+ print (f" [{ rk } ] Total time elapsed: { tt } seconds" )
0 commit comments