diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 64c82c80a..96e2e9c63 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -13,9 +13,9 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. ```bash -usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] - [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] - [--gpxq-block-name GPXQ_BLOCK_NAME] +usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] + [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] + [--dataset {wikitext2,c4}] [--gpxq-block-name GPXQ_BLOCK_NAME] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse,hqo}] [--weight-scale-precision {float_scale,po2_scale}] @@ -58,6 +58,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] options: -h, --help show this help message and exit + --config CONFIG Specify alternative default commandline args (e.g., + config/default_template.yml). Default: None. --model MODEL HF model name. Default: facebook/opt-125m. --seed SEED Seed for sampling the calibration data. Default: 0. --nsamples NSAMPLES Number of calibration data samples. Default: 128. @@ -200,5 +202,4 @@ options: --learned-round-fast-update Whether to use fast update with learned round. Prototype (default: False) - ``` diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml new file mode 100644 index 000000000..02f803cf2 --- /dev/null +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -0,0 +1,61 @@ +act_calibration: false +act_equalization: null +bias_corr: false +checkpoint_name: null +config: null +convert_layernorm_to_rmsnorm: false +dataset: wikitext2 +eval: false +export_prefix: null +export_target: null +fuse_sequences: false +gpfq: false +gptq: false +gpxq_act_order: false +gpxq_block_name: null +gpxq_create_weight_orig: false +gpxq_max_accumulator_bit_width: null +gpxq_max_accumulator_tile_size: null +gpxq_use_quant_activations: false +input_bit_width: null +input_group_size: 64 +input_param_method: stats +input_quant_format: int +input_quant_granularity: per_tensor +input_quant_type: asym +input_scale_precision: float_scale +input_scale_type: static +learned_round: null +learned_round_fast_update: false +learned_round_iters: 200 +learned_round_lr: 0.005 +learned_round_scale: false +learned_round_scale_lr: 0.01 +learned_round_scale_momentum: 0.9 +ln_affine_merge: false +load_awq: null +model: facebook/opt-125m +no_float16: false +no_quantize: false +nsamples: 128 +quantize_input_zero_point: false +quantize_last_layer: false +quantize_weight_zero_point: false +replace_mha: false +replace_rmsnorm: false +rotation: null +rotation_mode: had +rotation_orphan_sink: false +scale_rounding_func_type: null +scaling_min_val: 0.0001 +seed: 0 +seqlen: 2048 +weight_bit_width: 8 +weight_equalization: false +weight_group_dim: null +weight_group_size: 128 +weight_param_method: stats +weight_quant_format: int +weight_quant_granularity: per_group +weight_quant_type: sym +weight_scale_precision: float_scale diff --git a/src/brevitas_examples/llm/config/gen_template.py b/src/brevitas_examples/llm/config/gen_template.py new file mode 100644 index 000000000..fb95af0e1 --- /dev/null +++ b/src/brevitas_examples/llm/config/gen_template.py @@ -0,0 +1,8 @@ +import yaml + +from brevitas_examples.llm.main import parse_args + +if __name__ == "__main__": + default_args = parse_args([]) + with open('default_template.yml', 'w') as f: + yaml.dump(default_args.__dict__, f) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c8c76d4a1..33ee49f09 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -12,6 +12,7 @@ from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from transformers.utils.fx import _SUPPORTED_MODELS +import yaml from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager @@ -440,8 +441,33 @@ def main(args): return float_ppl, quant_ppl, model -def parse_args(args): +def override_defaults(args): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument( + '--config', + type=str, + default=None, + help= + 'Specify alternative default commandline args (e.g., config/default_template.yml). Default: %(default)s.' + ) + known_args = parser.parse_known_args()[0] # Returns a tuple + if known_args.config is not None: + with open(known_args.config, 'r') as f: + defaults = yaml.safe_load(f) + else: + defaults = {} + return defaults + + +def parse_args(args, override_defaults={}): parser = argparse.ArgumentParser() + parser.add_argument( + '--config', + type=str, + default=None, + help= + 'Specify alternative default commandline args (e.g., config/default_template.yml). Default: %(default)s.' + ) parser.add_argument( '--model', type=str, @@ -711,9 +737,11 @@ def parse_args(args): default=False, action="store_true", help='Whether to use fast update with learned round. Prototype (default: %(default)s)') + parser.set_defaults(**override_defaults) return parser.parse_args(args) if __name__ == '__main__': - args = parse_args(sys.argv[1:]) + overrides = override_defaults(sys.argv[1:]) + args = parse_args(sys.argv[1:], override_defaults=overrides) main(args)