From 9b31212e11e0022941711f39f402ca67353f6ac3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Nov 2024 08:54:29 +0000 Subject: [PATCH] Fix (ptq/rotation): fix for rotation implementation (#1095) --- src/brevitas/graph/equalize.py | 8 +++----- src/brevitas_examples/llm/README.md | 9 ++++----- src/brevitas_examples/llm/main.py | 20 +++++++++++--------- tests/brevitas_examples/test_llm.py | 7 +++---- 4 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f39a951d0..4e5c1a162 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1356,18 +1356,16 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = module.weight.data if axis == 1: - weight = rot_func(weight, rot_mat, K) + _update_weights(module, rot_func(weight, rot_mat, K), 'weight') elif axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() + _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') else: raise RuntimeError("Not supported yet") - module.weight.data = weight if hasattr(module, 'offload_params'): module.offload_params(module) if insert_rotation_module and len(region.srcs) == 0: - # print(name, module.in_features, K) rewriter = ModuleInstanceToModuleInstance( module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) rewriters.append(rewriter) @@ -1467,7 +1465,7 @@ def rotate_matmuls(self, graph_module): def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: - + rewriters = [] regions = _extract_regions( graph_model, state_impl_kwargs={ diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 64b87b3b1..225923e3f 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -44,9 +44,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] [--no-quantize] [--no-float16] [--scaling-min-val SCALING_MIN_VAL] [--replace-mha] - [--weight-equalization] - [--graph-rotation {fx,layerwise,fused_no_fx}] - [--graph-rotation-mode {had,ort}] [--rotation-orphan-sink] + [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] + [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] [--export-prefix EXPORT_PREFIX] @@ -148,9 +147,9 @@ options: --weight-equalization Apply weight equalization. Relevant to ReLU based models (e.g. OPT). - --graph-rotation {fx,layerwise,fused_no_fx} + --rotation {fx,layerwise,fused_no_fx} Apply graph rotation equalization - --graph-rotation-mode {had,ort} + --rotation-mode {had,ort} If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b74225a6a..495c47919 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -61,7 +61,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): new_model = offload_model(new_model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.graph_rotation_mode, + full_rotation_method=args.rotation_mode, return_rewriters=True) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -104,10 +104,12 @@ def model_export(model, ref_input, args): def validate(args): - if args.graph_rotation == 'fx': + if args.rotation == 'fx': assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters' assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm' + elif args.rotation == 'fused_no_fx': + assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' if not args.no_quantize: if args.gptq and args.gpfq: warn("Both GPTQ and GPFQ are enabled.") @@ -259,16 +261,16 @@ def main(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - if args.graph_rotation == 'fx': + if args.rotation == 'fx': model = offload_model(model) eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode) + orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) model = eq.apply(model) remove_hooks(model) - elif args.graph_rotation == 'layerwise': + elif args.rotation == 'layerwise': eq = LayerwiseActivationRotation() model = eq.apply(model) - elif args.graph_rotation == 'fused_no_fx': + elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing @@ -354,7 +356,7 @@ def main(args): # If any equalization has taken places, the embedding layer and the fully connected one are # not tied anymore, and they need to be treated as standalone, separate layers. # In all other cases we can tie them back so to preserve memory. - if args.act_equalization is None and not require_fx: + if args.act_equalization is None and not require_fx and args.rotation is None: model.tie_weights() if args.bias_corr: @@ -600,13 +602,13 @@ def parse_args(args): action='store_true', help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') parser.add_argument( - '--graph-rotation', + '--rotation', type=str, default=None, choices=['fx', 'layerwise', 'fused_no_fx'], help='Apply graph rotation equalization') parser.add_argument( - '--graph-rotation-mode', + '--rotation-mode', default='had', choices=['had', 'ort'], help= diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 105a1ea8b..576af04b1 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -380,7 +380,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "no_quantize": True, "rotation_orphan_sink": True, "convert_layernorm_to_rmsnorm": True, - "graph_rotation": "fx", + "rotation": "fx", "exp_layer_types": { "L__self___model_layers_0_self_attn_k_proj": "", @@ -394,7 +394,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "no_quantize": True, "rotation_orphan_sink": False, "convert_layernorm_to_rmsnorm": True, - "graph_rotation": "fx", + "rotation": "fx", "exp_layer_types": { "L__self___model_layers_0_self_attn_k_proj": "", @@ -417,8 +417,7 @@ def test_small_models_quant_layer(caplog, layer_args): if args.replace_rmsnorm: if torch_version < version.parse('2.4'): pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater") - if hasattr(args, 'graph_rotation') and args.graph_rotation == 'fx' and platform.system( - ) == 'Windows': + if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows': pytest.skip("Skipping dynamo + windows") float_ppl, quant_ppl, model = validate_args_and_run_main(args) assert_layer_types(model, exp_layer_types)