Skip to content

Commit

Permalink
Fix (examples/llm): fix for main and README (#1092)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Nov 14, 2024
1 parent 054c961 commit 552d24f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
13 changes: 10 additions & 3 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def generate_quantizers(
scale_rounding_func_type=None,
device=None,
weight_kwargs=None,
input_kwargs=None):
input_kwargs=None,
scaling_min_val=1e-4):
"""
Replace float layers with quant layers in the target model
"""
Expand Down Expand Up @@ -299,8 +300,14 @@ def generate_quantizers(
if weight_group_dim is not None:
weight_quant = weight_quant.let(**{'group_dim': weight_group_dim})

if dtype == torch.float16:
weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4})
if dtype != torch.float32:
weight_quant = weight_quant.let(**{'scaling_min_val': scaling_min_val})
input_quant = input_quant.let(
**{'scaling_min_val': scaling_min_val}) if input_quant is not None else None
linear_input_quant = linear_input_quant.let(
**{'scaling_min_val': scaling_min_val}) if linear_input_quant is not None else None
sym_input_quant = sym_input_quant.let(
**{'scaling_min_val': scaling_min_val}) if sym_input_quant is not None else None
if weight_kwargs is not None:
weight_quant = weight_quant.let(**weight_kwargs)

Expand Down
27 changes: 22 additions & 5 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--weight-quant-type {sym,asym}]
[--weight-quant-format WEIGHT_QUANT_FORMAT]
[--weight-quant-granularity {per_channel,per_tensor,per_group}]
[--scale-rounding-func-type {round,ceil,floor}]
[--weight-group-dim {1,0}]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point]
Expand All @@ -35,12 +36,17 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--input-quant-granularity {per_tensor,per_row,per_group}]
[--input-group-size INPUT_GROUP_SIZE]
[--quantize-input-zero-point] [--quantize-last-layer] [--gptq]
[--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations] [--gpxq-create-weight-orig]
[--gpfq] [--gpxq-act-order] [--gpxq-use-quant-activations]
[--gpxq-create-weight-orig]
[--gpxq-max-accumulator-bit-width GPXQ_MAX_ACCUMULATOR_BIT_WIDTH]
[--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE]
[--act-calibration] [--bias-corr] [--ln-affine-merge]
[--no-quantize] [--no-float16] [--replace-mha]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--replace-mha]
[--weight-equalization]
[--graph-rotation {fx,layerwise,fused_no_fx}]
[--graph-rotation-mode {had,ort}] [--rotation-orphan-sink]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
Expand All @@ -54,7 +60,7 @@ options:
--seqlen SEQLEN Sequence length. Default: 2048.
--eval Eval model PPL on the chosen Dataset.
--dataset {wikitext2,c4}
Dataset to use for quantization (default: c4)
Dataset to use for quantization (default: wikitext2)
--gpxq-block-name GPXQ_BLOCK_NAME
Block name for faster GPxQ optimization. It works only
if FX is not needed (default: None)
Expand All @@ -74,6 +80,9 @@ options:
--weight-quant-granularity {per_channel,per_tensor,per_group}
Granularity for scales/zero-point of weights. Default:
per_group.
--scale-rounding-func-type {round,ceil,floor}
Rounding function to use with Po2 scale. Default:
None.
--weight-group-dim {1,0}
Override default group_dim for groupsize quantization.
Default: layer-dependant
Expand Down Expand Up @@ -125,21 +134,29 @@ options:
--act-calibration Apply activation calibration.
--bias-corr Apply bias correction.
--ln-affine-merge Merge LN affine params.
--convert-layernorm-to-rmsnorm
Merge LN affine params.
--replace-rmsnorm Replace HF RMSNorms with Torch one.
--no-quantize Disable quantization.
--no-float16 Disable float16 as base datatype and switch to
float32.
--scaling-min-val SCALING_MIN_VAL
Minimum value to clamp scale to when using bf16 or
fp16 quantization.
--replace-mha Replace HuggingFace Attention with a quantizable
version
--weight-equalization
Apply weight equalization. Relevant to ReLU based
models (e.g. OPT).
--graph-rotation Apply graph rotation equalization
--graph-rotation {fx,layerwise,fused_no_fx}
Apply graph rotation equalization
--graph-rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
partial rotation will always be Hadamard
--layerwise-rotation Apply layerwise rotation equalization
--rotation-orphan-sink
If GraphRotation is enabled, decide wheter to add
standalone hadamard matrices for the unfused layers
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
13 changes: 9 additions & 4 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ def main(args):
print("Layernorm To RMSNorm applied.")

if args.graph_rotation == 'fx':
assert args.ln_affine_merge
assert args.replace_rmsnorm
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode)
Expand Down Expand Up @@ -319,7 +317,8 @@ def main(args):
input_group_size=args.input_group_size,
quantize_input_zero_point=args.quantize_input_zero_point,
scale_rounding_func_type=args.scale_rounding_func_type,
device=device)
device=device,
scaling_min_val=args.scaling_min_val)
layer_map = generate_quant_maps(
linear_input_quant=linear_input_quant,
weight_quant=weight_quant,
Expand Down Expand Up @@ -363,7 +362,8 @@ def main(args):

model = offload_model(model)

model(**calibration_loader[0])
with torch.no_grad():
model(**calibration_loader[0])

if args.act_calibration:
print("Apply act calibration...")
Expand Down Expand Up @@ -586,6 +586,11 @@ def parse_args(args):
'--no-float16',
action='store_true',
help='Disable float16 as base datatype and switch to float32.')
parser.add_argument(
'--scaling-min-val',
type=float,
default=1e-4,
help='Minimum value to clamp scale to when using bf16 or fp16 quantization.')
parser.add_argument(
'--replace-mha',
action='store_true',
Expand Down

0 comments on commit 552d24f

Please sign in to comment.