diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index bbbbc27bb..d594c851d 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -296,6 +296,15 @@ class FnToModule(CallableToModule): def match_node(self, node: Node) -> bool: return node.op == 'call_function' and node.target is self.old_callable + def move_node_args_to_kwargs(self, node: Node): + super().move_node_args_to_kwargs(node) + # Moving to stateful modules, we remove the 'training' argument if it is passed to the + # functional version of the layer since it is not needed anymore + kwargs = dict(node.kwargs) + if 'training' in kwargs: + del kwargs['training'] + node.kwargs = immutable_dict(kwargs) + class MethodToModule(CallableToModule): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0012e9065..f753eba8c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -27,8 +27,10 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] -EPSILON = 1e-9 -FLOAT16_EPSILON = 2e-5 +# TODO: if we are able to run activation equalization in GPU + float16, we could have two separate +# epsilon factors for float16 (2e-5) vs float32/bfloat16 (1e-9). At the moment we are tied to one +# single epsilon for both cases. +EPSILON = 2e-5 _supported_layers = ( nn.ConvTranspose1d, @@ -74,6 +76,8 @@ _batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) +_ignore_ops = (getattr, 'size', 'contiguous') + # Required for being hashable @dataclass(eq=True, frozen=True) @@ -280,8 +284,7 @@ def _combine_weights_bias( weight = weight.data.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) - epsilon = FLOAT16_EPSILON if weight.dtype == torch.float16 else EPSILON - weight = torch.where(torch.abs(weight) < epsilon, torch.tensor(epsilon).type_as(weight), weight) + weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight) factor = torch.abs(bias) / torch.abs(weight) # From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450 @@ -336,7 +339,6 @@ def _cross_layer_equalization( # Determine device and type of tensors device = next(sinks[0].parameters()).device dtype = next(sinks[0].parameters()).dtype - epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed @@ -401,7 +403,6 @@ def _no_equalize(): scale_fn = _select_scale_computation_fn(scale_computation_type) sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) - sinks_range = torch.clamp(sinks_range, epsilon) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization @@ -434,10 +435,16 @@ def _no_equalize(): "Detected source and sink with non compatible shapes, equalization is skipped") return _no_equalize() + # Instead of clipping very low values, which would cause their reciprocal to be very large + # thus hindering quantization, we set them to one, which is the no-op equivalent for equalization + sinks_range = torch.where( + sinks_range > EPSILON, sinks_range, torch.tensor(1., dtype=dtype, device=device)) + srcs_range = torch.where( + srcs_range > EPSILON, srcs_range, torch.tensor(1., dtype=dtype, device=device)) srcs_range = torch.pow(srcs_range, alpha) + sinks_range = torch.pow(sinks_range, 1 - alpha) scaling_factors = srcs_range / sinks_range - scaling_factors = torch.clamp(scaling_factors, epsilon) inverse_scaling_factors = torch.reciprocal(scaling_factors) if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: @@ -458,8 +465,8 @@ def _no_equalize(): torch.reshape(inverse_scaling_factors, src_broadcast_size)), attr='weight') for module, axis in sink_axes.items(): - src_broadcast_size = [1] * module.weight.ndim - src_broadcast_size[axis] = module.weight.size(axis) + sink_broadcast_size = [1] * module.weight.ndim + sink_broadcast_size[axis] = module.weight.size(axis) if isinstance(module, _batch_norm): # We re-compute the bias as function of running_mean and running_var to adjust the # additive factor for equalization. @@ -469,7 +476,7 @@ def _no_equalize(): module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias') _update_weights( module, - module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size), + module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size), attr='weight') return scaling_factors @@ -578,6 +585,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, node.op == 'call_function' and node.target in _residual_fns): find_srcs(graph_model, node, state) find_sinks(graph_model, node, state) + elif node.target in _ignore_ops: + continue else: # If we meet an unrecognized op, we add None to invalidate the region state.srcs.add(_UNSUPPORTED_OP) @@ -609,6 +618,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, node.op == 'call_function' and node.target in _residual_fns): find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) + elif node.target in _ignore_ops: + continue else: # If we meet an unrecognized op, we add None to invalidate the region state.sinks.add(_UNSUPPORTED_OP) @@ -764,16 +775,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim + input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) + self.float_act_map[name] = input_scales else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], - dim=batch_dim) - self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): broadcastable_shape = [1] * len(shape) @@ -910,10 +919,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim - if name not in self.float_act_map: self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) else: @@ -921,6 +928,12 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k dim=batch_dim) self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) + input_scales = self.scale_fn(x, dim=batch_dim) + if name not in self.float_act_map: + self.float_act_map[name] = input_scales + else: + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) + def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): broadcastable_shape = [1] * len(shape) broadcastable_shape[axis] = shape[axis] diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 5ff3b6676..c3d667aee 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -107,7 +107,8 @@ class TorchFunctionalToModule(GraphTransform): nn.AvgPool1d), (F.avg_pool2d, nn.AvgPool2d), (F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d), (F.adaptive_avg_pool2d, - nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, nn.AdaptiveAvgPool3d)) + nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, + nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout)) def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP): super().__init__() diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index f3e4c3b0d..f736971a9 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -10,6 +10,8 @@ from brevitas.fx.brevitas_tracer import value_trace from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizeGraph +from brevitas.graph.standardize import DuplicateSharedStatelessModule +from brevitas.graph.standardize import TorchFunctionalToModule from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 @@ -50,6 +52,8 @@ def apply_act_equalization( # So we have to cast to fp32 first, trace, apply equalization, and then cast back with cast_to_float32(model, dtype): graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = TorchFunctionalToModule().apply(graph_model) + graph_model = DuplicateSharedStatelessModule().apply(graph_model) # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode # or an FX interpreter to run it on GPU warnings.warn( diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4d8b2c3ef..05dc06981 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -303,9 +303,14 @@ def main(): quantize_embedding=args.quantize_embedding, seqlen=args.seqlen) # Tie back first/last layer weights in case they got untied - model.tie_weights() print("Model quantization applied.") + # If any equalization has taken places, the embedding layer and the fully connected one are + # not tied anymore, and they need to be treated as standalone, separate layers. + # In all other cases we can tie them back so to preserve memory. + if args.act_equalization is None and not args.weight_equalization: + model.tie_weights() + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader, args.nsamples)