Skip to content

Commit

Permalink
Merge branch 'fix_mcore_imports' into 'main'
Browse files Browse the repository at this point in the history
Bugfix: make sure MCore doesn't have MLM imports

See merge request ADLR/megatron-lm!1206
  • Loading branch information
ericharper committed Mar 13, 2024
2 parents 58f13de + 971f9ae commit f0f8150
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
15 changes: 11 additions & 4 deletions megatron/core/deploy/gpt/state_dict_hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from megatron import print_rank_0
from logging import getLogger

import torch

logger = getLogger(__name__)


def mcore_gpt_load_classic_state_dict_pre_hook(
Expand Down Expand Up @@ -46,7 +50,8 @@ def mcore_gpt_load_classic_state_dict_pre_hook(
for key, param in language_model_state_dict["output_layer"].items():
state_dict.update({"output_layer." + key: param})

print_rank_0("ModelOptGPTModel {}".format(state_dict.keys()))
if torch.distributed.get_rank() == 0:
logger.info("ModelOptGPTModel {}".format(state_dict.keys()))

module_name_rewrite_list = [
("input_norm", "input_layernorm"),
Expand All @@ -69,7 +74,8 @@ def mcore_gpt_load_classic_state_dict_pre_hook(
key_rewrite_list += [(key, key.replace(old_name, new_name))]

for old_key, new_key in key_rewrite_list:
print_rank_0("replace {} with {}".format(old_key, new_key))
if torch.distributed.get_rank() == 0:
logger.info("replace {} with {}".format(old_key, new_key))
state_dict[new_key] = state_dict[old_key]
state_dict.pop(old_key)

Expand Down Expand Up @@ -121,6 +127,7 @@ def mcore_gpt_load_te_state_dict_pre_hook(
key_rewrite_list += [(key, key.replace(old_name, new_name))]

for old_key, new_key in key_rewrite_list:
print_rank_0("replace {} with {}".format(old_key, new_key))
if torch.distributed.get_rank() == 0:
logger.info("replace {} with {}".format(old_key, new_key))
state_dict[new_key] = state_dict[old_key]
state_dict.pop(old_key)
3 changes: 0 additions & 3 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
import torch
from apex.multi_tensor_apply import multi_tensor_applier

from megatron.core import tensor_parallel
from megatron.model.module import param_is_not_shared

from .. import parallel_state, tensor_parallel
from ..dist_checkpointing.mapping import ShardedStateDict
from ..dist_checkpointing.optimizer import (
Expand Down

0 comments on commit f0f8150

Please sign in to comment.