diff --git a/huggingface_model/internlm/internlm_7b/modeling_internlm.py b/huggingface_model/internlm/internlm_7b/modeling_internlm.py index 3450a58..7c7282f 100644 --- a/huggingface_model/internlm/internlm_7b/modeling_internlm.py +++ b/huggingface_model/internlm/internlm_7b/modeling_internlm.py @@ -994,6 +994,67 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def split_weights(self, first_layer, model_state_dict, state_dict, split_size, local_rank, row_dim): + for i in range(0, gpc.config.model.num_layers): + model_state_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.self_attn.o_proj.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.mlp.down_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.mlp.down_proj.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.mlp.up_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.mlp.up_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + model_state_dict[f"model.layers.{i}.input_layernorm.weight"] = state_dict.pop( + f"model.layers.{i+first_layer}.input_layernorm.weight" + ) + model_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = state_dict.pop( + f"model.layers.{i+first_layer}.post_attention_layernorm.weight" + ) + + if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + model_state_dict[f"model.embed_tokens.weight"] = torch.chunk( + state_dict.pop(f"model.embed_tokens.weight"), + split_size, + dim=1, + )[local_rank] + + if gpc.is_last_rank(ParallelMode.PIPELINE): + model_state_dict[f"lm_head.weight"] = torch.chunk( + state_dict.pop(f"lm_head.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.norm.weight"] = state_dict[f"model.norm.weight"] + + return model_state_dict + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1367,4 +1428,4 @@ def forward( past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) \ No newline at end of file + )