Despite the computational simplicity and efficiency, maximum likelihood training of sequence generation models (e.g., RNNs) suffers from the exposure bias (Ranzato et al., 2015). That is, the model is trained to predict the next token given the previous ground-truth tokens; while at test time, since the resulting model does not have access to the ground truth, tokens generated by the model itself are instead used to make the next prediction. This discrepancy between training and test leads to the issue that mistakes in prediction can quickly accumulate.
This example provide implementations of some classic and advanced training algorithms that tackles the exposure bias. The base model is an attentional seq2seq.
- Maximum Likelihood (MLE): attentional seq2seq model with maximum likelihood training.
- Reward Augmented Maximum Likelihood (RAML): Described in (Norouzi et al., 2016) and we use the sampling approach (n-gram replacement) by (Ma et al., 2017).
- Scheduled Sampling: Described in (Bengio et al., 2015)
- Interpolation Algorithm: Described in (Tan et al., 2018) Connecting the Dots Between MLE and RL for Sequence Generation
Two example datasets are provided:
- iwslt14: The benchmark IWSLT2014 (de-en) machine translation dataset, following (Ranzato et al., 2015) for data pre-processing.
- gigaword: The benchmark GIGAWORD text summarization dataset. we sampled 200K out of the 3.8M pre-processed training examples provided by (Rush et al., 2015) for the sake of training efficiency. We used the refined validation and test sets provided by (Zhou et al., 2017).
Download the data with the following commands:
python utils/prepare_data.py --data iwslt14
python utils/prepare_data.py --data giga
python baseline_seq2seq_attn_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14
Here:
--config_model
specifies the model config. Note not to include the.py
suffix.--config_data
specifies the data config.
configs.config_model.py specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. Hyperparameters taking default values can be omitted from the config file.
For demonstration purpose, configs.config_model_full.py gives all possible hyperparameters for the model. The two config files will lead to the same model.
python raml_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14 \
--raml_file data/iwslt14/samples_iwslt14.txt \
--n_samples 10
Here:
--raml_file
specifies the file containing the augmented samples and rewards.--n_samples
specifies number of augmented samples for every target sentence.--tau
specifies the temperature of the exponentiated payoff distribution in RAML.
In the downloaded datasets, we have provided example files for --raml_file
, which including augmented samples for iwslt14
and gigaword
respectively. We also provide scripts for generating augmented samples by yourself. Please refer to utils/raml_samples_generation.
python scheduled_sampling_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14 \
--decay_factor 500.
Here:
--decay_factor
specifies the hyperparameter controling the speed of increasing the probability of sampling from model.
python interpolation_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14 \
--lambdas_init [0.04,0.96,0.0] \
--delta_lambda_self 0.06 \
--delta_lambda_reward 0.06 \
--lambda_reward_steps 4
Here:
--lambdas_init
specifies the initial value of lambdas.--delta_lambda_reward
specifies the increment of lambda_reward every annealing step.--delta_lambda_self
specifies the decrement of lambda_self every annealing step.--k
specifies the times of increasing lambda_reward after incresing lambda_self once.
Model | BLEU Score |
---|---|
MLE | 26.44 ± 0.18 |
Scheduled Sampling | 26.76 ± 0.17 |
RAML | 27.22 ± 0.14 |
Interpolation | 27.82 ± 0.11 |
Model | Rouge-1 | Rouge-2 | Rouge-L |
---|---|---|---|
MLE | 36.11 ± 0.21 | 16.39 ± 0.16 | 32.32 ± 0.19 |
Scheduled Sampling | 36.59 ± 0.12 | 16.79 ± 0.22 | 32.77 ± 0.17 |
RAML | 36.30 ± 0.24 | 16.69 ± 0.20 | 32.49 ± 0.17 |
Interpolation | 36.72 ± 0.29 | 16.99 ± 0.17 | 32.95 ± 0.33 |