diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index dc5d42a03..1143186f1 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace -from collections import defaultdict from dataclasses import dataclass import logging import os @@ -83,7 +82,7 @@ def assert_layer_types_count(model, exp_layer_types_count): for name, count in exp_layer_types_count.items(): curr_count = 0 if name not in layer_types_count else layer_types_count[name] - assert count == curr_count, f"Expect {count} instances of layer type: {name}, found {curr_count}." + assert count == curr_count, f"Expected {count} instances of layer type: {name}, found {curr_count}." class UpdatableNamespace(Namespace): @@ -299,9 +298,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "mistral-fp8_fnuz", "llama-mxfp8", "llama-int8-act_equalization=layerwise", - "mistral-int8-quant-last-layer", - "llama-rotation-mixed-fx", - "llama-rotation-full-fx",], + "mistral-int8-quant-last-layer",], params=[ { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", @@ -314,12 +311,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": "",}, - "exp_layer_types_count": { - "": 1, # LM Head - "": - 14, # Q/K/V/O projs + Up/Gate/Down projs - "": 28, - }}, # input_quant/weight_quant + }, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "input_bit_width": None, @@ -331,12 +323,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": "",}, - "exp_layer_types_count": { - "": 1, # LM Head - "": - 14, # Q/K/V/O projs + Up/Gate/Down projs - "": 14, - }}, # input_quant/weight_quant + }, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "weight_quant_format": "float_ocp_e4m3", @@ -350,11 +337,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": "",}, - "exp_layer_types_count": { - "": 1, # LM Head - "": - 14, # Q/K/V/O projs + Up/Gate/Down projs - "": 28,}}, # input_quant/weight_quant + }, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "weight_quant_format": "float_fnuz_e4m3", @@ -368,11 +351,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": "",}, - "exp_layer_types_count": { - "": 1, # LM Head - "": - 14, # Q/K/V/O projs + Up/Gate/Down projs - "": 28,}}, # input_quant/weight_quant + }, # input_quant/weight_quant { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "weight_quant_format": "float_ocp_e4m3", @@ -399,7 +378,112 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": "", "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl": - "",}, + "",},}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_equalization": "layerwise", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.layer": + "",},}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "quantize_last_layer": True, + "exp_layer_types": { + "lm_head": ""},}, + ]) # LM Head + Q/K/V/O projs + Up/Gate/Down projs +def layer_args(default_run_args, request): + args = default_run_args + layer_dict = request.param + exp_layer_types = layer_dict["exp_layer_types"] + del layer_dict["exp_layer_types"] + args.update(**layer_dict) + yield args, exp_layer_types + + +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_quant_layer(caplog, layer_args): + caplog.set_level(logging.INFO) + args, exp_layer_types = layer_args + if args.replace_rmsnorm: + if torch_version < version.parse('2.4'): + pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater") + if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows': + pytest.skip("Skipping dynamo + windows") + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + assert_layer_types(model, exp_layer_types) + + +@pytest_cases.fixture( + ids=[ + "mistral-int8", + "mistral-weight-only", + "mistral-fp8_ocp", + "mistral-fp8_fnuz", + "llama-mxfp8", + "llama-int8-act_equalization=layerwise", + "mistral-int8-quant-last-layer", + "llama-rotation-mixed-fx", + "llama-rotation-full-fx",], + params=[ + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28, + }}, # input_quant/weight_quant + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "input_bit_width": None, + "act_calibration": False, + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 14, + }}, # input_quant/weight_quant + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "weight_quant_format": "float_ocp_e4m3", + "weight_quant_type": "sym", + "input_quant_format": "float_ocp_e5m2", + "input_quant_type": "sym", + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28,}}, # input_quant/weight_quant + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "weight_quant_format": "float_fnuz_e4m3", + "weight_quant_type": "sym", + "input_quant_format": "float_fnuz_e5m2", + "input_quant_type": "sym", + "exp_layer_types_count": { + "": 1, # LM Head + "": + 14, # Q/K/V/O projs + Up/Gate/Down projs + "": 28,}}, # input_quant/weight_quant + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "weight_quant_format": "float_ocp_e4m3", + "weight_scale_precision": "po2_scale", + "weight_param_method": "stats", + "weight_quant_granularity": "per_group", + "weight_group_size": 16, + "weight_quant_type": "sym", + "input_quant_format": "float_ocp_e5m2", + "input_scale_type": "dynamic", + "input_scale_precision": "po2_scale", + "input_param_method": "stats", + "input_quant_granularity": "per_group", + "input_group_size": 16, + "input_quant_type": "sym", + "act_calibration": False, "exp_layer_types_count": { "": 14, # Q/K/V/O projs + Up/Gate/Down projs @@ -413,11 +497,6 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_equalization": "layerwise", - "exp_layer_types": { - "model.layers.0.self_attn.q_proj": - "", - "model.layers.0.self_attn.q_proj.layer": - "",}, "exp_layer_types_count": { "": 14, # Q/K/V/O projs + Up/Gate/Down projs @@ -428,8 +507,6 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "quantize_last_layer": True, - "exp_layer_types": { - "lm_head": ""}, "exp_layer_types_count": { "": 15, }}, # LM Head + Q/K/V/O projs + Up/Gate/Down projs @@ -442,11 +519,6 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "rotation_orphan_sink": True, "convert_layernorm_to_rmsnorm": True, "rotation": "fx", - "exp_layer_types": { - "L__self___model_layers_0_self_attn_k_proj": - "", - "L__self___model_layers_0_self_attn_o_proj": - "",}, "exp_layer_types_count": { "": 4, # Sinks: O proj + Down proj @@ -463,42 +535,31 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "rotation_orphan_sink": False, "convert_layernorm_to_rmsnorm": True, "rotation": "fx", - "exp_layer_types": { - "L__self___model_layers_0_self_attn_k_proj": - "", - "L__self___model_layers_0_self_attn_o_proj": - ""}, "exp_layer_types_count": { "": 15, # LM Head + Q/K/V projs + Up/Gate/Down projs "": 5, # Input + Post attention "": 0,}},]) -def layer_args(default_run_args, request): +def layer_args_types_count(default_run_args, request): args = default_run_args layer_dict = request.param - exp_layer_types = layer_dict["exp_layer_types"] exp_layer_types_count = layer_dict["exp_layer_types_count"] - del layer_dict["exp_layer_types"] del layer_dict["exp_layer_types_count"] args.update(**layer_dict) - yield args, exp_layer_types, exp_layer_types_count + yield args, exp_layer_types_count -def test_small_models_quant_layer(caplog, layer_args): +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_quant_layer_types_count(caplog, layer_args_types_count): caplog.set_level(logging.INFO) - args, exp_layer_types, exp_layer_types_count = layer_args + args, exp_layer_types_count = layer_args_types_count if args.replace_rmsnorm: if torch_version < version.parse('2.4'): pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater") if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows': pytest.skip("Skipping dynamo + windows") float_ppl, quant_ppl, model = validate_args_and_run_main(args) - # Naming of modules in the GraphModule generated by FX changes across transformers versions, e.g. - # (4.45.0)"L__self___model_layers_2_self_attn_k_proj" -> - # (4.46.0) 'L__self___model_layers_slice_None__2__None___0_self_attn_q_proj' - # Therefore, this check is skipped when rotation="fx". - if args.rotation != "fx": - assert_layer_types(model, exp_layer_types) assert_layer_types_count(model, exp_layer_types_count)