-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhighway.py
49 lines (36 loc) · 1.21 KB
/
highway.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import sys
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
class HighwayLayer(torch.nn.Module):
def __init__(self, opt, hidden_size):
super(HighwayLayer, self).__init__()
self.opt = opt
self.drop = nn.Dropout(0.2)
self.tran_linear = nn.Linear(hidden_size, hidden_size)
self.gate_linear = nn.Linear(hidden_size, hidden_size)
self.tran_act = nn.ReLU()
self.gate_act = nn.Sigmoid()
# x is of shape (batch_l * seq_l, opt.hidden_size)
def forward(self, x):
self.one = Variable(torch.ones(1), requires_grad=False)
if self.opt.gpuid != -1:
self.one = self.one.cuda()
x = self.drop(x)
tran = self.tran_act(self.tran_linear(x))
gate = self.gate_act(self.gate_linear(x))
return gate * tran + (self.one - gate) * x
# Highway networks
class Highway(torch.nn.Module):
def __init__(self, opt, hidden_size):
super(Highway, self).__init__()
self.opt = opt
hw_layer = opt.hw_layer
self.hw_layers = nn.ModuleList([HighwayLayer(opt, hidden_size) for _ in range(hw_layer)])
# input is encoding tensor of shape (batch_l * seq_l, hidden_size)
def forward(self, seq):
for i, hl in enumerate(self.hw_layers):
seq = hl(seq)
return seq