Skip to content

Commit

Permalink
formatting and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
saiprabhakar committed Jan 12, 2025
1 parent 142b609 commit d5c83a0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
8 changes: 4 additions & 4 deletions train_DPO.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import HfArgumentParser
import wandb
from trainer.trainer import ScriptArguments, load_dataset, trainer
from trainer.trainer import ScriptArguments, load_dataset_hg_local, trainer

parser = HfArgumentParser(ScriptArguments)

Expand Down Expand Up @@ -44,17 +44,17 @@
wandb.init(project=script_args.run_name)

data_subset = "sub_eval_w_simulated_edits"
train_dataset = load_dataset(
train_dataset = load_dataset_hg_local(
data_subset,
sanity_check=script_args.sanity_check,
alignment_function=script_args.alignment_function,
)

# 3. Load evaluation dataset
eval_dataset = load_dataset(
eval_dataset = load_dataset_hg_local(
data_subset,
sanity_check=True,
alignment_function=script_args.alignment_function,
)

dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
dpo_trainer = trainer(script_args, train_dataset, eval_dataset)
6 changes: 3 additions & 3 deletions trainer_SALT.py → train_SALT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import HfArgumentParser
import wandb
from trainer.trainer import ScriptArguments, load_dataset, trainer
from trainer.trainer import ScriptArguments, load_dataset_hg_local, trainer

parser = HfArgumentParser(ScriptArguments)

Expand Down Expand Up @@ -48,14 +48,14 @@
wandb.init(project=script_args.run_name)

data_subset = "sub_eval_w_simulated_edits"
train_dataset = load_dataset(
train_dataset = load_dataset_hg_local(
data_subset,
sanity_check=script_args.sanity_check,
alignment_function=script_args.alignment_function,
)

# 3. Load evaluation dataset
eval_dataset = load_dataset(
eval_dataset = load_dataset_hg_local(
data_subset,
sanity_check=True,
alignment_function=script_args.alignment_function,
Expand Down
6 changes: 3 additions & 3 deletions trainer_SFT.py → train_SFT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import HfArgumentParser
import wandb
from trainer.trainer import ScriptArguments, load_dataset, trainer
from trainer.trainer import ScriptArguments, load_dataset_hg_local, trainer

parser = HfArgumentParser(ScriptArguments)

Expand Down Expand Up @@ -41,14 +41,14 @@
wandb.init(project=script_args.run_name)

data_subset = "sub_eval_w_simulated_edits"
train_dataset = load_dataset(
train_dataset = load_dataset_hg_local(
data_subset,
sanity_check=script_args.sanity_check,
alignment_function=script_args.alignment_function,
)

# 3. Load evaluation dataset
eval_dataset = load_dataset(
eval_dataset = load_dataset_hg_local(
data_subset,
sanity_check=True,
alignment_function=script_args.alignment_function,
Expand Down
34 changes: 23 additions & 11 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from datasets import Dataset, load_from_disk # , load_dataset, load_metric
from datasets import Dataset, load_from_disk
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
)

# from transformers.trainer_utils import EvalPrediction# , EvalLoopOutput
# from transformers.trainer_pt_utils import find_batch_size, nested_concat

# import pandas as pd

from peft import LoraConfig, get_peft_model

# from torch.utils.data import DataLoader

from peft import LoraConfig
from trainer.dpo_salt_sft_trainer import DPOTrainer, SALTTrainer, SFTTrainer


def extract_prompt(prompt_and_response):
"""
Extract the prompt from the prompt and response string. This is done by searching for the hard coded string
args:
prompt_and_response: str - the prompt and response string
returns:
str: the prompt
"""
search_term = "\n\nGenerate the corresponding Discharge Instructions according to the input article:"
search_term_idx = prompt_and_response.rfind(search_term)
assert (
Expand All @@ -33,7 +32,7 @@ def extract_prompt(prompt_and_response):
return prompt_and_response[: search_term_idx + len(search_term)]


def load_dataset(
def load_dataset_hg_local(
split: str,
sanity_check: bool = False,
alignment_function: str = "sft",
Expand All @@ -51,6 +50,16 @@ def load_dataset(
Prompts should be structured as follows:
Conversation <prompt>\n\nSummary
args:
split: str - the split to load
sanity_check: bool - only load a small subset of the dataset
alignment_function: str - the alignment function to use
silent: bool - whether to print output
cache_dir: str - the cache directory to use
returns:
Dataset: the dataset
"""
# dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if alignment_function in ["sft", "dpo", "salt"]:
Expand Down Expand Up @@ -241,6 +250,9 @@ class ScriptArguments:


def trainer(script_args, train_dataset, eval_dataset):
"""
Train a model using the DPO or SFT or SALT loss function.
"""
with open("hg_secret", "r") as f:
hg_auth_token = f.read()

Expand Down

0 comments on commit d5c83a0

Please sign in to comment.