@@ -147,9 +147,17 @@ def run(self):
147147 for epoch in range (self .start_epoch , self .max_epoch + 1 ):
148148
149149 self ._train_epoch (epoch )
150+
150151
152+ if self .use_ddp or self .use_fsdp :
153+ dist .barrier ()
154+
151155 self ._validate_epoch (epoch )
152-
156+
157+ if self .use_ddp or self .use_fsdp :
158+ dist .barrier ()
159+
160+
153161 if self .rank == 0 :
154162 self ._save_checkpoint (epoch )
155163
@@ -164,7 +172,9 @@ def run(self):
164172
165173 if self .use_ddp or self .use_fsdp :
166174 dist .barrier ()
167- self .writer .close ()
175+
176+ if self .writer :
177+ self .writer .close ()
168178
169179
170180 def _train_epoch (self , epoch ):
@@ -230,6 +240,8 @@ def _train_epoch(self, epoch):
230240 continue
231241
232242 # Execute an optimization step (update model parameters)
243+ if self .use_ddp or self .use_fsdp :
244+ dist .barrier ()
233245 self .optim .step ()
234246 self .scheduler .step ()
235247 # Clear gradients for the next accumulation stage
@@ -244,7 +256,7 @@ def _train_epoch(self, epoch):
244256 pbar .update (1 )
245257 if self .local_rank == 0 :
246258 description = (
247- f"Epoch : { epoch } /{ self .max_epoch } , "
259+ f"Train epoch : { epoch } /{ self .max_epoch } , "
248260 f"step { batch_idx } /{ len (self .dataloader_train )} , "
249261 f"{ speed_stats } , "
250262 f"(loss: { loss .detach ().cpu ().item ():.3f} ), "
@@ -306,7 +318,7 @@ def _validate_epoch(self, epoch):
306318 pbar .update (1 )
307319 if self .local_rank == 0 :
308320 description = (
309- f"validation: \n Epoch : { epoch } /{ self .max_epoch } , "
321+ f"validation epoch : { epoch } /{ self .max_epoch } , "
310322 f"step { batch_idx } /{ len (self .dataloader_train )} , "
311323 f"{ speed_stats } , "
312324 f"(loss: { loss .detach ().cpu ().item ():.3f} ), "
0 commit comments