-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6131210
commit ce973c2
Showing
13 changed files
with
1,683 additions
and
1,266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ data/* | |
results/* | ||
.ipynb_checkpoints/ | ||
.DS_Store | ||
*.zip | ||
*.zip | ||
data/*/*/cache* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,60 @@ | ||
from transformers import HfArgumentParser | ||
from trainer import ScriptArguments, load_dataset, trainer | ||
import wandb | ||
from trainer import ScriptArguments, load_dataset, trainer | ||
|
||
parser = HfArgumentParser(ScriptArguments) | ||
|
||
# for DPO | ||
script_args = parser.parse_args_into_dataclasses( | ||
args=[ | ||
'--per_device_train_batch_size', '1', | ||
'--per_device_eval_batch_size', '1', | ||
'--gradient_accumulation_steps', '1', | ||
'--model_name_or_path', 'sshleifer/tiny-gpt2', | ||
# '--model_name_or_path', 'results/avs/BASE_model/gpt2', | ||
# '--model_name_or_path', 'gpt2', | ||
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf', | ||
'--load_in_4bit', | ||
'--use_peft', | ||
# '--learning_rate', '1e-3', | ||
'--learning_rate', '1e-4', | ||
# '--report_to', 'wandb', | ||
'--run_name', 'DPO-avs-gpt2', | ||
'--max_length', '1024', | ||
'--max_prompt_length', '768', | ||
'--num_train_epochs', '5', | ||
'--max_steps', '-1', | ||
'--evaluation_strategy', 'epoch', | ||
'--eval_steps', '-1', | ||
# '--eval_first_step', | ||
'--logging_strategy', 'steps', | ||
'--log_steps', '20', | ||
'--logging_first_step', | ||
# '--save_strategy', 'epoch', | ||
'--save_strategy', 'steps', | ||
'--save_steps', '10000000', | ||
# '--save_total_limit', '3', | ||
# '--load_best_model_at_end', | ||
# '--metric_for_best_model', 'metrics_policy_rouge1', | ||
# '--alignment_function', 'dpo', | ||
'--output_dir', './results/avs/DPO_model/DPO-avs-gpt2(1|1|0.3)', | ||
'--alpha1', '1.0', #sft loss | ||
'--alpha2', '1.0', #dpo loss | ||
'--beta', '0.3', | ||
])[0] | ||
'--per_device_train_batch_size', '2', | ||
'--per_device_eval_batch_size', '2', | ||
'--gradient_accumulation_steps', '4', | ||
'--model_name_or_path', 'gpt2', | ||
# '--model_name_or_path', 'sshleifer/tiny-gpt2', | ||
# '--model_name_or_path', 'huggy llama/llama-7b', | ||
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf', | ||
'--load_in_4bit', | ||
'--use_peft', | ||
'--learning_rate', '1e-4', | ||
# '--report_to', 'wandb', | ||
'--run_name', 'DPO-avs-gpt2', | ||
'--max_length', '1024', | ||
'--max_prompt_length', '768', | ||
'--num_train_epochs', '1', | ||
'--max_steps', '-1', | ||
'--evaluation_strategy', 'epoch', | ||
'--eval_steps', '-1', | ||
'--logging_strategy', 'steps', | ||
'--log_steps', '10', | ||
'--logging_first_step', | ||
'--save_strategy', 'epoch', | ||
'--save_steps', '-1', | ||
'--save_total_limit', '3', | ||
'--load_best_model_at_end', | ||
'--metric_for_best_model', 'metrics_policy_rouge1', | ||
'--output_dir', './results/avs/DPO_model/DPO-avs-gpt2(1|1|0.3)', | ||
"--alpha1", "1.0", # sft loss | ||
"--alpha2", "1.0", # dpo loss | ||
"--beta", "0.3", | ||
] | ||
)[0] | ||
|
||
# Initialize wandb if reporting to wandb | ||
if script_args.report_to == 'wandb': | ||
if script_args.report_to == "wandb": | ||
wandb.init(project=script_args.run_name) | ||
|
||
# 2. Load training dataset | ||
# train_dataset = load_dataset("train", sanity_check=script_args.sanity_check) | ||
train_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function) | ||
data_subset = "sub_eval_w_simulated_edits" | ||
train_dataset = load_dataset( | ||
data_subset, | ||
sanity_check=script_args.sanity_check, | ||
alignment_function=script_args.alignment_function, | ||
) | ||
|
||
# 3. Load evaluation dataset | ||
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function) | ||
eval_dataset = load_dataset( | ||
data_subset, | ||
sanity_check=True, | ||
alignment_function=script_args.alignment_function, | ||
) | ||
|
||
dpo_trainer = trainer(script_args, train_dataset, eval_dataset) | ||
dpo_trainer = trainer(script_args, train_dataset, eval_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,64 @@ | ||
from transformers import HfArgumentParser | ||
import wandb | ||
from trainer import ScriptArguments, load_dataset, trainer | ||
|
||
parser = HfArgumentParser(ScriptArguments) | ||
|
||
# for SALT | ||
script_args = parser.parse_args_into_dataclasses(args=['--per_device_train_batch_size', '1', | ||
'--per_device_eval_batch_size', '1', | ||
'--gradient_accumulation_steps', '1', | ||
'--model_name_or_path', 'sshleifer/tiny-gpt2', | ||
# '--model_name_or_path', 'gpt2', | ||
# '--model_name_or_path', 'results/avs/BASE_model/gpt2', | ||
# '--model_name_or_path', 'huggy llama/llama-7b', | ||
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf', | ||
'--load_in_4bit', | ||
'--use_peft', | ||
'--learning_rate', '1e-3', | ||
# '--learning_rate', '1e-4', | ||
# '--report_to', 'wandb', | ||
# '--run_name', 'SALT-avs-llama2', | ||
'--max_length', '1024', | ||
'--max_prompt_length', '768', | ||
'--num_train_epochs', '5', | ||
'--max_steps', '-1', | ||
'--evaluation_strategy', 'epoch', | ||
'--eval_steps', '-1', | ||
# '--eval_first_step', | ||
'--logging_strategy', 'steps', | ||
'--log_steps', '2', | ||
'--logging_first_step', | ||
# '--save_strategy', 'epoch', | ||
'--save_strategy', 'steps', | ||
'--save_steps', '10000000', | ||
# '--save_total_limit', '3', | ||
# '--load_best_model_at_end', | ||
# '--metric_for_best_model', 'metrics_policy_rouge1', | ||
'--alignment_function', 'salt', | ||
'--output_dir', './results/avs/SALT_model/SALT-avs-llama2(1|-0.1|-0.1|1|1.1|1.1)', | ||
'--omega1', '1.0', #salt chosen likelihood loss weight | ||
'--omega2', '0.1', #salt rejected unlikelihood loss weight | ||
'--S_generated_C_weight', '1.0', #sequence alignment weights | ||
'--S_generated_D_weight', '-0.1', #sequence alignment weights | ||
'--S_generated_S_weight', '-0.1', #sequence alignment weights | ||
'--S_edited_C_weight', '1.0', #sequence alignment weights | ||
'--S_edited_I_weight', '1.1', #sequence alignment weights | ||
'--S_edited_S_weight', '1.1', #sequence alignment weights | ||
])[0] | ||
script_args = parser.parse_args_into_dataclasses( | ||
args=[ | ||
"--per_device_train_batch_size", "2", | ||
"--per_device_eval_batch_size", "2", | ||
"--gradient_accumulation_steps", "4", | ||
'--model_name_or_path', 'gpt2', | ||
# "--model_name_or_path", "sshleifer/tiny-gpt2", | ||
# '--model_name_or_path', 'huggy llama/llama-7b', | ||
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf', | ||
"--load_in_4bit", | ||
"--use_peft", | ||
"--learning_rate", "1e-4", | ||
'--run_name', 'SALT-avs-gpt2', | ||
"--max_length", "1024", | ||
"--max_prompt_length", "768", | ||
"--num_train_epochs", "1", | ||
"--max_steps", "11", | ||
"--evaluation_strategy", "epoch", | ||
"--eval_steps", "-1", | ||
"--logging_strategy", "steps", | ||
"--log_steps", "10", | ||
"--logging_first_step", | ||
"--save_strategy", "epoch", | ||
'--save_steps', '-1', | ||
'--save_total_limit', '3', | ||
'--load_best_model_at_end', | ||
'--metric_for_best_model', 'metrics_policy_rouge1', | ||
"--output_dir", "./results/avs/SALT_model/SALT-avs-llama2(1|-0.1|-0.1|1|1.1|1.1)", | ||
"--omega1", "1.0", # salt chosen likelihood loss weight | ||
"--omega2", "0.1", # salt rejected unlikelihood loss weight | ||
"--S_generated_C_weight", "1.0", # sequence alignment weights | ||
"--S_generated_D_weight", "-0.1", # sequence alignment weights | ||
"--S_generated_S_weight", "-0.1", # sequence alignment weights | ||
"--S_edited_C_weight", "1.0", # sequence alignment weights | ||
"--S_edited_I_weight", "1.1", # sequence alignment weights | ||
"--S_edited_S_weight", "1.1", # sequence alignment weights | ||
] | ||
)[0] | ||
|
||
# 2. Load training dataset | ||
# train_dataset = load_dataset("train", sanity_check=script_args.sanity_check) | ||
train_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function) | ||
# Initialize wandb if reporting to wandb | ||
if script_args.report_to == "wandb": | ||
wandb.init(project=script_args.run_name) | ||
|
||
data_subset = "sub_eval_w_simulated_edits" | ||
train_dataset = load_dataset( | ||
data_subset, | ||
sanity_check=script_args.sanity_check, | ||
alignment_function=script_args.alignment_function, | ||
) | ||
|
||
# 3. Load evaluation dataset | ||
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function) | ||
eval_dataset = load_dataset( | ||
data_subset, | ||
sanity_check=True, | ||
alignment_function=script_args.alignment_function, | ||
) | ||
|
||
dpo_trainer = trainer(script_args, train_dataset, eval_dataset) | ||
dpo_trainer = trainer(script_args, train_dataset, eval_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,57 @@ | ||
from transformers import HfArgumentParser | ||
from trainer import ScriptArguments, load_dataset, trainer | ||
import wandb | ||
from trainer import ScriptArguments, load_dataset, trainer | ||
|
||
parser = HfArgumentParser(ScriptArguments) | ||
|
||
#for SFT | ||
script_args = parser.parse_args_into_dataclasses( | ||
args=[ | ||
'--per_device_train_batch_size', '1', | ||
'--per_device_eval_batch_size', '1', | ||
'--gradient_accumulation_steps', '1', | ||
# '--model_name_or_path', 'gpt2', | ||
'--model_name_or_path', 'sshleifer/tiny-gpt2', | ||
'--per_device_train_batch_size', '2', | ||
'--per_device_eval_batch_size', '2', | ||
'--gradient_accumulation_steps', '4', | ||
'--model_name_or_path', 'gpt2', | ||
# '--model_name_or_path', 'sshleifer/tiny-gpt2', | ||
# '--model_name_or_path', 'huggy llama/llama-7b', | ||
# '--model_name_or_path', 'meta-llama/Llama-2-7b-hf', | ||
'--load_in_4bit', | ||
'--use_peft', | ||
# '--learning_rate', '1e-3', | ||
'--learning_rate', '1e-4', | ||
# '--report_to', 'wandb', | ||
# '--report_to', 'tensorboard', | ||
'--run_name', 'SFT-avs-gpt2', | ||
'--max_length', '1024', | ||
'--max_prompt_length', '768', | ||
'--num_train_epochs', '1', | ||
'--max_steps', '-1', | ||
'--max_steps', '30', | ||
'--evaluation_strategy', 'epoch', | ||
'--eval_steps', '-1', | ||
# '--eval_first_step', | ||
'--logging_strategy', 'steps', | ||
'--log_steps', '1', | ||
'--log_steps', '10', | ||
'--logging_first_step', | ||
'--save_strategy', 'epoch', | ||
'--save_steps', '-1', | ||
'--save_total_limit', '3', | ||
'--load_best_model_at_end', | ||
'--metric_for_best_model', 'metrics_policy_rouge1', | ||
'--alignment_function', 'sft', | ||
'--output_dir', './results/avs/SFT_model/gpt2', | ||
# '--output_dir', './results/SFT_model/llama2_7b', | ||
] | ||
)[0] | ||
|
||
# Initialize wandb if reporting to wandb | ||
if script_args.report_to == 'wandb': | ||
if script_args.report_to == "wandb": | ||
wandb.init(project=script_args.run_name) | ||
|
||
# 2. Load training dataset | ||
# train_dataset = load_dataset("train", sanity_check=script_args.sanity_check) | ||
train_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function) | ||
data_subset = "sub_eval_w_simulated_edits" | ||
train_dataset = load_dataset( | ||
data_subset, | ||
sanity_check=script_args.sanity_check, | ||
alignment_function=script_args.alignment_function, | ||
) | ||
|
||
# 3. Load evaluation dataset | ||
eval_dataset = load_dataset("sub_eval", sanity_check=script_args.sanity_check, alignment_function=script_args.alignment_function) | ||
eval_dataset = load_dataset( | ||
data_subset, | ||
sanity_check=True, | ||
alignment_function=script_args.alignment_function, | ||
) | ||
|
||
dpo_trainer = trainer(script_args, train_dataset, eval_dataset) | ||
dpo_trainer = trainer(script_args, train_dataset, eval_dataset) |
Binary file not shown.
Oops, something went wrong.