diff --git a/rewardbench/models/__init__.py b/rewardbench/models/__init__.py index e44be10..9c9bcf2 100644 --- a/rewardbench/models/__init__.py +++ b/rewardbench/models/__init__.py @@ -152,8 +152,8 @@ "model_builder": AutoModelForSequenceClassification.from_pretrained, "pipeline_builder": ArmoRMPipeline, "quantized": False, - "custom_dialogue": True, - "model_type": "Custom Classifier", + "custom_dialogue": False, + "model_type": "Sequence Classifier", "torch_dtype": torch.bfloat16, }, "Ray2333/GRM-Gemma-2B-sftreg": { diff --git a/rewardbench/models/armorm.py b/rewardbench/models/armorm.py index 94c9242..18436b3 100644 --- a/rewardbench/models/armorm.py +++ b/rewardbench/models/armorm.py @@ -5,6 +5,35 @@ class ArmoRMPipeline: + def __init__(self, task, model, tokenizer): + self.task = task + self.model = model.eval() + self.tokenizer = tokenizer + + def __call__(self, samples, return_inputs=False, **kwargs): + _ = kwargs.get("batch_size", 1) + truncation = kwargs.get("truncation", True) + padding = kwargs.get("padding", True) + max_length = kwargs.get("max_length", 2048) + inputs = self.tokenizer( + samples, + truncation=truncation, + max_length=max_length, + padding=padding, + # return_special_tokens_mask=True, + return_tensors="pt", + ).to("cuda") + + with torch.no_grad(): + outputs = self.model(**inputs) + if return_inputs: + return outputs.logits, inputs + else: + return outputs.logits + + +# Moved to newer implementation that doesn't require "Custom Dialogue" tag +class LegacyArmoRMPipeline: def __init__(self, task, model, tokenizer): self.task = task self.model = model diff --git a/rewardbench/rewardbench.py b/rewardbench/rewardbench.py index 6b52223..b480891 100644 --- a/rewardbench/rewardbench.py +++ b/rewardbench/rewardbench.py @@ -21,7 +21,7 @@ import time from dataclasses import dataclass from pprint import pformat -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union import numpy as np import pkg_resources @@ -41,6 +41,7 @@ REWARD_MODEL_CONFIG, check_tokenizer_chat_template, load_and_process_dataset, + torch_dtype_mapping, ) @@ -85,6 +86,10 @@ class Args: """The batch size to use.""" max_length: int = 512 """The max length to use.""" + torch_dtype: Literal["float16", "bfloat16", "float32", "float64"] = "float16" + """PyTorch dtype (default: float16)""" + attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2"]] = None + """Attention implementation to use (default: None)""" # system args load_json: bool = False @@ -274,11 +279,22 @@ def rewardbench(args: Args): custom_dialogue = config["custom_dialogue"] pipeline_builder = config["pipeline_builder"] _ = config["model_type"] + torch_dtype = config.get("torch_dtype", None) if custom_dialogue: raise NotImplementedError("Custom dialogue not implemented yet for simpler data formatting.") model_builder = config["model_builder"] + # Handle datatype + args.torch_dtype = torch_dtype_mapping(args.torch_dtype) + # if not datatype in config (default), check args + if torch_dtype is None: + # if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes) + if args.torch_dtype == torch.bfloat16: + quantized = False + logger.info("Disabling quantization for bfloat16 datatype") + torch_dtype = args.torch_dtype + ######################### # load dataset ######################### @@ -344,7 +360,7 @@ def rewardbench(args: Args): model_kwargs = { "load_in_8bit": True, - "device_map": "auto", + "device_map": "auto" if torch.cuda.is_available() else "cpu", "torch_dtype": torch.float16 if torch.cuda.is_available() else None, } model = model_builder( @@ -408,11 +424,19 @@ def rewardbench(args: Args): model_kwargs = { "load_in_8bit": True, "device_map": {"": current_device}, - "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + "torch_dtype": torch_dtype if torch.cuda.is_available() else None, } else: - # note, device map auto does not work for quantized models - model_kwargs = {"device_map": "auto"} + # note, device map auto does not work for bitsandbytes quantized models + model_kwargs = { + "device_map": "auto", + "torch_dtype": torch_dtype, + } + + # if attn_implementation is not specified, this falls back to Hugging Face's default + # strategy (which chooses between sdpa and eager depending on pytorch version) + if args.attn_implementation: + model_kwargs["attn_implementation"] = args.attn_implementation model = model_builder( args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code @@ -472,8 +496,8 @@ def rewardbench(args: Args): score_rejected_batch = [result["score"] for result in rewards_rejected] # for classes that directly output scores (custom code) else: - score_chosen_batch = rewards_chosen.cpu().numpy().tolist() - score_rejected_batch = rewards_rejected.cpu().numpy().tolist() + score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist() + score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist() # log results [ diff --git a/scripts/run_rm.py b/scripts/run_rm.py index f2e6b1d..65d4223 100644 --- a/scripts/run_rm.py +++ b/scripts/run_rm.py @@ -153,6 +153,7 @@ def main(): model_builder = config["model_builder"] pipeline_builder = config["pipeline_builder"] torch_dtype = config.get("torch_dtype", None) + # if not datatype in config (default), check args if torch_dtype is None: # if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes) @@ -211,7 +212,7 @@ def main(): } else: model_kwargs = { - "device_map": "auto", + "device_map": "auto" if torch.cuda.is_available() else "cpu", "torch_dtype": torch_dtype, }