Skip to content

Commit

Permalink
last changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 13, 2024
1 parent 604e3c4 commit 17328aa
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 113 deletions.
8 changes: 6 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,10 +1499,11 @@ def apply(self,

class LayerNormToRMS(GraphTransform):

def __init__(self) -> None:
def __init__(self, return_rewriters=False) -> None:
super(LayerNormToRMS, self).__init__()
self.supported_srcs = (nn.Linear, nn.Embedding)
self.supported_sinks = (nn.LayerNorm)
self.return_rewriters = return_rewriters
assert RMSNorm is not object, 'Update your Pytorch version to 2.4+'

def apply(self, graph_model: GraphModule) -> GraphModule:
Expand Down Expand Up @@ -1536,7 +1537,10 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
ModuleToModuleByInstance(layer_norm, RMSNorm, dtype=layer_norm_dtype))
for r in rewriters:
graph_model = r.apply(graph_model)
return graph_model, rewriters
if self.return_rewriters:
return graph_model, rewriters
else:
return graph_model


class MergeLnAffine(GraphTransform):
Expand Down
7 changes: 3 additions & 4 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def apply_layernorm_affine_merge(graph_model):


@torch.no_grad()
def apply_layernorm_to_rmsnorm(graph_model):
eq = LayerNormToRMS()
graph_model = eq.apply(graph_model)
return graph_model
def apply_layernorm_to_rmsnorm(graph_model, return_rewriters=False):
eq = LayerNormToRMS(return_rewriters)
return eq.apply(graph_model)
25 changes: 21 additions & 4 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from copy import deepcopy
import sys
from warnings import warn

Expand All @@ -18,6 +19,7 @@
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
from brevitas_examples.common.generative.quantize import generate_quant_maps
Expand Down Expand Up @@ -51,14 +53,16 @@ def fused_rotation_no_fx(model, calibration_loader, args):
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
apply_layernorm_affine_merge(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')

for r in rewriters:
r.apply(model)
new_model = offload_model(new_model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode)
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.graph_rotation_mode,
return_rewriters=True)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')

Expand Down Expand Up @@ -100,7 +104,7 @@ def model_export(model, ref_input, args):


def validate(args):
if args.graph_rotation:
if args.graph_rotation == 'fx':
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm'
Expand Down Expand Up @@ -329,7 +333,20 @@ def main(args):
input_quant_format=args.input_quant_format,
quantize_embedding=False)
if not args.quantize_last_layer:
name_blacklist += ["lm_head", "embed_out"]
if require_fx:
last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1]
last_module = get_module(model, last_node.target)
last_layer_kwargs = layer_map[type(last_module)][1]
prev_weight_quant = deepcopy(last_layer_kwargs['weight_quant'])
prev_input_quant = deepcopy(last_layer_kwargs['input_quant'])
weight_quant = lambda module: prev_weight_quant if id(module) != id(
last_module) else None
input_quant = lambda module: prev_input_quant if id(module) != id(
last_module) else None
last_layer_kwargs['weight_quant'] = weight_quant
last_layer_kwargs['input_quant'] = input_quant
else:
name_blacklist += ["lm_head", "embed_out"]
model = layerwise_quantize(
model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist)
# Tie back first/last layer weights in case they got untied
Expand Down
220 changes: 117 additions & 103 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,109 +282,123 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-fx"],
params=
[{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"exp_layer_types": {
"lm_head":
"<class 'torch.nn.modules.linear.Linear'>",
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
"act_calibration": False,
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant":
"<class 'brevitas.proxy.runtime_quant.ActQuantProxyFromInjector'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_quant_type": "sym",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_fnuz_e5m2",
"input_quant_type": "sym",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_scale_precision": "po2_scale",
"weight_param_method": "stats",
"weight_quant_granularity": "per_group",
"weight_group_size": 16,
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_scale_type": "dynamic",
"input_scale_precision": "po2_scale",
"input_param_method": "stats",
"input_quant_granularity": "per_group",
"input_group_size": 16,
"input_quant_type": "sym",
"act_calibration": False,
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
"replace_rmsnorm": True,
"quantize_last_layer": True,
"no_quantize": True,
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'brevitas.nn.equalized_layer.RotatedModule'>"}}])
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"exp_layer_types": {
"lm_head":
"<class 'torch.nn.modules.linear.Linear'>",
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
"act_calibration": False,
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant":
"<class 'brevitas.proxy.runtime_quant.ActQuantProxyFromInjector'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_quant_type": "sym",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_fnuz_e5m2",
"input_quant_type": "sym",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_scale_precision": "po2_scale",
"weight_param_method": "stats",
"weight_quant_granularity": "per_group",
"weight_group_size": 16,
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_scale_type": "dynamic",
"input_scale_precision": "po2_scale",
"input_param_method": "stats",
"input_quant_granularity": "per_group",
"input_group_size": 16,
"input_quant_type": "sym",
"act_calibration": False,
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
"replace_rmsnorm": True,
"quantize_last_layer": True,
"no_quantize": True,
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'brevitas.nn.equalized_layer.RotatedModule'>"}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
"replace_rmsnorm": True,
"quantize_last_layer": True,
"no_quantize": True,
"rotation_orphan_sink": False,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'torch.nn.modules.linear.Linear'>"}},])
def layer_args(default_run_args, request):
args = default_run_args
layer_dict = request.param
Expand Down

0 comments on commit 17328aa

Please sign in to comment.