Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support rlhf #184

Merged
merged 9 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dbgpt_hub/configs/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ class DataArguments:
default="dbgpt_hub/data/",
metadata={"help": "The name of the folder containing datasets."},
)
cutoff_len: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."},
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."},
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."},
Expand Down
12 changes: 12 additions & 0 deletions dbgpt_hub/configs/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ class ModelArguments:
"help": "Used in rope scaling. Do not specify this argument manually."
},
)
hf_hub_token: Optional[str] = field(
default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
split_special_tokens: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not the special tokens should be split during the tokenization process."
},
)

def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
Expand Down Expand Up @@ -182,6 +191,9 @@ class FinetuningArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["sft", "rm"]] = field(
default="sft", metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora", metadata={"help": "Which fine-tuning method to use."}
)
Expand Down
9 changes: 9 additions & 0 deletions dbgpt_hub/data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,14 @@
"response": "output",
"history": "history"
}
},
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
}
101 changes: 97 additions & 4 deletions dbgpt_hub/data_process/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import pandas as pd
import tiktoken
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING, Generator
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Union,
TYPE_CHECKING,
Generator,
Literal,
)
from datasets import (
Dataset,
DatasetDict,
Expand Down Expand Up @@ -64,6 +74,17 @@ def extract_sql_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
return {"input": prompt_format.format(**example)}


def infer_max_len(
source_len: int, target_len: int, data_args: "DataArguments"
) -> Tuple[int, int]:
max_target_len = int(
data_args.cutoff_len * (target_len / (source_len + target_len))
)
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len


def local_dataset(
dataset_path: str, eval_dataset_size: float = 0.1
) -> Tuple[Dataset, Dataset]:
Expand Down Expand Up @@ -579,6 +600,7 @@ def preprocess_dataset(
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
Expand Down Expand Up @@ -670,6 +692,69 @@ def preprocess_unsupervised_dataset(

return model_inputs

def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]]
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` for rm stage
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (
isinstance(query, str)
and isinstance(response, list)
and query != ""
and len(response) > 1
):
continue

prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, query, response[0], history, system
)
_, rejected_ids = template.encode_oneturn(
tokenizer, query, response[1], history, system
)

# if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]

source_len, target_len = len(prompt_ids), max(
len(chosen_ids), len(rejected_ids)
)
max_source_len, max_target_len = infer_max_len(
source_len, target_len, data_args
)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]

model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)

return model_inputs

def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print(
"prompt:\n{}".format(
tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)
)
)
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print(
"chosen:\n{}".format(
tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)
)
)
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print(
"rejected:\n{}".format(
tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)
)
)

def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print(
Expand All @@ -690,9 +775,17 @@ def print_supervised_dataset_example(example):
)
)

dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
if stage == "pt":
pass
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
print(111111111111111111)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug code remains

Copy link
Member

@wangzaistone wangzaistone Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oushu1zhangxiangxuan1 what 's the detail about the bug and your command ,env ? I have passed to

preprocess_function = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
pass

with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
Expand Down
119 changes: 73 additions & 46 deletions dbgpt_hub/llm_base/load_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import inspect
import math
from typing import Optional, Tuple, Dict, TYPE_CHECKING, Literal, List
from types import MethodType
Expand All @@ -10,8 +11,9 @@
from dbgpt_hub.configs.config import LAYERNORM_NAMES, VALUE_HEAD_FILE_NAME

from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.utils import check_min_version
from transformers.utils import check_min_version, cached_file
from transformers.utils.versions import require_version
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -103,32 +105,54 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return model


def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
if not os.path.exists(valuehead_file):
logger.warning(
"Provided path ({}) does not contain valuehead weights.".format(
checkpoint_dir
)
)
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer(
"default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])
)
model.register_buffer(
"default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])
def load_valuehead_params(
path_or_repo_id: str, model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.

Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir}

if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif (
"use_auth_token" in inspect.signature(cached_file).parameters
): # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")

try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))

try:
from safetensors import safe_open

vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
"v_head.summary.bias": f.get_tensor("v_head.summary.bias"),
}
except Exception as err:
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))

logger.warning(
"Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)
)
return True
return None


def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft",
add_valuehead: Optional[bool] = False,
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Expand All @@ -151,7 +175,8 @@ def load_model_and_tokenizer(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side=model_args.padding_side,
split_special_tokens=model_args.split_special_tokens,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)

Expand All @@ -171,6 +196,15 @@ def load_model_and_tokenizer(
else:
setattr(config, "fp16", True)

# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [
("fp16", torch.float16),
("bf16", torch.bfloat16),
("fp32", torch.float32),
]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)

# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
Expand Down Expand Up @@ -294,33 +328,26 @@ def load_model_and_tokenizer(
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)

# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = (
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = (
AutoModelForCausalLMWithValueHead.from_pretrained(model)
)
reset_logging()
if (
stage == "rm" and model_args.checkpoint_dir is not None
): # load valuehead weights to evaluate reward model
logger.warning(
"Only the last checkpoint containing valuehead will be loaded as the valuehead."
)
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict(
{
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias"),
}
)

if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(
model_args.reward_model, "reward", is_trainable=False
)
assert load_valuehead_params(
model, model_args.reward_model
), "Reward model is not correctly loaded."
ignore_modules = [
name for name, _ in model.named_parameters() if "pretrained_model" in name
]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(
model, "tie_weights", MethodType(lambda _: None, model)
) # use empty method
vhead_path = (
model_args.checkpoint_dir[-1]
if model_args.checkpoint_dir is not None
else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))

# Prepare model for inference
if not is_trainable:
Expand Down
Loading
Loading