-
Notifications
You must be signed in to change notification settings - Fork 18
/
language_model.py
81 lines (67 loc) · 2.58 KB
/
language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
class WordEmbedding(nn.Module):
"""Word Embedding
The ntoken-th dim is used for padding_idx, which agrees *implicitly*
with the definition in Dictionary.
"""
def __init__(self, ntoken, emb_dim, dropout):
super(WordEmbedding, self).__init__()
self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
self.dropout = nn.Dropout(dropout)
self.ntoken = ntoken
self.emb_dim = emb_dim
def init_embedding(self, np_file):
weight_init = torch.from_numpy(np.load(np_file))
assert weight_init.shape == (self.ntoken, self.emb_dim)
self.emb.weight.data[:self.ntoken] = weight_init
def forward(self, x):
emb = self.emb(x)
emb = self.dropout(emb)
return emb
class QuestionEmbedding(nn.Module):
def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'):
"""Module for question embedding
"""
super(QuestionEmbedding, self).__init__()
assert rnn_type == 'LSTM' or rnn_type == 'GRU'
rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU
self.rnn = rnn_cls(
in_dim, num_hid, nlayers,
bidirectional=bidirect,
dropout=dropout,
batch_first=True)
self.in_dim = in_dim
self.num_hid = num_hid
self.nlayers = nlayers
self.rnn_type = rnn_type
self.ndirections = 1 + int(bidirect)
def init_hidden(self, batch):
# just to get the type of tensor
weight = next(self.parameters()).data
hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
if self.rnn_type == 'LSTM':
return (Variable(weight.new(*hid_shape).zero_()),
Variable(weight.new(*hid_shape).zero_()))
else:
return Variable(weight.new(*hid_shape).zero_())
def forward(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
self.rnn.flatten_parameters()
output, hidden = self.rnn(x, hidden)
if self.ndirections == 1:
return output[:, -1]
forward_ = output[:, -1, :self.num_hid]
backward = output[:, 0, self.num_hid:]
return torch.cat((forward_, backward), dim=1)
def forward_all(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
self.rnn.flatten_parameters()
output, hidden = self.rnn(x, hidden)
return output