Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llm tests transformers #1118

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 162 additions & 27 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def assert_layer_types(model, exp_layer_types):
assert matched, f"Layer key: {key} not found in {layer_names}"


def assert_layer_types_count(model, exp_layer_types_count):
layer_types_count = {}
for name, layer in model.named_modules():
ltype = str(type(layer))
if ltype not in layer_types_count:
layer_types_count[ltype] = 0
layer_types_count[ltype] += 1

for name, count in exp_layer_types_count.items():
curr_count = 0 if name not in layer_types_count else layer_types_count[name]
assert count == curr_count, f"Expected {count} instances of layer type: {name}, found {curr_count}."


class UpdatableNamespace(Namespace):

def update(self, **kwargs):
Expand Down Expand Up @@ -293,9 +306,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"mistral-fp8_fnuz",
"llama-mxfp8",
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
"mistral-int8-quant-last-layer",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
Expand All @@ -307,7 +318,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
Expand All @@ -318,7 +330,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant":
"<class 'brevitas.proxy.runtime_quant.ActQuantProxyFromInjector'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",}},
"<class 'brevitas.core.quant.int.RescalingIntQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -331,7 +344,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
"<class 'brevitas.core.quant.float.FloatQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
Expand All @@ -344,7 +358,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",}},
"<class 'brevitas.core.quant.float.FloatQuant'>",},
}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
Expand All @@ -371,20 +386,138 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl":
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",}},
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>",},},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types": {
"model.layers.0.self_attn.q_proj":
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>",
"model.layers.0.self_attn.q_proj.layer":
"<class 'brevitas.nn.quant_linear.QuantLinear'>",}},
"<class 'brevitas.nn.quant_linear.QuantLinear'>",},},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types": {
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"}},
"lm_head": "<class 'brevitas.nn.quant_linear.QuantLinear'>"},},
]) # LM Head + Q/K/V/O projs + Up/Gate/Down projs
def layer_args(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types"]
args.update(**layer_dict)
yield args, exp_layer_types


@pytest.mark.llm
@requires_pt_ge('2.2')
def test_small_models_quant_layer(caplog, layer_args):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)


@pytest_cases.fixture(
ids=[
"mistral-int8",
"mistral-weight-only",
"mistral-fp8_ocp",
"mistral-fp8_fnuz",
"llama-mxfp8",
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 28,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"input_bit_width": None,
"act_calibration": False,
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.int.RescalingIntQuant'>": 14,
}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_quant_type": "sym",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"weight_quant_format": "float_fnuz_e4m3",
"weight_quant_type": "sym",
"input_quant_format": "float_fnuz_e5m2",
"input_quant_type": "sym",
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28,}}, # input_quant/weight_quant
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_format": "float_ocp_e4m3",
"weight_scale_precision": "po2_scale",
"weight_param_method": "stats",
"weight_quant_granularity": "per_group",
"weight_group_size": 16,
"weight_quant_type": "sym",
"input_quant_format": "float_ocp_e5m2",
"input_scale_type": "dynamic",
"input_scale_precision": "po2_scale",
"input_param_method": "stats",
"input_quant_granularity": "per_group",
"input_group_size": 16,
"input_quant_type": "sym",
"act_calibration": False,
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'brevitas.core.quant.float.FloatQuant'>": 28, # input_quant/weight_quant
"<class 'brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView'>":
14, # input_quant..input_view_impl/input_quant..scaling_impl.input_view_impl
"<class 'brevitas.core.function_wrapper.shape.OverSubChannelBlockView'>":
28, # weight_quant..input_view_impl/weight_quant..scaling_impl.input_view_impl
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>": 5,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "layerwise",
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>":
14, # Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.linear.Linear'>": 1, # LM Head
"<class 'brevitas.nn.equalized_layer.EqualizedModule'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>": 5,}},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"quantize_last_layer": True,
"exp_layer_types_count": {
"<class 'brevitas.nn.quant_linear.QuantLinear'>": 15,
}}, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
Expand All @@ -394,11 +527,13 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'brevitas.nn.equalized_layer.RotatedModule'>"}},
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>":
4, # Sinks: O proj + Down proj
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5,
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
Expand All @@ -408,32 +543,32 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"rotation_orphan_sink": False,
"convert_layernorm_to_rmsnorm": True,
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
"L__self___model_layers_0_self_attn_o_proj":
"<class 'torch.nn.modules.linear.Linear'>"}},])
def layer_args(default_run_args, request):
"exp_layer_types_count": {
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5, # Input + Post attention
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},])
def layer_args_types_count(default_run_args, request):
args = default_run_args
layer_dict = request.param
exp_layer_types = layer_dict["exp_layer_types"]
del layer_dict["exp_layer_types"]
exp_layer_types_count = layer_dict["exp_layer_types_count"]
del layer_dict["exp_layer_types_count"]
args.update(**layer_dict)
yield args, exp_layer_types
yield args, exp_layer_types_count


@pytest.mark.llm
@requires_pt_ge('2.2')
nickfraser marked this conversation as resolved.
Show resolved Hide resolved
def test_small_models_quant_layer(caplog, layer_args):
def test_small_models_quant_layer_types_count(caplog, layer_args_types_count):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
args, exp_layer_types_count = layer_args_types_count
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)
assert_layer_types_count(model, exp_layer_types_count)


@pytest_cases.fixture(
Expand Down
Loading