-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathattentionmask.py
118 lines (91 loc) · 3.52 KB
/
attentionmask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from einops import einsum
from torch import Tensor
from torch.nn import Sigmoid
from torch.nn import Module
from torch.nn import Linear
from torch.nn import Parameter
from attentionpooling import GlobalAttentionPooling1D
class AttentionMask1d(Module):
"""
Calculates masks for a channel derived from the information contained in other channels.
"""
def __init__(self,
num_of_channels: int,
bias: bool = True,
init: float = 5):
"""_summary_
Args:
num_of_channels (int): Number of channels
bias (bool, optional): Should the layers have bias. Defaults to True.
init (float, optional): Starts the weights from this value.
"""
super().__init__()
self.num_of_channels = num_of_channels
self.bias = bias
self.weights = Linear(out_features = num_of_channels,
in_features = num_of_channels,
bias = bias)
self.activation = Sigmoid()
with torch.no_grad():
self.weights.weight[:] = 0
self.weights.weight.fill_diagonal_(init)
if self.bias:
self.weights.bias = Parameter(self.weights.bias.view(1, -1, 1))
self.weights.bias[:] = 0
def forward(self, x: Tensor.float) -> Tensor.float:
"""Forward operation for the layer.
Args:
x (Tensor.float): input tensor
Returns:
Tensor.float: effect diffusion
"""
_, _, l = x.shape
mask = einsum(self.weights.weight, x, "o i, b i l -> b o l")
if self.bias:
mask = mask + self.weights.bias
mask = self.activation(mask)
return x * mask
class AttentionMaskGlobal1d(Module):
"""
Calculates masks for a channel derived from the information contained in other channels.
This derives a single mask for the entire channel instead of computing a mask for each position
separately.
"""
def __init__(self,
num_of_channels: int,
bias: bool = True,
init: float = 5):
"""_summary_
Args:
num_of_channels (int): Number of channels
bias (bool, optional): Should the layers have bias. Defaults to True.
init (float, optional): Starts the weights from this value.
"""
super().__init__()
self.num_of_channels = num_of_channels
self.bias = bias
self.weights = Linear(out_features = num_of_channels,
in_features = num_of_channels,
bias = bias)
self.activation = Sigmoid()
self.pooling = GlobalAttentionPooling1D(feature_size = num_of_channels)
with torch.no_grad():
self.weights.weight[:] = 0
self.weights.weight.fill_diagonal_(init)
if self.bias:
self.weights.bias = Parameter(self.weights.bias.view(1, -1, 1))
self.weights.bias[:] = 0
def forward(self, x: Tensor.float) -> Tensor.float:
"""Forward operation for the layer.
Args:
x (Tensor.float): input tensor
Returns:
Tensor.float: effect diffusion
"""
mask = einsum(self.weights.weight, x, "o i, b i l -> b o l")
if self.bias:
mask = mask + self.weights.bias
mask = self.pooling(mask)[0]
mask = self.activation(mask)
return einsum(x, mask, "b c l, b c -> b c l")