forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLanguageModel.lua
263 lines (222 loc) · 7.04 KB
/
LanguageModel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
require 'torch'
require 'nn'
require 'VanillaRNN'
require 'LSTM'
require 'cudnn'
local utils = require 'util.utils'
local utf8 = require 'lua-utf8'
local LM, parent = torch.class('nn.LanguageModel', 'nn.Module')
function LM:__init(kwargs)
self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token')
self.token_to_idx = {}
self.vocab_size = 0
for idx, token in pairs(self.idx_to_token) do
self.token_to_idx[token] = idx
self.vocab_size = self.vocab_size + 1
end
self.model_type = utils.get_kwarg(kwargs, 'model_type')
self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_size')
self.rnn_size = utils.get_kwarg(kwargs, 'rnn_size')
self.num_layers = utils.get_kwarg(kwargs, 'num_layers')
self.dropout = utils.get_kwarg(kwargs, 'dropout')
self.batchnorm = utils.get_kwarg(kwargs, 'batchnorm')
local usecudnn = utils.get_kwarg(kwargs, 'cudnn')
if utils.get_kwarg(kwargs,'cudnn_fastest') then
cudnn.fastest = true
end
local V, D, H = self.vocab_size, self.wordvec_dim, self.rnn_size
self.net = nn.Sequential()
self.rnns = {}
self.bn_view_in = {}
self.bn_view_out = {}
self.net:add(nn.LookupTable(V, D))
if (usecudnn == 0) then
for i = 1, self.num_layers do
local prev_dim = H
if i == 1 then prev_dim = D end
local rnn
if self.model_type == 'rnn' then
rnn = nn.VanillaRNN(prev_dim, H)
elseif self.model_type == 'lstm' then
rnn = nn.LSTM(prev_dim, H)
end
rnn.remember_states = true
table.insert(self.rnns, rnn)
self.net:add(rnn)
if self.batchnorm == 1 then
local view_in = nn.View(1, 1, -1):setNumInputDims(3)
table.insert(self.bn_view_in, view_in)
self.net:add(view_in)
self.net:add(nn.BatchNormalization(H))
local view_out = nn.View(1, -1):setNumInputDims(2)
table.insert(self.bn_view_out, view_out)
self.net:add(view_out)
end
if self.dropout > 0 then
self.net:add(nn.Dropout(self.dropout))
end
end
else
local rnn
local batchFirst = true
if self.model_type == 'rnn' then
rnn = cudnn.RNNTanh(D, H, self.num_layers, batchFirst, self.dropout, true)
elseif self.model_type == 'lstm' then
rnn = cudnn.LSTM(D, H, self.num_layers, batchFirst, self.dropout, true)
end
rnn:resetDropoutDescriptor()
if not batchFirst then
self.net:add(nn.Transpose({1, 2}))
--contiguous not needed after transpose, as transpose makes outputs contig
else
self.net:add(nn.Contiguous())
end
self.net:add(rnn)
if not batchFirst then
self.net:add(nn.Transpose({1, 2}))
else
self.net:add(nn.Contiguous())
end
if self.dropout > 0 then
self.net:add(nn.Dropout(self.dropout))
end
table.insert(self.rnns, rnn)
end
-- After all the RNNs run, we will have a tensor of shape (N, T, H);
-- we want to apply a 1D temporal convolution to predict scores for each
-- vocab element, giving a tensor of shape (N, T, V). Unfortunately
-- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
-- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
-- between. Unfortunately N and T can change on every minibatch, so we need
-- to set them in the forward pass.
self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
self.view2 = nn.View(1, -1):setNumInputDims(2)
self.net:add(self.view1)
self.net:add(nn.Linear(H, V))
self.net:add(self.view2)
end
function LM:updateOutput(input)
local N, T = input:size(1), input:size(2)
self.view1:resetSize(N * T, -1)
self.view2:resetSize(N, T, -1)
for _, view_in in ipairs(self.bn_view_in) do
view_in:resetSize(N * T, -1)
end
for _, view_out in ipairs(self.bn_view_out) do
view_out:resetSize(N, T, -1)
end
return self.net:forward(input)
end
function LM:backward(input, gradOutput, scale)
return self.net:backward(input, gradOutput, scale)
end
function LM:parameters()
return self.net:parameters()
end
function LM:training()
self.net:training()
parent.training(self)
end
function LM:evaluate()
self.net:evaluate()
parent.evaluate(self)
end
function LM:resetStates()
for i, rnn in ipairs(self.rnns) do
rnn:resetStates()
end
end
function LM:encode_string(s)
local encoded = torch.LongTensor(utf8.len(s))
for i = 1, utf8.len(s) do
local token = utf8.sub(s, i, i)
local idx = self.token_to_idx[token]
assert(idx ~= nil, 'Got invalid idx')
encoded[i] = idx
end
return encoded
end
function LM:decode_string(encoded)
assert(torch.isTensor(encoded) and encoded:dim() == 1)
local s = {}
for i = 1, encoded:size(1) do
local idx = encoded[i]
local token = self.idx_to_token[idx]
table.insert(s, token)
end
return table.concat(s)
end
--[[
Sample from the language model. Note that this will reset the states of the
underlying RNNs.
Inputs:
- init: String of length T0
- max_length: Number of characters to sample
Returns:
- sampled: (1, max_length) array of integers, where the first part is init.
--]]
function LM:sample(kwargs)
local T = utils.get_kwarg(kwargs, 'length', 100)
local start_text = utils.get_kwarg(kwargs, 'start_text', '')
local verbose = utils.get_kwarg(kwargs, 'verbose', 0)
local sample = utils.get_kwarg(kwargs, 'sample', 1)
local temperature = utils.get_kwarg(kwargs, 'temperature', 1)
local start_tokens = utils.get_kwarg(kwargs,'start_tokens','')
local stream = utils.get_kwarg(kwargs, 'stream', 0)
local sampled = torch.LongTensor(1, T)
self:resetStates()
local scores, first_t
if #start_tokens > 0 then
local json_tokens = utils.read_json(start_tokens)
local num_tokens = table.getn(json_tokens.tokens)
local tokenTensor = torch.LongTensor(num_tokens)
for i = 1,num_tokens do
tokenTensor[i] = json_tokens.tokens[i]
end
local x = tokenTensor:view(1,-1)
local T0 = x:size(2)
sampled[{{}, {1, T0}}]:copy(x)
scores = self:forward(x)[{{}, {T0, T0}}]
first_t = T0 + 1
elseif #start_text > 0 then
if verbose > 0 then
print('Seeding with: "' .. start_text .. '"')
end
local x = self:encode_string(start_text):view(1, -1)
if stream == 1 then
io.write(start_text)
end
local T0 = x:size(2)
sampled[{{}, {1, T0}}]:copy(x)
scores = self:forward(x)[{{}, {T0, T0}}]
first_t = T0 + 1
else
if verbose > 0 then
print('Seeding with uniform probabilities')
end
local w = self.net:get(1).weight
scores = w.new(1, 1, self.vocab_size):fill(1)
first_t = 1
end
local _, next_char = nil, nil
for t = first_t, T do
if sample == 0 then
_, next_char = scores:max(3)
next_char = next_char[{{}, {}, 1}]
else
local probs = torch.div(scores, temperature):double():exp():squeeze()
probs:div(torch.sum(probs))
next_char = torch.multinomial(probs, 1):view(1, 1)
end
sampled[{{}, {t, t}}]:copy(next_char)
if stream == 1 then
io.write(self.idx_to_token[next_char[1][1]])
end
scores = self:forward(next_char)
end
self:resetStates()
return self:decode_string(sampled[1])
end
function LM:clearState()
self.net:clearState()
end