Skip to content

Latest commit

 

History

History
94 lines (78 loc) · 19.1 KB

experiments.md

File metadata and controls

94 lines (78 loc) · 19.1 KB

Experiments

Warming up the pause token on random pauses on GSM8K

In this experiment, we try various ways to warm up our model on random pauses on GSM8K. More specifically, we try unfreezing different components of the model. For all of the experiments we start from the same model that can be found at the following path: /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-08-28_13-23-45/final. The model is fine-tuned on GSM8K for 1 epoch (GSM8K without any pauses). The dataset with random pauses we will use can by found at: /dlabscratch1/baldwin/pause2/PauseToken/data/gsm8k_json/gsm8k_variable_random_pauses. Alternatively you can generate it with the following command:

    python scripts/data_generation/gsm8k_pause_injector.py --dataset_location data/gsm8k_jsonl/gsm8k --pause_token "<|pause|>" --n_pauses_per_patterns '{}' --augm_dataset_save_location data/gsm8k_json/gsm8k_variable_random_pauses --pause_augm_col_name "answer" --verbose --n_random_pauses 10 --tokenizer_hf_name "/dlabscratch1/public/llm_weights/llm_hub/Mistral-7B-v0.1/" --variable_number_of_pauses --n_generated_samples_per_datapoint 5 --verbose --seed 42

Training Specifications

All models are trained for 1 epoch on GSM8K with random pauses. The training is done with the following command:

The table below shows for each model the components that are unfrozen. If not specified, all components are frozen.

experiment-yaml-file Unfreeze Pause Embedding  Unfreeze Pause Head  Unfreeze LM head Unfreeze LM Embeddings LORA python command
sft.yaml X X python src/trl_train.py experiment=trl_train/step2_exp/sft.yaml
sft_peft.yaml X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_peft.yaml
sft_unfr_lm_head.yaml X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_unfr_lm_head.yaml
sft_unfr_lm_head_peft X X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_unfr_lm_head_peft.yaml
sft_unfr_lm_head_embed X X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_unfr_lm_head_embed.yaml
sft_unfr_lm_head_embed_peft X X X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_unfr_lm_head_embed_peft.yaml
sft_fr_phead.yaml X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_phead.yaml
sft_fr_phead_peft.yaml X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_phead_peft.yaml
sft_fr_phead_unfr_lm_head.yaml X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_phead_unfr_lm_head.yaml
sft_fr_phead_unfr_lm_head_peft X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_phead_unfr_lm_head_peft.yaml
sft_fr_phead_unfr_lm_head_embed X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_phead_unfr_lm_head_embed.yaml
sft_fr_phead_unfr_lm_head_embed_peft X X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_phead_unfr_lm_head_embed_peft.yaml
sft_fr_pembed.yaml X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_pembed.yaml
sft_fr_pembed_peft.yaml X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_pembed_peft.yaml
sft_fr_pembed_unfr_lm_head.yaml X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_pembed_unfr_lm_head.yaml
sft_fr_pembed_unfr_lm_head_peft X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_pembed_unfr_lm_head_peft.yaml
sft_fr_pembed_unfr_lm_head_embed X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_pembed_unfr_lm_head_embed.yaml
sft_fr_pembed_unfr_lm_head_embed_peft X X X X python src/trl_train.py experiment=trl_train/step2_exp/sft_fr_pembed_unfr_lm_head_embed_peft.yaml
baseline (model w/out pause; peft) 1 epoch X python src/trl_train.py experiment=trl_train/step_1_sft.yaml
baseline (model w/out pause; peft) 2 epoch X python src/trl_train.py experiment=trl_train/step_1_sft.yaml trainer.args.num_train_epochs=2.0

Results

