forked from ok1zjf/VASNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvasnet_model.py
117 lines (80 loc) · 3.41 KB
/
vasnet_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
__author__ = 'Jiri Fajtl'
__email__ = '[email protected]'
__version__= '3.6'
__status__ = "Research"
__date__ = "1/12/2018"
__license__= "MIT License"
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import *
from layer_norm import *
class SelfAttention(nn.Module):
def __init__(self, apperture=-1, ignore_itself=False, input_size=1024, output_size=1024):
super(SelfAttention, self).__init__()
self.apperture = apperture
self.ignore_itself = ignore_itself
self.m = input_size
self.output_size = output_size
self.K = nn.Linear(in_features=self.m, out_features=self.output_size, bias=False)
self.Q = nn.Linear(in_features=self.m, out_features=self.output_size, bias=False)
self.V = nn.Linear(in_features=self.m, out_features=self.output_size, bias=False)
self.output_linear = nn.Linear(in_features=self.output_size, out_features=self.m, bias=False)
self.drop50 = nn.Dropout(0.5)
def forward(self, x):
n = x.shape[0] # sequence length
K = self.K(x) # ENC (n x m) => (n x H) H= hidden size
Q = self.Q(x) # ENC (n x m) => (n x H) H= hidden size
V = self.V(x)
Q *= 0.06
logits = torch.matmul(Q, K.transpose(1,0))
if self.ignore_itself:
# Zero the diagonal activations (a distance of each frame with itself)
logits[torch.eye(n).byte()] = -float("Inf")
if self.apperture > 0:
# Set attention to zero to frames further than +/- apperture from the current one
onesmask = torch.ones(n, n)
trimask = torch.tril(onesmask, -self.apperture) + torch.triu(onesmask, self.apperture)
logits[trimask == 1] = -float("Inf")
att_weights_ = nn.functional.softmax(logits, dim=-1)
weights = self.drop50(att_weights_)
y = torch.matmul(V.transpose(1,0), weights).transpose(1,0)
y = self.output_linear(y)
return y, att_weights_
class VASNet(nn.Module):
def __init__(self):
super(VASNet, self).__init__()
self.m = 1024 # cnn features size
self.hidden_size = 1024
self.att = SelfAttention(input_size=self.m, output_size=self.m)
self.ka = nn.Linear(in_features=self.m, out_features=1024)
self.kb = nn.Linear(in_features=self.ka.out_features, out_features=1024)
self.kc = nn.Linear(in_features=self.kb.out_features, out_features=1024)
self.kd = nn.Linear(in_features=self.ka.out_features, out_features=1)
self.sig = nn.Sigmoid()
self.relu = nn.ReLU()
self.drop50 = nn.Dropout(0.5)
self.softmax = nn.Softmax(dim=0)
self.layer_norm_y = LayerNorm(self.m)
self.layer_norm_ka = LayerNorm(self.ka.out_features)
def forward(self, x, seq_len):
m = x.shape[2] # Feature size
# Place the video frames to the batch dimension to allow for batch arithm. operations.
# Assumes input batch size = 1.
x = x.view(-1, m)
y, att_weights_ = self.att(x)
y = y + x
y = self.drop50(y)
y = self.layer_norm_y(y)
# Frame level importance score regression
# Two layer NN
y = self.ka(y)
y = self.relu(y)
y = self.drop50(y)
y = self.layer_norm_ka(y)
y = self.kd(y)
y = self.sig(y)
y = y.view(1, -1)
return y, att_weights_
if __name__ == "__main__":
pass