Skip to content

Commit

Permalink
Feat (ex/llm): Specify experiments via YAML files
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Dec 5, 2024
1 parent 7a5f77d commit d63bfdd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 6 deletions.
9 changes: 5 additions & 4 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--gpxq-block-name GPXQ_BLOCK_NAME]
usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval]
[--dataset {wikitext2,c4}] [--gpxq-block-name GPXQ_BLOCK_NAME]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
Expand Down Expand Up @@ -58,6 +58,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]

options:
-h, --help show this help message and exit
--config CONFIG Specify alternative default commandline args (e.g.,
config/default_template.yml). Default: None.
--model MODEL HF model name. Default: facebook/opt-125m.
--seed SEED Seed for sampling the calibration data. Default: 0.
--nsamples NSAMPLES Number of calibration data samples. Default: 128.
Expand Down Expand Up @@ -200,5 +202,4 @@ options:
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)

```
61 changes: 61 additions & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
act_calibration: false
act_equalization: null
bias_corr: false
checkpoint_name: null
config: null
convert_layernorm_to_rmsnorm: false
dataset: wikitext2
eval: false
export_prefix: null
export_target: null
fuse_sequences: false
gpfq: false
gptq: false
gpxq_act_order: false
gpxq_block_name: null
gpxq_create_weight_orig: false
gpxq_max_accumulator_bit_width: null
gpxq_max_accumulator_tile_size: null
gpxq_use_quant_activations: false
input_bit_width: null
input_group_size: 64
input_param_method: stats
input_quant_format: int
input_quant_granularity: per_tensor
input_quant_type: asym
input_scale_precision: float_scale
input_scale_type: static
learned_round: null
learned_round_fast_update: false
learned_round_iters: 200
learned_round_lr: 0.005
learned_round_scale: false
learned_round_scale_lr: 0.01
learned_round_scale_momentum: 0.9
ln_affine_merge: false
load_awq: null
model: facebook/opt-125m
no_float16: false
no_quantize: false
nsamples: 128
quantize_input_zero_point: false
quantize_last_layer: false
quantize_weight_zero_point: false
replace_mha: false
replace_rmsnorm: false
rotation: null
rotation_mode: had
rotation_orphan_sink: false
scale_rounding_func_type: null
scaling_min_val: 0.0001
seed: 0
seqlen: 2048
weight_bit_width: 8
weight_equalization: false
weight_group_dim: null
weight_group_size: 128
weight_param_method: stats
weight_quant_format: int
weight_quant_granularity: per_group
weight_quant_type: sym
weight_scale_precision: float_scale
8 changes: 8 additions & 0 deletions src/brevitas_examples/llm/config/gen_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import yaml

from brevitas_examples.llm.main import parse_args

if __name__ == "__main__":
default_args = parse_args([])
with open('default_template.yml', 'w') as f:
yaml.dump(default_args.__dict__, f)
32 changes: 30 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils.fx import _SUPPORTED_MODELS
import yaml

from brevitas.export import export_torch_qcdq
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
Expand Down Expand Up @@ -440,8 +441,33 @@ def main(args):
return float_ppl, quant_ppl, model


def parse_args(args):
def override_defaults(args):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
'--config',
type=str,
default=None,
help=
'Specify alternative default commandline args (e.g., config/default_template.yml). Default: %(default)s.'
)
known_args = parser.parse_known_args()[0] # Returns a tuple
if known_args.config is not None:
with open(known_args.config, 'r') as f:
defaults = yaml.safe_load(f)
else:
defaults = {}
return defaults


def parse_args(args, override_defaults={}):
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
type=str,
default=None,
help=
'Specify alternative default commandline args (e.g., config/default_template.yml). Default: %(default)s.'
)
parser.add_argument(
'--model',
type=str,
Expand Down Expand Up @@ -711,9 +737,11 @@ def parse_args(args):
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.set_defaults(**override_defaults)
return parser.parse_args(args)


if __name__ == '__main__':
args = parse_args(sys.argv[1:])
overrides = override_defaults(sys.argv[1:])
args = parse_args(sys.argv[1:], override_defaults=overrides)
main(args)

0 comments on commit d63bfdd

Please sign in to comment.