From af23a55444526d1518241f2ddf0c733e2883369d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Nov 2024 08:16:12 +0000 Subject: [PATCH] Fix (examples/llm): change rotation interface --- src/brevitas/graph/equalize.py | 3 +-- src/brevitas_examples/llm/README.md | 9 ++++----- src/brevitas_examples/llm/main.py | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f39a951d0..d6a327bab 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1367,7 +1367,6 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= module.offload_params(module) if insert_rotation_module and len(region.srcs) == 0: - # print(name, module.in_features, K) rewriter = ModuleInstanceToModuleInstance( module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) rewriters.append(rewriter) @@ -1467,7 +1466,7 @@ def rotate_matmuls(self, graph_module): def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: - + rewriters = [] regions = _extract_regions( graph_model, state_impl_kwargs={ diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 64b87b3b1..225923e3f 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -44,9 +44,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] [--no-quantize] [--no-float16] [--scaling-min-val SCALING_MIN_VAL] [--replace-mha] - [--weight-equalization] - [--graph-rotation {fx,layerwise,fused_no_fx}] - [--graph-rotation-mode {had,ort}] [--rotation-orphan-sink] + [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] + [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] [--export-prefix EXPORT_PREFIX] @@ -148,9 +147,9 @@ options: --weight-equalization Apply weight equalization. Relevant to ReLU based models (e.g. OPT). - --graph-rotation {fx,layerwise,fused_no_fx} + --rotation {fx,layerwise,fused_no_fx} Apply graph rotation equalization - --graph-rotation-mode {had,ort} + --rotation-mode {had,ort} If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b74225a6a..5df431786 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -61,7 +61,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): new_model = offload_model(new_model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.graph_rotation_mode, + full_rotation_method=args.rotation_mode, return_rewriters=True) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -104,10 +104,12 @@ def model_export(model, ref_input, args): def validate(args): - if args.graph_rotation == 'fx': + if args.rotation == 'fx': assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters' assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm' + elif args.rotation == 'fused_no_fx': + assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' if not args.no_quantize: if args.gptq and args.gpfq: warn("Both GPTQ and GPFQ are enabled.") @@ -259,16 +261,16 @@ def main(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - if args.graph_rotation == 'fx': + if args.rotation == 'fx': model = offload_model(model) eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode) + orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) model = eq.apply(model) remove_hooks(model) - elif args.graph_rotation == 'layerwise': + elif args.rotation == 'layerwise': eq = LayerwiseActivationRotation() model = eq.apply(model) - elif args.graph_rotation == 'fused_no_fx': + elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing @@ -600,13 +602,13 @@ def parse_args(args): action='store_true', help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') parser.add_argument( - '--graph-rotation', + '--rotation', type=str, default=None, choices=['fx', 'layerwise', 'fused_no_fx'], help='Apply graph rotation equalization') parser.add_argument( - '--graph-rotation-mode', + '--rotation-mode', default='had', choices=['had', 'ort'], help=