From 66d6661c7aa073810518186a577e64872937cfad Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 20 Oct 2024 15:57:34 +0100 Subject: [PATCH] First prototype impl --- src/brevitas/graph/equalize.py | 81 ++-- src/brevitas/nn/equalized_layer.py | 32 +- .../llm/llm_quant/ln_affine_merge.py | 25 +- src/brevitas_examples/llm/main.py | 49 +- tests/brevitas/graph/equalization_fixtures.py | 33 ++ tests/brevitas/graph/test_equalization.py | 427 ++++++++++-------- 6 files changed, 387 insertions(+), 260 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 8b470091d..025abddd7 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -18,6 +18,7 @@ from brevitas.fx import Node from brevitas.graph.base import GraphTransform from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.hadamard import get_hadK, matmul_hadU, matmul_hadU_cuda from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule, RotatedModule @@ -25,12 +26,8 @@ from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook -from .base import GraphTransform from .base import InsertModuleCallAfter -try: - from scipy.linalg import hadamard -except ImportError: - hadamard = None + __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] EPSILON = 1e-9 @@ -680,7 +677,7 @@ def _is_scale_varying_activation(graph_model, node): def _is_scale_invariant_function(node: Node, scale_invariant_op: Set =_scale_invariant_op) -> bool: - out = node.op == 'call_function' and node.target in scale_invariant_op + _select_op + _reshaping_op + out = node.op in ('call_function', 'call_method') and node.target in scale_invariant_op + _select_op + _reshaping_op if node.target == torch.nn.functional.interpolate: out &= node.kwargs.get('mode', None) == 'nearest' return out @@ -1205,55 +1202,75 @@ def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): rewriter = InsertModuleCallAfter(mul_factor_name, act_node) rewriter.apply(self.model) +def _apply_had_device(tensor, had_K, K): + is_cuda = 'cuda' in str(tensor.device) and torch.version.cuda is not None + # Accelerated kernel only available for CUDA + if is_cuda: + return matmul_hadU_cuda(tensor, had_K, K) + else: + return matmul_hadU(tensor) def _apply_rotate(model: nn.Module, regions: List[Region], insert_rotation_func: bool = False): for region in regions: - for name in (region.srcs_names + region.sinks_names): - module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) if not insert_rotation_func and not region.is_valid: continue hidden_dim = region.max_shape_sinks - # Check that hidden_dim is an exact Po2 - if torch.log2(torch.tensor(hidden_dim)) != torch.ceil(torch.log2(torch.tensor(hidden_dim))): - continue - # Build hadamard rotation matrix - h = torch.from_numpy(hadamard(hidden_dim)) / torch.sqrt(torch.tensor(hidden_dim)) - hadamard_inverse = h.t() + try: + # Build hadamard rotation matrix + had_K, K = get_hadK(hidden_dim) + except AssertionError as e: + print("Incomptible shapes") + raise e + for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) axis = _get_output_axis(module) - h_inv = hadamard_inverse.type_as(module.weight.data) + weight = module.weight.data + if axis == 0: - module.weight.data = torch.matmul(h_inv, module.weight.data) + weight = _apply_had_device(weight.t(), had_K, K).t() # matmul_hadU_cuda(weight.t(), had_K, K).t() elif axis == 1: - module.weight.data = torch.matmul(module.weight.data, h_inv) + weight = _apply_had_device(weight, had_K, K) else: raise RuntimeError("Not supported yet") + module.weight.data = weight + + if getattr(module, 'bias', None) is not None: + bias = module.bias.data + bias = _apply_had_device(bias, had_K, K) #matmul_hadU_cuda(bias, had_K, K) + module.bias.data = bias + if hasattr(module, 'offload_params'): + module.offload_params(module) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - h = h.type_as(module.weight.data) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) axis = _get_input_axis(module) + weight = module.weight.data + if axis == 1: - module.weight.data = torch.matmul(module.weight.data, h) + weight = _apply_had_device(weight, had_K, K) elif axis == 0: - module.weight.data = torch.matmul(h, module.weight.data) + weight = _apply_had_device(weight.t(), had_K, K).t() else: raise RuntimeError("Not supported yet") - if insert_rotation_func and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(h_inv=hadamard_inverse, layer=module)) - rewriter.apply(model) - for name in (region.srcs_names + region.sinks_names): - module = region.get_module_from_name(name) + module.weight.data = weight if hasattr(module, 'offload_params'): module.offload_params(module) + if insert_rotation_func and len(region.srcs) == 0: + # print(name, module.in_features, K) + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(had_mat=had_K, k=K, layer=module)) + rewriter.apply(model) + + class GraphRotationEqualization(GraphTransform): def __init__(self) -> None: @@ -1266,9 +1283,8 @@ def __init__(self) -> None: def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: - # It is not possible to equalize through LayerNorm/BatchNorm as sink - regions = _extract_regions(graph_model, state_impl_kwargs={'supported_srcs':self.supported_srcs, 'supported_sinks':self.supported_sinks, 'scale_invariant_layers':self.scale_invariant_layers }) + regions = _extract_regions(graph_model, state_impl_kwargs={'supported_srcs':self.supported_srcs, 'supported_sinks':self.supported_sinks, 'scale_invariant_layers':self.scale_invariant_layers, 'scale_invariant_function': self.scale_invariant_function }) if len(regions) > 0: _apply_rotate(graph_model, regions, False) @@ -1291,6 +1307,7 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight): # We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor scale = layer_norm.weight.data.view(view_shape).expand_as(next_module.weight) next_module.weight = torch.nn.Parameter(next_module.weight.clone() * scale) + # Merge bias, new_bias includes the bias of next_module by going through its fwd if hasattr(layer_norm, 'bias'): inp = layer_norm.bias.data.view(view_shape) @@ -1307,6 +1324,7 @@ def __init__(self) -> None: def apply(self, graph_model: GraphModule) -> GraphModule: regions = _extract_regions(graph_model, state_impl_kwargs={'supported_srcs':self.supported_srcs, 'supported_sinks':self.supported_sinks}) + if len(regions) > 0: scaled_biases = set() for region in regions: @@ -1314,7 +1332,6 @@ def apply(self, graph_model: GraphModule) -> GraphModule: layernorm_module =region.get_module_from_name(layernorm_module_name) if not layernorm_module.elementwise_affine: continue - for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) scale_bias = id(module) not in scaled_biases @@ -1327,7 +1344,7 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model -class LayerwiseActivationEqualization(GraphTransform): +class LayerwiseActivationRotation(GraphTransform): def __init__(self, blacklist_layer=None): super(GraphTransform, self).__init__() diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 7620fbeab..a2339beef 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -3,7 +3,7 @@ import torch from brevitas.nn.quant_mha import QuantMultiheadAttention -from brevitas.quant_tensor.base_quant_tensor import QuantTensor +import fast_hadamard_transform INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states'] @@ -45,15 +45,27 @@ def forward(self, *args, **kwargs): class RotatedModule(torch.nn.Module): - def __init__(self, h_inv, layer) -> None: + def __init__(self, had_mat, k, layer) -> None: super().__init__() - self.h_inv = torch.nn.Parameter(h_inv) + self.had_mat = torch.nn.Parameter(had_mat).cpu() self.layer = layer + self.k = k + + def forward(self, inp, **kwargs): + + shape = inp.shape + n = inp.shape[-1] + if self.k == 1: + inp = fast_hadamard_transform.hadamard_transform(inp.contiguous(), 1.0/torch.tensor(n).sqrt()) + o = self.layer(inp) + + # if transpose: + # hadK = hadK.T.contiguous() + inp = inp.view(*inp.shape[:-1], self.k, n // self.k) + inp = fast_hadamard_transform.hadamard_transform(inp.contiguous(), 1.0/torch.tensor(n).sqrt()) + inp = self.had_mat.to(inp.device).to(inp.dtype) @ inp + inp = inp.reshape(shape) + o = self.layer(inp) + + return o - def forward(self, *args, **kwargs): - inp = args[0] - if isinstance(inp, QuantTensor): - inp = inp.value - inp = torch.matmul(inp, self.h_inv) - out = self.layer(inp) - return out diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py index 7ac39347f..14653016f 100644 --- a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -6,11 +6,21 @@ import torch from torch import nn -from brevitas.graph.equalize import _is_reshaping_op -from brevitas.graph.equalize import _is_scale_invariant_module +# from brevitas.graph.equalize import _is_reshaping_op +from brevitas.graph.base import ModuleToModuleByClass +from brevitas.graph.equalize import MergeLnAffine, _is_scale_invariant_module from brevitas.graph.utils import get_module from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 +def replace_rmsnorm_with_torch(model, config): + from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS + ALL_RMSNORM_LAYERS = [x for x in ALL_LAYERNORM_LAYERS if 'RMS' in x.__name__] + rewriters = [ModuleToModuleByClass(rms_cls, torch.nn.RMSNorm, normalized_shape = config.hidden_size, eps = config.rms_norm_eps) for rms_cls in ALL_RMSNORM_LAYERS] + dtype = next(iter(model.parameters())).dtype + for r in rewriters: + model = r.apply(model) + model = model.to(dtype) + return model def replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) @@ -49,7 +59,7 @@ def merge_layernorm_affine_params(graph_model): module = get_module(graph_model, node.target) if isinstance(module, nn.LayerNorm): for next in node.users: - while (_is_reshaping_op(next) or _is_scale_invariant_module(graph_model, next)): + while (_is_scale_invariant_module(graph_model, next)): next = node.next if next.op == 'call_module': next_module = get_module(graph_model, next.target) @@ -83,8 +93,7 @@ def merge_layernorm_affine_params(graph_model): @torch.no_grad() -def apply_layernorm_affine_merge(graph_model, dtype): - # We can't do fp16 tracing on CPU as many kernels are not implemented - # So we have to cast to fp32 first, trace, apply merging, and then cast back - with cast_to_float32(graph_model, dtype): - merge_layernorm_affine_params(graph_model) +def apply_layernorm_affine_merge(graph_model): + eq = MergeLnAffine() + graph_model = eq.apply(graph_model) + return graph_model diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bf995a426..22d02e750 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -8,6 +8,7 @@ from warnings import warn import numpy as np +from brevitas.graph.equalize import GraphRotationEqualization, LayerwiseActivationRotation from optimum.amd.brevitas.accelerate_utils import offload_model from optimum.amd.brevitas.accelerate_utils import remove_hooks from optimum.amd.brevitas.data_utils import compute_perplexity @@ -31,7 +32,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq from brevitas_examples.llm.llm_quant.gpxq import apply_gptq -from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge, replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 @@ -176,18 +177,22 @@ def main(args): device = next(iter(model.parameters())).device print("Data loaded.") - - if args.eval: - assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" - print("Float model eval...") - model = offload_model(model) - float_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + + # if args.eval: + # assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" + # print("Float model eval...") + # model = offload_model(model) + # float_ppl = compute_perplexity( + # model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + # remove_hooks(model) + # print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + + if args.replace_rmsnorm: + model = replace_rmsnorm_with_torch(model, model.config) if require_fx: - model = get_fx(model) + with torch.no_grad(): + model, guards = torch._dynamo.export(model)(**calibration_loader[0]) # Blockwise optimization does not work with FX at the moment args.gpxq_block_name = None @@ -195,8 +200,18 @@ def main(args): # since currently there is support only for merging into Linear if args.ln_affine_merge: print("Apply LN affine merge...") - apply_layernorm_affine_merge(model, dtype) + # apply_layernorm_affine_merge(model) print("LN affine merge applied.") + + + if args.graph_rotation: + assert args.ln_affine_merge + assert args.replace_rmsnorm + eq = GraphRotationEqualization() + model = eq.apply(model) + elif args.layerwise_rotation: + eq = LayerwiseActivationRotation() + model = eq.apply(model) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -466,6 +481,7 @@ def parse_args(args): '--act-calibration', action='store_true', help='Apply activation calibration.') parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.') + parser.add_argument('--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.') parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.') parser.add_argument( '--no-float16', @@ -479,6 +495,15 @@ def parse_args(args): '--weight-equalization', action='store_true', help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') + parser.add_argument( + '--graph-rotation', + default=True, + action='store_true', + help='Apply graph rotation equalization') + parser.add_argument( + '--layerwise-rotation', + action='store_true', + help='Apply layerwise rotation equalization') parser.add_argument( '--act-equalization', default=None, diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 2719b48a0..cfba3ba30 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -491,3 +491,36 @@ def forward(self, x): toy_quant_model = fixture_union( 'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures) + + +## List of Rotation fixtures + + +@pytest_cases.fixture +def linear_rms(): + class LinearRMSModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 4, bias=True) + self.linear.weight.data.fill_(2.) + self.linear.bias.data.fill_(1.) + self.rms = nn.RMSNorm(4) + self.rms.weight.data = torch.randn_like(self.rms.weight.data) # Change learned parameters + self.linear_1 = nn.Linear(4, 8, bias=True) + self.linear_1.weight.data.fill_(2.) + + def forward(self, x): + x = self.linear(x) + x = self.rms(x) + x = self.linear_1(x) + return x + + return LinearRMSModel + + +list_of_rotation_mixtures = ['linear_rms'] + +rotation_fixtures = fixture_union( + 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) + diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 89759b41a..d5e963405 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -7,7 +7,7 @@ from torchvision import models from brevitas.fx import symbolic_trace -from brevitas.graph.equalize import _batch_norm +from brevitas.graph.equalize import GraphRotationEqualization, MergeLnAffine, _batch_norm from brevitas.graph.equalize import _extract_regions from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import activation_equalization_mode @@ -18,210 +18,241 @@ from .equalization_fixtures import * -def test_resnet18_equalization(): - model = models.resnet18(pretrained=True) - - torch.manual_seed(SEED) - inp = torch.randn(IN_SIZE_CONV) - model.eval() - model = symbolic_trace(model) - expected_out = model(inp) - - model_orig = copy.deepcopy(model) - regions = _extract_regions(model) - _ = equalize_test( - regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') - out = model(inp) - - # Check that equalization is not introducing FP variations - assert torch.allclose(expected_out, out, atol=ATOL) - - regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) - resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) - equalized_layers = set() - for r in resnet_18_regions: - equalized_layers.update(r[0]) - equalized_layers.update(r[1]) - - # Check that we found all the expected regions - for region, expected_region in zip(regions, resnet_18_regions): - srcs = region.srcs_names - sources_check = set(srcs) == set(expected_region[0]) - sinks = region.sinks_names - sinks_check = set(sinks) == set(expected_region[1]) - assert sources_check - assert sinks_check - - # Check that all layers were equalized and weights changed - for layer in equalized_layers: - eq_module = get_module(model, layer) - orig_module = get_module(model_orig, layer) - assert not torch.allclose(eq_module.weight, orig_module.weight) - - -@pytest_cases.parametrize("merge_bias", [True, False]) -def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool): - model, coverage = model_coverage - - torch.manual_seed(SEED) - inp = torch.randn(IN_SIZE_CONV) - model.eval() - # The isistance does not work after symbolic trace - is_alexnet = isinstance(model, models.AlexNet) - model = symbolic_trace(model) - model = TorchFunctionalToModule().apply(model) - - expected_out = model(inp) - - regions = _extract_regions(model) - scale_factor_regions = equalize_test( - regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') - shape_scale_regions = [scale.shape for scale in scale_factor_regions] - - out = model(inp) - srcs = set() - sinks = set() - for r in regions: - srcs.update([x for x in list(r.srcs_names)]) - sinks.update([x for x in list(r.sinks_names)]) - - count_region_srcs = 0 - count_region_sinks = 0 - for n in model.graph.nodes: - if _is_supported_module(model, n): - count_region_srcs += 1 - if not isinstance(get_module(model, n.target), (nn.LayerNorm,) + _batch_norm): - count_region_sinks += 1 - - src_coverage = len(srcs) / count_region_srcs - sink_coverage = len(sinks) / count_region_sinks - assert src_coverage >= coverage[0] - assert sink_coverage >= coverage[1] - assert torch.allclose(expected_out, out, atol=ATOL) - # Graph equalization can exit in case of shape mismatches or other error without performing any - # equalization and returning a scalar value. We check that the equalized regions are as many as - # expected - if is_alexnet: - # In AlexNet, we cannot equalize only through one region - assert sum([shape == () for shape in shape_scale_regions]) == 1 - else: - assert all([shape != () for shape in shape_scale_regions]) - - -@pytest_cases.parametrize("merge_bias", [True, False]) -def test_models(toy_model, merge_bias, request): - test_id = request.node.callspec.id - - if 'mha' in test_id: - in_shape = IN_SIZE_LINEAR - else: - in_shape = IN_SIZE_CONV - - model_class = toy_model - model = model_class() - inp = torch.randn(in_shape) +# def test_resnet18_equalization(): +# model = models.resnet18(pretrained=True) + +# torch.manual_seed(SEED) +# inp = torch.randn(IN_SIZE_CONV) +# model.eval() +# model = symbolic_trace(model) +# expected_out = model(inp) + +# model_orig = copy.deepcopy(model) +# regions = _extract_regions(model) +# _ = equalize_test( +# regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') +# out = model(inp) + +# # Check that equalization is not introducing FP variations +# assert torch.allclose(expected_out, out, atol=ATOL) + +# regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) +# resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) +# equalized_layers = set() +# for r in resnet_18_regions: +# equalized_layers.update(r[0]) +# equalized_layers.update(r[1]) + +# # Check that we found all the expected regions +# for region, expected_region in zip(regions, resnet_18_regions): +# srcs = region.srcs_names +# sources_check = set(srcs) == set(expected_region[0]) +# sinks = region.sinks_names +# sinks_check = set(sinks) == set(expected_region[1]) +# assert sources_check +# assert sinks_check + +# # Check that all layers were equalized and weights changed +# for layer in equalized_layers: +# eq_module = get_module(model, layer) +# orig_module = get_module(model_orig, layer) +# assert not torch.allclose(eq_module.weight, orig_module.weight) + + +# @pytest_cases.parametrize("merge_bias", [True, False]) +# def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool): +# model, coverage = model_coverage + +# torch.manual_seed(SEED) +# inp = torch.randn(IN_SIZE_CONV) +# model.eval() +# # The isistance does not work after symbolic trace +# is_alexnet = isinstance(model, models.AlexNet) +# model = symbolic_trace(model) +# model = TorchFunctionalToModule().apply(model) + +# expected_out = model(inp) + +# regions = _extract_regions(model) +# scale_factor_regions = equalize_test( +# regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') +# shape_scale_regions = [scale.shape for scale in scale_factor_regions] + +# out = model(inp) +# srcs = set() +# sinks = set() +# for r in regions: +# srcs.update([x for x in list(r.srcs_names)]) +# sinks.update([x for x in list(r.sinks_names)]) + +# count_region_srcs = 0 +# count_region_sinks = 0 +# for n in model.graph.nodes: +# if _is_supported_module(model, n): +# count_region_srcs += 1 +# if not isinstance(get_module(model, n.target), (nn.LayerNorm,) + _batch_norm): +# count_region_sinks += 1 + +# src_coverage = len(srcs) / count_region_srcs +# sink_coverage = len(sinks) / count_region_sinks +# assert src_coverage >= coverage[0] +# assert sink_coverage >= coverage[1] +# assert torch.allclose(expected_out, out, atol=ATOL) +# # Graph equalization can exit in case of shape mismatches or other error without performing any +# # equalization and returning a scalar value. We check that the equalized regions are as many as +# # expected +# if is_alexnet: +# # In AlexNet, we cannot equalize only through one region +# assert sum([shape == () for shape in shape_scale_regions]) == 1 +# else: +# assert all([shape != () for shape in shape_scale_regions]) + + +# @pytest_cases.parametrize("merge_bias", [True, False]) +# def test_models(toy_model, merge_bias, request): +# test_id = request.node.callspec.id + +# if 'mha' in test_id: +# in_shape = IN_SIZE_LINEAR +# else: +# in_shape = IN_SIZE_CONV + +# model_class = toy_model +# model = model_class() +# inp = torch.randn(in_shape) + +# model.eval() +# with torch.no_grad(): +# expected_out = model(inp) + +# model = symbolic_trace(model) +# regions = _extract_regions(model) +# scale_factor_regions = equalize_test( +# regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') +# shape_scale_regions = [scale.shape for scale in scale_factor_regions] + +# with torch.no_grad(): +# out = model(inp) +# assert len(regions) > 0 +# assert torch.allclose(expected_out, out, atol=ATOL) +# # Check that at least one region performs "true" equalization +# # If all shapes are scalar, no equalization has been performed +# if 'convgroupconv' in test_id: +# with pytest.raises(AssertionError): +# assert all([shape != () for shape in shape_scale_regions]) +# else: +# assert all([shape != () for shape in shape_scale_regions]) + + +# @pytest_cases.parametrize("layerwise", [True, False]) +# def test_act_equalization_models(toy_model, layerwise, request): +# test_id = request.node.callspec.id + +# if 'mha' in test_id: +# in_shape = IN_SIZE_LINEAR +# else: +# in_shape = IN_SIZE_CONV + +# model_class = toy_model +# model = model_class() +# inp = torch.randn(in_shape) + +# model.eval() +# expected_out = model(inp) +# model = symbolic_trace(model) +# with torch.no_grad(): +# with activation_equalization_mode(model, 0.5, True, layerwise=layerwise) as aem: +# regions = aem.graph_act_eq.regions +# model(inp) +# scale_factor_regions = aem.scale_factors +# shape_scale_regions = [scale.shape for scale in scale_factor_regions] + +# out = model(inp) +# assert torch.allclose(expected_out, out, atol=ATOL) + +# # This region is made up of a residual branch, so no regions are found for act equalization +# if 'convgroupconv' in test_id: +# with pytest.raises(AssertionError): +# assert len(regions) > 0 +# # Check that at least one region performs "true" equalization +# # If all shapes are scalar, no equalization has been performed +# assert all([shape != () for shape in shape_scale_regions]) +# else: +# assert len(regions) > 0 +# # Check that at least one region performs "true" equalization +# # If all shapes are scalar, no equalization has been performed +# assert all([shape != () for shape in shape_scale_regions]) + + +# @pytest_cases.parametrize( +# "model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()], +# ids=[model_name for model_name, _ in MODELS.items()]) +# @pytest_cases.parametrize("layerwise", [True, False]) +# def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool): +# model, coverage = model_dict + +# if model == 'googlenet' and torch_version == version.parse('1.8.1'): +# pytest.skip( +# 'Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ' +# ) +# if 'vit' in model and torch_version < version.parse('1.13'): +# pytest.skip( +# f'ViT supported from torch version 1.13, current torch version is {torch_version}') + +# try: +# model = getattr(models, model)(pretrained=True, transform_input=False) +# except TypeError: +# model = getattr(models, model)(pretrained=True) + +# torch.manual_seed(SEED) +# inp = torch.randn(IN_SIZE_CONV) +# model.eval() + +# model = symbolic_trace(model) +# model = TorchFunctionalToModule().apply(model) +# model = DuplicateSharedStatelessModule().apply(model) +# expected_out = model(inp) + +# with torch.no_grad(): +# with activation_equalization_mode(model, 0.5, True, layerwise=layerwise) as aem: +# model(inp) +# scale_factor_regions = aem.scale_factors +# shape_scale_regions = [scale.shape for scale in scale_factor_regions] + +# out = model(inp) + +# assert torch.allclose(expected_out, out, atol=ATOL) +# # Check that at least one region performs "true" equalization +# # If all shapes are scalar, no equalization has been performed +# assert any([shape != () for shape in shape_scale_regions]) + + +def test_models(rotation_fixtures): + + in_shape = IN_SIZE_LINEAR + + model_class = rotation_fixtures + model = model_class().cuda() + inp = torch.ones(in_shape).cuda() model.eval() + weight = copy.deepcopy(list(iter(model.parameters()))[-2]) with torch.no_grad(): expected_out = model(inp) model = symbolic_trace(model) - regions = _extract_regions(model) - scale_factor_regions = equalize_test( - regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') - shape_scale_regions = [scale.shape for scale in scale_factor_regions] + merge = MergeLnAffine() + model = merge.apply(model) + eq = GraphRotationEqualization() + model = eq.apply(model) with torch.no_grad(): out = model(inp) - assert len(regions) > 0 - assert torch.allclose(expected_out, out, atol=ATOL) - # Check that at least one region performs "true" equalization - # If all shapes are scalar, no equalization has been performed - if 'convgroupconv' in test_id: - with pytest.raises(AssertionError): - assert all([shape != () for shape in shape_scale_regions]) - else: - assert all([shape != () for shape in shape_scale_regions]) - - -@pytest_cases.parametrize("layerwise", [True, False]) -def test_act_equalization_models(toy_model, layerwise, request): - test_id = request.node.callspec.id - - if 'mha' in test_id: - in_shape = IN_SIZE_LINEAR - else: - in_shape = IN_SIZE_CONV - - model_class = toy_model - model = model_class() - inp = torch.randn(in_shape) - - model.eval() - expected_out = model(inp) - model = symbolic_trace(model) - with torch.no_grad(): - with activation_equalization_mode(model, 0.5, True, layerwise=layerwise) as aem: - regions = aem.graph_act_eq.regions - model(inp) - scale_factor_regions = aem.scale_factors - shape_scale_regions = [scale.shape for scale in scale_factor_regions] - - out = model(inp) - assert torch.allclose(expected_out, out, atol=ATOL) - - # This region is made up of a residual branch, so no regions are found for act equalization - if 'convgroupconv' in test_id: - with pytest.raises(AssertionError): - assert len(regions) > 0 - # Check that at least one region performs "true" equalization - # If all shapes are scalar, no equalization has been performed - assert all([shape != () for shape in shape_scale_regions]) - else: - assert len(regions) > 0 - # Check that at least one region performs "true" equalization - # If all shapes are scalar, no equalization has been performed - assert all([shape != () for shape in shape_scale_regions]) - - -@pytest_cases.parametrize( - "model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()], - ids=[model_name for model_name, _ in MODELS.items()]) -@pytest_cases.parametrize("layerwise", [True, False]) -def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool): - model, coverage = model_dict - - if model == 'googlenet' and torch_version == version.parse('1.8.1'): - pytest.skip( - 'Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' ' - ) - if 'vit' in model and torch_version < version.parse('1.13'): - pytest.skip( - f'ViT supported from torch version 1.13, current torch version is {torch_version}') - - try: - model = getattr(models, model)(pretrained=True, transform_input=False) - except TypeError: - model = getattr(models, model)(pretrained=True) - - torch.manual_seed(SEED) - inp = torch.randn(IN_SIZE_CONV) - model.eval() - - model = symbolic_trace(model) - model = TorchFunctionalToModule().apply(model) - model = DuplicateSharedStatelessModule().apply(model) - expected_out = model(inp) - - with torch.no_grad(): - with activation_equalization_mode(model, 0.5, True, layerwise=layerwise) as aem: - model(inp) - scale_factor_regions = aem.scale_factors - shape_scale_regions = [scale.shape for scale in scale_factor_regions] - - out = model(inp) - assert torch.allclose(expected_out, out, atol=ATOL) - # Check that at least one region performs "true" equalization - # If all shapes are scalar, no equalization has been performed - assert any([shape != () for shape in shape_scale_regions]) + nweight = copy.deepcopy(list(iter(model.parameters()))[-2]) + # Invariance of the output + assert torch.allclose(out, expected_out, atol=ATOL) + # Rotate weights must be different + assert not torch.allclose(weight, nweight) + # Merging affine parameters of RMS + assert torch.allclose(model.rms.weight.data, torch.ones_like(model.rms.weight.data))