-
Notifications
You must be signed in to change notification settings - Fork 0
/
headnorm.py
146 lines (101 loc) · 3.88 KB
/
headnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations
import torch
from torch import nn
from exllamav2.module import ExLlamaV2Module
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2
class ExLlamaV2HeadNorm(ExLlamaV2Module):
name: str = "LayerNorm"
layernorm: nn.LayerNorm | None
weight: nn.Parameter | None
bias: nn.Parameter | None
variance_epsilon: float
head_dim: int
num_heads: int
def __init__(self,
model: ExLlamaV2,
key: str,
num_heads: int,
head_dim: int):
super().__init__(model, key)
self.layernorm = None
self.weight = None
self.bias = None
self.variance_epsilon = self.model.config.norm_eps
self.head_dim = head_dim
self.num_heads = num_heads
def load(self):
w = self.load_weight()
if isinstance(w, tuple):
weight = w[0]
bias = w[1]
else:
weight = w
bias = None
assert isinstance(weight, nn.Parameter)
assert bias is None or isinstance(bias, nn.Parameter)
self.layernorm = nn.LayerNorm(self.model.config.hidden_size,
elementwise_affine = True,
bias = bias is not None)
self.layernorm.weight = weight
self.weight = weight
if bias is not None:
self.layernorm.bias = bias
self.bias = bias
assert self.weight.shape == (self.num_heads, self.head_dim), "Head norm tensor shape mismatch"
def unload(self):
self.layernorm = None
self.weight = None
self.bias = None
def numel(self):
return 0
# return self.layernorm.weight.data.numel()
def get_weight(self) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.bias is not None: return self.weight, self.bias
return self.weight
def weight_footprint(self) -> int:
hidden_size = self.model.config.hidden_size
return hidden_size * 2
def scratch_space_fixed(self) -> int:
return 0
def scratch_space(self) -> int:
return 0
def forward(self,
hidden_states: torch.Tensor,
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
norm = torch.empty_like(hidden_states)
ext_c.head_norm(hidden_states,
self.weight.data,
self.bias.data if self.bias is not None else none_tensor,
hidden_states,
self.variance_epsilon)
if intermediates:
return {"hidden_states": hidden_states}
else:
return hidden_states
def forward_torch(self,
hidden_states: torch.Tensor,
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim = True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim = True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.to(input_dtype)
if intermediates:
return {"hidden_states": hidden_states}
else:
return hidden_states