Skip to content

Commit

Permalink
Feat (brevitas_examples/llm): Eval harness for few-shot testing (#1131)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 20, 2024
1 parent 09235be commit 39ce837
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
accelerate
datasets
lm_eval
onnx
onnxruntime
optimum
Expand Down
15 changes: 14 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]
[--learned-round {None,linear_round}]
[--learned-round-fast-update]
[--learned-round-fast-update] [--few-shot-eval]
[--few-shot-compile] [--few-shot-zeroshot]
[--few-shot-limit FEW_SHOT_LIMIT]
[--few-shot-tasks [FEW_SHOT_TASKS ...]]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -210,5 +213,15 @@ options:
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)
--few-shot-eval Perform zero_shot evaluation with lm_eval. Default
False)
--few-shot-compile Compile during zero_shot evaluation with lm_eval.
Default False)
--few-shot-zeroshot Whether to do zero or few shot eval. Default False)
--few-shot-limit FEW_SHOT_LIMIT
Few shot limit. Default None)
--few-shot-tasks [FEW_SHOT_TASKS ...]
A list of tasks for zero_shot evaluation. Default:
['arc_challenge', 'arc_easy', 'winogrande', 'piqa']

```
9 changes: 9 additions & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ dataset: wikitext2
eval: false
export_prefix: null
export_target: null
few_shot_compile: false
few_shot_eval: false
few_shot_limit: null
few_shot_tasks:
- arc_challenge
- arc_easy
- winogrande
- piqa
few_shot_zeroshot: false
fuse_sequences: false
gpfq: false
gptq: false
Expand Down
63 changes: 63 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
from warnings import warn

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
Expand Down Expand Up @@ -52,6 +54,23 @@
from brevitas_examples.llm.llm_quant.run_utils import get_fx


def filter_results(results, tasks):
# filter out what we actually want to track in azureml
eval_results = dict()
for task_name in tasks:
# first, log n_shots for each task
# for subtask, n_shots in results["n-shot"].items():
# name = f"{subtask}_n_shot"
# eval_results[name] = float(n_shots)
# then log all result metrics we have for this task
for key, val in results["results"][task_name].items():
if not isinstance(val, str):
# for mmlu, we don't log results per subtask, but simply overall results
name = f"{task_name}_{key}"
eval_results[name] = val
return eval_results


def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -463,6 +482,29 @@ def main(args):
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)

if args.checkpoint_name is not None:
Expand Down Expand Up @@ -783,7 +825,28 @@ def parse_args(args, override_defaults={}):
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--few-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--few-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--few-shot-zeroshot',
action="store_true",
help='Whether to do zero or few shot eval. Default %(default)s)')
parser.add_argument(
'--few-shot-limit', type=int, default=None, help='Few shot limit. Default %(default)s)')
parser.add_argument(
'--few-shot-tasks',
default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'],
type=str,
nargs='*',
help='A list of tasks for zero_shot evaluation. Default: %(default)s')
parser.set_defaults(**override_defaults)

return parser.parse_args(args)


Expand Down

0 comments on commit 39ce837

Please sign in to comment.