From d1433397b7a694ea737bafb20736f355b19e53ea Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Mar 2024 14:13:08 -0700 Subject: [PATCH] Bypass TE layernorm* params during renaming of state_dict keys Signed-off-by: Kirthi Shankar Sivamani --- megatron/checkpointing.py | 2 +- megatron/model/transformer.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index e9417c4799..0929357e68 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -669,7 +669,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri print_rank_0('could not find arguments in the checkpoint ...') # Model. - strict = False if args.retro_add_retriever or args.transformer_impl == 'transformer_engine' else strict + strict = False if args.retro_add_retriever else strict if len(model) == 1: model[0].load_state_dict(state_dict['model'], strict=strict) else: diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c90307f0ce..9c9ac389a1 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1802,6 +1802,10 @@ def load_state_dict(self, state_dict, strict=True): # Handle renaming layernorm -> norm in component names state_dict_ = {} for key in state_dict.keys(): + # Bypass TransformerEngine module parameters. + if "layernorm_qkv" in key or "layernorm_mlp" in key: + state_dict_[key] = state_dict[key] + continue newkey = key.replace("layernorm", "norm") state_dict_[newkey] = state_dict[key]