Skip to content

Commit

Permalink
Review + README
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2024
1 parent ba347a9 commit f7e1f48
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
15 changes: 14 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]
[--learned-round {None,linear_round}]
[--learned-round-fast-update]
[--learned-round-fast-update] [--few-shot-eval]
[--few-shot-compile] [--few-shot-zeroshot]
[--few-shot-limit FEW_SHOT_LIMIT]
[--few-shot-tasks [FEW_SHOT_TASKS ...]]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -210,5 +213,15 @@ options:
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)
--few-shot-eval Perform zero_shot evaluation with lm_eval. Default
False)
--few-shot-compile Compile during zero_shot evaluation with lm_eval.
Default False)
--few-shot-zeroshot Whether to do zero or few shot eval. Default False)
--few-shot-limit FEW_SHOT_LIMIT
Few shot limit. Default None)
--few-shot-tasks [FEW_SHOT_TASKS ...]
A list of tasks for zero_shot evaluation. Default:
['arc_challenge', 'arc_easy', 'winogrande', 'piqa']

```
27 changes: 17 additions & 10 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,10 @@ def main(args):
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.zero_shot_eval:
if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.zero_shot_compile:
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)
Expand All @@ -495,14 +495,15 @@ def main(args):
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.zero_shot_tasks),
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=None,
num_fewshot=0,
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size="auto:3",
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.zero_shot_tasks)
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)

Expand Down Expand Up @@ -825,15 +826,21 @@ def parse_args(args, override_defaults={}):
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--zero-shot-eval',
'--few-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--zero-shot-compile',
'--few-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
parser.add_argument(
'--zero-shot-tasks',
'--few-shot-zeroshot',
action="store_true",
help='Whether to do zero or few shot eval. Default %(default)s)')
parser.add_argument(
'--few-shot-limit', type=int, default=None, help='Few shot limit. Default %(default)s)')
parser.add_argument(
'--few-shot-tasks',
default=['arc_challenge', 'arc_easy', 'winogrande', 'piqa'],
type=str,
nargs='*',
Expand Down

0 comments on commit f7e1f48

Please sign in to comment.