From 552d24fba6abfa5fa7443fdbb6f24d69139f1657 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Nov 2024 13:43:32 +0000 Subject: [PATCH] Fix (examples/llm): fix for main and README (#1092) --- .../common/generative/quantize.py | 13 ++++++--- src/brevitas_examples/llm/README.md | 27 +++++++++++++++---- src/brevitas_examples/llm/main.py | 13 ++++++--- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 87392a83f..08543d4e4 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -227,7 +227,8 @@ def generate_quantizers( scale_rounding_func_type=None, device=None, weight_kwargs=None, - input_kwargs=None): + input_kwargs=None, + scaling_min_val=1e-4): """ Replace float layers with quant layers in the target model """ @@ -299,8 +300,14 @@ def generate_quantizers( if weight_group_dim is not None: weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) - if dtype == torch.float16: - weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4}) + if dtype != torch.float32: + weight_quant = weight_quant.let(**{'scaling_min_val': scaling_min_val}) + input_quant = input_quant.let( + **{'scaling_min_val': scaling_min_val}) if input_quant is not None else None + linear_input_quant = linear_input_quant.let( + **{'scaling_min_val': scaling_min_val}) if linear_input_quant is not None else None + sym_input_quant = sym_input_quant.let( + **{'scaling_min_val': scaling_min_val}) if sym_input_quant is not None else None if weight_kwargs is not None: weight_quant = weight_quant.let(**weight_kwargs) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index de55258db..64b87b3b1 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -23,6 +23,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--weight-quant-type {sym,asym}] [--weight-quant-format WEIGHT_QUANT_FORMAT] [--weight-quant-granularity {per_channel,per_tensor,per_group}] + [--scale-rounding-func-type {round,ceil,floor}] [--weight-group-dim {1,0}] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] @@ -35,12 +36,17 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--input-quant-granularity {per_tensor,per_row,per_group}] [--input-group-size INPUT_GROUP_SIZE] [--quantize-input-zero-point] [--quantize-last-layer] [--gptq] - [--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations] [--gpxq-create-weight-orig] + [--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations] + [--gpxq-create-weight-orig] [--gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH] [--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE] [--act-calibration] [--bias-corr] [--ln-affine-merge] - [--no-quantize] [--no-float16] [--replace-mha] + [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] + [--no-quantize] [--no-float16] + [--scaling-min-val SCALING_MIN_VAL] [--replace-mha] [--weight-equalization] + [--graph-rotation {fx,layerwise,fused_no_fx}] + [--graph-rotation-mode {had,ort}] [--rotation-orphan-sink] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] [--export-prefix EXPORT_PREFIX] @@ -54,7 +60,7 @@ options: --seqlen SEQLEN Sequence length. Default: 2048. --eval Eval model PPL on the chosen Dataset. --dataset {wikitext2,c4} - Dataset to use for quantization (default: c4) + Dataset to use for quantization (default: wikitext2) --gpxq-block-name GPXQ_BLOCK_NAME Block name for faster GPxQ optimization. It works only if FX is not needed (default: None) @@ -74,6 +80,9 @@ options: --weight-quant-granularity {per_channel,per_tensor,per_group} Granularity for scales/zero-point of weights. Default: per_group. + --scale-rounding-func-type {round,ceil,floor} + Rounding function to use with Po2 scale. Default: + None. --weight-group-dim {1,0} Override default group_dim for groupsize quantization. Default: layer-dependant @@ -125,21 +134,29 @@ options: --act-calibration Apply activation calibration. --bias-corr Apply bias correction. --ln-affine-merge Merge LN affine params. + --convert-layernorm-to-rmsnorm + Merge LN affine params. --replace-rmsnorm Replace HF RMSNorms with Torch one. --no-quantize Disable quantization. --no-float16 Disable float16 as base datatype and switch to float32. + --scaling-min-val SCALING_MIN_VAL + Minimum value to clamp scale to when using bf16 or + fp16 quantization. --replace-mha Replace HuggingFace Attention with a quantizable version --weight-equalization Apply weight equalization. Relevant to ReLU based models (e.g. OPT). - --graph-rotation Apply graph rotation equalization + --graph-rotation {fx,layerwise,fused_no_fx} + Apply graph rotation equalization --graph-rotation-mode {had,ort} If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard - --layerwise-rotation Apply layerwise rotation equalization + --rotation-orphan-sink + If GraphRotation is enabled, decide wheter to add + standalone hadamard matrices for the unfused layers --act-equalization {None,layerwise,fx} Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,while fx merges them diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 03035dd02..b74225a6a 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -260,8 +260,6 @@ def main(args): print("Layernorm To RMSNorm applied.") if args.graph_rotation == 'fx': - assert args.ln_affine_merge - assert args.replace_rmsnorm model = offload_model(model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode) @@ -319,7 +317,8 @@ def main(args): input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, scale_rounding_func_type=args.scale_rounding_func_type, - device=device) + device=device, + scaling_min_val=args.scaling_min_val) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, weight_quant=weight_quant, @@ -363,7 +362,8 @@ def main(args): model = offload_model(model) - model(**calibration_loader[0]) + with torch.no_grad(): + model(**calibration_loader[0]) if args.act_calibration: print("Apply act calibration...") @@ -586,6 +586,11 @@ def parse_args(args): '--no-float16', action='store_true', help='Disable float16 as base datatype and switch to float32.') + parser.add_argument( + '--scaling-min-val', + type=float, + default=1e-4, + help='Minimum value to clamp scale to when using bf16 or fp16 quantization.') parser.add_argument( '--replace-mha', action='store_true',