Skip to content

Commit

Permalink
Fix the AttentionWrapper(contrib_seq2seq) to run on version higher th…
Browse files Browse the repository at this point in the history
…an 1.2
  • Loading branch information
healess committed Dec 13, 2017
1 parent e5afbb9 commit da0defe
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions RNN_seq2seq/contrib_seq2seq/04_AttentionWrapper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -556,19 +556,29 @@
" num_units=self.attn_size,\n",
" memory=self.enc_outputs,\n",
" memory_sequence_length=self.enc_sequence_length,\n",
" normalize=False,\n",
" # normalize=False,\n",
" name='LuongAttention')\n",
"\n",
" dec_cell = tf.contrib.seq2seq.DynamicAttentionWrapper(\n",
" \n",
" dec_cell = tf.contrib.seq2seq.AttentionWrapper(\n",
" cell=dec_cell,\n",
" attention_mechanism=attn_mech,\n",
" attention_size=self.attn_size,\n",
" # attention_history=False (in ver 1.2)\n",
" attention_layer_size=self.attn_size,\n",
" # attention_history=False, # (in ver 1.2)\n",
" name='Attention_Wrapper')\n",
" \n",
" initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(\n",
" cell_state=self.enc_last_state,\n",
" attention=_zero_state_tensors(self.attn_size, batch_size, tf.float32))\n",
"# outputs = tf.contrib.rnn.OutputProjectionWrapper(\n",
"# dec_cell, batch_size, reuse=False\n",
"# )\n",
" \n",
" initial_state=dec_cell.zero_state(dtype=tf.float32, batch_size=batch_size)\n",
"\n",
"# initial_state = tf.contrib.seq2seq.AttentionWrapperState(\n",
"# cell_state=self.enc_last_state,\n",
"# attention=_zero_state_tensors(self.attn_size, batch_size, tf.float32),\n",
"# time=0, alignments=(), \n",
"# alignment_history=()\n",
"# )\n",
"\n",
" # output projection (replacing `OutputProjectionWrapper`)\n",
" output_layer = Dense(dec_vocab_size+2, name='output_projection')\n",
Expand All @@ -593,7 +603,7 @@
" initial_state=initial_state,\n",
" output_layer=output_layer) \n",
"\n",
" train_dec_outputs, train_dec_last_state = tf.contrib.seq2seq.dynamic_decode(\n",
" train_dec_outputs, train_dec_last_state, _ = tf.contrib.seq2seq.dynamic_decode(\n",
" training_decoder,\n",
" output_time_major=False,\n",
" impute_finished=True,\n",
Expand Down Expand Up @@ -642,7 +652,7 @@
" initial_state=initial_state,\n",
" output_layer=output_layer)\n",
" \n",
" infer_dec_outputs, infer_dec_last_state = tf.contrib.seq2seq.dynamic_decode(\n",
" infer_dec_outputs, infer_dec_last_state, _ = tf.contrib.seq2seq.dynamic_decode(\n",
" inference_decoder,\n",
" output_time_major=False,\n",
" impute_finished=True,\n",
Expand All @@ -659,7 +669,7 @@
" self.training_op = self.optimizer(self.learning_rate, name='training_op').minimize(self.batch_loss)\n",
" \n",
" def save(self, sess, var_list=None, save_path=None):\n",
" print(f'Saving model at {save_path}')\n",
" print('Saving model at {save_path}')\n",
" if hasattr(self, 'training_variables'):\n",
" var_list = self.training_variables\n",
" saver = tf.train.Saver(var_list)\n",
Expand Down Expand Up @@ -741,10 +751,10 @@
" print('Epoch', epoch)\n",
" for input_batch, target_batch, batch_preds in zip(input_batches, target_batches, all_preds):\n",
" for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):\n",
" print(f'\\tInput: {input_sent}')\n",
" print(f'\\tPrediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))\n",
" print(f'\\tTarget:, {target_sent}')\n",
" print(f'\\tepoch loss: {epoch_loss:.2f}\\n')\n",
" print('\\tInput: {input_sent}')\n",
" print('\\tPrediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))\n",
" print('\\tTarget:, {target_sent}')\n",
" print('\\tepoch loss: {epoch_loss:.2f}\\n')\n",
" \n",
" if save_path:\n",
" self.save(sess, save_path=save_path)\n",
Expand Down Expand Up @@ -1103,7 +1113,7 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"version": 3.0
},
"file_extension": ".py",
"mimetype": "text/x-python",
Expand All @@ -1114,5 +1124,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
}
"nbformat_minor": 0
}

0 comments on commit da0defe

Please sign in to comment.