From 89609d854d75604abd90175481b5f2453b977ba1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 15 Dec 2024 20:49:22 +0000 Subject: [PATCH] Feat (brevitas_examples/llm): zero shot eval --- requirements/requirements-llm.txt | 1 + src/brevitas_examples/llm/main.py | 56 +++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 38999bdb8..1724827c8 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,5 +1,6 @@ accelerate datasets +lm_eval onnx onnxruntime optimum diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b9c2c1c4d..3b4b1eb05 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -7,6 +7,8 @@ import sys from warnings import warn +from lm_eval import evaluator +from lm_eval.models.huggingface import HFLM import numpy as np from optimum.exporters.onnx import onnx_export_from_model import torch @@ -51,6 +53,23 @@ from brevitas_examples.llm.llm_quant.run_utils import get_fx +def filter_results(results, tasks): + # filter out what we actually want to track in azureml + eval_results = dict() + for task_name in tasks: + # first, log n_shots for each task + # for subtask, n_shots in results["n-shot"].items(): + # name = f"{subtask}_n_shot" + # eval_results[name] = float(n_shots) + # then log all result metrics we have for this task + for key, val in results["results"][task_name].items(): + if not isinstance(val, str): + # for mmlu, we don't log results per subtask, but simply overall results + name = f"{task_name}_{key}" + eval_results[name] = val + return eval_results + + def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) @@ -462,6 +481,28 @@ def main(args): quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + + if args.zero_shot_eval: + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.zero_shot_compile: + remove_hooks(model) + model.cuda() + model = torch.compile(model) + + wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval + results = evaluator.simple_evaluate( + model=wrapped_model, + model_args=None, + tasks=list(args.zero_shot_tasks), + device='cuda:0', + limit=None, + num_fewshot=0, + log_samples=False, + batch_size="auto:3", + verbosity="ERROR") + results = filter_results(results, args.zero_shot_tasks) + print(results) remove_hooks(model) if args.checkpoint_name is not None: @@ -757,6 +798,21 @@ def parse_args(args): default=False, action="store_true", help='Whether to use fast update with learned round. Prototype (default: %(default)s)') + parser.add_argument( + '--zero-shot-eval', + action="store_true", + help='Perform zero_shot evaluation with lm_eval. Default %(default)s)') + parser.add_argument( + '--zero-shot-compile', + action="store_true", + help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)') + parser.add_argument( + '--zero-shot-tasks', + default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'], + type=str, + nargs='*', + help='A list of tasks for zero_shot evaluation. Default: %(default)s') + return parser.parse_args(args)