diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7f412148f..f39a951d0 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -10,25 +10,59 @@ from typing import Callable, Dict, List, Optional, Set, Tuple, Union import warnings +import packaging +import packaging.version import torch from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn +from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node +from brevitas.graph import ModuleToModuleByClass +from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform +from brevitas.graph.base import InsertModuleCallAfter from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import Transform +from brevitas.graph.hadamard import get_hadK +from brevitas.graph.hadamard import matmul_hadU +from brevitas.graph.hadamard import matmul_hadU_cuda from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES +from brevitas.nn.equalized_layer import RotatedModule from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook -from .base import GraphTransform -from .base import InsertModuleCallAfter +# External optional dependency +try: + # fast_hadamard_transform @ git+https://github.com/Dao-AILab/fast-hadamard-transform.git@main + import fast_hadamard_transform +except: + warnings.warn("fast_hadamard_transform package not found, using standard pytorch kernels") + fast_hadamard_transform = None -__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] +# RMSNorm was introduced with torch 2.4 +if torch_version >= packaging.version.parse('2.4'): + RMSNorm = nn.RMSNorm +else: + + class PlaceholderRMSNorm: + pass + + RMSNorm = PlaceholderRMSNorm + +__all__ = [ + 'GraphActivationEqualization', + 'LayerwiseActivationEqualization', + 'EqualizeGraph', + 'LayerwiseActivationRotation', + 'MergeLnAffine', + 'LayerNormToRMS', + 'GraphRotationEqualization'] EPSILON = 1e-9 @@ -69,14 +103,13 @@ operator.imul, operator.__mul__, operator.__imul__, - torch.nn.functional.interpolate) + nn.functional.interpolate) _select_op = (operator.getitem, operator.__getitem__) -_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten) +_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', 'to', torch.reshape, torch.flatten) -_scale_varying_activations = ( - torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU) +_scale_varying_activations = (nn.Sigmoid, nn.Tanh, nn.ReLU6, nn.GELU, nn.SiLU) _residual_methods = ('add', 'add_') @@ -126,6 +159,34 @@ def get_module_from_name(self, name: str) -> nn.Module: name = name.split("$")[0] return self.name_to_module[name] + @property + def max_shape_srcs(self): + # Compute the number of output channel from the sources. If we are equalizing through cat, + # we need to add together the number of channels. Otherwise, all sources must have the same + # number of output channel. + # Furthermore, all output channels of all the sources are always fully equalized. + max_shape_srcs = 0 + for name, indexes in self.srcs.items(): + max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) + return max_shape_srcs + + @property + def max_shape_sinks(self): + # Compute the number of input channel from the sinks. If we are equalizing through cat, + # we need to slice and potentially select only a subset of input channel from sinks. + max_shape_sinks = 0 + for name, indexes in self.sinks.items(): + max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start)) + return max_shape_sinks + + @property + def is_valid(self): + """ + To perform equalization, we need that the number of output channel of the sources matches the + number of input channel of the sinks. If that's not the case, the region is considered invalid + """ + return self.max_shape_srcs == self.max_shape_sinks + @dataclass class WalkRegionState: @@ -135,6 +196,11 @@ class WalkRegionState: history: set = field(default_factory=set) name_to_module: Dict = field(default_factory=dict) + supported_srcs: set = _supported_layers + supported_sinks: set = _supported_layers + scale_invariant_function: set = _scale_invariant_op + scale_invariant_layers: set = _scale_invariant_layers + cat_encoutered: bool = False offset: int = 0 update_offset: bool = False @@ -175,10 +241,6 @@ def get_module_from_name(self, name: str) -> nn.Module: return self.name_to_module[name] -def __str__(self): - return str(self.start) + '_' + str(self.end) + '_' + str(self.offset) - - _UNSUPPORTED_OP = object() @@ -271,7 +333,7 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 elif module.groups == module.out_channels: return 1 - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, RMSNorm)): # We assume normalization happens only along the channel dimension if len(module.weight.shape) == 1: return 0 @@ -296,9 +358,10 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: nn.BatchNorm2d, nn.BatchNorm3d)): return 0 - elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + elif isinstance(module, + (nn.Embedding, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): return 1 - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, RMSNorm)): # We assume normalization happens only along the channel dimension if len(module.weight.shape) == 1: return 0 @@ -415,15 +478,8 @@ def _no_equalize(): single_module = region.get_module_from_name(next(iter(region.sinks_names))) dtype = next(single_module.parameters()).dtype - max_shape_srcs = 0 - for name, indexes in region.srcs.items(): - max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) - max_shape_sinks = 0 - for name, indexes in region.sinks.items(): - max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start)) - - # Exit if source and sink have different sizes - if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0: + # If region is not valid, don't equalize. If we are inserting a standalone mul, we don't need this check + if not region.is_valid and list_of_insert_mul_node_fn is None: return _no_equalize() src_axes = {} @@ -432,8 +488,7 @@ def _no_equalize(): # If module is not supported, do not perform graph equalization axis = _get_output_axis(module) act_sources_axes[name] = _get_act_axis(module) - if not isinstance(module, _supported_layers): - return _no_equalize() + if isinstance(module, nn.MultiheadAttention): module = module.out_proj src_axes[name] = (module, axis) @@ -443,9 +498,6 @@ def _no_equalize(): module = region.get_module_from_name(name) axis = _get_input_axis(module) act_sink_axes[name] = _get_act_axis(module) - # If module is not supported, do not perform graph equalization - if not isinstance(module, _supported_layers) or module in _batch_norm: - return _no_equalize() # For MultiheadAttention, we support only self-attetion if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias @@ -479,8 +531,8 @@ def _no_equalize(): sink_weights = { name: transpose(m.weight.cpu().to(torch.float32), axis) for name, (m, axis) in sink_axes.items()} - srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=torch.float32) - sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=torch.float32) + srcs_range = -1 * torch.ones(region.max_shape_srcs, device='cpu', dtype=torch.float32) + sinks_range = -1 * torch.ones(region.max_shape_sinks, device='cpu', dtype=torch.float32) for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in @@ -638,10 +690,11 @@ def _equalize( return model -def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: +def _is_supported_module( + graph_model: GraphModule, node: Node, supported_layers: Set = _supported_layers) -> bool: if node.op == 'call_module': module = get_module(graph_model, node.target) - if isinstance(module, _supported_layers): + if isinstance(module, supported_layers): # We support only self-attention if isinstance(module, nn.MultiheadAttention): kwargs = dict(node.kwargs) @@ -654,27 +707,30 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: return False -def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: +def _is_scale_invariant_module( + graph_model: GraphModule, + node: Node, + scale_invariant_layers=_scale_invariant_layers) -> bool: return node.op == 'call_module' and isinstance( - get_module(graph_model, node.target), _scale_invariant_layers) + get_module(graph_model, node.target), scale_invariant_layers) def _is_scale_varying_activation(graph_model, node): + node_target = node.meta.get('orig_target', node.target) return node.op == 'call_module' and isinstance( - get_module(graph_model, node.target), _scale_varying_activations) + get_module(graph_model, node_target), _scale_varying_activations) -def _is_scale_invariant_function(node: Node) -> bool: - out = node.op == 'call_function' and node.target in _scale_invariant_op + _select_op - if node.target == torch.nn.functional.interpolate: +def _is_scale_invariant_function(node: Node, scale_invariant_op: Set = _scale_invariant_op) -> bool: + node_target = node.meta.get('orig_target', node.target) + out = node.op in ( + 'call_function', + 'call_method') and node_target in scale_invariant_op + _select_op + _reshaping_op + if node_target == nn.functional.interpolate: out &= node.kwargs.get('mode', None) == 'nearest' return out -def _is_reshaping_op(node: Node) -> bool: - return node.target in _reshaping_op - - def get_weight_source(module): transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'): @@ -696,10 +752,11 @@ def get_weight_sink(module): return weight -def find_srcs_channel_dim(model, inp_node): - if _is_supported_module(model, inp_node): +def find_srcs_channel_dim(state, model, inp_node): + inp_node_target = inp_node.meta.get('orig_target', inp_node.target) + if _is_supported_module(model, inp_node, state.supported_srcs): # If we meet a supported module, determine the channel shape - module = get_module(model, inp_node.target) + module = get_module(model, inp_node_target) # Since we are walking up, we consider the module as srcs weight = get_weight_source(module) channel = weight.shape[0] @@ -707,7 +764,7 @@ def find_srcs_channel_dim(model, inp_node): elif _is_add(inp_node): all_channels = [] for n in inp_node.all_input_nodes: - all_channels.append(find_srcs_channel_dim(model, n)) + all_channels.append(find_srcs_channel_dim(state, model, n)) # All branches to add should have the same amount of channels if all([channel == all_channels[0] for channel in all_channels]): return all_channels[0] @@ -717,10 +774,12 @@ def find_srcs_channel_dim(model, inp_node): total_channels = 0 # If it's cat, we need to sum the channel shape of all the branches for n in inp_node.all_input_nodes: - total_channels += find_srcs_channel_dim(model, n) + total_channels += find_srcs_channel_dim(state, model, n) return total_channels - elif _is_scale_invariant_module(model, inp_node) or _is_scale_invariant_function(inp_node): - return find_srcs_channel_dim(model, inp_node.all_input_nodes[0]) + elif _is_scale_invariant_module(model, inp_node, + state.scale_invariant_layers) or _is_scale_invariant_function( + inp_node, state.scale_invariant_function): + return find_srcs_channel_dim(state, model, inp_node.all_input_nodes[0]) else: return _UNSUPPORTED_OP @@ -741,13 +800,15 @@ def cat_handler(graph_model: GraphModule, starting_node: Node, state: WalkRegion def _is_cat(node): - return node.target in (torch.cat,) + node_target = node.meta.get('orig_target', node.target) + return node_target in (torch.cat,) def _is_add(node): + node_target = node.meta.get('orig_target', node.target) return ( - node.op == 'call_method' and node.target in _residual_methods or - node.op == 'call_function' and node.target in _residual_fns) + node.op == 'call_method' and node_target in _residual_methods or + node.op == 'call_function' and node_target in _residual_fns) def find_srcs(graph_model: GraphModule, starting_node: Node, @@ -757,27 +818,29 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, for node in node_list: # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop + node_target = node.meta.get('orig_target', node.target) path = (node, starting_node) if path not in state.history: state.history.add(path) else: continue - if _is_supported_module(graph_model, node): - module = get_module(graph_model, node.target) + if _is_supported_module(graph_model, node, state.supported_srcs): + module = get_module(graph_model, node_target) weight = get_weight_source(module) eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) # After we found a source, we need to check if it branches into multiple sinks - state.add_srcs(node.target, module, eq_indexes) + state.add_srcs(node_target, module, eq_indexes) find_sinks(graph_model, node, state) state.offset = state.offset if not state.update_offset else state.offset + weight.shape[ 0] elif _is_scale_invariant_module( - graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): + graph_model, node, state.scale_invariant_layers) or _is_scale_invariant_function( + node, state.scale_invariant_function): find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) - elif (node.op == 'call_method' and node.target in _residual_methods or - node.op == 'call_function' and node.target in _residual_fns): + elif (node.op == 'call_method' and node_target in _residual_methods or + node.op == 'call_function' and node_target in _residual_fns): state.update_offset = False find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) @@ -793,7 +856,7 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, state.update_offset = True find_srcs(graph_model, node, state) state.update_offset = update_offset_state - elif node.target in _ignore_ops: + elif node_target in _ignore_ops: continue else: # If we meet an unrecognized op, we add None to invalidate the region @@ -808,26 +871,24 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop # Note that the path is inverted with respect to find_srcs + node_target = node.meta.get('orig_target', node.target) path = (starting_node, node) if path not in state.history: state.history.add(path) else: continue - if _is_supported_module(graph_model, node): - module = get_module(graph_model, node.target) + if _is_supported_module(graph_model, node, state.supported_sinks): + module = get_module(graph_model, node_target) weight = get_weight_sink(module) eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) - # It is not possible to equalize through LayerNorm as sink - if isinstance(module, (nn.LayerNorm,) + _batch_norm): - state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP - else: - state.add_sinks(node.target, module, eq_indexes) + state.add_sinks(node_target, module, eq_indexes) elif _is_scale_invariant_module( - graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): + graph_model, node, state.scale_invariant_layers) or _is_scale_invariant_function( + node, state.scale_invariant_function): find_sinks(graph_model, node, state) - elif (node.op == 'call_method' and node.target in _residual_methods or - node.op == 'call_function' and node.target in _residual_fns): + elif (node.op == 'call_method' and node_target in _residual_methods or + node.op == 'call_function' and node_target in _residual_fns): state.update_offset = False find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) @@ -844,7 +905,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, index = node.all_input_nodes.index(starting_node) channels = [] for n in node.all_input_nodes: - channel_dim = find_srcs_channel_dim(graph_model, n) + channel_dim = find_srcs_channel_dim(state, graph_model, n) channels.append(channel_dim) # If we found an unsupported op while walking up, we exit this branch and @@ -863,7 +924,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, new_state.get_module_from_name(k), EqualizationIndexes(start, end, new_state.offset)) state.srcs.update(new_state.srcs) - elif node.target in _ignore_ops: + elif node_target in _ignore_ops: continue else: # If we meet an unrecognized op, we add None to invalidate the region @@ -873,13 +934,16 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, def _extract_regions( graph_model: GraphModule, add_mul_node: bool = False, - return_acts: bool = False) -> List[Region]: + return_acts: bool = False, + state_impl_kwargs=None) -> List[Region]: regions = list() for node in graph_model.graph.nodes: - if _is_supported_module(graph_model, - node) or (add_mul_node and - _is_scale_varying_activation(graph_model, node)): + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: state = WalkRegionState() + if _is_supported_module(graph_model, node, state.supported_srcs) or ( + add_mul_node and _is_scale_varying_activation(graph_model, node)): if _is_scale_varying_activation(graph_model, node): module = get_module(graph_model, node.target) state.add_acts(node.target, module) @@ -928,7 +992,11 @@ def __init__( def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: - regions = _extract_regions(graph_model) + # It is not possible to equalize through LayerNorm/BatchNorm as sink + supported_sinks = tuple([ + x for x in _supported_layers if x not in (nn.LayerNorm, *_batch_norm)]) + regions = _extract_regions( + graph_model, state_impl_kwargs={'supported_sinks': supported_sinks}) if len(regions) > 0: graph_model = _equalize( graph_model, @@ -1099,7 +1167,15 @@ def __init__( self.hooked_modules = set() self.add_mul_node = add_mul_node self.co_optimize_act_weights = co_optimize_act_weights - self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) + + # It is not possible to equalize through LayerNorm/BatchNorm as sink + supported_sinks = tuple([ + x for x in _supported_layers if x not in (nn.LayerNorm, *_batch_norm)]) + self.regions = _extract_regions( + model, + add_mul_node=add_mul_node, + return_acts=True, + state_impl_kwargs={'supported_sinks': supported_sinks}) if self.scale_computation_type == 'maxabs': self.scale_fn = _channel_maxabs @@ -1186,3 +1262,330 @@ def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): self.model.add_module(mul_factor_name, mul_factor) rewriter = InsertModuleCallAfter(mul_factor_name, act_node) rewriter.apply(self.model) + + +def _apply_had_device(tensor, had_K, K): + is_cuda = 'cuda' in str(tensor.device) and torch.version.cuda is not None + # Accelerated kernel only available for CUDA + if is_cuda and fast_hadamard_transform is not None: + return matmul_hadU_cuda(tensor, had_K, K) + else: + return matmul_hadU(tensor) + + +def _apply_ort_device(tensor, ort, *args): + ort = ort.type_as(tensor) + return torch.matmul(tensor, ort) + + +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +def random_orthogonal_matrix(size): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'): + rewriters = [] + for region in regions: + insert_rotation_module = len(region.srcs) == 0 + + if not insert_rotation_module and not region.is_valid: + continue + hidden_dim = region.max_shape_sinks + if not insert_rotation_module and full_rotation_method == 'ort': + rot_mat = random_orthogonal_matrix(hidden_dim) + K = None + rot_func = _apply_ort_device + else: + try: + # Build hadamard rotation matrix + rot_mat, K = get_hadK(hidden_dim) + rot_func = _apply_had_device + except AssertionError as e: + print(f"Incomptible shapes {hidden_dim}") + if not insert_rotation_module: + print("Falling back to orthogonal matrices") + rot_mat = random_orthogonal_matrix(hidden_dim) + K = None + rot_func = _apply_ort_device + print("Skipping layers") + continue + + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + axis = _get_output_axis(module) + weight = module.weight.data + + if axis == 0: + weight = rot_func(weight.t(), rot_mat, K).t() + elif axis == 1: + weight = rot_func(weight, rot_mat, K) + else: + raise RuntimeError("Not supported yet") + module.weight.data = weight + + if getattr(module, 'bias', None) is not None: + bias = module.bias.data + bias = rot_func(bias, rot_mat, K) + module.bias.data = bias + if hasattr(module, 'offload_params'): + module.offload_params(module) + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + axis = _get_input_axis(module) + weight = module.weight.data + + if axis == 1: + weight = rot_func(weight, rot_mat, K) + elif axis == 0: + weight = rot_func(weight.t(), rot_mat, K).t() + else: + raise RuntimeError("Not supported yet") + + module.weight.data = weight + if hasattr(module, 'offload_params'): + module.offload_params(module) + + if insert_rotation_module and len(region.srcs) == 0: + # print(name, module.in_features, K) + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriters.append(rewriter) + for r in rewriters: + model = r.apply(model) + return rewriters + + +def _replace_bias(next_module, new_bias): + new_bias = new_bias.view(-1) + if next_module.bias is not None: + next_module.bias.data.copy_(new_bias) + else: + new_bias = new_bias.to(next_module.weight.device).to(next_module.weight.dtype) + next_module.register_parameter('bias', nn.Parameter(new_bias)) + + +def _merge_ln(layer_norm, next_module, scale_bias_by_weight): + view_shape = (1, -1) + # Merge weight + if scale_bias_by_weight and hasattr(layer_norm, 'bias'): + layer_norm.bias.data /= layer_norm.weight.data + # We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor + scale = layer_norm.weight.data.view(view_shape).expand_as(next_module.weight) + next_module.weight = nn.Parameter(next_module.weight.clone() * scale) + + # Merge bias, new_bias includes the bias of next_module by going through its fwd + if hasattr(layer_norm, 'bias'): + inp = layer_norm.bias.data.view(view_shape) + new_bias = next_module(inp) + _replace_bias(next_module, new_bias) + + +class RotationEqualization(GraphTransform): + + def __init__(self) -> None: + super(RotationEqualization, self).__init__() + + def find_module(self, model, regions: List, prefix=''): + """ + Iterate through the model looking at immediate children of every module to look for supported modules. + This allows us to stop the search when we meet a top-level module that is supported. + """ + if isinstance(model, self.supported_sinks): + if self.blacklist_layers is not None and prefix in self.blacklist_layers: + return + weight = get_weight_sink(model) + eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) + region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) + regions.append(region) + else: + for name, module in model.named_children(): + full_name = prefix + '.' + name if prefix != '' else name + self.find_module(module, regions, full_name) + + +class GraphRotationEqualization(RotationEqualization): + + def __init__( + self, + blacklist_layers: Optional[List[str]] = None, + orphan_sink: bool = False, + rotate_matmul: bool = False, + full_rotation_method: str = 'had', + return_rewriters: bool = False) -> None: + super(GraphRotationEqualization, self).__init__() + + self.supported_srcs = (nn.Linear, nn.Embedding) + self.supported_sinks = (nn.Linear) + common_scale_invariant = list(_scale_invariant_layers) + common_scale_invariant.remove(torch.nn.ReLU) + common_scale_invariant.remove(torch.nn.LeakyReLU) + self.scale_invariant_layers = tuple(common_scale_invariant) + (RMSNorm,) + self.scale_invariant_function = () + self.blacklist_layers = blacklist_layers + self.orphan_sink = orphan_sink + self.rotate_matmul = rotate_matmul + self.full_rotation_method = full_rotation_method + self.return_rewriters = return_rewriters + + def rotate_matmuls(self, graph_module): + matmul_nodes = list(graph_module.graph.nodes) + matmul_nodes = [c for c in matmul_nodes if c.name == 'matmul'] + for node in matmul_nodes: + with graph_module.graph.inserting_before(node): + matmul_arg0 = graph_module.graph.call_function( + functional_rotate_input, args=(node.args[0],)) + matmul_arg1 = graph_module.graph.call_function( + functional_rotate_input, args=(node.args[1],), kwargs={'transpose': True}) + args = list(node.args) + args[0] = matmul_arg0 + args[1] = matmul_arg1 + node.args = tuple(args) + + graph_module.recompile() + graph_module.graph.lint() + + def apply(self, + graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + + regions = _extract_regions( + graph_model, + state_impl_kwargs={ + 'supported_srcs': self.supported_srcs, + 'supported_sinks': self.supported_sinks, + 'scale_invariant_layers': self.scale_invariant_layers, + 'scale_invariant_function': self.scale_invariant_function}) + eq_layers = set() + orphan_regions = [] + self.find_module(graph_model, orphan_regions) + for r in regions: + id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] + eq_layers.update(id_list) + if self.orphan_sink: + for o_r in orphan_regions: + # Layerwise have only a single sink named 'sinks0' + id_sink = id(o_r.get_module_from_name('sinks0')) + if id_sink not in eq_layers: + regions.append(o_r) + if self.rotate_matmul: + self.rotate_matmuls(graph_model) + if len(regions) > 0: + rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + if self.return_rewriters: + return graph_model, rewriters + else: + return graph_model + + +class LayerNormToRMS(GraphTransform): + + def __init__(self, return_rewriters=False) -> None: + super(LayerNormToRMS, self).__init__() + self.supported_srcs = (nn.Linear, nn.Embedding) + self.supported_sinks = (nn.LayerNorm) + self.return_rewriters = return_rewriters + assert RMSNorm is not object, 'Update your Pytorch version to 2.4+' + + def apply(self, graph_model: GraphModule) -> GraphModule: + regions = _extract_regions( + graph_model, + state_impl_kwargs={ + 'supported_srcs': self.supported_srcs, 'supported_sinks': self.supported_sinks}) + + rewriters = [] + if len(regions) > 0: + for region in regions: + for src in region.srcs: + linear = region.get_module_from_name(src) + if isinstance(linear, torch.nn.Embedding): + dim = -1 + else: + dim = -2 + linear_dtype = linear.weight.data.dtype + W_ = linear.weight.data.double() + linear.weight.data = W_ - W_.mean(dim=dim, keepdim=True) + linear.weight.data = linear.weight.data.to(linear_dtype) + if hasattr(linear, 'bias') and linear.bias is not None: + b_ = linear.bias.data.double() + linear.bias.data = b_ - b_.mean() + linear.bias.data = linear.bias.data.to(linear_dtype) + for sink in region.sinks: + layer_norm = region.get_module_from_name(sink) + del layer_norm.bias + layer_norm_dtype = layer_norm.weight.data.dtype + rewriters.append( + ModuleToModuleByInstance(layer_norm, RMSNorm, dtype=layer_norm_dtype)) + for r in rewriters: + graph_model = r.apply(graph_model) + if self.return_rewriters: + return graph_model, rewriters + else: + return graph_model + + +class MergeLnAffine(GraphTransform): + + def __init__(self) -> None: + super(MergeLnAffine, self).__init__() + self.supported_srcs = (RMSNorm, nn.LayerNorm) + self.supported_sinks = (nn.Linear) + + def apply(self, graph_model: GraphModule) -> GraphModule: + regions = _extract_regions( + graph_model, + state_impl_kwargs={ + 'supported_srcs': self.supported_srcs, 'supported_sinks': self.supported_sinks}) + + if len(regions) > 0: + scaled_biases = set() + for region in regions: + layernorm_module_name = next(iter(region.srcs)) + layernorm_module = region.get_module_from_name(layernorm_module_name) + if not layernorm_module.elementwise_affine: + continue + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + scale_bias = id(module) not in scaled_biases + _merge_ln(layernorm_module, module, scale_bias_by_weight=scale_bias) + + scaled_biases.add(id(module)) + layernorm_module.weight.data.fill_(1.) + if hasattr(layernorm_module, 'bias'): + layernorm_module.bias.data.fill_(0.) + return graph_model + + +class LayerwiseActivationRotation(RotationEqualization): + + def __init__(self, blacklist_layer=None): + super(GraphTransform, self).__init__() + + self.supported_sinks = (nn.Linear) + self.blacklist_layers = blacklist_layer + + def apply(self, model: nn.Module) -> nn.Module: + regions: List[Region] = [] + self.find_module(model, regions) + if len(regions) > 0: + _apply_rotate(model, regions) + return model diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py new file mode 100644 index 000000000..235e22567 --- /dev/null +++ b/src/brevitas/graph/hadamard.py @@ -0,0 +1,168 @@ +# This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). +# Licensed under Apache License 2.0. + +import math +import os +import pathlib + +try: + import fast_hadamard_transform +except: + fast_hadamard_transform = None +import torch + +# Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py + + +def get_hadK(n, transpose=False): + parent = pathlib.Path(os.path.abspath(__file__)).parent + # hadamard matrices for had12, had36.pal2, had52,will, + # # had60.pal, had108.pal, had140.pal, had156.will, had172.will: + # http://www.neilsloane.com/hadamard/index.html + tensors = torch.load(str(parent) + '/hadamard_tensors.pt') + tensors = {k: v.to(torch.float) for k, v in tensors.items()} + hadK, K = None, None + if n % 172 == 0: # llama-2-7b up + assert (is_pow2(n // 172)) + K = 172 + hadK = tensors['get_had172'].T if transpose else tensors['get_had172'] + elif n % 156 == 0: # llama-1-30b 3x hidden + assert (is_pow2(n // 156)) + K = 156 + hadK = tensors['get_had156'].T if transpose else tensors['get_had156'] + elif n % 140 == 0: # llama-1-30b intermediate + assert (is_pow2(n // 140)) + K = 140 + hadK = tensors['get_had140'].T if transpose else tensors['get_had140'] + elif n % 108 == 0: # llama-1-13b intermediate + assert (is_pow2(n // 108)) + K = 108 + hadK = tensors['get_had108'].T if transpose else tensors['get_had108'] + elif n % 60 == 0: # llama-1-13b 3x hidden + assert (is_pow2(n // 60)) + K = 60 + hadK = tensors['get_had60'].T if transpose else tensors['get_had60'] + elif n % 52 == 0: # llama-1-13b 1x hidden + assert (is_pow2(n // 52)) + K = 52 + hadK = tensors['get_had52'].T if transpose else tensors['get_had52'] + elif n % 36 == 0: + assert (is_pow2(n // 36)) + K = 36 + hadK = tensors['get_had36'].T if transpose else tensors['get_had36'] + elif n % 28 == 0: + assert (is_pow2(n // 28)) + K = 28 + hadK = tensors['get_had28'].T if transpose else tensors['get_had28'] + elif n % 40 == 0: + assert (is_pow2(n // 40)) + K = 40 + hadK = tensors['get_had40'].T if transpose else tensors['get_had40'] + elif n % 20 == 0: + assert (is_pow2(n // 20)) + K = 20 + hadK = tensors['get_had20'].T if transpose else tensors['get_had20'] + elif n % 12 == 0: + assert (is_pow2(n // 12)) + K = 12 + hadK = tensors['get_had12'].T if transpose else tensors['get_had12'] + else: + assert (is_pow2(n)) + K = 1 + + return hadK, K + + +def matmul_hadU(X, transpose=False): + n = X.shape[-1] + hadK, K = get_hadK(n, transpose) + input = X.clone().view(-1, n, 1) + output = input.clone() + while input.shape[1] > K: + input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) + output = output.view(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.view(input.shape[0], input.shape[1], -1) + (input, output) = (output, input) + del output + + if K > 1: + # Do not explicitly repeat - OOM + # input = torch.bmm( + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) + # Use bcast instead + input = hadK.view(1, K, K).to(input) @ input + + return input.view(X.shape) / torch.tensor(n).sqrt() + + +def matmul_hadUt(X): + return matmul_hadU(X, transpose=True) + + +def random_hadamard_matrix(size, device): + # See https://github.com/Cornell-RelaxML/quip-sharp , Section "Randomized Hadamard Transformation" + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return matmul_hadU(Q).to(device) + + +def matmul_hadU_cuda(X, hadK, K): + n = X.shape[-1] + if K == 1: + return fast_hadamard_transform.hadamard_transform( + X.contiguous(), 1.0 / torch.tensor(n).sqrt()) + # if transpose: + # hadK = hadK.T.contiguous() + input = X.view(*X.shape[:-1], K, n // K) + input = fast_hadamard_transform.hadamard_transform( + input.contiguous(), 1.0 / torch.tensor(n).sqrt()) + input = hadK.to(input.device).to(input.dtype) @ input + return input.reshape(X.shape) + + +def matmul_hadUt_cuda(X, hadK, K): + return matmul_hadU_cuda(X, hadK, K, transpose=True) + + +def apply_exact_had_to_linear(module, had_dim=-1, output=False): + assert isinstance(module, torch.nn.Linear) + in_features, out_features = module.in_features, module.out_features + + if had_dim != -1: + assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" + + W_ = module.weight.data + dtype = W_.dtype + dev = W_.device + init_shape = W_.shape + W_ = W_.float().cuda() + + if had_dim == -1: + if output: + had_K, K = get_hadK(out_features) + W_ = matmul_hadU_cuda(W_.t(), had_K, K).t() + if not output: + had_K, K = get_hadK(in_features) + W_ = matmul_hadU_cuda(W_, had_K, K) + else: + # Apply Hadamard to the last had_dim chunks of the weights + if output: + W_ = W_.t() + transposed_shape = W_.shape + W_ = fast_hadamard_transform.hadamard_transform( + W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim), + scale=1 / math.sqrt(had_dim)).reshape(transposed_shape).t() + else: + raise NotImplementedError("Not implemented (or tested) yet!") + n = W_.shape[1] + W_ = hadamard_transform( + W_.reshape(-1, n // had_dim, had_dim), + scale=1 / math.sqrt(had_dim)).reshape(init_shape) + module.weight.data = W_.to(device=dev, dtype=dtype) + + +def is_pow2(n): + return (n & (n - 1) == 0) and (n > 0) diff --git a/src/brevitas/graph/hadamard_tensors.pt b/src/brevitas/graph/hadamard_tensors.pt new file mode 100644 index 000000000..e45a68538 Binary files /dev/null and b/src/brevitas/graph/hadamard_tensors.pt differ diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 7093c8c17..8413a8208 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -2,8 +2,16 @@ import torch +from brevitas.graph.hadamard import get_hadK +from brevitas.graph.hadamard import matmul_hadU +from brevitas.graph.hadamard import matmul_hadU_cuda from brevitas.nn.quant_mha import QuantMultiheadAttention +try: + import fast_hadamard_transform +except: + fast_hadamard_transform = None + INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states'] @@ -41,3 +49,45 @@ def forward(self, *args, **kwargs): # We convert everything to args so that hooks can work correctly out = self.layer(*kwargs.values()) return out + + +class RotatedModule(torch.nn.Module): + + def __init__(self, layer, had_mat=None, k=None) -> None: + super().__init__() + if had_mat is not None: + self.had_mat = torch.nn.Parameter(had_mat).cpu() + else: + self.had_mat = None + self.layer = layer + self.k = k + + def forward(self, inp, **kwargs): + is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None + if is_cuda and fast_hadamard_transform is not None: + if self.had_mat is None or self.k is None: + had_K, K = get_hadK(inp.shape[-1]) + else: + had_K = self.had_mat + K = self.k + inp = matmul_hadU_cuda(inp, had_K, K) + else: + inp = matmul_hadU(inp) + o = self.layer(inp) + + return o + + +def functional_rotate_input(inp, transpose=False): + is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None + if transpose: + inp = inp.t() + if is_cuda and fast_hadamard_transform is not None: + had_K, K = get_hadK(inp.shape[-1]) + inp = matmul_hadU_cuda(inp, had_K, K) + else: + inp = matmul_hadU(inp) + + if transpose: + inp = inp.t() + return inp diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 5cd067e64..de55258db 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -16,6 +16,7 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren ```bash usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] + [--gpxq-block-name GPXQ_BLOCK_NAME] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse,hqo}] [--weight-scale-precision {float_scale,po2_scale}] @@ -53,7 +54,10 @@ options: --seqlen SEQLEN Sequence length. Default: 2048. --eval Eval model PPL on the chosen Dataset. --dataset {wikitext2,c4} - Dataset to use for quantization (default: wikitext2) + Dataset to use for quantization (default: c4) + --gpxq-block-name GPXQ_BLOCK_NAME + Block name for faster GPxQ optimization. It works only + if FX is not needed (default: None) --weight-bit-width WEIGHT_BIT_WIDTH Weight bit width. Default: 8. --weight-param-method {stats,mse,hqo} @@ -121,6 +125,7 @@ options: --act-calibration Apply activation calibration. --bias-corr Apply bias correction. --ln-affine-merge Merge LN affine params. + --replace-rmsnorm Replace HF RMSNorms with Torch one. --no-quantize Disable quantization. --no-float16 Disable float16 as base datatype and switch to float32. @@ -129,6 +134,12 @@ options: --weight-equalization Apply weight equalization. Relevant to ReLU based models (e.g. OPT). + --graph-rotation Apply graph rotation equalization + --graph-rotation-mode {had,ort} + If GraphRotation is enabled, decide how to compute the + random rotation matrix that is fully fused. Online or + partial rotation will always be Hadamard + --layerwise-rotation Apply layerwise rotation equalization --act-equalization {None,layerwise,fx} Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,while fx merges them diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py index 7ac39347f..75e31826a 100644 --- a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -3,13 +3,34 @@ SPDX-License-Identifier: MIT """ +from packaging import version import torch from torch import nn -from brevitas.graph.equalize import _is_reshaping_op +from brevitas import torch_version +from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.equalize import _is_scale_invariant_module +from brevitas.graph.equalize import LayerNormToRMS +from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.utils import get_module -from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 + + +def replace_rmsnorm_with_torch(model, config): + assert torch_version >= version.parse('2.4'), "torch.nn.RMSNorm requires torch 2.4 or greater" + set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__) + dtype = next(model.parameters()).dtype + rewriters = [ + ModuleToModuleByClass( + rms_cls, + torch.nn.RMSNorm, + normalized_shape=config.hidden_size, + eps=config.rms_norm_eps, + dtype=dtype) for rms_cls in set_of_layers] + dtype = next(iter(model.parameters())).dtype + for r in rewriters: + model = r.apply(model) + model = model.to(dtype) + return model def replace_bias(next_module, new_bias): @@ -49,7 +70,7 @@ def merge_layernorm_affine_params(graph_model): module = get_module(graph_model, node.target) if isinstance(module, nn.LayerNorm): for next in node.users: - while (_is_reshaping_op(next) or _is_scale_invariant_module(graph_model, next)): + while (_is_scale_invariant_module(graph_model, next)): next = node.next if next.op == 'call_module': next_module = get_module(graph_model, next.target) @@ -83,8 +104,13 @@ def merge_layernorm_affine_params(graph_model): @torch.no_grad() -def apply_layernorm_affine_merge(graph_model, dtype): - # We can't do fp16 tracing on CPU as many kernels are not implemented - # So we have to cast to fp32 first, trace, apply merging, and then cast back - with cast_to_float32(graph_model, dtype): - merge_layernorm_affine_params(graph_model) +def apply_layernorm_affine_merge(graph_model): + eq = MergeLnAffine() + graph_model = eq.apply(graph_model) + return graph_model + + +@torch.no_grad() +def apply_layernorm_to_rmsnorm(graph_model, return_rewriters=False): + eq = LayerNormToRMS(return_rewriters) + return eq.apply(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index bdc1b3a1e..44ba711a5 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -32,11 +32,13 @@ from brevitas.fx.value_tracer import ValueProxy -def get_fx(model): +def get_fx(model, is_export=True): forward_signature = inspect.signature(model.forward).parameters if all(input_name in forward_signature for input_name in ["input_ids", "attention_mask", "past_key_values"]): input_names = ["input_ids", "attention_mask", "past_key_values"] + if not is_export: + input_names.remove('past_key_values') else: raise ValueError( f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}" @@ -106,3 +108,17 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): args, kwargs = tree_map(self.cast_cpu_float32, (args, kwargs)) out = func(*args, **kwargs) return out + + +# This functions remap rewriters so match modules in a potentially different model that shares the same underlying tensors +# We rely on the fact that two versions of the same model (eager vs FX) might have different modules id (id(fx_module) != id (eager_module)) +# However, the underlying tensors are still shared, so we can recostruct the mapping between the two +# modules. +def fix_rewriter(rewriters, old_model_ref, tensor_name): + for r in rewriters: + tensor_id = id(r.old_module_instance.weight) + module = [ + m for m in old_model_ref.modules() + if hasattr(m, tensor_name) and id(m.weight) == tensor_id] + r.old_module_instance = module[0] + return rewriters diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f40a367e1..f0613b3f6 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +from copy import deepcopy import sys from warnings import warn @@ -11,10 +12,14 @@ import torch from transformers import AutoModelForCausalLM from transformers import AutoTokenizer +from transformers.utils.fx import _SUPPORTED_MODELS from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.equalize import GraphRotationEqualization +from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize +from brevitas.graph.utils import get_module from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.generative.quantize import generate_quant_maps @@ -30,12 +35,42 @@ from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq from brevitas_examples.llm.llm_quant.gpxq import apply_gptq from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm +from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 +from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def fused_rotation_no_fx(model, calibration_loader, args): + with torch.no_grad(): + new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) + apply_layernorm_affine_merge(new_model) + new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) + rewriters = fix_rewriter(rewriters, model, 'weight') + + for r in rewriters: + r.apply(model) + new_model = offload_model(new_model) + eq = GraphRotationEqualization( + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.graph_rotation_mode, + return_rewriters=True) + new_model, rewriters = eq.apply(new_model) + rewriters = fix_rewriter(rewriters, model, 'weight') + + for r in rewriters: + r.apply(model) + remove_hooks(new_model) + + def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) @@ -69,6 +104,10 @@ def model_export(model, ref_input, args): def validate(args): + if args.graph_rotation == 'fx': + assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters' + assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' + assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm' if not args.no_quantize: if args.gptq and args.gpfq: warn("Both GPTQ and GPFQ are enabled.") @@ -157,7 +196,7 @@ def main(args): with CastFloat16ToFloat32(): apply_awq(model, awq_results) - require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge else False + require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm else False # Load the data for calibration and evaluation. calibration_loader = get_dataset_for_model( @@ -168,7 +207,7 @@ def main(args): seqlen=args.seqlen, split="train", seed=args.seed, - require_fx=require_fx, + require_fx=require_fx and args.export_target is not None, device=None, fuse_sequences=args.fuse_sequences) @@ -180,7 +219,7 @@ def main(args): seqlen=args.seqlen, split="validation", seed=args.seed, - require_fx=require_fx, + require_fx=require_fx and args.export_target is not None, device=None, fuse_sequences=args.fuse_sequences) @@ -196,8 +235,15 @@ def main(args): remove_hooks(model) print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + if args.replace_rmsnorm: + model = replace_rmsnorm_with_torch(model, model.config) + if require_fx: - model = get_fx(model) + if model.__class__.__name__ in _SUPPORTED_MODELS and not args.replace_rmsnorm: + model = get_fx(model, is_export=args.export_target is not None) + else: + with torch.no_grad(): + model, guards = torch._dynamo.export(model)(**calibration_loader[0]) # Blockwise optimization does not work with FX at the moment args.gpxq_block_name = None @@ -205,9 +251,28 @@ def main(args): # since currently there is support only for merging into Linear if args.ln_affine_merge: print("Apply LN affine merge...") - apply_layernorm_affine_merge(model, dtype) + apply_layernorm_affine_merge(model) print("LN affine merge applied.") + if args.convert_layernorm_to_rmsnorm: + print("Convert LayerNorm to RMSNorm...") + apply_layernorm_to_rmsnorm(model) + print("Layernorm To RMSNorm applied.") + + if args.graph_rotation == 'fx': + assert args.ln_affine_merge + assert args.replace_rmsnorm + model = offload_model(model) + eq = GraphRotationEqualization( + orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode) + model = eq.apply(model) + remove_hooks(model) + elif args.graph_rotation == 'layerwise': + eq = LayerwiseActivationRotation() + model = eq.apply(model) + elif args.graph_rotation == 'fused_no_fx': + fused_rotation_no_fx(model, calibration_loader, args) + # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations if args.replace_mha: @@ -268,7 +333,20 @@ def main(args): input_quant_format=args.input_quant_format, quantize_embedding=False) if not args.quantize_last_layer: - name_blacklist += ["lm_head", "embed_out"] + if require_fx: + last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] + last_module = get_module(model, last_node.target) + last_layer_kwargs = layer_map[type(last_module)][1] + prev_weight_quant = deepcopy(last_layer_kwargs['weight_quant']) + prev_input_quant = deepcopy(last_layer_kwargs['input_quant']) + weight_quant = lambda module: prev_weight_quant if id(module) != id( + last_module) else None + input_quant = lambda module: prev_input_quant if id(module) != id( + last_module) else None + last_layer_kwargs['weight_quant'] = weight_quant + last_layer_kwargs['input_quant'] = input_quant + else: + name_blacklist += ["lm_head", "embed_out"] model = layerwise_quantize( model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) # Tie back first/last layer weights in case they got untied @@ -497,6 +575,10 @@ def parse_args(args): '--act-calibration', action='store_true', help='Apply activation calibration.') parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.') + parser.add_argument( + '--convert-layernorm-to-rmsnorm', action='store_true', help='Merge LN affine params.') + parser.add_argument( + '--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.') parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.') parser.add_argument( '--no-float16', @@ -510,6 +592,25 @@ def parse_args(args): '--weight-equalization', action='store_true', help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') + parser.add_argument( + '--graph-rotation', + type=str, + default=None, + choices=['fx', 'layerwise', 'fused_no_fx'], + help='Apply graph rotation equalization') + parser.add_argument( + '--graph-rotation-mode', + default='had', + choices=['had', 'ort'], + help= + 'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard' + ) + parser.add_argument( + '--rotation-orphan-sink', + action="store_true", + help= + 'If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers' + ) parser.add_argument( '--act-equalization', default=None, diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 2719b48a0..035cdaadd 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -491,3 +491,40 @@ def forward(self, x): toy_quant_model = fixture_union( 'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures) + +## List of Rotation fixtures + + +@pytest_cases.fixture +def linear_rms(): + + class LinearRMSModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 4, bias=True) + self.linear.weight.data.fill_(2.) + self.linear.bias.data.fill_(1.) + self.rms = nn.RMSNorm(4) + self.rms.weight.data = torch.randn_like( + self.rms.weight.data) # Change learned parameters + self.linear_1 = nn.Linear(4, 8, bias=False) + self.linear_1.weight.data.fill_(2.) + self.linear_2 = nn.Linear(8, 8, bias=False) + + def forward(self, x): + x = self.linear(x) + x = self.rms(x) + x = self.linear_1(x) + x = self.linear_2(x) * x + x = torch.matmul(x.flatten(1), x.flatten(1).t()) + + return x + + return LinearRMSModel + + +list_of_rotation_mixtures = ['linear_rms'] + +rotation_fixtures = fixture_union( + 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 89759b41a..afb8636e4 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -10,10 +10,14 @@ from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions from brevitas.graph.equalize import _is_supported_module +from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.equalize import GraphRotationEqualization +from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module +from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -28,14 +32,14 @@ def test_resnet18_equalization(): expected_out = model(inp) model_orig = copy.deepcopy(model) - regions = _extract_regions(model) + supported_sinks = list(_supported_layers) + supported_sinks = tuple([ + x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)]) + regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks}) _ = equalize_test( regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') out = model(inp) - # Check that equalization is not introducing FP variations - assert torch.allclose(expected_out, out, atol=ATOL) - regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) equalized_layers = set() @@ -58,6 +62,9 @@ def test_resnet18_equalization(): orig_module = get_module(model_orig, layer) assert not torch.allclose(eq_module.weight, orig_module.weight) + # Check that equalization is not introducing FP variations + assert torch.allclose(expected_out, out, atol=ATOL) + @pytest_cases.parametrize("merge_bias", [True, False]) def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool): @@ -73,7 +80,10 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool expected_out = model(inp) - regions = _extract_regions(model) + supported_sinks = list(_supported_layers) + supported_sinks = tuple([ + x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)]) + regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks}) scale_factor_regions = equalize_test( regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] @@ -126,7 +136,10 @@ def test_models(toy_model, merge_bias, request): expected_out = model(inp) model = symbolic_trace(model) - regions = _extract_regions(model) + supported_sinks = list(_supported_layers) + supported_sinks = tuple([ + x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)]) + regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks}) scale_factor_regions = equalize_test( regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] @@ -225,3 +238,41 @@ def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool): # Check that at least one region performs "true" equalization # If all shapes are scalar, no equalization has been performed assert any([shape != () for shape in shape_scale_regions]) + + +@requires_pt_ge('2.4') +@pytest_cases.parametrize('partial_had', [True, False]) +def test_models(rotation_fixtures, partial_had): + + in_shape = IN_SIZE_LINEAR + + model_class = rotation_fixtures + model = model_class() + inp = torch.ones(in_shape) + + model.eval() + penultimate_weight = model.linear_1.weight.data + last_weight = model.linear_2.weight.data + with torch.no_grad(): + expected_out = model(inp) + + model = symbolic_trace(model) + merge = MergeLnAffine() + model = merge.apply(model) + eq = GraphRotationEqualization(orphan_sink=partial_had) + model = eq.apply(model) + + with torch.no_grad(): + out = model(inp) + + penultimate_weight_new = model.linear_1.weight.data + + # Invariance of the output + assert torch.allclose(out, expected_out, atol=ATOL) + # Rotate weights must be different + assert not torch.allclose(penultimate_weight, penultimate_weight_new) + # Merging affine parameters of RMS + assert torch.allclose(model.rms.weight.data, torch.ones_like(model.rms.weight.data)) + if partial_had: + last_weight_new = model.linear_2.layer.weight.data + assert not torch.allclose(last_weight, last_weight_new) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 90981df29..105a1ea8b 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -5,15 +5,18 @@ from dataclasses import dataclass import logging import os +import platform import shutil import numpy as np import onnx +from packaging import version import pytest import pytest_cases import torch from brevitas import config +from brevitas import torch_version # LLM example depends on optimum-amd, which requires PyTorch>=2.2 from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args @@ -277,7 +280,9 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "mistral-fp8_fnuz", "llama-mxfp8", "llama-int8-act_equalization=layerwise", - "mistral-int8-quant-last-layer",], + "mistral-int8-quant-last-layer", + "llama-rotation-mixed-fx", + "llama-rotation-full-fx",], params=[ { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", @@ -366,7 +371,35 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "quantize_last_layer": True, "exp_layer_types": { - "lm_head": ""}},]) + "lm_head": ""}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "ln_affine_merge": True, + "replace_rmsnorm": True, + "quantize_last_layer": True, + "no_quantize": True, + "rotation_orphan_sink": True, + "convert_layernorm_to_rmsnorm": True, + "graph_rotation": "fx", + "exp_layer_types": { + "L__self___model_layers_0_self_attn_k_proj": + "", + "L__self___model_layers_0_self_attn_o_proj": + ""}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "ln_affine_merge": True, + "replace_rmsnorm": True, + "quantize_last_layer": True, + "no_quantize": True, + "rotation_orphan_sink": False, + "convert_layernorm_to_rmsnorm": True, + "graph_rotation": "fx", + "exp_layer_types": { + "L__self___model_layers_0_self_attn_k_proj": + "", + "L__self___model_layers_0_self_attn_o_proj": + ""}},]) def layer_args(default_run_args, request): args = default_run_args layer_dict = request.param @@ -381,6 +414,12 @@ def layer_args(default_run_args, request): def test_small_models_quant_layer(caplog, layer_args): caplog.set_level(logging.INFO) args, exp_layer_types = layer_args + if args.replace_rmsnorm: + if torch_version < version.parse('2.4'): + pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater") + if hasattr(args, 'graph_rotation') and args.graph_rotation == 'fx' and platform.system( + ) == 'Windows': + pytest.skip("Skipping dynamo + windows") float_ppl, quant_ppl, model = validate_args_and_run_main(args) assert_layer_types(model, exp_layer_types)