-
Notifications
You must be signed in to change notification settings - Fork 1
/
fusion.py
37 lines (31 loc) · 979 Bytes
/
fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
from torch.autograd import Variable
class Fusion(nn.Module):
def __init__(self, opt, hidden_size):
super(Fusion, self).__init__()
self.opt = opt
self.match = nn.Sequential(
nn.Dropout(opt.dropout),
nn.Linear(hidden_size * 4, hidden_size),
nn.Tanh())
self.gate = nn.Sequential(
nn.Dropout(opt.dropout),
nn.Linear(hidden_size * 4, 1),
nn.Sigmoid())
# x of shape (batch_l, l1, hidden_size)
# y of shape (batch_l, l2, hidden_size)
# t = tanh(w1^T [x, y, x-y, x*y])
# g = sigm(w2^T [x, y, x-y, x*y])
# t * g + (1-g) * x
def forward(self, x, y):
assert(x.shape[0] == y.shape[0] and x.shape[2] == y.shape[2])
batch_l, l1, hidden_size = x.shape
l2 = y.shape[1]
one = Variable(torch.ones(1), requires_grad=False)
if self.opt.gpuid != -1:
one = one.cuda()
merged = torch.cat([x, y, x-y, x*y], 2)
matched = self.match(merged)
gated = self.gate(merged)
return gated * matched + (one - gated) * x