diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 38999bdb8..1724827c8 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,5 +1,6 @@ accelerate datasets +lm_eval onnx onnxruntime optimum diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index cad08ab9f..0d6fb5f42 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -57,7 +57,10 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] [--learned-round {None,linear_round}] - [--learned-round-fast-update] + [--learned-round-fast-update] [--few-shot-eval] + [--few-shot-compile] [--few-shot-zeroshot] + [--few-shot-limit FEW_SHOT_LIMIT] + [--few-shot-tasks [FEW_SHOT_TASKS ...]] options: -h, --help show this help message and exit @@ -210,5 +213,15 @@ options: --learned-round-fast-update Whether to use fast update with learned round. Prototype (default: False) + --few-shot-eval Perform zero_shot evaluation with lm_eval. Default + False) + --few-shot-compile Compile during zero_shot evaluation with lm_eval. + Default False) + --few-shot-zeroshot Whether to do zero or few shot eval. Default False) + --few-shot-limit FEW_SHOT_LIMIT + Few shot limit. Default None) + --few-shot-tasks [FEW_SHOT_TASKS ...] + A list of tasks for zero_shot evaluation. Default: + ['arc_challenge', 'arc_easy', 'winogrande', 'piqa'] ``` diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml index e844c7dbe..ef20184ac 100644 --- a/src/brevitas_examples/llm/config/default_template.yml +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -8,6 +8,15 @@ dataset: wikitext2 eval: false export_prefix: null export_target: null +few_shot_compile: false +few_shot_eval: false +few_shot_limit: null +few_shot_tasks: +- arc_challenge +- arc_easy +- winogrande +- piqa +few_shot_zeroshot: false fuse_sequences: false gpfq: false gptq: false diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 877fbac5d..1d2c86d88 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -7,6 +7,8 @@ import sys from warnings import warn +from lm_eval import evaluator +from lm_eval.models.huggingface import HFLM import numpy as np from optimum.exporters.onnx import onnx_export_from_model import torch @@ -52,6 +54,23 @@ from brevitas_examples.llm.llm_quant.run_utils import get_fx +def filter_results(results, tasks): + # filter out what we actually want to track in azureml + eval_results = dict() + for task_name in tasks: + # first, log n_shots for each task + # for subtask, n_shots in results["n-shot"].items(): + # name = f"{subtask}_n_shot" + # eval_results[name] = float(n_shots) + # then log all result metrics we have for this task + for key, val in results["results"][task_name].items(): + if not isinstance(val, str): + # for mmlu, we don't log results per subtask, but simply overall results + name = f"{task_name}_{key}" + eval_results[name] = val + return eval_results + + def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) @@ -463,6 +482,29 @@ def main(args): quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + + if args.few_shot_eval: + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.few_shot_compile: + remove_hooks(model) + model.cuda() + model = torch.compile(model) + + wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval + results = evaluator.simple_evaluate( + model=wrapped_model, + model_args=None, + tasks=list(args.few_shot_tasks), + device='cuda:0', + limit=args.few_shot_limit, + num_fewshot=0 if args.few_shot_zeroshot else None, + log_samples=False, + batch_size=None, + verbosity="ERROR") + results = filter_results(results, args.few_shot_tasks) + print("Few shot eval results") + print(results) remove_hooks(model) if args.checkpoint_name is not None: @@ -783,7 +825,28 @@ def parse_args(args, override_defaults={}): default=False, action="store_true", help='Whether to use fast update with learned round. Prototype (default: %(default)s)') + parser.add_argument( + '--few-shot-eval', + action="store_true", + help='Perform zero_shot evaluation with lm_eval. Default %(default)s)') + parser.add_argument( + '--few-shot-compile', + action="store_true", + help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)') + parser.add_argument( + '--few-shot-zeroshot', + action="store_true", + help='Whether to do zero or few shot eval. Default %(default)s)') + parser.add_argument( + '--few-shot-limit', type=int, default=None, help='Few shot limit. Default %(default)s)') + parser.add_argument( + '--few-shot-tasks', + default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'], + type=str, + nargs='*', + help='A list of tasks for zero_shot evaluation. Default: %(default)s') parser.set_defaults(**override_defaults) + return parser.parse_args(args)