experiment-yaml-file Test Accuracy Eval Loss Average Number of pauses per reply Average Number of pauses per reply (correctly predicted) Average Number of pauses per reply (incorrectly predicted)
sft.yaml 0.48 0.57003 1.57 0.99 2.12
sft_peft.yaml 0.56 0.51219 1.6 1.52 1.7
sft_unfr_lm_head.yaml 0.46 0.47131 1.33 0.95 1.64
sft_unfr_lm_head_peft 0.51 0.43576 5.62 3.1 8.34
sft_unfr_lm_head_embed 0.49 0.40649 2.25 1.86 2.62
sft_unfr_lm_head_embed_peft 0.50 0.37507 14.03 3.37 24.52
sft_fr_phead.yaml 0.14 1.99349 77.8 45.39 83.49
sft_fr_phead_peft.yaml 0.05 1.96408 648.07 52.18 681.97
sft_fr_phead_unfr_lm_head.yaml 0.10 1.92745 205.72 47.72 224.19
sft_fr_phead_unfr_lm_head_peft 0.02 2.31768 793.02 51.59 805.59
sft_fr_phead_unfr_lm_head_embed 0.08 1.9016 365.44 56.37 391.62
sft_fr_phead_unfr_lm_head_embed_peft 0.02 2.30675 798.37 55.84 816.24
sft_fr_pembed.yaml 0.46 0.77144 365.44 56.37 391.62
sft_fr_pembed_peft.yaml 0.55  0.51348 1.66 1.47 1.89
sft_fr_pembed_unfr_lm_head.yaml 0.46 0.54759 0.77 0.56 0.95
sft_fr_pembed_unfr_lm_head_peft 0.51 0.43743 5.02 3.36 6.77
sft_fr_pembed_unfr_lm_head_embed 0.47 0.40834 2.06 1.86 2.24
sft_fr_pembed_unfr_lm_head_embed_peft 0.01 2.40071 781.04 43.67 787.81
baseline (model w/out pause; peft) 1 epoch 0.47 - 0.0 0.0 0.0
baseline (model w/out pause; peft) 2 epoch 0.53 - 0.0 0.0 0.0
experiment-yaml-file Path to predictions Model Location WandB Link
sft.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-37-50/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-37-50/final click here
sft_peft.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-00/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-00/final click here
sft_unfr_lm_head.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-08/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-08/final click here
sft_unfr_lm_head_peft /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-21/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-21/final click here
sft_unfr_lm_head_embed /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-42-59/test_results.json  /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-42-59/final  click here
sft_unfr_lm_head_embed_peft dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_09-50-27/test_results.json  /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_09-50-27/final  click here
sft_fr_phead.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-51/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-51/final click here
sft_fr_phead_peft.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-56/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-56/final click here
sft_fr_phead_unfr_lm_head.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-39-17/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-39-17/final click here
sft_fr_phead_unfr_lm_head_peft /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-39-21/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-39-21/final click here
sft_fr_phead_unfr_lm_head_embed /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-43-29/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-43-29/final click here
sft_fr_phead_unfr_lm_head_embed_peft /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_09-50-28/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_09-50-28/final click here
sft_fr_pembed.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-43-29/test_results.json  /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-43-29/final click here
sft_fr_pembed_peft.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-10/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-10/final click here
sft_fr_pembed_unfr_lm_head.yaml /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-18/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-18/final click here
sft_fr_pembed_unfr_lm_head_peft /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-26/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-26/final click here
sft_fr_pembed_unfr_lm_head_embed /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-37/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-40-37/final click here
sft_fr_pembed_unfr_lm_head_embed_peft /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_09-50-47/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_09-50-47/final click here
baseline (model w/out pause; peft) 1 /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_10-10-53/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-08-28_13-23-45/final click here
baseline (model w/out pause; peft) 2 epoch /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_10-08-12/test_results.json /dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-21_10-08-12/final click here

Takeaways

It seems like the best way to 'pretrain' the model on random pauses is by unfreezing the Pause Head, and the Pause embedding and using LoRa for updating the base model. Note however that training the pause embedding doesn't have much of an effect on the accuracy. We do notice a generally tendency (in this pretraining phase), that incorrect replies seem to generate more pauses than incorrect ones on average.