Skip to content

Commit

Permalink
Feat (brevitas_examples/llm): zero shot eval
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2024
1 parent 09235be commit ba347a9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
accelerate
datasets
lm_eval
onnx
onnxruntime
optimum
Expand Down
7 changes: 7 additions & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,10 @@ weight_quant_format: int
weight_quant_granularity: per_group
weight_quant_type: sym
weight_scale_precision: float_scale
zero_shot_compile: false
zero_shot_eval: false
zero_shot_tasks:
- arc_challenge
- arc_easy
- winogrande
- piqa
56 changes: 56 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
from warnings import warn

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
Expand Down Expand Up @@ -52,6 +54,23 @@
from brevitas_examples.llm.llm_quant.run_utils import get_fx


def filter_results(results, tasks):
# filter out what we actually want to track in azureml
eval_results = dict()
for task_name in tasks:
# first, log n_shots for each task
# for subtask, n_shots in results["n-shot"].items():
# name = f"{subtask}_n_shot"
# eval_results[name] = float(n_shots)
# then log all result metrics we have for this task
for key, val in results["results"][task_name].items():
if not isinstance(val, str):
# for mmlu, we don't log results per subtask, but simply overall results
name = f"{task_name}_{key}"
eval_results[name] = val
return eval_results


def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -463,6 +482,28 @@ def main(args):
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.zero_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.zero_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.zero_shot_tasks),
device='cuda:0',
limit=None,
num_fewshot=0,
log_samples=False,
batch_size="auto:3",
verbosity="ERROR")
results = filter_results(results, args.zero_shot_tasks)
print(results)
remove_hooks(model)

if args.checkpoint_name is not None:
Expand Down Expand Up @@ -783,7 +824,22 @@ def parse_args(args, override_defaults={}):
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--zero-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--zero-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--zero-shot-tasks',
default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'],
type=str,
nargs='*',
help='A list of tasks for zero_shot evaluation. Default: %(default)s')
parser.set_defaults(**override_defaults)

return parser.parse_args(args)


Expand Down

0 comments on commit ba347a9

Please sign in to comment.