Skip to content

Commit

Permalink
Fix llm tests transformers (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago authored Dec 11, 2024
1 parent abd9855 commit 482531c
Showing 1 changed file with 162 additions and 27 deletions.
189 changes: 162 additions & 27 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def assert_layer_types(model, exp_layer_types):
assert matched, f"Layer key: {key} not found in {layer_names}"


def assert_layer_types_count(model, exp_layer_types_count):
layer_types_count = {}
for name, layer in model.named_modules():
ltype = str(type(layer))
if ltype not in layer_types_count:
layer_types_count[ltype] = 0
layer_types_count[ltype] += 1

for name, count in exp_layer_types_count.items():
curr_count = 0 if name not in layer_types_count else layer_types_count[name]
assert count == curr_count, f"Expected {count} instances of layer type: {name}, found {curr_count}."


class UpdatableNamespace(Namespace):

def update(self, **kwargs):
Expand Down Expand Up @@ -293,9 +306,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"mistral-fp8_fnuz",
"llama-mxfp8",
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
"mistral-int8-quant-last-layer",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
Expand All @@ -307,7 +318,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
Expand All @@ -318,7 +330,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant":
"<class 'brevitas.proxy.runtime_quant.ActQuantProxyFromInjector'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -331,7 +344,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
"<class 'brevitas.core.quant.float.FloatQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
Expand All @@ -344,7 +358,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
"<class 'brevitas.core.quant.float.FloatQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -371,20 +386,138 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",}},
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",},},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",}},
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"}},
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"},},
]) # LM Head + Q/K/V/O projs + Up/Gate/Down projs
def layer_args(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types"]
args.update(**layer_dict)
yield args, exp_layer_types


@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer(caplog, layer_args):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)


@pytest_cases.fixture(
ids=[
"mistral-int8",
"mistral-weight-only",
"mistral-fp8_ocp",
"mistral-fp8_fnuz",
"llama-mxfp8",
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 28,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
"act_calibration": False,
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 14,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_quant_type": "sym",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_fnuz_e5m2",
"input_quant_type": "sym",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_scale_precision": "po2_scale",
"weight_param_method": "stats",
"weight_quant_granularity": "per_group",
"weight_group_size": 16,
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_scale_type": "dynamic",
"input_scale_precision": "po2_scale",
"input_param_method": "stats",
"input_quant_granularity": "per_group",
"input_group_size": 16,
"input_quant_type": "sym",
"act_calibration": False,
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28, # input_quant/weight_quant
"<class 'brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView'>":
14, # input_quant..input_view_impl/input_quant..scaling_impl.input_view_impl
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>":
28, # weight_quant..input_view_impl/weight_quant..scaling_impl.input_view_impl
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>": 5,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>": 5,}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>": 15,
}}, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
Expand All @@ -394,11 +527,13 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'brevitas.nn.equalized_layer.RotatedModule'>"}},
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>":
4, # Sinks: O proj + Down proj
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5,
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
Expand All @@ -408,32 +543,32 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"rotation_orphan_sink": False,
"convert_layernorm_to_rmsnorm": True,
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'torch.nn.modules.linear.Linear'>"}},])
def layer_args(default_run_args, request):
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5, # Input + Post attention
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},])
def layer_args_types_count(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types"]
exp_layer_types_count = layer_dict["exp_layer_types_count"]
del layer_dict["exp_layer_types_count"]
args.update(**layer_dict)
yield args, exp_layer_types
yield args, exp_layer_types_count


@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer(caplog, layer_args):
def test_small_models_quant_layer_types_count(caplog, layer_args_types_count):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
args, exp_layer_types_count = layer_args_types_count
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)
assert_layer_types_count(model, exp_layer_types_count)


@pytest_cases.fixture(
Expand Down

0 comments on commit 482531c

Please sign in to comment.