From fcdc6238508ec64ca598fcdf7968524c75f3221e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 23:43:51 +0100 Subject: [PATCH] Feat (graph/equalize): support for cat equalization (#778) --- src/brevitas/graph/equalize.py | 535 ++++++++++++------ src/brevitas/graph/target/flexml.py | 4 +- tests/brevitas/graph/equalization_fixtures.py | 18 +- tests/brevitas/graph/test_equalization.py | 26 +- 4 files changed, 389 insertions(+), 194 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 9701c1fdf..fa455cbac 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -62,7 +62,13 @@ nn.ReLU, nn.LeakyReLU) -_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) +_scale_invariant_op = ( + torch.mul, + operator.mul, + operator.imul, + operator.__mul__, + operator.__imul__, + torch.nn.functional.interpolate) _select_op = (operator.getitem, operator.__getitem__) @@ -80,6 +86,18 @@ _ignore_ops = (getattr, 'size') +# Start and End identify the starting and ending channels of the weight matrix that need to be +# equalized. +# Offset refers to the relative position of these channels with respect to +# the other matrices' channels that are equalized simultaneously. +# Source matrix are always fully equalized, while sinks can be partially equalized. +@dataclass +class EqualizationIndexes: + start: int = 0 + end: int = 0 + offset: int = 0 + + # Required for being hashable @dataclass(eq=True, frozen=True) class WeightBiasTuple: @@ -90,18 +108,74 @@ class WeightBiasTuple: # Required for being hashable @dataclass(eq=True, frozen=True) class Region: - srcs: Tuple = field(default_factory=tuple) - sinks: Tuple = field(default_factory=tuple) + srcs: Dict = field(default_factory=dict) + sinks: Dict = field(default_factory=dict) acts: Tuple = field(default_factory=tuple) + name_to_module: Dict = field(default_factory=dict) + + @property + def srcs_names(self): + return [name.split("$")[0] for name in self.srcs.keys()] + + @property + def sinks_names(self): + return [name.split("$")[0] for name in self.sinks.keys()] + + def get_module_from_name(self, name: str) -> nn.Module: + name = name.split("$")[0] + return self.name_to_module[name] @dataclass class WalkRegionState: - srcs: Set = field(default_factory=set) - sinks: Set = field(default_factory=set) + srcs: Dict = field(default_factory=dict) + sinks: Dict = field(default_factory=dict) acts: Set = field(default_factory=set) history: set = field(default_factory=set) - add_mul_node: bool = False + name_to_module: Dict = field(default_factory=dict) + + cat_encoutered: bool = False + offset: int = 0 + update_offset: bool = False + + @property + def srcs_names(self): + return [name.split("$")[0] for name in self.srcs.keys()] + + @property + def sinks_names(self): + return [name.split("$")[0] for name in self.sinks.keys() if name is not _UNSUPPORTED_OP] + + def add( + self, + type: str, + name: str, + module: nn.Module, + indexes: Optional[EqualizationIndexes] = None): + if type == 'srcs' or type == 'sinks': + assert indexes is not None + full_source_name = name + '$' + str(indexes) + getattr(self, type)[full_source_name] = indexes + elif type == 'acts': + self.acts.add(name) + self.name_to_module[name] = module + + def add_srcs(self, src_name: str, src: nn.Module, indexes: EqualizationIndexes): + self.add('srcs', src_name, src, indexes) + + def add_sinks(self, sink_name: str, sink: nn.Module, indexes: EqualizationIndexes): + self.add('sinks', sink_name, sink, indexes) + + def add_acts(self, act_name: str, act: nn.Module): + self.add('acts', act_name, act) + + def get_module_from_name(self, name: str) -> nn.Module: + name = name.split("$")[0] + return self.name_to_module[name] + + +def __str__(self): + return str(self.start) + '_' + str(self.end) + '_' + str(self.offset) _UNSUPPORTED_OP = object() @@ -147,23 +221,6 @@ def __exit__(self, type, value, traceback): return True # To propagate exceptions -def dict_name_to_module(model, regions): - name_to_module: Dict[str, torch.nn.Module] = {} - - name_set = set() - for region in regions: - for name in region.srcs: - name_set.add(name) - for name in region.sinks: - name_set.add(name) - for name in region.acts: - name_set.add(name) - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module - return name_to_module - - def _channel_range(inp: torch.Tensor, dim: int = 1) -> torch.Tensor: mins, _ = inp.min(dim=dim) maxs, _ = inp.max(dim=dim) @@ -183,15 +240,6 @@ def _channel_maxabs(inp: torch.Tensor, dim: int = 1) -> torch.Tensor: return out -def _get_size(axes: Dict[nn.Module, int]) -> int: - m0, axis0 = list(axes.items())[0] - size = m0.weight.size(axis0) - for m, axis in axes.items(): - if m.weight.size(axis) != size: - return None - return size - - def _get_input_axis(module: nn.Module) -> Optional[int]: """ Given a sink module, determine the axis associated to the input channels. @@ -282,7 +330,7 @@ def _combine_weights_bias( return weight.data bias = bias.data - weight = weight.data.reshape(weight.shape[0], -1) + weight = weight.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) weight = torch.where( @@ -324,8 +372,7 @@ def transpose(module: torch.nn.Module, axis: int): def _cross_layer_equalization( - srcs: List[nn.Module], - sinks: List[nn.Module], + region: Region, merge_bias: bool, scale_computation_type: str, bias_shrinkage: Optional[Union[float, str]] = None, @@ -338,42 +385,55 @@ def _cross_layer_equalization( ranges of the second tensors' input channel """ - # Determine device and type of tensors - device = next(sinks[0].parameters()).device - dtype = next(sinks[0].parameters()).dtype - # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed def _no_equalize(): return torch.tensor(1., dtype=dtype, device=device) - src_axes = {} - sink_axes = {} act_sink_axes = {} act_sources_axes = {} + single_module = region.get_module_from_name(next(iter(region.sinks_names))) + device = next(single_module.parameters()).device + dtype = next(single_module.parameters()).dtype + + max_shape_srcs = 0 + for name, indexes in region.srcs.items(): + max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) + max_shape_sinks = 0 + for name, indexes in region.sinks.items(): + max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start)) + + # Exit if source and sink have different sizes + if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0: + return _no_equalize() - for i, module in enumerate(srcs): + src_axes = {} + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) # If module is not supported, do not perform graph equalization + axis = _get_output_axis(module) + act_sources_axes[name] = _get_act_axis(module) if not isinstance(module, _supported_layers): return _no_equalize() if isinstance(module, nn.MultiheadAttention): - srcs[i] = module.out_proj - src_axes[srcs[i]] = _get_output_axis(module) - act_sources_axes[srcs[i]] = _get_act_axis(module) + module = module.out_proj + src_axes[name] = (module, axis) - for i, module in enumerate(sinks): + sink_axes = {} + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + axis = _get_input_axis(module) + act_sink_axes[name] = _get_act_axis(module) # If module is not supported, do not perform graph equalization - if not isinstance(module, _supported_layers): + if not isinstance(module, _supported_layers) or module in _batch_norm: return _no_equalize() # For MultiheadAttention, we support only self-attetion if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias - sinks[i] = WeightBiasTuple(weight=module.in_proj_weight) + module = WeightBiasTuple(module.in_proj_weight) elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None: return _no_equalize() - sink_axes[sinks[i]] = _get_input_axis(module) - act_sink_axes[sinks[i]] = _get_act_axis(module) - + sink_axes[name] = (module, axis) # If act_val is enabled, use source or sink weights to determine the activation channel # For example, if the source is BatchNorm, we need to use the information coming from the sinks if list_of_act_val is not None: @@ -396,40 +456,50 @@ def _no_equalize(): if None in axes_to_check: return _no_equalize() - # Check if the sink_size is None, - # which means that the some of the sinks do not have the same size as the others. - sink_size = _get_size(sink_axes) - if None in [sink_size]: - return _no_equalize() - scale_fn = _select_scale_computation_fn(scale_computation_type) - sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] - sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) + sink_weights = {name: transpose(m, axis) for name, (m, axis) in sink_axes.items()} + srcs_range = -1 * torch.ones(max_shape_srcs, device=device, dtype=dtype) + sinks_range = -1 * torch.ones(max_shape_sinks, device=device, dtype=dtype) + for k, v in sink_weights.items(): + # Sinks can be partially equalized, thus we need to select + # only the channels we are interested in + indexes = region.sinks[k] + # Compute the range of the channels we need to equalize + weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end] + # Compute the numbers of channels we are equalizing + channel_range = indexes.end - indexes.start + # Use the offset and the range to update the correct range in the sinks + sinks_range[indexes.offset:indexes.offset + channel_range] = torch.max( + sinks_range[indexes.offset:indexes.offset + channel_range], weight_range) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization if list_of_act_val is not None: list_of_act_val_shapes = [act_val.shape for act_val in list_of_act_val] + if len(list_of_act_val_shapes) > 0: + shape_0 = list_of_act_val_shapes[0] + if any(shape_0 != shape for shape in list_of_act_val_shapes): + return _no_equalize() list_of_act_val = [ transpose(WeightBiasTuple(act_val), act_axis) for act_val in list_of_act_val] srcs_range = scale_fn( torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1)) else: - # If we do weight equalization, perform additional check on source size - src_size = _get_size(src_axes) - # Exit if source and sink have different different sizes, or if sources contains None - if src_size != sink_size or None in [src_size]: - warnings.warn( - "Detected source and sink with non compatible shapes, equalization is skipped") - return _no_equalize() - if merge_bias: - src_weights = [ - _combine_weights_bias(transpose(m, axis), bias_shrinkage, m.bias) for m, - axis in src_axes.items()] + src_weights = { + name: _combine_weights_bias(transpose(m, axis), bias_shrinkage, m.bias) + for name, (m, axis) in src_axes.items()} else: - src_weights = [transpose(m, axis) for m, axis in src_axes.items()] - srcs_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in src_weights], 1)) + src_weights = {name: transpose(m, axis) for name, (m, axis) in src_axes.items()} + for k, v in src_weights.items(): + # Srcs are always fully equalized, thus we simply need to apply the offset to position them + # correctly with respect to the other srcs matrices. + indexes = region.srcs[k] + channel_start = indexes.offset + indexes.start + channel_end = indexes.offset + indexes.end + weight_range = scale_fn(v.reshape(v.size(0), -1)) + srcs_range[channel_start:channel_end] = torch.max( + srcs_range[channel_start:channel_end], weight_range) # If there is a mismatch between srcs and sinks values, exit if srcs_range.shape != sinks_range.shape: @@ -455,32 +525,37 @@ def _no_equalize(): for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn): insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis) if len(src_axes) > 0: - for module, axis in src_axes.items(): + for name, (module, axis) in src_axes.items(): + indexes = region.srcs[name] + channel_start = indexes.offset + indexes.start + channel_end = indexes.offset + indexes.end if hasattr(module, 'bias') and module.bias is not None: _update_weights( module, - module.bias.clone() * inverse_scaling_factors.view_as(module.bias), + module.bias.clone() * + inverse_scaling_factors[channel_start:channel_end].view_as(module.bias), attr='bias') src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) + _update_weights( - module, ( - module.weight.clone() * - torch.reshape(inverse_scaling_factors, src_broadcast_size)), + module, + module.weight.clone() * torch.reshape( + inverse_scaling_factors[channel_start:channel_end], src_broadcast_size), attr='weight') - for module, axis in sink_axes.items(): + for name, (module, axis) in sink_axes.items(): sink_broadcast_size = [1] * module.weight.ndim sink_broadcast_size[axis] = module.weight.size(axis) - if isinstance(module, _batch_norm): - # We re-compute the bias as function of running_mean and running_var to adjust the - # additive factor for equalization. - additive_factor = module.running_mean.data * module.weight.data / torch.sqrt( - module.running_var.data + module.eps) - _update_weights( - module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias') + indexes = region.sinks[name] + channel_range = indexes.end - indexes.start + partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype) + # We replace the scaling factors of the channels we need to equalize, leaving the other to + # one (i.e., no equalization) + partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset + + channel_range] _update_weights( module, - module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size), + module.weight.clone() * torch.reshape(partial_scaling, sink_broadcast_size), attr='weight') return scaling_factors @@ -504,25 +579,14 @@ def _equalize( """ Generalized version of section 4.1 of https://arxiv.org/pdf/1906.04721.pdf """ - name_to_module: Dict[str, nn.Module] = {} - name_set = set() - for region in regions: - for name in region.srcs: - name_set.add(name) - for name in region.sinks: - name_set.add(name) - - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module for i in range(iterations): scale_factor_max = None for region in regions: scale_factors_region = _cross_layer_equalization( - [name_to_module[n] for n in region.srcs], [name_to_module[n] for n in region.sinks], + region, merge_bias=merge_bias, - scale_computation_type=scale_computation_type, - bias_shrinkage=bias_shrinkage) + bias_shrinkage=bias_shrinkage, + scale_computation_type=scale_computation_type) scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region)) if scale_factor_max is not None: scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max) @@ -557,16 +621,95 @@ def _is_scale_varying_activation(graph_model, node): def _is_scale_invariant_function(node: Node) -> bool: - return node.op == 'call_function' and node.target in _scale_invariant_op + _select_op + out = node.op == 'call_function' and node.target in _scale_invariant_op + _select_op + if node.target == torch.nn.functional.interpolate: + out &= node.kwargs.get('mode', None) == 'nearest' + return out def _is_reshaping_op(node: Node) -> bool: return node.target in _reshaping_op +def get_weight_source(module): + transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) + if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'): + raise RuntimeError("Configuration for Multiheadattention not supported") + weight = module.out_proj.weight if isinstance(module, nn.MultiheadAttention) else module.weight + axis = _get_output_axis(module) + weight = transpose(weight, axis) + return weight + + +def get_weight_sink(module): + transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) + if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'): + raise RuntimeError("Configuration for Multiheadattention not supported") + weight = WeightBiasTuple(module.in_proj_weight).weight if isinstance( + module, nn.MultiheadAttention) else module.weight + axis = _get_input_axis(module) + weight = transpose(weight, axis) + return weight + + +def find_srcs_channel_dim(model, inp_node): + if _is_supported_module(model, inp_node): + # If we meet a supported module, determine the channel shape + module = get_module(model, inp_node.target) + # Since we are walking up, we consider the module as srcs + weight = get_weight_source(module) + channel = weight.shape[0] + return channel + elif _is_add(inp_node): + all_channels = [] + for n in inp_node.all_input_nodes: + all_channels.append(find_srcs_channel_dim(model, n)) + # All branches to add should have the same amount of channels + if all([channel == all_channels[0] for channel in all_channels]): + return all_channels[0] + else: + return _UNSUPPORTED_OP + elif _is_cat(inp_node): + total_channels = 0 + # If it's cat, we need to sum the channel shape of all the branches + for n in inp_node.all_input_nodes: + total_channels += find_srcs_channel_dim(model, n) + return total_channels + elif _is_scale_invariant_module(model, inp_node) or _is_scale_invariant_function(inp_node): + return find_srcs_channel_dim(model, inp_node.all_input_nodes[0]) + else: + return _UNSUPPORTED_OP + + +def cat_handler(graph_model: GraphModule, starting_node: Node, state: WalkRegionState): + + state.srcs.clear() + state.sinks.clear() + state.history.clear() + # Keep track that concatenation has been encoutered once + state.cat_encoutered = True + state.update_offset = True + state.offset = 0 + find_srcs(graph_model, starting_node, state) + state.update_offset = False + state.offset = 0 + find_sinks(graph_model, starting_node, state) + + +def _is_cat(node): + return node.target in (torch.cat,) + + +def _is_add(node): + return ( + node.op == 'call_method' and node.target in _residual_methods or + node.op == 'call_function' and node.target in _residual_fns) + + def find_srcs(graph_model: GraphModule, starting_node: Node, state: WalkRegionState) -> Dict[str, Set]: node_list = starting_node.all_input_nodes + update_offset_state = state.update_offset for node in node_list: # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop @@ -576,27 +719,47 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, else: continue if _is_supported_module(graph_model, node): - state.srcs.add(node.target) + module = get_module(graph_model, node.target) + weight = get_weight_source(module) + eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) + # After we found a source, we need to check if it branches into multiple sinks + state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) + state.offset = state.offset if not state.update_offset else state.offset + weight.shape[ + 0] elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): - find_srcs(graph_model, node, state) find_sinks(graph_model, node, state) + find_srcs(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or node.op == 'call_function' and node.target in _residual_fns): - find_srcs(graph_model, node, state) + state.update_offset = False find_sinks(graph_model, node, state) + find_srcs(graph_model, node, state) + state.update_offset = update_offset_state + elif _is_cat(node): + # The first time we encoutered a cat differes from all subsequent ones + if not state.cat_encoutered: + # We restart the region search starting from the cat + cat_handler(graph_model, node, state) + else: + state.update_offset = False + find_sinks(graph_model, node, state) + state.update_offset = True + find_srcs(graph_model, node, state) + state.update_offset = update_offset_state elif node.target in _ignore_ops: continue else: # If we meet an unrecognized op, we add None to invalidate the region - state.srcs.add(_UNSUPPORTED_OP) + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP def find_sinks(graph_model: GraphModule, starting_node: Node, state: WalkRegionState) -> Dict[str, Set]: node_list = starting_node.users + update_offset_state = state.update_offset for node in node_list: # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop @@ -608,50 +771,93 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, continue if _is_supported_module(graph_model, node): module = get_module(graph_model, node.target) + weight = get_weight_sink(module) + eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) # It is not possible to equalize through LayerNorm as sink if isinstance(module, (nn.LayerNorm,) + _batch_norm): - state.sinks.add(_UNSUPPORTED_OP) + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP else: - state.sinks.add(node.target) + state.add_sinks(node.target, module, eq_indexes) + elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): find_sinks(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or node.op == 'call_function' and node.target in _residual_fns): + state.update_offset = False find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) + state.update_offset = update_offset_state + elif _is_cat(node): + # The first time we encoutered a cat differes from all subsequent ones + if not state.cat_encoutered: + # We restart the region search starting from the cat + cat_handler(graph_model, node, state) + else: + # In this case we define all our sinks, and isolate only the channels we want + # to equalize (start, end). + # Furthermore, we need to consider the offset given by the sources of the second cat + index = node.all_input_nodes.index(starting_node) + channels = [] + for n in node.all_input_nodes: + channel_dim = find_srcs_channel_dim(graph_model, n) + if channel_dim is _UNSUPPORTED_OP: + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP + continue + channels.append(channel_dim) + start = sum(channels[:index]) + end = start + channels[index] + new_state = WalkRegionState(offset=state.offset) + find_sinks(graph_model, node, new_state) + + for k in new_state.sinks_names: + state.add_sinks( + k, + new_state.get_module_from_name(k), + EqualizationIndexes(start, end, new_state.offset)) + state.srcs.update(new_state.srcs) elif node.target in _ignore_ops: continue else: # If we meet an unrecognized op, we add None to invalidate the region - state.sinks.add(_UNSUPPORTED_OP) + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP def _extract_regions( graph_model: GraphModule, add_mul_node: bool = False, return_acts: bool = False) -> List[Region]: - regions = [] + regions = list() for node in graph_model.graph.nodes: if _is_supported_module(graph_model, node) or (add_mul_node and _is_scale_varying_activation(graph_model, node)): - state = WalkRegionState(srcs={node.target}, add_mul_node=add_mul_node) + state = WalkRegionState() if _is_scale_varying_activation(graph_model, node): - state.acts.add(node.target) + module = get_module(graph_model, node.target) + state.add_acts(node.target, module) + else: + module = get_module(graph_model, node.target) + weight = get_weight_source(module) + eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) + state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) - if state.sinks and _UNSUPPORTED_OP not in state.sinks and _UNSUPPORTED_OP not in state.srcs: - # each region should appear only once, so to make it hashable - # we convert srcs and sinks to ordered lists first, and then to tuples - srcs = tuple(sorted(state.srcs)) - sinks = tuple(sorted(state.sinks)) - acts = tuple(sorted(state.acts)) + if len(state.sinks) > 0 and _UNSUPPORTED_OP not in state.sinks.keys(): + sorted_srcs = dict(sorted(state.srcs.items())) + sorted_sinks = dict(sorted(state.sinks.items())) + sorted_acts = tuple(sorted(state.acts)) if return_acts: - region_to_add = Region(srcs=srcs, sinks=sinks, acts=acts) + region = Region( + srcs=sorted_srcs, + sinks=sorted_sinks, + acts=sorted_acts, + name_to_module=state.name_to_module) else: - region_to_add = Region(srcs=srcs, sinks=sinks) - if region_to_add not in regions: - regions.append(region_to_add) + region = Region( + srcs=sorted_srcs, sinks=sorted_sinks, name_to_module=state.name_to_module) + + if region not in regions: + regions.append(region) return regions @@ -757,7 +963,7 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'): self.hooks = [] self.add_mul_node = True - regions = [] + regions: List[Region] = [] self.find_module(model, regions) self.regions = regions @@ -773,39 +979,44 @@ def find_module(self, model, regions: List): """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): - regions.append(model) + weight = get_weight_sink(model) + eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) + region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) + regions.append(region) else: for module in model.children(): self.find_module(module, regions) def setup(self): for region in self.regions: + module = region.get_module_from_name('sinks0') batch_dim = 0 if hasattr(region, 'batch_first'): batch_dim = 0 if region.batch_first else 1 hook_fn = partial( - self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True) - new_instance = KwargsForwardHook(region, hook_fn) - ModuleInstanceToModuleInstance(region, new_instance).apply(self.model) + self.forward_stats_hook, name=module, batch_dim=batch_dim, use_inp=True) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) self.hooks.append(new_instance) def apply(self, alpha): scale_factors = [] self.remove_hooks() for region in self.regions: - if self.float_act_map[region] == None: + module = region.get_module_from_name('sinks0') + if self.float_act_map[module] == None: continue - sinks = region insert_mul_fn = partial( - self.insert_mul_node, region=region, batch_dim=self.batch_dim_act_map[region]) + self.insert_mul_node, region=module, batch_dim=self.batch_dim_act_map[module]) scale_factors.append( - _cross_layer_equalization([], [sinks], - False, - scale_computation_type=self.scale_computation_type, - list_of_act_val=[self.float_act_map[region]], - list_of_insert_mul_node_fn=[insert_mul_fn], - alpha=alpha)) + _cross_layer_equalization( + region, + False, + scale_computation_type=self.scale_computation_type, + list_of_act_val=[self.float_act_map[module]], + list_of_insert_mul_node_fn=[insert_mul_fn], + alpha=alpha)) return scale_factors def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): @@ -826,6 +1037,7 @@ def __init__( self.float_act_map = {} self.batch_dim_act_map = {} self.hooks = [] + self.hooked_modules = set() self.add_mul_node = add_mul_node self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) @@ -835,7 +1047,6 @@ def __init__( self.scale_fn = _channel_range def setup(self): - name_to_module = dict_name_to_module(self.model, self.regions) # Select only regions with activation to equalize through. # If a region has multiple scale varying activation, must also be dropped # because we can't propagate scaling factors @@ -843,7 +1054,7 @@ def setup(self): for region in self.regions: # This condition is for redudancy, since # a region with two scale-varying activations cannot be detected in the first place - if len(region.acts) > 1 and any([isinstance(name_to_module[act_name], + if len(region.acts) > 1 and any([isinstance(region.get_module_from_name(act_name), _scale_varying_activations) for act_name in region.acts]): regions_to_drop.append(region) @@ -851,38 +1062,38 @@ def setup(self): # We assume that the entire region has a unique batch_dim batch_dim = 0 - region_to_search = region.sinks if len(region.acts) == 0 else region.acts - for name in region.srcs + region.sinks: - module = name_to_module[name] - if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first else 1 + for name in region.srcs: + module = region.get_module_from_name(name) + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + for name in region.sinks: + module = region.get_module_from_name(name) + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + + region_to_search = region.sinks_names if len(region.acts) == 0 else region.acts for name in region_to_search: - module = name_to_module[name] - use_inp = True if region_to_search == region.sinks else False - hook_fn = partial( - self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) - new_instance = KwargsForwardHook(module, hook_fn) - ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) - self.hooks.append(new_instance) + module = region.get_module_from_name(name) + if module not in self.hooked_modules: + self.hooked_modules.add(module) + use_inp = True if region_to_search == region.sinks_names else False + hook_fn = partial( + self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) + self.hooks.append(new_instance) self.regions = [x for x in self.regions if x not in regions_to_drop] def apply(self, alpha): scale_factors = [] self.remove_hooks() - name_to_module = dict_name_to_module(self.model, self.regions) for region in self.regions: - region_to_search = region.sinks if len(region.acts) == 0 else region.acts - if any([self.float_act_map[name] is None for name in region_to_search]): + region_names = region.sinks_names if len(region.acts) == 0 else region.acts + if any([self.float_act_map[name] is None for name in region_names]): continue - act_module = [name_to_module[act_name] for act_name in region.acts] - list_of_act_val = [self.float_act_map[name] for name in region_to_search] - sinks = [name_to_module[sink] for sink in region.sinks] - # Filter out scale_varying activations from the srcs - srcs = [ - name_to_module[src] - for src in region.srcs - if not isinstance(name_to_module[src], _scale_varying_activations)] + act_module = [region.get_module_from_name(act_name) for act_name in region.acts] + list_of_act_val = [self.float_act_map[name] for name in region_names] list_of_insert_mul_node_fn = None if self.add_mul_node and any([ @@ -896,10 +1107,10 @@ def apply(self, alpha): self.insert_mul_node, act_node=act_node, batch_dim=self.batch_dim_act_map[act_name])) + scale_factors.append( _cross_layer_equalization( - srcs, - sinks, + region, False, scale_computation_type=self.scale_computation_type, list_of_act_val=list_of_act_val, diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index c10f688e3..9aedd337c 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -124,8 +124,8 @@ def preprocess_for_flexml_quantize( equalize_iters=0, equalize_merge_bias=True, merge_bn=True, - equalize_bias_shrinkage: str = 'vaiq', - equalize_scale_computation: str = 'maxabs', + equalize_bias_shrinkage='vaiq', + equalize_scale_computation='maxabs', **model_kwargs): training_state = model.training model.eval() diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 8e13c49cd..263543a82 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -20,30 +20,20 @@ 'shufflenet_v2_x0_5': [0.318, 0.649], 'mobilenet_v2': [0.161, 0.320], 'resnet18': [0.487, 0.952], - 'googlenet': [0.1826, 0.413], - 'inception_v3': [0.264, 0.6], + 'googlenet': [0.495, 0.982], + 'inception_v3': [0.497, 0.989], 'alexnet': [0.875, 0.875],} IN_SIZE_CONV = (1, 3, 224, 224) IN_SIZE_LINEAR = (1, 224, 3) -def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type): - name_to_module = {} - name_set = set() - for region in regions: - for name in region.srcs: - name_set.add(name) - for name in region.sinks: - name_set.add(name) +def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): scale_factors_regions = [] - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module for i in range(3): for region in regions: scale_factors_region = _cross_layer_equalization( - [name_to_module[n] for n in region.srcs], [name_to_module[n] for n in region.sinks], + region, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, scale_computation_type=scale_computation_type) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index caca0fd29..4f713211c 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -30,13 +30,13 @@ def test_resnet18_equalization(): model_orig = copy.deepcopy(model) regions = _extract_regions(model) _ = equalize_test( - model, regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') + regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') out = model(inp) # Check that equalization is not introducing FP variations assert torch.allclose(expected_out, out, atol=ATOL) - regions = sorted(regions, key=lambda region: region.srcs[0]) + regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) equalized_layers = set() for r in resnet_18_regions: @@ -45,8 +45,10 @@ def test_resnet18_equalization(): # Check that we found all the expected regions for region, expected_region in zip(regions, resnet_18_regions): - sources_check = set(region.srcs) == set(expected_region[0]) - sinks_check = set(region.sinks) == set(expected_region[1]) + srcs = region.srcs_names + sources_check = set(srcs) == set(expected_region[0]) + sinks = region.sinks_names + sinks_check = set(sinks) == set(expected_region[1]) assert sources_check assert sinks_check @@ -73,19 +75,15 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool regions = _extract_regions(model) scale_factor_regions = equalize_test( - model, - regions, - merge_bias=merge_bias, - bias_shrinkage='vaiq', - scale_computation_type='maxabs') + regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp) srcs = set() sinks = set() for r in regions: - srcs.update(list(r.srcs)) - sinks.update(list(r.sinks)) + srcs.update([x for x in list(r.srcs_names)]) + sinks.update([x for x in list(r.sinks_names)]) count_region_srcs = 0 count_region_sinks = 0 @@ -130,11 +128,7 @@ def test_models(toy_model, merge_bias, request): model = symbolic_trace(model) regions = _extract_regions(model) scale_factor_regions = equalize_test( - model, - regions, - merge_bias=merge_bias, - bias_shrinkage='vaiq', - scale_computation_type='maxabs') + regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] with torch.no_grad():