-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
230 lines (172 loc) · 9.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from typing import Dict
from typing import Optional
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
# Use for changing dim of input in encoder and decoder embeddings
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_attention = all(substr in k for substr in ['attn', '.weight'])
is_mlp = any([k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']])
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
return state_dict
class LayerNormANE(LayerNormANEBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)
class MultiHeadAttentionANE(MultiHeadAttention):
def __init__(self, n_state: int, n_head: int):
super().__init__(n_state, n_head)
setattr(self, 'query', nn.Conv2d(n_state, n_state, kernel_size=1))
setattr(self, 'key', nn.Conv2d(n_state, n_state, kernel_size=1, bias=False))
setattr(self, 'value', nn.Conv2d(n_state, n_state, kernel_size=1))
setattr(self, 'out', nn.Conv2d(n_state, n_state, kernel_size=1))
def forward(self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention_ane(q, k, v, mask)
return self.out(wv), qk
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
bs, dim, dummy, seqlen = q.size()
dim_per_head = dim // self.n_head
scale = float(dim_per_head)**-0.5
q = q * scale
mh_q = q.split(dim_per_head, dim=1)
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
mh_v = v.split(dim_per_head, dim=1)
mh_qk = [
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
for qi, ki in zip(mh_q, mh_k)
] # (bs, max_seq_length, 1, max_seq_length) * n_heads
if mask is not None:
for head_idx in range(self.n_head):
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (bs, max_seq_length, 1, max_seq_length) * n_heads
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (bs, dim_per_head, 1, max_seq_length) * n_heads
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
return attn, torch.cat(mh_qk, dim=1).float().detach()
class ResidualAttentionBlockANE(ResidualAttentionBlock):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__(n_state, n_head, cross_attention)
setattr(self, 'attn', MultiHeadAttentionANE(n_state, n_head))
setattr(self, 'attn_ln', LayerNormANE(n_state))
setattr(self, 'cross_attn', MultiHeadAttentionANE(n_state, n_head) if cross_attention else None)
setattr(self, 'cross_attn_ln', LayerNormANE(n_state) if cross_attention else None)
n_mlp = n_state * 4
setattr(self, 'mlp', nn.Sequential(
nn.Conv2d(n_state, n_mlp, kernel_size=1),
nn.GELU(),
nn.Conv2d(n_mlp, n_state, kernel_size=1)
))
setattr(self, 'mlp_ln', LayerNormANE(n_state))
class AudioEncoderANE(AudioEncoder):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
setattr(self, 'blocks', nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
))
setattr(self, 'ln_post', LayerNormANE(n_state))
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
# add positional embedding and add dummy dim for ANE
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoderANE(TextDecoder):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
setattr(self, 'blocks', nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
))
setattr(self, 'ln', LayerNormANE(n_state))
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
# Reformat for ANE
mask = self.mask[None, None, :, :].permute(0,3,1,2)
x = x.transpose(1,2).unsqueeze(2)
for block in self.blocks:
x = block(x, xa, mask=mask, kv_cache=kv_cache)
x = self.ln(x)
# Reformat back from ANE
x = x.permute(0,2,3,1).squeeze(0)
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,865 tokens so we need to compute in chunks
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
return logits
class WhisperANE(Whisper):
def __init__(self, dims: ModelDimensions):
super().__init__(dims)
setattr(self, 'encoder', AudioEncoderANE(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
))
# TODO: Remove from here
setattr(self, 'decoder', TextDecoderANE(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
))
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=3).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttentionANE):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks