Skip to content

Commit

Permalink
bug fix and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
saiprabhakar committed Jan 5, 2025
1 parent 6131210 commit ce973c2
Show file tree
Hide file tree
Showing 13 changed files with 1,683 additions and 1,266 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ data/*
results/*
.ipynb_checkpoints/
.DS_Store
*.zip
*.zip
data/*/*/cache*
89 changes: 46 additions & 43 deletions DPO_trainer.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,60 @@
from transformers import HfArgumentParser
from trainer import ScriptArguments, load_dataset, trainer
import wandb
from trainer import ScriptArguments, load_dataset, trainer

parser = HfArgumentParser(ScriptArguments)

# for DPO
script_args = parser.parse_args_into_dataclasses(
args=[
'--per_device_train_batch_size', '1',
'--per_device_eval_batch_size', '1',
'--gradient_accumulation_steps', '1',
'--model_name_or_path', 'sshleifer/tiny-gpt2',
# '--model_name_or_path', 'results/avs/BASE_model/gpt2',
# '--model_name_or_path', 'gpt2',
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
'--load_in_4bit',
'--use_peft',
# '--learning_rate', '1e-3',
'--learning_rate', '1e-4',
# '--report_to', 'wandb',
'--run_name', 'DPO-avs-gpt2',
'--max_length', '1024',
'--max_prompt_length', '768',
'--num_train_epochs', '5',
'--max_steps', '-1',
'--evaluation_strategy', 'epoch',
'--eval_steps', '-1',
# '--eval_first_step',
'--logging_strategy', 'steps',
'--log_steps', '20',
'--logging_first_step',
# '--save_strategy', 'epoch',
'--save_strategy', 'steps',
'--save_steps', '10000000',
# '--save_total_limit', '3',
# '--load_best_model_at_end',
# '--metric_for_best_model', 'metrics_policy_rouge1',
# '--alignment_function', 'dpo',
'--output_dir', './results/avs/DPO_model/DPO-avs-gpt2(1|1|0.3)',
'--alpha1', '1.0', #sft loss
'--alpha2', '1.0', #dpo loss
'--beta', '0.3',
])[0]
'--per_device_train_batch_size', '2',
'--per_device_eval_batch_size', '2',
'--gradient_accumulation_steps', '4',
'--model_name_or_path', 'gpt2',
# '--model_name_or_path', 'sshleifer/tiny-gpt2',
# '--model_name_or_path', 'huggy llama/llama-7b',
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
'--load_in_4bit',
'--use_peft',
'--learning_rate', '1e-4',
# '--report_to', 'wandb',
'--run_name', 'DPO-avs-gpt2',
'--max_length', '1024',
'--max_prompt_length', '768',
'--num_train_epochs', '1',
'--max_steps', '-1',
'--evaluation_strategy', 'epoch',
'--eval_steps', '-1',
'--logging_strategy', 'steps',
'--log_steps', '10',
'--logging_first_step',
'--save_strategy', 'epoch',
'--save_steps', '-1',
'--save_total_limit', '3',
'--load_best_model_at_end',
'--metric_for_best_model', 'metrics_policy_rouge1',
'--output_dir', './results/avs/DPO_model/DPO-avs-gpt2(1|1|0.3)',
"--alpha1", "1.0", # sft loss
"--alpha2", "1.0", # dpo loss
"--beta", "0.3",
]
)[0]

# Initialize wandb if reporting to wandb
if script_args.report_to == 'wandb':
if script_args.report_to == "wandb":
wandb.init(project=script_args.run_name)

# 2. Load training dataset
# train_dataset = load_dataset("train", sanity_check=script_args.sanity_check)
train_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function)
data_subset = "sub_eval_w_simulated_edits"
train_dataset = load_dataset(
data_subset,
sanity_check=script_args.sanity_check,
alignment_function=script_args.alignment_function,
)

# 3. Load evaluation dataset
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function)
eval_dataset = load_dataset(
data_subset,
sanity_check=True,
alignment_function=script_args.alignment_function,
)

dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
60 changes: 8 additions & 52 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,17 @@
# LearnFromHumanEdit

## Installation
If using `conda`, you can get this to work as follows:
## Poetry installation

```
conda create -n salt python=3.8
conda activate salt
```
`poetry install`

We have experimented with 11.7 and 10.2 cuda version, but this release should work with more recent versions as well.
```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch
```
or
- Add HG auth token to the project by creating a `hg_secret` file
- `python -m spacy download en_core_web_sm`
- If you encountered a problem with poetry installation with torch versions (python 3.10) do:

```
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
```

Install other packages:
```
conda install -c conda-forge matplotlib
conda install -c conda-forge spacy
conda install -c conda-forge scipy
python -m spacy download en_core_web_sm
pip install nltk
pip install ipdb
pip install rouge
pip install rouge-score
pip install trl
pip install minineedle
pip install nltk
```poetry run pip install torch==2.1.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121```

pip install datasets
pip install transformers
```
If you want to use qlora for llm:
```
pip install -q -U bitsandbytes
pip install -q -U git+https://github.com/huggingface/peft.git
pip install -q -U git+https://github.com/huggingface/accelerate.git
```

## Run the trainer
## Run training

```
python DPO_trainer.py
Expand All @@ -51,18 +20,5 @@ python SALT_trainer.py
```

## TODO
- Adapt the codes *_trainer.py
- Save output models
- Save outputs
- Modify the classes in dpo.py and rename it to be more generic
- Add link to paper and bib
- Add dataset
- Do we need wandb instructions



## Poetry installation

If you encountered a problem with poetry installation with torch versions (python 3.10)

`poetry run pip install torch==2.1.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121`
- Add full dataset
102 changes: 55 additions & 47 deletions SALT_trainer.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,64 @@
from transformers import HfArgumentParser
import wandb
from trainer import ScriptArguments, load_dataset, trainer

parser = HfArgumentParser(ScriptArguments)

# for SALT
script_args = parser.parse_args_into_dataclasses(args=['--per_device_train_batch_size', '1',
'--per_device_eval_batch_size', '1',
'--gradient_accumulation_steps', '1',
'--model_name_or_path', 'sshleifer/tiny-gpt2',
# '--model_name_or_path', 'gpt2',
# '--model_name_or_path', 'results/avs/BASE_model/gpt2',
# '--model_name_or_path', 'huggy llama/llama-7b',
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
'--load_in_4bit',
'--use_peft',
'--learning_rate', '1e-3',
# '--learning_rate', '1e-4',
# '--report_to', 'wandb',
# '--run_name', 'SALT-avs-llama2',
'--max_length', '1024',
'--max_prompt_length', '768',
'--num_train_epochs', '5',
'--max_steps', '-1',
'--evaluation_strategy', 'epoch',
'--eval_steps', '-1',
# '--eval_first_step',
'--logging_strategy', 'steps',
'--log_steps', '2',
'--logging_first_step',
# '--save_strategy', 'epoch',
'--save_strategy', 'steps',
'--save_steps', '10000000',
# '--save_total_limit', '3',
# '--load_best_model_at_end',
# '--metric_for_best_model', 'metrics_policy_rouge1',
'--alignment_function', 'salt',
'--output_dir', './results/avs/SALT_model/SALT-avs-llama2(1|-0.1|-0.1|1|1.1|1.1)',
'--omega1', '1.0', #salt chosen likelihood loss weight
'--omega2', '0.1', #salt rejected unlikelihood loss weight
'--S_generated_C_weight', '1.0', #sequence alignment weights
'--S_generated_D_weight', '-0.1', #sequence alignment weights
'--S_generated_S_weight', '-0.1', #sequence alignment weights
'--S_edited_C_weight', '1.0', #sequence alignment weights
'--S_edited_I_weight', '1.1', #sequence alignment weights
'--S_edited_S_weight', '1.1', #sequence alignment weights
])[0]
script_args = parser.parse_args_into_dataclasses(
args=[
"--per_device_train_batch_size", "2",
"--per_device_eval_batch_size", "2",
"--gradient_accumulation_steps", "4",
'--model_name_or_path', 'gpt2',
# "--model_name_or_path", "sshleifer/tiny-gpt2",
# '--model_name_or_path', 'huggy llama/llama-7b',
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
"--load_in_4bit",
"--use_peft",
"--learning_rate", "1e-4",
'--run_name', 'SALT-avs-gpt2',
"--max_length", "1024",
"--max_prompt_length", "768",
"--num_train_epochs", "1",
"--max_steps", "11",
"--evaluation_strategy", "epoch",
"--eval_steps", "-1",
"--logging_strategy", "steps",
"--log_steps", "10",
"--logging_first_step",
"--save_strategy", "epoch",
'--save_steps', '-1',
'--save_total_limit', '3',
'--load_best_model_at_end',
'--metric_for_best_model', 'metrics_policy_rouge1',
"--output_dir", "./results/avs/SALT_model/SALT-avs-llama2(1|-0.1|-0.1|1|1.1|1.1)",
"--omega1", "1.0", # salt chosen likelihood loss weight
"--omega2", "0.1", # salt rejected unlikelihood loss weight
"--S_generated_C_weight", "1.0", # sequence alignment weights
"--S_generated_D_weight", "-0.1", # sequence alignment weights
"--S_generated_S_weight", "-0.1", # sequence alignment weights
"--S_edited_C_weight", "1.0", # sequence alignment weights
"--S_edited_I_weight", "1.1", # sequence alignment weights
"--S_edited_S_weight", "1.1", # sequence alignment weights
]
)[0]

