diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 73263896af82..431c7ab84bb7 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -390,17 +390,16 @@ def save_checkpoint( """ # check if using distributed checkpointing if self.use_distributed_checkpointing: - assert ( - len(checkpoint['optimizer_states']) == 1 - ), "Currently only support checkpointing 1 distributed optimizer per time!" - # converts the optimizer states to their sharded equivalents - sharded_optim_state = self.optimizer_sharded_state_dict( - unsharded_optim_state=checkpoint['optimizer_states'][0] - ) - # Check whether to save optim states include_optimizer = True if not storage_options else storage_options.get('include_optimizer', True) if include_optimizer: + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" + # converts the optimizer states to their sharded equivalents + sharded_optim_state = self.optimizer_sharded_state_dict( + unsharded_optim_state=checkpoint['optimizer_states'][0] + ) checkpoint['optimizer_states'] = [sharded_optim_state] else: checkpoint['optimizer_states'] = None