diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index cad08ab9f..0d6fb5f42 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -57,7 +57,10 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] [--learned-round {None,linear_round}] - [--learned-round-fast-update] + [--learned-round-fast-update] [--few-shot-eval] + [--few-shot-compile] [--few-shot-zeroshot] + [--few-shot-limit FEW_SHOT_LIMIT] + [--few-shot-tasks [FEW_SHOT_TASKS ...]] options: -h, --help show this help message and exit @@ -210,5 +213,15 @@ options: --learned-round-fast-update Whether to use fast update with learned round. Prototype (default: False) + --few-shot-eval Perform zero_shot evaluation with lm_eval. Default + False) + --few-shot-compile Compile during zero_shot evaluation with lm_eval. + Default False) + --few-shot-zeroshot Whether to do zero or few shot eval. Default False) + --few-shot-limit FEW_SHOT_LIMIT + Few shot limit. Default None) + --few-shot-tasks [FEW_SHOT_TASKS ...] + A list of tasks for zero_shot evaluation. Default: + ['arc_challenge', 'arc_easy', 'winogrande', 'piqa'] ``` diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index d028bf52c..1d2c86d88 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -483,10 +483,10 @@ def main(args): model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - if args.zero_shot_eval: + if args.few_shot_eval: with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) - if args.zero_shot_compile: + if args.few_shot_compile: remove_hooks(model) model.cuda() model = torch.compile(model) @@ -495,14 +495,15 @@ def main(args): results = evaluator.simple_evaluate( model=wrapped_model, model_args=None, - tasks=list(args.zero_shot_tasks), + tasks=list(args.few_shot_tasks), device='cuda:0', - limit=None, - num_fewshot=0, + limit=args.few_shot_limit, + num_fewshot=0 if args.few_shot_zeroshot else None, log_samples=False, - batch_size="auto:3", + batch_size=None, verbosity="ERROR") - results = filter_results(results, args.zero_shot_tasks) + results = filter_results(results, args.few_shot_tasks) + print("Few shot eval results") print(results) remove_hooks(model) @@ -825,15 +826,21 @@ def parse_args(args, override_defaults={}): action="store_true", help='Whether to use fast update with learned round. Prototype (default: %(default)s)') parser.add_argument( - '--zero-shot-eval', + '--few-shot-eval', action="store_true", help='Perform zero_shot evaluation with lm_eval. Default %(default)s)') parser.add_argument( - '--zero-shot-compile', + '--few-shot-compile', action="store_true", help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)') parser.add_argument( - '--zero-shot-tasks', + '--few-shot-zeroshot', + action="store_true", + help='Whether to do zero or few shot eval. Default %(default)s)') + parser.add_argument( + '--few-shot-limit', type=int, default=None, help='Few shot limit. Default %(default)s)') + parser.add_argument( + '--few-shot-tasks', default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'], type=str, nargs='*',