From 8a7fcea59fb5f8564858944ef1f325bd90076ce4 Mon Sep 17 00:00:00 2001 From: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:44:31 +0800 Subject: [PATCH] fix minor error (#12) --- .../internlm/internlm_7b/modeling_internlm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/huggingface_model/internlm/internlm_7b/modeling_internlm.py b/huggingface_model/internlm/internlm_7b/modeling_internlm.py index 84d0d28..7b41d93 100644 --- a/huggingface_model/internlm/internlm_7b/modeling_internlm.py +++ b/huggingface_model/internlm/internlm_7b/modeling_internlm.py @@ -494,10 +494,10 @@ def forward( cumulative_len=cu_seqlens, max_seqlen=max_seqlen, dropout_p=0.0, - ).unsqueeze(0) + ) else: attn_output = hf_q_k_v_without_cu_seqlens( - query_states, key_states, value_states, dropout_p=0, softmax_scale=None, causal=True, + query_states, key_states, value_states, dropout_p=0.0, softmax_scale=None, causal=True, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -541,7 +541,7 @@ def _flash_attention_forward( cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - attn_output_unpad = varlen_flash_attn( + attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, @@ -556,7 +556,7 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_wo_mask( + attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal )