Skip to content

Commit

Permalink
Merge branch 'fix_checkpointing_with_te' into 'main'
Browse files Browse the repository at this point in the history
Fix TE checkpoint loading

See merge request ADLR/megatron-lm!1220
  • Loading branch information
jaredcasper committed Mar 20, 2024
2 parents 1106d80 + d143339 commit e4dcc71
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0('could not find arguments in the checkpoint ...')

# Model.
strict = False if args.retro_add_retriever or args.transformer_impl == 'transformer_engine' else strict
strict = False if args.retro_add_retriever else strict
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
Expand Down
4 changes: 4 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,10 @@ def load_state_dict(self, state_dict, strict=True):
# Handle renaming layernorm -> norm in component names
state_dict_ = {}
for key in state_dict.keys():
# Bypass TransformerEngine module parameters.
if "layernorm_qkv" in key or "layernorm_mlp" in key:
state_dict_[key] = state_dict[key]
continue
newkey = key.replace("layernorm", "norm")
state_dict_[newkey] = state_dict[key]

Expand Down

0 comments on commit e4dcc71

Please sign in to comment.