Skip to content

Commit

Permalink
Merge pull request #13 from allenai/beaver_fix
Browse files Browse the repository at this point in the history
Beaver fix; working towards another model
  • Loading branch information
natolambert authored Feb 9, 2024
2 parents a4a5f38 + d3c9c33 commit 0299429
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ python scripts/run_rm.py --model=berkeley-nest/Starling-RM-7B-alpha --tokenizer=
python scripts/run_rm.py --model=stanfordnlp/SteamSHP-flan-t5-xl --direct_load --batch_size=32
python scripts/run_rm.py --model=PKU-Alignment/beaver-7b-v1.0-reward --chat_template=pku-align --direct_load --batch_size=16
python scripts/run_rm.py --model=PKU-Alignment/beaver-7b-v1.0-cost --chat_template=pku-align --direct_load --batch_size=16
python scripts/run_rm.py --model=IDEA-CCNL/Ziya-LLaMA-7B-Reward --batch_size=32 --direct_load --trust_remote_code --chat_template=Ziya # custom code causing cuda issues
```

And for DPO:
Expand Down
15 changes: 15 additions & 0 deletions herm/chattemplates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Added chat templates for models (when they have examples)
# TODO add as needed
2 changes: 1 addition & 1 deletion herm/models/beaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,4 +491,4 @@ def __call__(self, samples, **kwargs):
).to("cuda")
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.scores
return outputs.end_scores
55 changes: 55 additions & 0 deletions herm/models/ziya.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2023 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template

# e.g. https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-7B-Reward#usage
# prefix_user = "Human:"
# prefix_bot = "\n\nAssistant:"
# query = "列举一种空气污染。"
# response = "一种常见的空气污染源是化石燃料的燃烧产生的尾气排放,包括来自汽车、卡车、飞机、
# 火车和工业厂房的废气排放。这会导致大气中的二氧化硫、氮氧化物、一氧化碳、臭氧和颗粒物(例如灰尘和烟雾)等污染物含量增加,对人类健康和环境造成不利影响。"
register_conv_template(
Conversation(
name="Ziya",
roles=("Human", "Assistant"),
sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
sep="\n\n",
)
)


# pipeline because custom model returns reward directly compared to other models
class ZiyaPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
self.model = model.eval().half().cuda()
self.tokenizer = tokenizer

def __call__(self, query, **kwargs):
_ = kwargs.get("batch_size", 1)
truncation = kwargs.get("truncation", True)
padding = kwargs.get("padding", True)
max_length = kwargs.get("max_length", 2048)
inputs = self.tokenizer(
query,
truncation=truncation,
max_length=max_length,
padding=padding,
return_tensors="pt",
).to("cuda")
with torch.no_grad():
reward = self.model(**inputs)
return reward
20 changes: 17 additions & 3 deletions scripts/run_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def get_args():
"--tokenizer", type=str, default=None, help="path to non-matching tokenizer, requires --direct_load"
)
parser.add_argument("--chat_template", type=str, default="tulu", help="path to chat template")
parser.add_argument("--direct_load", action="store_true", help="directly load model instead of pipeline")
parser.add_argument(
"--direct_load", action="store_true", default=False, help="directly load model instead of pipeline"
)
parser.add_argument(
"--trust_remote_code", action="store_true", default=False, help="directly load model instead of pipeline"
)
parser.add_argument("--do_not_save", action="store_true", help="do not save results to hub (for debugging)")
parser.add_argument("--batch_size", type=int, default=64, help="batch size for inference")
parser.add_argument(
Expand Down Expand Up @@ -102,10 +107,18 @@ def main():

model_builder = LlamaForScore.from_pretrained
pipeline_builder = BeaverPipeline
elif "Ziya" in args.model or "Ziya" in args.chat_template:
from herm.models.ziya import ZiyaPipeline

model_builder = AutoModelForSequenceClassification.from_pretrained
pipeline_builder = ZiyaPipeline
quantized = False # handled by .half() in the custom pipeline, as in model card
else:
model_builder = AutoModelForSequenceClassification.from_pretrained
pipeline_builder = pipeline

trust_remote_code = args.trust_remote_code

###############
# Setup logging
###############
Expand Down Expand Up @@ -168,8 +181,8 @@ def main():
model_kwargs = {"device_map": {"": current_device}}
# TODO remove direct load logic
# if pipeline_builder is pipeline, use built in pipeline, else custom
if args.direct_load:
model = model_builder(args.model, **model_kwargs)
if args.direct_load or not pipeline_builder == pipeline:
model = model_builder(args.model, **model_kwargs, trust_remote_code=trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
reward_pipe = pipeline_builder(
"text-classification",
Expand All @@ -183,6 +196,7 @@ def main():
tokenizer=tokenizer,
revision="main",
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
)

############################
Expand Down

0 comments on commit 0299429

Please sign in to comment.