Skip to content

Commit

Permalink
enable weight parallel for huggingface_model internlm2 (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai authored Aug 15, 2024
1 parent 1de0f6e commit 9e05051
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
13 changes: 12 additions & 1 deletion huggingface_model/dispatch_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from collections import abc
from typing import Any, Optional, Type, Union

from internlm.core.context.parallel_context import global_context as gpc


# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L8
class LazyObject:
"""LazyObject is used to lazily initialize the imported module during
Expand Down Expand Up @@ -220,6 +223,11 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None
)


RESET_PARAM_FUNC_MAPPING = dict(
internlm2_7b=LazyObject("huggingface_model.internlm.internlm2_7b", "reset_parameters"),
)


def replace_embed(model):
def traverse(module):
for name, child in module.named_children():
Expand Down Expand Up @@ -282,6 +290,9 @@ def hf_model_dispatch(model):
replace_embed(model)
replace_norm(model)
replace_linear(model)

reset_parameters = RESET_PARAM_FUNC_MAPPING.get(gpc.config.HF_MODEL_NAME.split("/")[1].replace("-", "_"), None)
assert reset_parameters is not None, "reset_parameters need to be implemented."
reset_parameters = reset_parameters.build()
reset_parameters(model)

__all__ = ["hf_model_dispatch"]
34 changes: 34 additions & 0 deletions huggingface_model/internlm/internlm2_7b/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,41 @@
from .configuration_internlm2 import InternLM2Config
from .modeling_internlm2 import InternLM2ForCausalLM

from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
)

import torch

def reset_attn_parameters(layer_idx, layer, use_scaled_init=True):
for name, param in layer.attention.named_parameters():
if param.ndim == 1:
param.data.zero_()
elif "wq" in name or "wk" in name or "wv" in name:
normal_(std=0.02)(param.data)
elif use_scaled_init: # wo
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=0.02)(param.data)

for name, param in layer.feed_forward.named_parameters():
if use_scaled_init:
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=0.02)(param.data)

def reset_parameters(model):
with torch.no_grad():
for _, param in model.model.tok_embeddings.named_parameters():
normal_(std=0.02)(param)
for layer_idx, layer in enumerate(model.model.layers):
reset_attn_parameters(layer_idx, layer)
for _, param in model.output.named_parameters():
normal_(std=0.02)(param)

__all__ = [
"InternLM2Config",
"InternLM2ForCausalLM",
"reset_parameters"
]

0 comments on commit 9e05051

Please sign in to comment.