From c49bc5c011e795e85e2c8efe39b7c960d003e733 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky <38561704+bertsky@users.noreply.github.com> Date: Tue, 30 Oct 2018 15:01:38 +0000 Subject: [PATCH] fix decoder depth for SimpleSeq2Seq if decoder depth is given as 1, configure only 1 LSTMCell with output_dim (instead of 2) --- seq2seq/models.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/seq2seq/models.py b/seq2seq/models.py index 8583722..59ab382 100644 --- a/seq2seq/models.py +++ b/seq2seq/models.py @@ -66,20 +66,18 @@ def SimpleSeq2Seq(output_dim, output_length, hidden_dim=None, input_shape=None, decode=True, output_length=output_length) decoder.add(Dropout(dropout, batch_input_shape=(shape[0], hidden_dim))) - if depth[1] == 1: - decoder.add(LSTMCell(output_dim)) - else: + if depth[1] > 1: decoder.add(LSTMCell(hidden_dim)) for _ in range(depth[1] - 2): decoder.add(Dropout(dropout)) decoder.add(LSTMCell(hidden_dim)) - decoder.add(Dropout(dropout)) + decoder.add(Dropout(dropout)) decoder.add(LSTMCell(output_dim)) - _input = Input(batch_shape=shape) - x = encoder(_input) + input_ = Input(batch_shape=shape) + x = encoder(input_) output = decoder(x) - return Model(_input, output) + return Model(input_, output) def Seq2Seq(output_dim, output_length, batch_input_shape=None,