From 56157ee9e1aa990f5a9fe4f38a2c4351dc402999 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Tue, 3 Dec 2024 10:13:42 -0800 Subject: [PATCH] get optimizer state dicts only if needed (#11451) Signed-off-by: Ananth Subramaniam --- nemo/collections/nlp/parts/nlp_overrides.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 73263896af82..431c7ab84bb7 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -390,17 +390,16 @@ def save_checkpoint( """ # check if using distributed checkpointing if self.use_distributed_checkpointing: - assert ( - len(checkpoint['optimizer_states']) == 1 - ), "Currently only support checkpointing 1 distributed optimizer per time!" - # converts the optimizer states to their sharded equivalents - sharded_optim_state = self.optimizer_sharded_state_dict( - unsharded_optim_state=checkpoint['optimizer_states'][0] - ) - # Check whether to save optim states include_optimizer = True if not storage_options else storage_options.get('include_optimizer', True) if include_optimizer: + assert ( + len(checkpoint['optimizer_states']) == 1 + ), "Currently only support checkpointing 1 distributed optimizer per time!" + # converts the optimizer states to their sharded equivalents + sharded_optim_state = self.optimizer_sharded_state_dict( + unsharded_optim_state=checkpoint['optimizer_states'][0] + ) checkpoint['optimizer_states'] = [sharded_optim_state] else: checkpoint['optimizer_states'] = None