Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Dec 9, 2024
1 parent f8a2542 commit dc4fe0f
Showing 1 changed file with 118 additions and 57 deletions.
175 changes: 118 additions & 57 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass
import logging
import os
Expand Down Expand Up @@ -83,7 +82,7 @@ def assert_layer_types_count(model, exp_layer_types_count):

for name, count in exp_layer_types_count.items():
curr_count = 0 if name not in layer_types_count else layer_types_count[name]
assert count == curr_count, f"Expect {count} instances of layer type: {name}, found {curr_count}."
assert count == curr_count, f"Expected {count} instances of layer type: {name}, found {curr_count}."


class UpdatableNamespace(Namespace):
Expand Down Expand Up @@ -299,9 +298,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"mistral-fp8_fnuz",
"llama-mxfp8",
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
"mistral-int8-quant-last-layer",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
Expand All @@ -314,12 +311,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 28,
}}, # input_quant/weight_quant
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
Expand All @@ -331,12 +323,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"<class 'brevitas.proxy.runtime_quant.ActQuantProxyFromInjector'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 14,
}}, # input_quant/weight_quant
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -350,11 +337,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
Expand All @@ -368,11 +351,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -399,7 +378,112 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",},
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",},},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"},},
]) # LM Head + Q/K/V/O projs + Up/Gate/Down projs
def layer_args(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types"]
args.update(**layer_dict)
yield args, exp_layer_types


@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer(caplog, layer_args):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)


@pytest_cases.fixture(
ids=[
"mistral-int8",
"mistral-weight-only",
"mistral-fp8_ocp",
"mistral-fp8_fnuz",
"llama-mxfp8",
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 28,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
"act_calibration": False,
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 14,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_quant_type": "sym",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_fnuz_e5m2",
"input_quant_type": "sym",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_scale_precision": "po2_scale",
"weight_param_method": "stats",
"weight_quant_granularity": "per_group",
"weight_group_size": 16,
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_scale_type": "dynamic",
"input_scale_precision": "po2_scale",
"input_param_method": "stats",
"input_quant_granularity": "per_group",
"input_group_size": 16,
"input_quant_type": "sym",
"act_calibration": False,
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
Expand All @@ -413,11 +497,6 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
Expand All @@ -428,8 +507,6 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"},
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>": 15,
}}, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
Expand All @@ -442,11 +519,6 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'brevitas.nn.equalized_layer.RotatedModule'>",},
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>":
4, # Sinks: O proj + Down proj
Expand All @@ -463,42 +535,31 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"rotation_orphan_sink": False,
"convert_layernorm_to_rmsnorm": True,
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'torch.nn.modules.linear.Linear'>"},
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5, # Input + Post attention
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},])
def layer_args(default_run_args, request):
def layer_args_types_count(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
exp_layer_types_count = layer_dict["exp_layer_types_count"]
del layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types_count"]
args.update(**layer_dict)
yield args, exp_layer_types, exp_layer_types_count
yield args, exp_layer_types_count


def test_small_models_quant_layer(caplog, layer_args):
@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer_types_count(caplog, layer_args_types_count):
caplog.set_level(logging.INFO)
args, exp_layer_types, exp_layer_types_count = layer_args
args, exp_layer_types_count = layer_args_types_count
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
# Naming of modules in the GraphModule generated by FX changes across transformers versions, e.g.
# (4.45.0)"L__self___model_layers_2_self_attn_k_proj" ->
# (4.46.0) 'L__self___model_layers_slice_None__2__None___0_self_attn_q_proj'
# Therefore, this check is skipped when rotation="fx".
if args.rotation != "fx":
assert_layer_types(model, exp_layer_types)
assert_layer_types_count(model, exp_layer_types_count)


Expand Down

0 comments on commit dc4fe0f

Please sign in to comment.