Skip to content

Commit

Permalink
get optimizer state dicts only if needed (#11451)
Browse files Browse the repository at this point in the history
Signed-off-by: Ananth Subramaniam <[email protected]>
  • Loading branch information
ananthsub authored Dec 3, 2024
1 parent 9abd81b commit 56157ee
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,16 @@ def save_checkpoint(
"""
# check if using distributed checkpointing
if self.use_distributed_checkpointing:
assert (
len(checkpoint['optimizer_states']) == 1
), "Currently only support checkpointing 1 distributed optimizer per time!"
# converts the optimizer states to their sharded equivalents
sharded_optim_state = self.optimizer_sharded_state_dict(
unsharded_optim_state=checkpoint['optimizer_states'][0]
)

# Check whether to save optim states
include_optimizer = True if not storage_options else storage_options.get('include_optimizer', True)
if include_optimizer:
assert (
len(checkpoint['optimizer_states']) == 1
), "Currently only support checkpointing 1 distributed optimizer per time!"
# converts the optimizer states to their sharded equivalents
sharded_optim_state = self.optimizer_sharded_state_dict(
unsharded_optim_state=checkpoint['optimizer_states'][0]
)
checkpoint['optimizer_states'] = [sharded_optim_state]
else:
checkpoint['optimizer_states'] = None
Expand Down

0 comments on commit 56157ee

Please sign in to comment.