From 9f9f2599b29d70898a4b46e05e3432d91ef14e5b Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 26 Jan 2024 02:54:22 -0800 Subject: [PATCH] Fix (equalize): align cross layer equalization with channel splitting --- src/brevitas/graph/equalize.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 580e4eb24..f5d8d5534 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -100,9 +100,9 @@ class EqualizationIndexes: # Required for being hashable @dataclass(eq=True, frozen=True) -class WeightBiasTuple: - weight: nn.Module = None - bias: nn.Module = None +class WeightBiasWrapper: + weight: torch.Tensor = None + bias: torch.Tensor = None # Required for being hashable @@ -359,16 +359,16 @@ def _combine_weights_bias( return weight_bias -def transpose(module: torch.nn.Module, axis: int): +def transpose(tensor: torch.Tensor, axis: int): """ - Given a module and an axis, this function re-arranges the module's weights so that the axis and + Given a tensor and an axis, this function re-arranges the tensor so that the axis and the first dimension are swapped. """ - shape = list(range(module.weight.ndim)) + shape = list(range(tensor.ndim)) axis = shape[axis] shape.insert(0, axis) del shape[axis + 1] - return module.weight.permute(shape) + return tensor.permute(shape) def _cross_layer_equalization( @@ -430,7 +430,7 @@ def _no_equalize(): # For MultiheadAttention, we support only self-attetion if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias - module = WeightBiasTuple(module.in_proj_weight) + module = WeightBiasWrapper(module.in_proj_weight) elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None: return _no_equalize() sink_axes[name] = (module, axis) @@ -452,12 +452,12 @@ def _no_equalize(): # Check if any of the axis is None, which means that the module is not supported. # In that case, do not perform graph equalization - axes_to_check = [*src_axes.values(), *sink_axes.values()] + axes_to_check = [axis for _, axis in list(src_axes.values()) + list(sink_axes.values())] if None in axes_to_check: return _no_equalize() scale_fn = _select_scale_computation_fn(scale_computation_type) - sink_weights = {name: transpose(m, axis) for name, (m, axis) in sink_axes.items()} + sink_weights = {name: transpose(m.weight, axis) for name, (m, axis) in sink_axes.items()} srcs_range = -1 * torch.ones(max_shape_srcs, device=device, dtype=dtype) sinks_range = -1 * torch.ones(max_shape_sinks, device=device, dtype=dtype) for k, v in sink_weights.items(): @@ -480,17 +480,16 @@ def _no_equalize(): shape_0 = list_of_act_val_shapes[0] if any(shape_0 != shape for shape in list_of_act_val_shapes): return _no_equalize() - list_of_act_val = [ - transpose(WeightBiasTuple(act_val), act_axis) for act_val in list_of_act_val] + list_of_act_val = [transpose(act_val, act_axis) for act_val in list_of_act_val] srcs_range = scale_fn( torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1)) else: if merge_bias: src_weights = { - name: _combine_weights_bias(transpose(m, axis), bias_shrinkage, m.bias) + name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias) for name, (m, axis) in src_axes.items()} else: - src_weights = {name: transpose(m, axis) for name, (m, axis) in src_axes.items()} + src_weights = {name: transpose(m.weight, axis) for name, (m, axis) in src_axes.items()} for k, v in src_weights.items(): # Srcs are always fully equalized, thus we simply need to apply the offset to position them # correctly with respect to the other srcs matrices. @@ -562,7 +561,7 @@ def _no_equalize(): def _update_weights(original_module, new_value, attr='weight'): - if isinstance(original_module, WeightBiasTuple): + if isinstance(original_module, WeightBiasWrapper): setattr(getattr(original_module, attr), 'data', new_value) else: setattr(original_module, attr, nn.Parameter(new_value)) @@ -645,7 +644,7 @@ def get_weight_sink(module): transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'): raise RuntimeError("Configuration for Multiheadattention not supported") - weight = WeightBiasTuple(module.in_proj_weight).weight if isinstance( + weight = WeightBiasWrapper(module.in_proj_weight).weight if isinstance( module, nn.MultiheadAttention) else module.weight axis = _get_input_axis(module) weight = transpose(weight, axis)