From 17328aacbcc698c1d15fed409c168436eb9e3475 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 13 Nov 2024 19:58:06 +0000 Subject: [PATCH] last changes --- src/brevitas/graph/equalize.py | 8 +- .../llm/llm_quant/ln_affine_merge.py | 7 +- src/brevitas_examples/llm/main.py | 25 +- tests/brevitas_examples/test_llm.py | 220 ++++++++++-------- 4 files changed, 147 insertions(+), 113 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index a1ab49b36..f39a951d0 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1499,10 +1499,11 @@ def apply(self, class LayerNormToRMS(GraphTransform): - def __init__(self) -> None: + def __init__(self, return_rewriters=False) -> None: super(LayerNormToRMS, self).__init__() self.supported_srcs = (nn.Linear, nn.Embedding) self.supported_sinks = (nn.LayerNorm) + self.return_rewriters = return_rewriters assert RMSNorm is not object, 'Update your Pytorch version to 2.4+' def apply(self, graph_model: GraphModule) -> GraphModule: @@ -1536,7 +1537,10 @@ def apply(self, graph_model: GraphModule) -> GraphModule: ModuleToModuleByInstance(layer_norm, RMSNorm, dtype=layer_norm_dtype)) for r in rewriters: graph_model = r.apply(graph_model) - return graph_model, rewriters + if self.return_rewriters: + return graph_model, rewriters + else: + return graph_model class MergeLnAffine(GraphTransform): diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py index 1e004e95f..75e31826a 100644 --- a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -111,7 +111,6 @@ def apply_layernorm_affine_merge(graph_model): @torch.no_grad() -def apply_layernorm_to_rmsnorm(graph_model): - eq = LayerNormToRMS() - graph_model = eq.apply(graph_model) - return graph_model +def apply_layernorm_to_rmsnorm(graph_model, return_rewriters=False): + eq = LayerNormToRMS(return_rewriters) + return eq.apply(graph_model) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 95f7eae60..f0613b3f6 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +from copy import deepcopy import sys from warnings import warn @@ -18,6 +19,7 @@ from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize +from brevitas.graph.utils import get_module from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.generative.quantize import generate_quant_maps @@ -51,14 +53,16 @@ def fused_rotation_no_fx(model, calibration_loader, args): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) - new_model, rewriters = apply_layernorm_to_rmsnorm(new_model) + new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: r.apply(model) new_model = offload_model(new_model) eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode) + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.graph_rotation_mode, + return_rewriters=True) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -100,7 +104,7 @@ def model_export(model, ref_input, args): def validate(args): - if args.graph_rotation: + if args.graph_rotation == 'fx': assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters' assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm' @@ -329,7 +333,20 @@ def main(args): input_quant_format=args.input_quant_format, quantize_embedding=False) if not args.quantize_last_layer: - name_blacklist += ["lm_head", "embed_out"] + if require_fx: + last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] + last_module = get_module(model, last_node.target) + last_layer_kwargs = layer_map[type(last_module)][1] + prev_weight_quant = deepcopy(last_layer_kwargs['weight_quant']) + prev_input_quant = deepcopy(last_layer_kwargs['input_quant']) + weight_quant = lambda module: prev_weight_quant if id(module) != id( + last_module) else None + input_quant = lambda module: prev_input_quant if id(module) != id( + last_module) else None + last_layer_kwargs['weight_quant'] = weight_quant + last_layer_kwargs['input_quant'] = input_quant + else: + name_blacklist += ["lm_head", "embed_out"] model = layerwise_quantize( model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) # Tie back first/last layer weights in case they got untied diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index c1bf4d0f6..f933cffea 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -282,109 +282,123 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "llama-int8-act_equalization=layerwise", "mistral-int8-quant-last-layer", "llama-rotation-fx"], - params= - [{ - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "exp_layer_types": { - "lm_head": - "", - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": - "", - "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, - { - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "input_bit_width": None, - "act_calibration": False, - "exp_layer_types": { - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.input_quant": - "", - "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, - { - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "weight_quant_format": "float_ocp_e4m3", - "weight_quant_type": "sym", - "input_quant_format": "float_ocp_e5m2", - "input_quant_type": "sym", - "exp_layer_types": { - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": - "", - "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, - { - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "weight_quant_format": "float_fnuz_e4m3", - "weight_quant_type": "sym", - "input_quant_format": "float_fnuz_e5m2", - "input_quant_type": "sym", - "exp_layer_types": { - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": - "", - "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "",}}, - { - "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", - "weight_quant_format": "float_ocp_e4m3", - "weight_scale_precision": "po2_scale", - "weight_param_method": "stats", - "weight_quant_granularity": "per_group", - "weight_group_size": 16, - "weight_quant_type": "sym", - "input_quant_format": "float_ocp_e5m2", - "input_scale_type": "dynamic", - "input_scale_precision": "po2_scale", - "input_param_method": "stats", - "input_quant_granularity": "per_group", - "input_group_size": 16, - "input_quant_type": "sym", - "act_calibration": False, - "exp_layer_types": { - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": - "", - "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.input_view_impl": - "", - "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": - "", - "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl": - "",}}, - { - "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", - "act_equalization": "layerwise", - "exp_layer_types": { - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.layer": - "",}}, - { - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "quantize_last_layer": True, - "exp_layer_types": { - "lm_head": ""}}, - { - "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", - "ln_affine_merge": True, - "replace_rmsnorm": True, - "quantize_last_layer": True, - "no_quantize": True, - "rotation_orphan_sink": True, - "convert_layernorm_to_rmsnorm": True, - "graph_rotation": "fx", - "exp_layer_types": { - "L__self___model_layers_0_self_attn_k_proj": - "", - "L__self___model_layers_0_self_attn_o_proj": - ""}}]) + params=[ + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "exp_layer_types": { + "lm_head": + "", + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "input_bit_width": None, + "act_calibration": False, + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "weight_quant_format": "float_ocp_e4m3", + "weight_quant_type": "sym", + "input_quant_format": "float_ocp_e5m2", + "input_quant_type": "sym", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "weight_quant_format": "float_fnuz_e4m3", + "weight_quant_type": "sym", + "input_quant_format": "float_fnuz_e5m2", + "input_quant_type": "sym", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "weight_quant_format": "float_ocp_e4m3", + "weight_scale_precision": "po2_scale", + "weight_param_method": "stats", + "weight_quant_granularity": "per_group", + "weight_group_size": 16, + "weight_quant_type": "sym", + "input_quant_format": "float_ocp_e5m2", + "input_scale_type": "dynamic", + "input_scale_precision": "po2_scale", + "input_param_method": "stats", + "input_quant_granularity": "per_group", + "input_group_size": 16, + "input_quant_type": "sym", + "act_calibration": False, + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.input_view_impl": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl": + "",}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_equalization": "layerwise", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.layer": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "quantize_last_layer": True, + "exp_layer_types": { + "lm_head": ""}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "ln_affine_merge": True, + "replace_rmsnorm": True, + "quantize_last_layer": True, + "no_quantize": True, + "rotation_orphan_sink": True, + "convert_layernorm_to_rmsnorm": True, + "graph_rotation": "fx", + "exp_layer_types": { + "L__self___model_layers_0_self_attn_k_proj": + "", + "L__self___model_layers_0_self_attn_o_proj": + ""}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "ln_affine_merge": True, + "replace_rmsnorm": True, + "quantize_last_layer": True, + "no_quantize": True, + "rotation_orphan_sink": False, + "convert_layernorm_to_rmsnorm": True, + "graph_rotation": "fx", + "exp_layer_types": { + "L__self___model_layers_0_self_attn_k_proj": + "", + "L__self___model_layers_0_self_attn_o_proj": + ""}},]) def layer_args(default_run_args, request): args = default_run_args layer_dict = request.param