Skip to content

Latest commit

 

History

History
 
 

seq2seq_exposure_bias

Sequence Generation Algorithms Tackling Exposure Bias

Despite the computational simplicity and efficiency, maximum likelihood training of sequence generation models (e.g., RNNs) suffers from the exposure bias (Ranzato et al., 2015). That is, the model is trained to predict the next token given the previous ground-truth tokens; while at test time, since the resulting model does not have access to the ground truth, tokens generated by the model itself are instead used to make the next prediction. This discrepancy between training and test leads to the issue that mistakes in prediction can quickly accumulate.

This example provide implementations of some classic and advanced training algorithms that tackles the exposure bias. The base model is an attentional seq2seq.

Usage

Dataset

Two example datasets are provided:

  • iwslt14: The benchmark IWSLT2014 (de-en) machine translation dataset, following (Ranzato et al., 2015) for data pre-processing.
  • gigaword: The benchmark GIGAWORD text summarization dataset. we sampled 200K out of the 3.8M pre-processed training examples provided by (Rush et al., 2015) for the sake of training efficiency. We used the refined validation and test sets provided by (Zhou et al., 2017).

Download the data with the following commands:

python utils/prepare_data.py --data iwslt14
python utils/prepare_data.py --data giga

Train the models

Baseline Attentional Seq2seq

python baseline_seq2seq_attn_main.py \
    --config_model configs.config_model \
    --config_data configs.config_iwslt14

Here:

  • --config_model specifies the model config. Note not to include the .py suffix.
  • --config_data specifies the data config.

configs.config_model.py specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. Hyperparameters taking default values can be omitted from the config file.

For demonstration purpose, configs.config_model_full.py gives all possible hyperparameters for the model. The two config files will lead to the same model.

Reward Augmented Maximum Likelihood (RAML)

python raml_main.py \
    --config_model configs.config_model \
    --config_data configs.config_iwslt14 \
    --raml_file data/iwslt14/samples_iwslt14.txt \
    --n_samples 10

Here:

  • --raml_file specifies the file containing the augmented samples and rewards.
  • --n_samples specifies number of augmented samples for every target sentence.
  • --tau specifies the temperature of the exponentiated payoff distribution in RAML.

In the downloaded datasets, we have provided example files for --raml_file, which including augmented samples for iwslt14 and gigaword respectively. We also provide scripts for generating augmented samples by yourself. Please refer to utils/raml_samples_generation.

Scheduled Sampling

python scheduled_sampling_main.py \
    --config_model configs.config_model \
    --config_data configs.config_iwslt14 \
    --decay_factor 500.

Here:

  • --decay_factor specifies the hyperparameter controling the speed of increasing the probability of sampling from model.

Interpolation Algorithm

python interpolation_main.py \
    --config_model configs.config_model \
    --config_data configs.config_iwslt14 \
    --lambdas_init [0.04,0.96,0.0] \
    --delta_lambda_self 0.06 \
    --delta_lambda_reward 0.06 \
    --lambda_reward_steps 4

Here:

  • --lambdas_init specifies the initial value of lambdas.
  • --delta_lambda_reward specifies the increment of lambda_reward every annealing step.
  • --delta_lambda_self specifies the decrement of lambda_self every annealing step.
  • --k specifies the times of increasing lambda_reward after incresing lambda_self once.

Results

Machine Translation

Model BLEU Score
MLE 26.44 ± 0.18
Scheduled Sampling 26.76 ± 0.17
RAML 27.22 ± 0.14
Interpolation 27.82 ± 0.11

Text Summarization

Model Rouge-1 Rouge-2 Rouge-L
MLE 36.11 ± 0.21 16.39 ± 0.16 32.32 ± 0.19
Scheduled Sampling 36.59 ± 0.12 16.79 ± 0.22 32.77 ± 0.17
RAML 36.30 ± 0.24 16.69 ± 0.20 32.49 ± 0.17
Interpolation 36.72 ± 0.29 16.99 ± 0.17 32.95 ± 0.33