Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Sep 25, 2024
1 parent 7f501f6 commit ba94c43
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
31 changes: 19 additions & 12 deletions rewardbench/rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import os
import sys
from dataclasses import dataclass
from typing import Optional, List
from typing import Optional

import numpy as np
import wandb
import torch
import transformers
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from tqdm import tqdm
Expand All @@ -38,7 +38,6 @@
)



@dataclass
class Args:
# core args
Expand All @@ -58,8 +57,10 @@ class Args:
"""The chat template to use (defaults to from tokenizer, from chattemplate)."""
not_quantized: bool = False
"""Disable quantization for models that are quantized by default."""

# wandb args
wandb_run: Optional[str] = None
"""The split to evaluate on."""
"""The wandb run to extract model and revision from."""

# inference args
batch_size: int = 8
Expand All @@ -82,10 +83,13 @@ class Args:
"""Force truncation (for if model errors)."""


def main(args: Args):
if args.wandb_run is not None:
import wandb
def main():
parser = HfArgumentParser((Args))
actual_main(*parser.parse_args_into_dataclasses())


def actual_main(args: Args):
if args.wandb_run is not None:
wandb_run = wandb.Api().run(args.wandb_run)
args.model = wandb_run.config["hf_repo_id"]
args.revision = wandb_run.config["hf_repo_revision"]
Expand Down Expand Up @@ -169,7 +173,9 @@ def main(args: Args):
#########################
logger.info("*** Load dataset ***")
tokenizer_path = args.tokenizer if args.tokenizer else args.model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=args.trust_remote_code, revision=args.revision)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=args.trust_remote_code, revision=args.revision
)
if args.dataset == "allenai/reward-bench":
logger.info("Running core eval dataset.")
from rewardbench import load_eval_dataset
Expand Down Expand Up @@ -277,7 +283,9 @@ def main(args: Args):
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}

model = model_builder(args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code)
model = model_builder(
args.model, **model_kwargs, revision=args.revision, trust_remote_code=args.trust_remote_code
)
reward_pipe = pipeline_builder(
"text-classification", # often not used
model=model,
Expand Down Expand Up @@ -395,7 +403,7 @@ def main(args: Args):
}
with open(output_path, "w") as f:
json.dump(final_results, f)

if args.wandb_run is not None:
for key in final_results:
wandb_run.summary[f"rewardbench/{key}"] = final_results[key]
Expand All @@ -418,5 +426,4 @@ def main(args: Args):


if __name__ == "__main__":
parser = HfArgumentParser((Args))
main(*parser.parse_args_into_dataclasses())
main()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"tiktoken==0.6.0", # added for llama 3
"transformers==4.43.4", # pinned at llama 3
"trl>=0.8.2", # fixed transformers import error, for DPO
"wandb", # for loading model path / reivisions from wandb
],
extras_require={
"generative": [
Expand Down

0 comments on commit ba94c43

Please sign in to comment.