Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak ArmorRM implementation, add args to CLI #194

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@
"model_builder": AutoModelForSequenceClassification.from_pretrained,
"pipeline_builder": ArmoRMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"custom_dialogue": False,
"model_type": "Sequence Classifier",
"torch_dtype": torch.bfloat16,
},
"Ray2333/GRM-Gemma-2B-sftreg": {
Expand Down
29 changes: 29 additions & 0 deletions rewardbench/models/armorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,35 @@


class ArmoRMPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
self.model = model.eval()
self.tokenizer = tokenizer

def __call__(self, samples, return_inputs=False, **kwargs):
_ = kwargs.get("batch_size", 1)
truncation = kwargs.get("truncation", True)
padding = kwargs.get("padding", True)
max_length = kwargs.get("max_length", 2048)
inputs = self.tokenizer(
samples,
truncation=truncation,
max_length=max_length,
padding=padding,
# return_special_tokens_mask=True,
return_tensors="pt",
).to("cuda")

with torch.no_grad():
outputs = self.model(**inputs)
if return_inputs:
return outputs.logits, inputs
else:
return outputs.logits


# Moved to newer implementation that doesn't require "Custom Dialogue" tag
class LegacyArmoRMPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
self.model = model
Expand Down
38 changes: 31 additions & 7 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import sys
from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional

import numpy as np
import torch
Expand All @@ -36,6 +36,7 @@
REWARD_MODEL_CONFIG,
check_tokenizer_chat_template,
load_preference_dataset,
torch_dtype_mapping,
)


Expand Down Expand Up @@ -70,6 +71,10 @@ class Args:
"""The batch size to use."""
max_length: int = 512
"""The max length to use."""
torch_dtype: Literal["float16", "bfloat16", "float32", "float64"] = "float16"
"""PyTorch dtype (default: float16)"""
attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2"]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice add!

"""Attention implementation to use (default: None)"""

# system args
load_json: bool = False
Expand Down Expand Up @@ -166,11 +171,22 @@ def actual_main(args: Args):
custom_dialogue = config["custom_dialogue"]
pipeline_builder = config["pipeline_builder"]
_ = config["model_type"]
torch_dtype = config.get("torch_dtype", None)
if custom_dialogue:
raise NotImplementedError("Custom dialogue not implemented yet for simpler data formatting.")

model_builder = config["model_builder"]

# Handle datatype
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)
# if not datatype in config (default), check args
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
if args.torch_dtype == torch.bfloat16:
quantized = False
logger.info("Disabling quantization for bfloat16 datatype")
torch_dtype = args.torch_dtype

#########################
# load dataset
#########################
Expand Down Expand Up @@ -216,7 +232,7 @@ def actual_main(args: Args):

model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"device_map": "auto" if torch.cuda.is_available() else "cpu",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
model = model_builder(
Expand Down Expand Up @@ -280,11 +296,19 @@ def actual_main(args: Args):
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}
# note, device map auto does not work for bitsandbytes quantized models
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch_dtype,
}

# if attn_implementation is not specified, this falls back to Hugging Face's default
# strategy (which chooses between sdpa and eager depending on pytorch version)
if args.attn_implementation:
model_kwargs["attn_implementation"] = args.attn_implementation

model = model_builder(
args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code
Expand Down Expand Up @@ -341,8 +365,8 @@ def actual_main(args: Args):
score_rejected_batch = [result["score"] for result in rewards_rejected]
# for classes that directly output scores (custom code)
else:
score_chosen_batch = rewards_chosen.cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.cpu().numpy().tolist()
score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist()

# log results
[
Expand Down
3 changes: 2 additions & 1 deletion scripts/run_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def main():
model_builder = config["model_builder"]
pipeline_builder = config["pipeline_builder"]
torch_dtype = config.get("torch_dtype", None)

# if not datatype in config (default), check args
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
Expand Down Expand Up @@ -211,7 +212,7 @@ def main():
}
else:
model_kwargs = {
"device_map": "auto",
"device_map": "auto" if torch.cuda.is_available() else "cpu",
"torch_dtype": torch_dtype,
}

Expand Down
Loading