Skip to content

Commit

Permalink
add handling of bfloat16 models
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Sep 30, 2024
1 parent fe5caf3 commit 23b1d36
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def actual_main(args: Args):
custom_dialogue = config["custom_dialogue"]
pipeline_builder = config["pipeline_builder"]
_ = config["model_type"]
torch_dtype = config.get("torch_dtype", None)
if custom_dialogue:
raise NotImplementedError("Custom dialogue not implemented yet for simpler data formatting.")

Expand Down Expand Up @@ -277,14 +278,21 @@ def actual_main(args: Args):
"return_token_type_ids": False,
}
if quantized:
if torch_dtype is not None:
torch_dtype = torch_dtype
else:
torch_dtype = torch.float16
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}
# note, device map auto does not work for bitsandbytes quantized models
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch_dtype,
}

model = model_builder(
args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code
Expand Down Expand Up @@ -341,8 +349,8 @@ def actual_main(args: Args):
score_rejected_batch = [result["score"] for result in rewards_rejected]
# for classes that directly output scores (custom code)
else:
score_chosen_batch = rewards_chosen.cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.cpu().numpy().tolist()
score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist()

# log results
[
Expand Down

0 comments on commit 23b1d36

Please sign in to comment.