diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index bbbbc27bb..d594c851d 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -296,6 +296,15 @@ class FnToModule(CallableToModule): def match_node(self, node: Node) -> bool: return node.op == 'call_function' and node.target is self.old_callable + def move_node_args_to_kwargs(self, node: Node): + super().move_node_args_to_kwargs(node) + # Moving to stateful modules, we remove the 'training' argument if it is passed to the + # functional version of the layer since it is not needed anymore + kwargs = dict(node.kwargs) + if 'training' in kwargs: + del kwargs['training'] + node.kwargs = immutable_dict(kwargs) + class MethodToModule(CallableToModule): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index a9c492b97..174552241 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -64,6 +64,8 @@ _select_op = (operator.getitem, operator.__getitem__) +_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten) + _scale_varying_activations = ( torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU) @@ -73,6 +75,8 @@ _batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) +_ignore_ops = (getattr, 'size') + # Required for being hashable @dataclass(eq=True, frozen=True) @@ -279,7 +283,8 @@ def _combine_weights_bias( weight = weight.data.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) - weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight) + weight = torch.where( + torch.abs(weight) <= EPSILON, torch.tensor(EPSILON).type_as(weight), weight) factor = torch.abs(bias) / torch.abs(weight) # From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450 @@ -398,7 +403,6 @@ def _no_equalize(): scale_fn = _select_scale_computation_fn(scale_computation_type) sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) - sinks_range = torch.clamp(sinks_range, EPSILON) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization @@ -431,10 +435,18 @@ def _no_equalize(): "Detected source and sink with non compatible shapes, equalization is skipped") return _no_equalize() + # Instead of clipping very low values, which would cause their reciprocal to be very large + # thus hindering quantization, we set both sources and sinks to one, + # which is the no-op equivalent for equalization. + channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON) + sinks_range = torch.where( + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range) + srcs_range = torch.where( + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range) + srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) scaling_factors = srcs_range / sinks_range - scaling_factors = torch.clamp(scaling_factors, EPSILON) inverse_scaling_factors = torch.reciprocal(scaling_factors) if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: @@ -455,8 +467,8 @@ def _no_equalize(): torch.reshape(inverse_scaling_factors, src_broadcast_size)), attr='weight') for module, axis in sink_axes.items(): - src_broadcast_size = [1] * module.weight.ndim - src_broadcast_size[axis] = module.weight.size(axis) + sink_broadcast_size = [1] * module.weight.ndim + sink_broadcast_size[axis] = module.weight.size(axis) if isinstance(module, _batch_norm): # We re-compute the bias as function of running_mean and running_var to adjust the # additive factor for equalization. @@ -466,7 +478,7 @@ def _no_equalize(): module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias') _update_weights( module, - module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size), + module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size), attr='weight') return scaling_factors @@ -547,9 +559,7 @@ def _is_scale_invariant_function(node: Node) -> bool: def _is_reshaping_op(node: Node) -> bool: - return ( - node.op == 'call_function' and node.target in [torch.flatten, torch.reshape] or - node.op == 'call_method' and node.target in ['view', 'reshape', 'flatten']) + return node.target in _reshaping_op def find_srcs(graph_model: GraphModule, starting_node: Node, @@ -575,6 +585,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, node.op == 'call_function' and node.target in _residual_fns): find_srcs(graph_model, node, state) find_sinks(graph_model, node, state) + elif node.target in _ignore_ops: + continue else: # If we meet an unrecognized op, we add None to invalidate the region state.srcs.add(_UNSUPPORTED_OP) @@ -606,6 +618,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, node.op == 'call_function' and node.target in _residual_fns): find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) + elif node.target in _ignore_ops: + continue else: # If we meet an unrecognized op, we add None to invalidate the region state.sinks.add(_UNSUPPORTED_OP) @@ -699,8 +713,6 @@ def find_module(self, model, regions: List): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. - Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its - Linear submodules. """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): @@ -713,7 +725,7 @@ def setup(self): for region in self.regions: batch_dim = 0 if hasattr(region, 'batch_first'): - batch_dim = 0 if region.batch_first == True else 1 + batch_dim = 0 if region.batch_first else 1 hook_fn = partial( self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True) @@ -761,16 +773,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim + input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) + self.float_act_map[name] = input_scales else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], - dim=batch_dim) - self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): broadcastable_shape = [1] * len(shape) @@ -832,7 +842,7 @@ def setup(self): for name in region.srcs + region.sinks: module = name_to_module[name] if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first == True else 1 + batch_dim = 0 if module.batch_first else 1 for name in region_to_search: act_module = name_to_module[name] use_inp = True if region_to_search == region.sinks else False @@ -907,16 +917,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim + input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) + self.float_act_map[name] = input_scales else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], - dim=batch_dim) - self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): broadcastable_shape = [1] * len(shape) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 5ff3b6676..c3d667aee 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -107,7 +107,8 @@ class TorchFunctionalToModule(GraphTransform): nn.AvgPool1d), (F.avg_pool2d, nn.AvgPool2d), (F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d), (F.adaptive_avg_pool2d, - nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, nn.AdaptiveAvgPool3d)) + nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, + nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout)) def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP): super().__init__() diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index f3e4c3b0d..2dc80ccae 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -10,6 +10,7 @@ from brevitas.fx.brevitas_tracer import value_trace from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizeGraph +from brevitas.graph.standardize import TorchFunctionalToModule from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 @@ -26,6 +27,12 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): return outs +def trace_and_standardize(model, ref_kwargs): + graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = TorchFunctionalToModule().apply(graph_model) + return graph_model + + @torch.no_grad() def apply_act_equalization( model, @@ -49,7 +56,7 @@ def apply_act_equalization( # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply equalization, and then cast back with cast_to_float32(model, dtype): - graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs) # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode # or an FX interpreter to run it on GPU warnings.warn( @@ -70,5 +77,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type=' # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply equalization, and then cast back with cast_to_float32(model, dtype): - graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs) EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4d8b2c3ef..05dc06981 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -303,9 +303,14 @@ def main(): quantize_embedding=args.quantize_embedding, seqlen=args.seqlen) # Tie back first/last layer weights in case they got untied - model.tie_weights() print("Model quantization applied.") + # If any equalization has taken places, the embedding layer and the fully connected one are + # not tied anymore, and they need to be treated as standalone, separate layers. + # In all other cases we can tie them back so to preserve memory. + if args.act_equalization is None and not args.weight_equalization: + model.tie_weights() + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader, args.nsamples)