diff --git a/examples/internlm/internlm2_7b/train.py b/examples/internlm/internlm2_7b/train.py index fefe3a6..6170130 100644 --- a/examples/internlm/internlm2_7b/train.py +++ b/examples/internlm/internlm2_7b/train.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from functools import partial + from internlm.core.context import global_context as gpc from internlm.core.trainer_builder import TrainerBuilder from internlm.data import ( @@ -29,7 +31,7 @@ def main(args): hf_config_initializer.register_module(gpc.config.model_type, InternLM2Config) # initialize model - model = initialize_model(model_dispatch_func=hf_model_dispatch) + model = initialize_model(model_dispatch_func=partial(hf_model_dispatch, auto_dispatch=True)) # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/examples/internlm/internlm_7b/train.py b/examples/internlm/internlm_7b/train.py index afefc8d..f77e8b3 100644 --- a/examples/internlm/internlm_7b/train.py +++ b/examples/internlm/internlm_7b/train.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from functools import partial + from internlm.core.context import global_context as gpc from internlm.core.trainer_builder import TrainerBuilder from internlm.data import ( @@ -25,7 +27,7 @@ def main(args): hf_config_initializer.register_module(gpc.config.model_type, InternLMConfig) # initialize model - model = initialize_model(model_dispatch_func=hf_model_dispatch) + model = initialize_model(model_dispatch_func=partial(hf_model_dispatch, auto_dispatch=True)) # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/huggingface_model/dispatch_utils/__init__.py b/huggingface_model/dispatch_utils/__init__.py index e8a9ab4..761bdff 100644 --- a/huggingface_model/dispatch_utils/__init__.py +++ b/huggingface_model/dispatch_utils/__init__.py @@ -1,10 +1,12 @@ # adapted from https://github.com/InternLM/xtuner/blob/main/xtuner/model/modules/dispatch/__init__.py - import importlib from collections import abc from typing import Any, Optional, Type, Union from internlm.core.context.parallel_context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) # adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L8 @@ -190,27 +192,23 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None return False return True + EMBED_REPLACE_MAPPING = dict( Embedding=LazyObject("internlm.model.modules.embedding", "Embedding1D"), ) -NORM_REPLACE_MAPPING = dict( - InternLMRMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"), - InternLM2RMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"), -) LINEAR_REPLACE_MAPPING = dict( Linear=LazyObject("internlm.model.modules.linear", "new_linear"), ) -NORM2NEW_NORM_NAME_MAPPING = dict( - input_layernorm="rmsnorm", - post_attention_layernorm="rmsnorm", - norm="rmsnorm", - attention_norm="rmsnorm", - ffn_norm="rmsnorm", + +NORM_REPLACE_MAPPING = dict( + InternLMRMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"), + InternLM2RMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"), ) + LINEAR2NEW_LINEAR_NAME_MAPPING = dict( q_proj="wq", k_proj="wk", @@ -223,8 +221,18 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None ) +NORM2NEW_NORM_NAME_MAPPING = dict( + input_layernorm="rmsnorm", + post_attention_layernorm="rmsnorm", + norm="rmsnorm", + attention_norm="rmsnorm", + ffn_norm="rmsnorm", +) + + RESET_PARAM_FUNC_MAPPING = dict( internlm2_7b=LazyObject("huggingface_model.internlm.internlm2_7b", "reset_parameters"), + internlm_7b=LazyObject("huggingface_model.internlm.internlm_7b", "reset_parameters"), ) @@ -247,6 +255,26 @@ def traverse(module): traverse(model) +def replace_linear(model): + def traverse(module): + for name, child in module.named_children(): + cls_name = type(child).__name__ + if cls_name in LINEAR_REPLACE_MAPPING: + linear = LINEAR_REPLACE_MAPPING[cls_name] + linear = linear.build() + child_new = linear( + name=LINEAR2NEW_LINEAR_NAME_MAPPING.get(name, name), + in_features=child.in_features, + out_features=child.out_features, + bias=child.bias is not None, + ).to(device=child.weight.device, dtype=child.weight.dtype) + setattr(module, name, child_new) + else: + traverse(child) + + traverse(model) + + def replace_norm(model): def traverse(module): for name, child in module.named_children(): @@ -266,33 +294,61 @@ def traverse(module): traverse(model) -def replace_linear(model): +def check_embed(model): + def traverse(module): + for name, child in module.named_children(): + cls_name = type(child).__name__ + if cls_name in EMBED_REPLACE_MAPPING: + embed = EMBED_REPLACE_MAPPING[cls_name] + embed = embed.build() + logger.warning(f"{name} of type {cls_name} is suggested to be replaced with type {embed.__name__}") + else: + traverse(child) + + traverse(model) + + +def check_linear(model): def traverse(module): for name, child in module.named_children(): cls_name = type(child).__name__ if cls_name in LINEAR_REPLACE_MAPPING: linear = LINEAR_REPLACE_MAPPING[cls_name] linear = linear.build() - child_new = linear( - name=LINEAR2NEW_LINEAR_NAME_MAPPING.get(name, name), - in_features=child.in_features, - out_features=child.out_features, - bias=child.bias is not None, - ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) + logger.warning(f"{name} of type {cls_name} is suggested to be replaced with type {linear.__name__}") + else: + traverse(child) + + traverse(model) + + +def check_norm(model): + def traverse(module): + for name, child in module.named_children(): + cls_name = type(child).__name__ + if cls_name in NORM_REPLACE_MAPPING: + norm = NORM_REPLACE_MAPPING[cls_name] + norm = norm.build() + logger.warning(f"{name} of type {cls_name} is suggested to be replaced with type {norm.__name__}") else: traverse(child) traverse(model) -def hf_model_dispatch(model): - replace_embed(model) - replace_norm(model) - replace_linear(model) - reset_parameters = RESET_PARAM_FUNC_MAPPING.get(gpc.config.HF_MODEL_NAME.split("/")[1].replace("-", "_"), None) - assert reset_parameters is not None, "reset_parameters need to be implemented." - reset_parameters = reset_parameters.build() - reset_parameters(model) +def hf_model_dispatch(model, auto_dispatch=False): + if auto_dispatch: + replace_embed(model) + replace_linear(model) + replace_norm(model) + reset_parameters = RESET_PARAM_FUNC_MAPPING.get(gpc.config.HF_MODEL_NAME.split("/")[1].replace("-", "_"), None) + assert reset_parameters is not None, "In auto_dispatch mode, function reset_parameters need to be implemented." + reset_parameters = reset_parameters.build() + reset_parameters(model) + else: + check_embed(model) + check_linear(model) + check_norm(model) + -__all__ = ["hf_model_dispatch"] \ No newline at end of file +__all__ = ["hf_model_dispatch"] diff --git a/huggingface_model/internlm/internlm2_7b/__init__.py b/huggingface_model/internlm/internlm2_7b/__init__.py index c831dea..7485027 100644 --- a/huggingface_model/internlm/internlm2_7b/__init__.py +++ b/huggingface_model/internlm/internlm2_7b/__init__.py @@ -1,41 +1,36 @@ +import torch +from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal + from .configuration_internlm2 import InternLM2Config from .modeling_internlm2 import InternLM2ForCausalLM -from internlm.initialize.initialize_tensor import ( - normal_, - scaled_init_method_normal, -) - -import torch -def reset_attn_parameters(layer_idx, layer, use_scaled_init=True): +def reset_attn_parameters(layer_idx, layer, use_scaled_init=True, std=0.02): for name, param in layer.attention.named_parameters(): - if param.ndim == 1: + if param.ndim == 1: # bias param.data.zero_() elif "wq" in name or "wk" in name or "wv" in name: - normal_(std=0.02)(param.data) + normal_(std=std)(param.data) elif use_scaled_init: # wo - scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data) - else: - normal_(std=0.02)(param.data) + scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data) + else: # wo + normal_(std=std)(param.data) for name, param in layer.feed_forward.named_parameters(): if use_scaled_init: - scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data) + scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data) else: - normal_(std=0.02)(param.data) + normal_(std=std)(param.data) + -def reset_parameters(model): +def reset_parameters(model, std=0.02): with torch.no_grad(): for _, param in model.model.tok_embeddings.named_parameters(): - normal_(std=0.02)(param) + normal_(std=std)(param) for layer_idx, layer in enumerate(model.model.layers): reset_attn_parameters(layer_idx, layer) for _, param in model.output.named_parameters(): - normal_(std=0.02)(param) + normal_(std=std)(param) + -__all__ = [ - "InternLM2Config", - "InternLM2ForCausalLM", - "reset_parameters" -] +__all__ = ["InternLM2Config", "InternLM2ForCausalLM", "reset_parameters"] diff --git a/huggingface_model/internlm/internlm_7b/__init__.py b/huggingface_model/internlm/internlm_7b/__init__.py index 14f4e9e..f378db5 100644 --- a/huggingface_model/internlm/internlm_7b/__init__.py +++ b/huggingface_model/internlm/internlm_7b/__init__.py @@ -1,7 +1,40 @@ +import torch +from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal + from .configuration_internlm import InternLMConfig from .modeling_internlm import InternLMForCausalLM + +def reset_attn_parameters(layer_idx, layer, use_scaled_init=True, std=0.02): + for name, param in layer.self_attn.named_parameters(): + if param.ndim == 1: # bias + param.data.zero_() + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + normal_(std=std)(param.data) + elif use_scaled_init: # wo + scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data) + else: # wo + normal_(std=std)(param.data) + + for name, param in layer.mlp.named_parameters(): + if use_scaled_init: + scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data) + else: + normal_(std=std)(param.data) + + +def reset_parameters(model, std=0.02): + with torch.no_grad(): + for _, param in model.model.embed_tokens.named_parameters(): + normal_(std=std)(param) + for layer_idx, layer in enumerate(model.model.layers): + reset_attn_parameters(layer_idx, layer) + for _, param in model.lm_head.named_parameters(): + normal_(std=std)(param) + + __all__ = [ "InternLMConfig", "InternLMForCausalLM", + "reset_parameters" ]