From c3df1877f0a63b04642184c208281ded09fe77d3 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 19 Dec 2024 02:04:48 +0800 Subject: [PATCH] [tool]GQA convert support (#454) * [tools]GQA convert support * fix readme --- .../finetune_hf_llama/README.md | 4 +- .../finetune_hf_llama/ds_config.json | 8 ++- .../finetune_hf_llama/ds_config_empty.json | 5 ++ .../finetune_hf_llama/finetune_llama.sh | 10 ++- tools/hf2megads_weight_converter.py | 66 +++++++++++-------- 5 files changed, 62 insertions(+), 31 deletions(-) create mode 100755 examples_deepspeed/finetune_hf_llama/ds_config_empty.json diff --git a/examples_deepspeed/finetune_hf_llama/README.md b/examples_deepspeed/finetune_hf_llama/README.md index e8641ced21..7c81518fa4 100644 --- a/examples_deepspeed/finetune_hf_llama/README.md +++ b/examples_deepspeed/finetune_hf_llama/README.md @@ -10,9 +10,9 @@ The pre-trained weights can be found at [Hugging Face - LLAMA-7B](https://huggin #### 1. Converting Hugging Face Model Weights to Megatron-Deepspeed Model ```bash -bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert +bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert_hf2mds ``` -This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script. +This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.```convert_mds2hf``` can convert a Megatron-Deepspeed model into the Hugging Face format #### 2. Fine-tuning Process ```bash diff --git a/examples_deepspeed/finetune_hf_llama/ds_config.json b/examples_deepspeed/finetune_hf_llama/ds_config.json index 85f439ce47..9c0b332473 100755 --- a/examples_deepspeed/finetune_hf_llama/ds_config.json +++ b/examples_deepspeed/finetune_hf_llama/ds_config.json @@ -1,5 +1,11 @@ { "train_batch_size" : 256, "train_micro_batch_size_per_gpu": 16, - "steps_per_print": 1 + "steps_per_print": 100, + "zero_optimization": { + "stage": 0 + }, + "bf16": { + "enabled": true + } } diff --git a/examples_deepspeed/finetune_hf_llama/ds_config_empty.json b/examples_deepspeed/finetune_hf_llama/ds_config_empty.json new file mode 100755 index 0000000000..bc05743cd3 --- /dev/null +++ b/examples_deepspeed/finetune_hf_llama/ds_config_empty.json @@ -0,0 +1,5 @@ +{ + "train_batch_size" : 256, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 100 +} diff --git a/examples_deepspeed/finetune_hf_llama/finetune_llama.sh b/examples_deepspeed/finetune_hf_llama/finetune_llama.sh index ab8bfdf419..d6b472913f 100644 --- a/examples_deepspeed/finetune_hf_llama/finetune_llama.sh +++ b/examples_deepspeed/finetune_hf_llama/finetune_llama.sh @@ -43,6 +43,13 @@ cat < $DS_CONFIG } EOT +if [ "$1" = "convert_hf2mds" ]; then + DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json" +elif [ "$1" = "convert_mds2hf" ]; then + DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json" +else + DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config.json" +fi covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \ --hf-ckpt-num-shards 2 \ @@ -69,6 +76,7 @@ comm_args="--tensor-model-parallel-size $TP \ --num-layers $NUM_LAYERS \ --hidden-size $HIDDEN_SIZE \ --num-attention-heads $NUM_HEADS \ +--finetune \ --ffn-hidden-size $FFN_HIDDEN_SIZE \ --attention-dropout 0 \ --hidden-dropout 0 \ @@ -97,7 +105,7 @@ comm_args="--tensor-model-parallel-size $TP \ --zero-stage 0 \ --tokenizer-type HFTokenizer \ --tokenizer-model $HF_LLAMA_PATH \ ---deepspeed_config ./examples_deepspeed/finetune_hf_llama/ds_config.json \ +--deepspeed_config $DS_CONFIG_PATH \ --deepspeed \ --distributed-backend nccl \ --num-workers 0 \ diff --git a/tools/hf2megads_weight_converter.py b/tools/hf2megads_weight_converter.py index 12468963c5..74ce462281 100755 --- a/tools/hf2megads_weight_converter.py +++ b/tools/hf2megads_weight_converter.py @@ -193,28 +193,43 @@ def _qkv_refactor(self, pname, p, hf_layer): wk = self.hf_model[hf_wk_name] wv = self.hf_model[hf_wv_name] - hidden_size = wq.shape[0] - per_partition_size, start_index, end_index = compute_partition_range( - hidden_size, self.tp_rank, self.tp_size) - hidden_size_per_attention_head = divide(hidden_size, + query_hidden_size = wq.shape[0] + kv_hidden_size = wk.shape[0] + + per_partition_size, start_qindex, end_index = compute_partition_range( + query_hidden_size, self.tp_rank, self.tp_size) + _,start_kvindex, _= compute_partition_range( + kv_hidden_size, self.tp_rank, self.tp_size) + + hidden_size_per_attention_head = divide(query_hidden_size, self.config.num_attention_heads) num_attention_heads_per_partition = divide(self.config.num_attention_heads, self.tp_size) - new_w = torch.zeros((per_partition_size * 3, wq.shape[1]), dtype=wq.dtype) + num_kv_heads_per_partition= divide(self.config.num_key_value_heads, + self.tp_size) + qkv_size=(num_attention_heads_per_partition+2*num_kv_heads_per_partition)*hidden_size_per_attention_head + num_qheads_per_group=divide(self.config.num_attention_heads,self.config.num_key_value_heads) + num_groups =divide(num_attention_heads_per_partition,num_qheads_per_group) + new_w = torch.zeros((qkv_size, wq.shape[1]), dtype=wq.dtype) + + for i in range(num_groups): + query_current_index=start_qindex+i*num_qheads_per_group*hidden_size_per_attention_head + query_next_index=query_current_index+num_qheads_per_group*hidden_size_per_attention_head + kv_current_index=start_kvindex+i*hidden_size_per_attention_head + kv_next_kvindex=kv_current_index+hidden_size_per_attention_head + + new_w_index=i* (num_qheads_per_group+2)*hidden_size_per_attention_head - for i in range(num_attention_heads_per_partition): - current_index = start_index + i * hidden_size_per_attention_head - next_index = current_index + hidden_size_per_attention_head - new_w_index = i * (3 * hidden_size_per_attention_head) - new_w[new_w_index: new_w_index + (3 * hidden_size_per_attention_head), :] = \ + new_w[new_w_index:new_w_index+(num_qheads_per_group+2)*hidden_size_per_attention_head,:]=\ torch.cat([ - wq[current_index: next_index, :], - wk[current_index: next_index, :], - wv[current_index: next_index, :] - ], dim=0) + wq[query_current_index:query_next_index,:], + wk[kv_current_index:kv_next_kvindex,:], + wv[kv_current_index:kv_next_kvindex,:] + ],dim=0) + self.record_mapping_info( - f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{current_index}:{next_index},:] of q,k,v{wq.shape}" + f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{query_current_index}:{query_next_index},:] of q,k,v{wq.shape}" ) return new_w @@ -383,17 +398,18 @@ def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer): hidden_size = oldshape[-1] hidden_size_per_attention_head = divide(hidden_size, self.config.num_attention_heads) - num_attention_heads_per_partition = divide(self.config.num_attention_heads, - self.tp_size) - newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size) + # MHA & GQA + group = divide(self.config.num_attention_heads, self.config.num_key_value_heads) + newshape = (self.config.num_key_value_heads, group + 2, hidden_size_per_attention_head, hidden_size) ds_w_out = ds_w_all_rank.reshape(*newshape) - self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1])) - self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1])) - self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1])) + query_weight, key_weight, value_weight = torch.split(ds_w_out, [group, 1, 1], dim=1) + self.hf_dict[hf_q_name] = copy.deepcopy(query_weight.reshape(-1, hidden_size)) + self.hf_dict[hf_k_name] = copy.deepcopy(key_weight.reshape(-1, hidden_size)) + self.hf_dict[hf_v_name] = copy.deepcopy(value_weight.reshape(-1, hidden_size)) + del query_weight, key_weight, value_weight def transform_from_megads_to_hf(self): - use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False for pname, p in self.ds_model.named_parameters(): if pname in [ @@ -411,11 +427,7 @@ def transform_from_megads_to_hf(self): subname = mobj.group(2) hf_layer = layer_num - self.offset_num if subname in ["self_attention.query_key_value.weight"]: - if not use_gqa: - self._qkv_refactor_to_hf(pname, p, hf_layer) - else: - #TODO(billishyahao): Not impl yet ... - assert False + self._qkv_refactor_to_hf(pname, p, hf_layer) elif subname in ["mlp.dense_h_to_4h.weight"]: self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer) elif subname in [