This file gives instructions for how to setup and run the experiment on next sentence continuation. We will go over:
- How to create the dataset of prompts that the model is trained on.
- How to train the GFlowNet fine-tuned model on the dataset of prompts.
- The specific configuration used in our paper.
The requirements can be installed using pip
:
pip install -r requirements.txt
Note that this experiment has to be run from the next_sentence
directory for all imports and relative paths to function correctly.
The dataset consists of 1000 prompts obtained from Open Web Text. Each prompt is guaranteed to have between 1-3 sentences (each limited to a minimum of 5 and maximum of 30 tokens using the GPT-2 tokenizer). The data is already included in the repository under data/openwebtext/prompts.txt
, but to see how it was generated (and regenerate it yourself) run:
python data/openwebtext/data_generator.py
The resulting prompts will be saved within data/openwebtext/prompts.txt
, and some statistics about them will be saved in data/openwebtext/prompt_stats.png
.
To train the GFlowNet, run:
python train.py task=open_webtext_gpt2 device=gpu
The code uses PyTorch Lightning
to train the model, wandb
to log results during training, and hydra
to manage configurations.
We train with a specific set of arguments in our paper. These arguments can be specified at the command line, and modify the defaults from configs/task/openwebtext_gpt2.yaml
. To replicate our results, run:
python train.py task=open_webtext_gpt2 device=gpu task.training.n_samples=8 task.training.accumulate_grad_batches=32 task.reward.temp_end=[temp_end]
Where [temp_end]
is a placeholder that must be specified. In our paper, we sweep over (0.8, 0.95) in increments of 0.25.
Note that this configuration requires approximately 40GB of GPU RAM.