-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_head_attention_torch.py
33 lines (26 loc) · 1.18 KB
/
multi_head_attention_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
assert dim % num_heads == 0
self.coef = 4
self.trans_dims = nn.Linear(dim, dim * self.coef)
self.num_heads = self.num_heads * self.coef
self.k = 256 // self.coef
self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)
self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim * self.coef, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
x = self.trans_dims(x) # B, N, C
x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = self.linear_0(x)
attn = attn.softmax(dim=-2)
attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))
attn = self.attn_drop(attn)
x = self.linear_1(attn).permute(0,2,1,3).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x