diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index 1a10668c4..cd0746dc6 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -3,12 +3,15 @@ # SPDX-License-Identifier: BSD-3-Clause """ +import warnings + import torch from brevitas.fx.brevitas_tracer import value_trace from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizeGraph from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn +from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 @torch.no_grad() @@ -24,17 +27,47 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): @torch.no_grad() -def apply_act_equalization(model, dataloader, nsamples, seqlen=2048, alpha=0.5): - apply_layer_ptq_fn( +def apply_act_equalization( model, + act_equalization_type, dataloader, nsamples, - inference_fn=activation_equalization_iter, - seqlen=seqlen, - alpha=alpha) + seqlen=2048, + alpha=0.5, + ref_kwargs=None): + if act_equalization_type == 'layerwise': + apply_layer_ptq_fn( + model, + dataloader, + nsamples, + inference_fn=activation_equalization_iter, + seqlen=seqlen, + alpha=alpha) + elif act_equalization_type == 'fx': + assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX." + # We can't do fp16 tracing on CPU as many kernels are not implemented + # So we have to cast to fp32 first, trace, apply equalization, and then cast back + with cast_to_float32(model): + graph_model = value_trace(model, value_args=ref_kwargs) + # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode + # or an FX interpreter to run it on GPU + warnings.warn( + "FX mode activation equalization currently runs on CPU, expect it to be slow for large models." + ) + with activation_equalization_mode(graph_model, + alpha, + add_mul_node=False, + layerwise=False): + for input_ids in dataloader: + graph_model(input_ids=input_ids) + else: + raise RuntimeError(f"{act_equalization_type} not supported.") @torch.no_grad() def apply_weight_equalization(model, ref_kwargs, scale_computation_type='range'): - graph_model = value_trace(model, value_args=ref_kwargs) - EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) + # We can't do fp16 tracing on CPU as many kernels are not implemented + # So we have to cast to fp32 first, trace, apply equalization, and then cast back + with cast_to_float32(model): + graph_model = value_trace(model, value_args=ref_kwargs) + EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py index dd21d0fdb..19035a775 100644 --- a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -10,6 +10,7 @@ from brevitas.graph.equalize import _is_reshaping_op from brevitas.graph.equalize import _is_scale_invariant_module from brevitas.graph.utils import get_module +from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 def replace_bias(next_module, new_bias): @@ -73,6 +74,7 @@ def merge_layernorm_affine_params(graph_model): ) for module, merged in merged_dict.items(): if merged: + # We preserve weight and bias in case they are used to merge SmoothQuant scales in fx mode later on module.weight.data.fill_(1.) module.bias.data.fill_(0.) else: @@ -83,5 +85,8 @@ def merge_layernorm_affine_params(graph_model): @torch.no_grad() def apply_layernorm_affine_merge(model, ref_kwargs): - graph_model = value_trace(model, ref_kwargs) - merge_layernorm_affine_params(graph_model) + # We can't do fp16 tracing on CPU as many kernels are not implemented + # So we have to cast to fp32 first, trace, apply merging, and then cast back + with cast_to_float32(model): + graph_model = value_trace(model, ref_kwargs) + merge_layernorm_affine_params(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/quantize.py b/src/brevitas_examples/llm/llm_quant/quantize.py index ae2435a75..86768e272 100644 --- a/src/brevitas_examples/llm/llm_quant/quantize.py +++ b/src/brevitas_examples/llm/llm_quant/quantize.py @@ -139,7 +139,9 @@ def quantize_model( if input_quant is not None: input_quant = input_quant.let( **{ - 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point}) + 'bit_width': input_bit_width, + 'quantize_zero_point': quantize_input_zero_point, + 'dtype': dtype}) if input_quant_granularity == 'per_row': # QuantMHA internally always uses Seq, B, E input_quant = input_quant.let( @@ -150,7 +152,9 @@ def quantize_model( if sym_input_quant is not None: sym_input_quant = sym_input_quant.let( **{ - 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point}) + 'bit_width': input_bit_width, + 'quantize_zero_point': quantize_input_zero_point, + 'dtype': dtype}) if input_quant_granularity == 'per_row': q_scaled_quant = sym_input_quant.let( **{ @@ -169,7 +173,9 @@ def quantize_model( if per_tensor_input_quant is not None: per_tensor_input_quant = per_tensor_input_quant.let( **{ - 'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point}) + 'bit_width': input_bit_width, + 'quantize_zero_point': quantize_input_zero_point, + 'dtype': dtype}) quant_linear_kwargs = { 'input_quant': per_tensor_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index f5e9bdb7e..c81a62e4f 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -19,6 +19,8 @@ limitations under the License. """ +from contextlib import contextmanager + import torch from torch import nn from tqdm import tqdm @@ -140,3 +142,21 @@ def apply_layer_ptq_fn( input_capture_fn=calib_input_capture, seqlen=seqlen, **inference_fn_kwargs) + + +@contextmanager +def cast_to_float32(model): + dtype_dict = {} + for name, p in model.named_parameters(): + dtype_dict[name] = p.dtype + for name, b in model.named_buffers(): + dtype_dict[name] = b.dtype + if any(dtype != torch.float32 for dtype in dtype_dict.values()): + model.to(dtype=torch.float32) + try: + yield model + finally: + for name, p in model.named_parameters(): + p.data = p.data.to(dtype_dict[name]) + for name, b in model.named_buffers(): + b.data = b.data.to(dtype_dict[name]) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f9cff0c52..592c0e446 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,7 +4,6 @@ """ import argparse -import warnings import numpy as np import torch @@ -113,7 +112,12 @@ action='store_true', help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') parser.add_argument( - '--act-equalization', action='store_true', help='Apply activation equalization (SmoothQuant).') + '--act-equalization', + default=None, + choices=[None, 'layerwise', 'fx'], + help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,' + 'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).' +) parser.add_argument( '--export-target', default=None, @@ -168,10 +172,8 @@ def validate(args): assert args.quantize_weight_zero_point, "Quantized weight zero point required." if args.input_quant_type == 'asym': assert args.quantize_input_zero_point, "Quantized input zero point required." - if args.input_bit_width is not None and not args.act_calibration: - warnings.warn( - "Input quantization is being applied without activation calibration. Set --act-calibration." - ) + if args.input_bit_width: + assert args.act_calibration, "Input quantization is being applied without activation calibration. Set --act-calibration." def main(): @@ -204,9 +206,9 @@ def main(): apply_layernorm_affine_merge(model, ref_kwargs={'input_ids': calibration_loader[0]}) print("LN affine merge applied.") - # Insert standard MHA layers when performing weight equalization to avoid dealing + # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations - if args.weight_equalization or args.input_bit_width: + if args.weight_equalization or args.act_equalization == 'fx' or args.input_bit_width: print("Replace HF MHA with quantizable variants...") model = replace_mha_with_quantizable_layers(model, dtype) print("Replacing done.") @@ -216,9 +218,14 @@ def main(): apply_weight_equalization(model, ref_kwargs={'input_ids': calibration_loader[0]}) print("Weight equalization applied.") - if args.act_equalization: + if args.act_equalization is not None: print("Apply act equalization (SmoothQuant)...") - apply_act_equalization(model, calibration_loader, args.nsamples) + apply_act_equalization( + model, + args.act_equalization, + calibration_loader, + args.nsamples, + ref_kwargs={'input_ids': calibration_loader[0]}) print("Act equalization applied.") if not args.no_quantize: