diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 1cee07a584..4bbc4978e9 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -717,7 +717,7 @@ def prepare(self): if self.simulated_global_batchsize is not None: if self.simulated_global_batchsize % self.get_global_batchsize() != 0: raise ValueError( - f"Global batch size ({self.get_global_batchsize()}) must divide" + f"Global batch size ({self.get_global_batchsize()}) must divide " f"simulated_global_batchsize ({self.simulated_global_batchsize})" ) else: @@ -726,6 +726,10 @@ def prepare(self): self.optimizer_period = ( self.simulated_global_batchsize // self.get_global_batchsize() ) + if self.optimizer_period > 1: + logging.info( + f"Using gradient accumulation with a period of {self.optimizer_period}" + ) if self.checkpoint_path: self.checkpoint_dict = load_and_broadcast_checkpoint(self.checkpoint_path)