Skip to content

Commit

Permalink
Feat (brevitas_examples/llm): zero shot eval
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 18, 2024
1 parent 3612e90 commit 89609d8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
accelerate
datasets
lm_eval
onnx
onnxruntime
optimum
Expand Down
56 changes: 56 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
from warnings import warn

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
Expand Down Expand Up @@ -51,6 +53,23 @@
from brevitas_examples.llm.llm_quant.run_utils import get_fx


def filter_results(results, tasks):
# filter out what we actually want to track in azureml
eval_results = dict()
for task_name in tasks:
# first, log n_shots for each task
# for subtask, n_shots in results["n-shot"].items():
# name = f"{subtask}_n_shot"
# eval_results[name] = float(n_shots)
# then log all result metrics we have for this task
for key, val in results["results"][task_name].items():
if not isinstance(val, str):
# for mmlu, we don't log results per subtask, but simply overall results
name = f"{task_name}_{key}"
eval_results[name] = val
return eval_results


def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -462,6 +481,28 @@ def main(args):
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.zero_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.zero_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.zero_shot_tasks),
device='cuda:0',
limit=None,
num_fewshot=0,
log_samples=False,
batch_size="auto:3",
verbosity="ERROR")
results = filter_results(results, args.zero_shot_tasks)
print(results)
remove_hooks(model)

if args.checkpoint_name is not None:
Expand Down Expand Up @@ -757,6 +798,21 @@ def parse_args(args):
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--zero-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--zero-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--zero-shot-tasks',
default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'],
type=str,
nargs='*',
help='A list of tasks for zero_shot evaluation. Default: %(default)s')

return parser.parse_args(args)


Expand Down

0 comments on commit 89609d8

Please sign in to comment.