# 2. Load training dataset
# train_dataset = load_dataset("train", sanity_check=script_args.sanity_check)
train_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function)
# Initialize wandb if reporting to wandb
if script_args.report_to == "wandb":
wandb.init(project=script_args.run_name)

data_subset = "sub_eval_w_simulated_edits"
train_dataset = load_dataset(
data_subset,
sanity_check=script_args.sanity_check,
alignment_function=script_args.alignment_function,
)

# 3. Load evaluation dataset
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function)
eval_dataset = load_dataset(
data_subset,
sanity_check=True,
alignment_function=script_args.alignment_function,
)

dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
41 changes: 21 additions & 20 deletions SFT_trainer.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,57 @@
from transformers import HfArgumentParser
from trainer import ScriptArguments, load_dataset, trainer
import wandb
from trainer import ScriptArguments, load_dataset, trainer

parser = HfArgumentParser(ScriptArguments)

#for SFT
script_args = parser.parse_args_into_dataclasses(
args=[
'--per_device_train_batch_size', '1',
'--per_device_eval_batch_size', '1',
'--gradient_accumulation_steps', '1',
# '--model_name_or_path', 'gpt2',
'--model_name_or_path', 'sshleifer/tiny-gpt2',
'--per_device_train_batch_size', '2',
'--per_device_eval_batch_size', '2',
'--gradient_accumulation_steps', '4',
'--model_name_or_path', 'gpt2',
# '--model_name_or_path', 'sshleifer/tiny-gpt2',
# '--model_name_or_path', 'huggy llama/llama-7b',
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf',
'--load_in_4bit',
'--use_peft',
# '--learning_rate', '1e-3',
'--learning_rate', '1e-4',
# '--report_to', 'wandb',
# '--report_to', 'tensorboard',
'--run_name', 'SFT-avs-gpt2',
'--max_length', '1024',
'--max_prompt_length', '768',
'--num_train_epochs', '1',
'--max_steps', '-1',
'--max_steps', '30',
'--evaluation_strategy', 'epoch',
'--eval_steps', '-1',
# '--eval_first_step',
'--logging_strategy', 'steps',
'--log_steps', '1',
'--log_steps', '10',
'--logging_first_step',
'--save_strategy', 'epoch',
'--save_steps', '-1',
'--save_total_limit', '3',
'--load_best_model_at_end',
'--metric_for_best_model', 'metrics_policy_rouge1',
'--alignment_function', 'sft',
'--output_dir', './results/avs/SFT_model/gpt2',
# '--output_dir', './results/SFT_model/llama2_7b',
]
)[0]

# Initialize wandb if reporting to wandb
if script_args.report_to == 'wandb':
if script_args.report_to == "wandb":
wandb.init(project=script_args.run_name)

# 2. Load training dataset
# train_dataset = load_dataset("train", sanity_check=script_args.sanity_check)
train_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function)
data_subset = "sub_eval_w_simulated_edits"
train_dataset = load_dataset(
data_subset,
sanity_check=script_args.sanity_check,
alignment_function=script_args.alignment_function,
)

# 3. Load evaluation dataset
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function)
eval_dataset = load_dataset(
data_subset,
sanity_check=True,
alignment_function=script_args.alignment_function,
)

dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
Binary file not shown.
Loading

0 comments on commit ce973c2

Please sign in to comment.