diff --git a/model/LSTM.lua b/model/LSTM.lua index c9a738bd..70498be3 100644 --- a/model/LSTM.lua +++ b/model/LSTM.lua @@ -28,7 +28,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout) end -- evaluate the input sums at once for efficiency local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} - local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size, false)(prev_h):annotate{name='h2h_'..L} -- no bias term local all_input_sums = nn.CAddTable()({i2h, h2h}) local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)