Skip to content

Commit

Permalink
clean up handling
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Sep 30, 2024
1 parent 23b1d36 commit e1cf377
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import sys
from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional

import numpy as np
import torch
Expand All @@ -36,6 +36,7 @@
REWARD_MODEL_CONFIG,
check_tokenizer_chat_template,
load_preference_dataset,
torch_dtype_mapping,
)


Expand Down Expand Up @@ -70,6 +71,10 @@ class Args:
"""The batch size to use."""
max_length: int = 512
"""The max length to use."""
torch_dtype: Literal["float16", "bfloat16", "float32", "float64"] = "float16"
"""PyTorch dtype (default: float16)"""
attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2"]] = None
"""Attention implementation to use (default: None)"""

# system args
load_json: bool = False
Expand Down Expand Up @@ -172,6 +177,16 @@ def actual_main(args: Args):

model_builder = config["model_builder"]

# Handle datatype
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)
# if not datatype in config (default), check args
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
if args.torch_dtype == torch.bfloat16:
quantized = False
logger.info("Disabling quantization for bfloat16 datatype")
torch_dtype = args.torch_dtype

#########################
# load dataset
#########################
Expand Down Expand Up @@ -278,14 +293,10 @@ def actual_main(args: Args):
"return_token_type_ids": False,
}
if quantized:
if torch_dtype is not None:
torch_dtype = torch_dtype
else:
torch_dtype = torch.float16
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for bitsandbytes quantized models
Expand All @@ -294,6 +305,11 @@ def actual_main(args: Args):
"torch_dtype": torch_dtype,
}

# if attn_implementation is not specified, this falls back to Hugging Face's default
# strategy (which chooses between sdpa and eager depending on pytorch version)
if args.attn_implementation:
model_kwargs["attn_implementation"] = args.attn_implementation

model = model_builder(
args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code
)
Expand Down

0 comments on commit e1cf377

Please sign in to comment.