From 40cda8cc275330a3e703d6223fc4151fa7dff826 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 17 Oct 2024 15:22:55 +0100 Subject: [PATCH] Test rotation --- src/brevitas/graph/equalize.py | 314 ++++++++++++++++++++++------- src/brevitas/nn/equalized_layer.py | 16 ++ 2 files changed, 260 insertions(+), 70 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7f412148f..8b470091d 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -20,14 +20,17 @@ from brevitas.graph.base import ModuleInstanceToModuleInstance from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node -from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.equalized_layer import EqualizedModule, RotatedModule from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook from .base import GraphTransform from .base import InsertModuleCallAfter - +try: + from scipy.linalg import hadamard +except ImportError: + hadamard = None __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] EPSILON = 1e-9 @@ -73,7 +76,7 @@ _select_op = (operator.getitem, operator.__getitem__) -_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten) +_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', 'to', torch.reshape, torch.flatten) _scale_varying_activations = ( torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU) @@ -86,6 +89,24 @@ _ignore_ops = (getattr, 'size') +def _is_supported_module(graph_model: GraphModule, node: Node, supported_layers: Set =_supported_layers) -> bool: + if node.op == 'call_module': + module = get_module(graph_model, node.target) + if isinstance(module, supported_layers): + # We support only self-attention + if isinstance(module, nn.MultiheadAttention): + kwargs = dict(node.kwargs) + # When using hf/accelerate, we need to check the signature of the original forward + forward_to_check = module._old_forward if hasattr( + module, '_old_forward') else module.forward + kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args)) + return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name + return True + return False + +def _is_scale_invariant_module(graph_model: GraphModule, node: Node, scale_invariant_layers=_scale_invariant_layers) -> bool: + return node.op == 'call_module' and isinstance( + get_module(graph_model, node.target), scale_invariant_layers) # Start and End identify the starting and ending channels of the weight matrix that need to be # equalized. @@ -105,6 +126,12 @@ class WeightBiasWrapper: weight: torch.Tensor = None bias: torch.Tensor = None +# Required for being hashable +@dataclass(eq=True, frozen=True) +class FunctionalNodeReference: + node: Node = None + args_index: int = None + # Required for being hashable @dataclass(eq=True, frozen=True) @@ -125,7 +152,24 @@ def sinks_names(self): def get_module_from_name(self, name: str) -> nn.Module: name = name.split("$")[0] return self.name_to_module[name] + + @property + def max_shape_srcs(self): + max_shape_srcs = 0 + for name, indexes in self.srcs.items(): + max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) + return max_shape_srcs + @property + def max_shape_sinks(self): + max_shape_sinks = 0 + for name, indexes in self.sinks.items(): + max_shape_sinks = max(max_shape_sinks, indexes.end + indexes.offset) + return max_shape_sinks + + @property + def is_valid(self): + return self.max_shape_srcs == self.max_shape_sinks @dataclass class WalkRegionState: @@ -135,6 +179,11 @@ class WalkRegionState: history: set = field(default_factory=set) name_to_module: Dict = field(default_factory=dict) + supported_srcs: set = _supported_layers + supported_sinks: set = _supported_layers + scale_invariant_function: set = _scale_invariant_op + scale_invariant_layers: set = _scale_invariant_layers + cat_encoutered: bool = False offset: int = 0 update_offset: bool = False @@ -271,7 +320,7 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 elif module.groups == module.out_channels: return 1 - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)): # We assume normalization happens only along the channel dimension if len(module.weight.shape) == 1: return 0 @@ -296,9 +345,9 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: nn.BatchNorm2d, nn.BatchNorm3d)): return 0 - elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + elif isinstance(module, (nn.Embedding, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): return 1 - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)): # We assume normalization happens only along the channel dimension if len(module.weight.shape) == 1: return 0 @@ -415,16 +464,6 @@ def _no_equalize(): single_module = region.get_module_from_name(next(iter(region.sinks_names))) dtype = next(single_module.parameters()).dtype - max_shape_srcs = 0 - for name, indexes in region.srcs.items(): - max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) - max_shape_sinks = 0 - for name, indexes in region.sinks.items(): - max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start)) - - # Exit if source and sink have different sizes - if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0: - return _no_equalize() src_axes = {} for name, indexes in region.srcs.items(): @@ -432,8 +471,7 @@ def _no_equalize(): # If module is not supported, do not perform graph equalization axis = _get_output_axis(module) act_sources_axes[name] = _get_act_axis(module) - if not isinstance(module, _supported_layers): - return _no_equalize() + if isinstance(module, nn.MultiheadAttention): module = module.out_proj src_axes[name] = (module, axis) @@ -443,9 +481,6 @@ def _no_equalize(): module = region.get_module_from_name(name) axis = _get_input_axis(module) act_sink_axes[name] = _get_act_axis(module) - # If module is not supported, do not perform graph equalization - if not isinstance(module, _supported_layers) or module in _batch_norm: - return _no_equalize() # For MultiheadAttention, we support only self-attetion if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias @@ -479,8 +514,8 @@ def _no_equalize(): sink_weights = { name: transpose(m.weight.cpu().to(torch.float32), axis) for name, (m, axis) in sink_axes.items()} - srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=torch.float32) - sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=torch.float32) + srcs_range = -1 * torch.ones(region.max_shape_srcs, device='cpu', dtype=torch.float32) + sinks_range = -1 * torch.ones(region.max_shape_sinks, device='cpu', dtype=torch.float32) for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in @@ -638,43 +673,19 @@ def _equalize( return model -def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: - if node.op == 'call_module': - module = get_module(graph_model, node.target) - if isinstance(module, _supported_layers): - # We support only self-attention - if isinstance(module, nn.MultiheadAttention): - kwargs = dict(node.kwargs) - # When using hf/accelerate, we need to check the signature of the original forward - forward_to_check = module._old_forward if hasattr( - module, '_old_forward') else module.forward - kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args)) - return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name - return True - return False - - -def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: - return node.op == 'call_module' and isinstance( - get_module(graph_model, node.target), _scale_invariant_layers) - def _is_scale_varying_activation(graph_model, node): return node.op == 'call_module' and isinstance( get_module(graph_model, node.target), _scale_varying_activations) -def _is_scale_invariant_function(node: Node) -> bool: - out = node.op == 'call_function' and node.target in _scale_invariant_op + _select_op +def _is_scale_invariant_function(node: Node, scale_invariant_op: Set =_scale_invariant_op) -> bool: + out = node.op == 'call_function' and node.target in scale_invariant_op + _select_op + _reshaping_op if node.target == torch.nn.functional.interpolate: out &= node.kwargs.get('mode', None) == 'nearest' return out -def _is_reshaping_op(node: Node) -> bool: - return node.target in _reshaping_op - - def get_weight_source(module): transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'): @@ -696,8 +707,8 @@ def get_weight_sink(module): return weight -def find_srcs_channel_dim(model, inp_node): - if _is_supported_module(model, inp_node): +def find_srcs_channel_dim(state, model, inp_node): + if _is_supported_module(model, inp_node, state.supported_srcs): # If we meet a supported module, determine the channel shape module = get_module(model, inp_node.target) # Since we are walking up, we consider the module as srcs @@ -707,7 +718,7 @@ def find_srcs_channel_dim(model, inp_node): elif _is_add(inp_node): all_channels = [] for n in inp_node.all_input_nodes: - all_channels.append(find_srcs_channel_dim(model, n)) + all_channels.append(find_srcs_channel_dim(state, model, n)) # All branches to add should have the same amount of channels if all([channel == all_channels[0] for channel in all_channels]): return all_channels[0] @@ -717,10 +728,10 @@ def find_srcs_channel_dim(model, inp_node): total_channels = 0 # If it's cat, we need to sum the channel shape of all the branches for n in inp_node.all_input_nodes: - total_channels += find_srcs_channel_dim(model, n) + total_channels += find_srcs_channel_dim(state, model, n) return total_channels - elif _is_scale_invariant_module(model, inp_node) or _is_scale_invariant_function(inp_node): - return find_srcs_channel_dim(model, inp_node.all_input_nodes[0]) + elif _is_scale_invariant_module(model, inp_node, state.scale_invariant_layers) or _is_scale_invariant_function(inp_node, state.scale_invariant_function): + return find_srcs_channel_dim(state, model, inp_node.all_input_nodes[0]) else: return _UNSUPPORTED_OP @@ -762,7 +773,7 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, state.history.add(path) else: continue - if _is_supported_module(graph_model, node): + if _is_supported_module(graph_model, node, state.supported_srcs): module = get_module(graph_model, node.target) weight = get_weight_source(module) eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) @@ -773,7 +784,7 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, state.offset = state.offset if not state.update_offset else state.offset + weight.shape[ 0] elif _is_scale_invariant_module( - graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): + graph_model, node, state.scale_invariant_layers) or _is_scale_invariant_function(node, state.scale_invariant_function): find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or @@ -813,18 +824,14 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, state.history.add(path) else: continue - if _is_supported_module(graph_model, node): + if _is_supported_module(graph_model, node, state.supported_sinks): module = get_module(graph_model, node.target) weight = get_weight_sink(module) eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) - # It is not possible to equalize through LayerNorm as sink - if isinstance(module, (nn.LayerNorm,) + _batch_norm): - state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP - else: - state.add_sinks(node.target, module, eq_indexes) + state.add_sinks(node.target, module, eq_indexes) elif _is_scale_invariant_module( - graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): + graph_model, node, state.scale_invariant_layers) or _is_scale_invariant_function(node, state.scale_invariant_function): find_sinks(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or node.op == 'call_function' and node.target in _residual_fns): @@ -844,7 +851,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, index = node.all_input_nodes.index(starting_node) channels = [] for n in node.all_input_nodes: - channel_dim = find_srcs_channel_dim(graph_model, n) + channel_dim = find_srcs_channel_dim(state, graph_model, n) channels.append(channel_dim) # If we found an unsupported op while walking up, we exit this branch and @@ -873,13 +880,17 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, def _extract_regions( graph_model: GraphModule, add_mul_node: bool = False, - return_acts: bool = False) -> List[Region]: + return_acts: bool = False, + state_impl_kwargs = None)-> List[Region]: regions = list() for node in graph_model.graph.nodes: + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: + state = WalkRegionState() if _is_supported_module(graph_model, - node) or (add_mul_node and + node, state.supported_srcs) or (add_mul_node and _is_scale_varying_activation(graph_model, node)): - state = WalkRegionState() if _is_scale_varying_activation(graph_model, node): module = get_module(graph_model, node.target) state.add_acts(node.target, module) @@ -928,7 +939,10 @@ def __init__( def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: - regions = _extract_regions(graph_model) + # It is not possible to equalize through LayerNorm/BatchNorm as sink + supported_sinks = list(_supported_layers) + supported_sinks = [x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)] + regions = _extract_regions(graph_model, state_impl_kwargs={'supported_sinks':supported_sinks}) if len(regions) > 0: graph_model = _equalize( graph_model, @@ -1099,7 +1113,11 @@ def __init__( self.hooked_modules = set() self.add_mul_node = add_mul_node self.co_optimize_act_weights = co_optimize_act_weights - self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) + + # It is not possible to equalize through LayerNorm/BatchNorm as sink + supported_sinks = list(_supported_layers) + supported_sinks = [x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)] + self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True, state_impl_kwargs={'supported_sinks': supported_sinks}) if self.scale_computation_type == 'maxabs': self.scale_fn = _channel_maxabs @@ -1186,3 +1204,159 @@ def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): self.model.add_module(mul_factor_name, mul_factor) rewriter = InsertModuleCallAfter(mul_factor_name, act_node) rewriter.apply(self.model) + + +def _apply_rotate(model: nn.Module, regions: List[Region], insert_rotation_func: bool = False): + for region in regions: + for name in (region.srcs_names + region.sinks_names): + module = region.get_module_from_name(name) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + + if not insert_rotation_func and not region.is_valid: + continue + hidden_dim = region.max_shape_sinks + # Check that hidden_dim is an exact Po2 + if torch.log2(torch.tensor(hidden_dim)) != torch.ceil(torch.log2(torch.tensor(hidden_dim))): + continue + + # Build hadamard rotation matrix + h = torch.from_numpy(hadamard(hidden_dim)) / torch.sqrt(torch.tensor(hidden_dim)) + hadamard_inverse = h.t() + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) + axis = _get_output_axis(module) + h_inv = hadamard_inverse.type_as(module.weight.data) + if axis == 0: + module.weight.data = torch.matmul(h_inv, module.weight.data) + elif axis == 1: + module.weight.data = torch.matmul(module.weight.data, h_inv) + else: + raise RuntimeError("Not supported yet") + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + h = h.type_as(module.weight.data) + axis = _get_input_axis(module) + if axis == 1: + module.weight.data = torch.matmul(module.weight.data, h) + elif axis == 0: + module.weight.data = torch.matmul(h, module.weight.data) + else: + raise RuntimeError("Not supported yet") + if insert_rotation_func and len(region.srcs) == 0: + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(h_inv=hadamard_inverse, layer=module)) + rewriter.apply(model) + + for name in (region.srcs_names + region.sinks_names): + module = region.get_module_from_name(name) + if hasattr(module, 'offload_params'): + module.offload_params(module) + +class GraphRotationEqualization(GraphTransform): + + def __init__(self) -> None: + super(GraphRotationEqualization, self).__init__() + + self.supported_srcs = (torch.nn.Linear, torch.nn.Embedding) + self.supported_sinks = (torch.nn.Linear) + self.scale_invariant_layers = (torch.nn.RMSNorm,) + self.scale_invariant_function = () + + def apply(self, + graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: + # It is not possible to equalize through LayerNorm/BatchNorm as sink + + regions = _extract_regions(graph_model, state_impl_kwargs={'supported_srcs':self.supported_srcs, 'supported_sinks':self.supported_sinks, 'scale_invariant_layers':self.scale_invariant_layers }) + if len(regions) > 0: + _apply_rotate(graph_model, regions, False) + + return graph_model + +def _replace_bias(next_module, new_bias): + new_bias = new_bias.view(-1) + if next_module.bias is not None: + next_module.bias.data.copy_(new_bias) + else: + new_bias = new_bias.to(next_module.weight.device).to(next_module.weight.dtype) + next_module.register_parameter('bias', torch.nn.Parameter(new_bias)) + + +def _merge_ln(layer_norm, next_module, scale_bias_by_weight): + view_shape = (1, -1) + # Merge weight + if scale_bias_by_weight and hasattr(layer_norm, 'bias'): + layer_norm.bias.data /= layer_norm.weight.data + # We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor + scale = layer_norm.weight.data.view(view_shape).expand_as(next_module.weight) + next_module.weight = torch.nn.Parameter(next_module.weight.clone() * scale) + # Merge bias, new_bias includes the bias of next_module by going through its fwd + if hasattr(layer_norm, 'bias'): + inp = layer_norm.bias.data.view(view_shape) + new_bias = next_module(inp) + _replace_bias(next_module, new_bias) + + +class MergeLnAffine(GraphTransform): + + def __init__(self) -> None: + super(MergeLnAffine, self).__init__() + self.supported_srcs = (torch.nn.RMSNorm, torch.nn.LayerNorm) + self.supported_sinks = (torch.nn.Linear) + + def apply(self, graph_model: GraphModule) -> GraphModule: + regions = _extract_regions(graph_model, state_impl_kwargs={'supported_srcs':self.supported_srcs, 'supported_sinks':self.supported_sinks}) + if len(regions) > 0: + scaled_biases = set() + for region in regions: + layernorm_module_name = next(iter(region.srcs)) + layernorm_module =region.get_module_from_name(layernorm_module_name) + if not layernorm_module.elementwise_affine: + continue + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + scale_bias = id(module) not in scaled_biases + _merge_ln(layernorm_module, module, scale_bias_by_weight=scale_bias) + + scaled_biases.add(id(module)) + layernorm_module.weight.data.fill_(1.) + if hasattr(layernorm_module, 'bias'): + layernorm_module.bias.data.fill_(0.) + return graph_model + + +class LayerwiseActivationEqualization(GraphTransform): + + def __init__(self, blacklist_layer=None): + super(GraphTransform, self).__init__() + + self.supported_sinks = (torch.nn.Linear) + self.blacklist_layers = blacklist_layer + + def find_module(self, model, regions: List, prefix=''): + """ + Iterate through the model looking at immediate children of every module to look for supported modules. + This allows us to stop the search when we meet a top-level module that is supported. + """ + if isinstance(model, self.supported_sinks): + if self.blacklist_layers is not None and prefix in self.blacklist_layers: + return + weight = get_weight_sink(model) + eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) + region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) + regions.append(region) + else: + for name, module in model.named_children(): + full_name = prefix + '.' + name if prefix != '' else name + self.find_module(module, regions, full_name) + + def apply(self, model: nn.Module) -> nn.Module: + regions: List[Region] = [] + self.find_module(model, regions) + if len(regions) > 0: + _apply_rotate(model, regions, True) + return model + + diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 7093c8c17..7620fbeab 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -3,6 +3,7 @@ import torch from brevitas.nn.quant_mha import QuantMultiheadAttention +from brevitas.quant_tensor.base_quant_tensor import QuantTensor INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states'] @@ -41,3 +42,18 @@ def forward(self, *args, **kwargs): # We convert everything to args so that hooks can work correctly out = self.layer(*kwargs.values()) return out + +class RotatedModule(torch.nn.Module): + + def __init__(self, h_inv, layer) -> None: + super().__init__() + self.h_inv = torch.nn.Parameter(h_inv) + self.layer = layer + + def forward(self, *args, **kwargs): + inp = args[0] + if isinstance(inp, QuantTensor): + inp = inp.value + inp = torch.matmul(inp, self.h_inv) + out = self.layer(inp) + return out