-
Notifications
You must be signed in to change notification settings - Fork 4
/
internevo.patch
109 lines (101 loc) · 4.33 KB
/
internevo.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
--- huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py 2024-09-14 00:50:47.339987000 +0800
+++ huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py 2024-09-14 00:53:06.091602000 +0800
@@ -49,9 +49,10 @@
)
from .configuration_qwen2 import Qwen2Config
-
-if is_flash_attn_2_available():
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
+from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func
+from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
logger = logging.get_logger(__name__)
@@ -328,6 +329,12 @@
):
bsz, q_len, _ = hidden_states.size()
+ use_packed_dataset = gpc.config.data.get("use_packed_dataset", False)
+ if use_packed_dataset:
+ assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True"
+ cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
+ max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
+
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
@@ -414,27 +421,28 @@
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- if (
- self.config.use_sliding_window
- and getattr(self.config, "sliding_window", None) is not None
- and self.layer_idx >= self.config.max_window_layers
- ):
- sliding_window = self.config.sliding_window
+ if use_packed_dataset:
+ attn_output = isp_flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ attention_dropout=dropout_rate,
+ softmax_scale=None,
+ causal=True,
+ )
else:
- sliding_window = None
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- position_ids=position_ids,
- dropout=dropout_rate,
- sliding_window=sliding_window,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
+ attn_output = isp_flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ attention_dropout=dropout_rate,
+ softmax_scale=None,
+ causal=True,
+ )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
@@ -1139,6 +1147,30 @@
)
return model_inputs
+ def reset_parameters(self, std=0.02):
+ def reset_attn_parameters(layer_idx, layer, use_scaled_init=True, std=0.02):
+ for name, param in layer.self_attn.named_parameters():
+ if param.ndim == 1: # bias
+ param.data.zero_()
+ elif "q_proj" in name or "k_proj" in name or "v_proj" in name: # wq, wk, wv
+ normal_(std=std)(param.data)
+ elif use_scaled_init: # wo
+ scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data)
+ else: # wo
+ normal_(std=std)(param.data)
+
+ for name, param in layer.mlp.named_parameters():
+ if use_scaled_init:
+ scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data)
+ else:
+ normal_(std=std)(param.data)
+ with torch.no_grad():
+ for _, param in self.model.embed_tokens.named_parameters():
+ normal_(std=std)(param)
+ for layer_idx, layer in enumerate(self.model.layers):
+ reset_attn_parameters(layer_idx, layer)
+ for _, param in self.lm_head.named_parameters():
+ normal_(std=std)(param)
@add_start_docstrings(
"""