diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 174552241..9701c1fdf 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass from dataclasses import field from functools import partial @@ -147,7 +149,7 @@ def __exit__(self, type, value, traceback): def dict_name_to_module(model, regions): name_to_module: Dict[str, torch.nn.Module] = {} - # name_set = {name for region in regions for module_set in region for name in module_set} + name_set = set() for region in regions: for name in region.srcs: @@ -689,11 +691,67 @@ def apply(self, return graph_model -class LayerwiseActivationEqualization(GraphTransform): +class ActivationEqualization(GraphTransform, ABC): - def __init__(self, model, scale_computation_type: str = 'maxabs'): - super(LayerwiseActivationEqualization, self).__init__() + def __init__( + self, model: Union[nn.Module, GraphModule], scale_computation_type: str = 'maxabs'): self.model = model + self.scale_computation_type = scale_computation_type + + @abstractmethod + def setup(self): + pass + + @abstractmethod + def insert_mul_node(self): + pass + + def create_mul_node(self, scale, shape, axis, batch_dim=0): + broadcastable_shape = [1] * len(shape) + broadcastable_shape[axis] = shape[axis] + # Add Batch Dim + broadcastable_shape.insert(batch_dim, 1) + mul_factor = ScaleBias( + num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape) + mul_factor.weight.data = scale + return mul_factor + + def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs): + # Check for MHA Cross attention, and if found, skip it + kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) + if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: + if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): + self.float_act_map[name] = None + return + + possible_input_kwargs = ['input', 'inp', 'query'] + input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] + if use_inp: + x = kwargs[input_kwarg] + elif not use_inp: + x = args[-1] + + # Extra check for batch_dim + if hasattr(x, 'names') and 'N' in x.names: + batch_dim = x.names.index('N') + + self.batch_dim_act_map[name] = batch_dim + + input_scales = self.scale_fn(x, dim=batch_dim) + if name not in self.float_act_map: + self.float_act_map[name] = input_scales + else: + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) + + def remove_hooks(self): + for hook in self.hooks: + ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model) + + +class LayerwiseActivationEqualization(ActivationEqualization): + + def __init__(self, model, scale_computation_type: str = 'maxabs'): + super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type) self.float_act_map = {} self.batch_dim_act_map = {} self.hooks = [] @@ -703,7 +761,6 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'): self.find_module(model, regions) self.regions = regions - self.scale_computation_type = scale_computation_type if self.scale_computation_type == 'maxabs': self.scale_fn = _channel_maxabs elif self.scale_computation_type == 'range': @@ -751,79 +808,34 @@ def apply(self, alpha): alpha=alpha)) return scale_factors - def remove_hooks(self): - for hook in self.hooks: - ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model) - - def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs): - # Check for MHA Cross attention, and if found, skip it - kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) - if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: - if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): - self.float_act_map[name] = None - return - - possible_input_kwargs = ['input', 'inp', 'query'] - input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] - if use_inp: - x = kwargs[input_kwarg] - elif not use_inp: - x = args[-1] - - # Extra check for batch_dim - if hasattr(x, 'names') and 'N' in x.names: - batch_dim = x.names.index('N') - - self.batch_dim_act_map[name] = batch_dim - - input_scales = self.scale_fn(x, dim=batch_dim) - if name not in self.float_act_map: - self.float_act_map[name] = input_scales - else: - self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) - def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): - broadcastable_shape = [1] * len(shape) - broadcastable_shape[axis] = shape[axis] - # Add Batch Dim - broadcastable_shape.insert(batch_dim, 1) - mul_factor = ScaleBias( - num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape) - mul_factor.weight.data = scale + mul_factor = self.create_mul_node(scale, shape, axis, batch_dim) rewriter = ModuleInstanceToModuleInstance( region, EqualizedModule(scale_module=mul_factor, layer=region)) rewriter.apply(self.model) -class GraphActivationEqualization(GraphTransform): +class GraphActivationEqualization(ActivationEqualization): def __init__( - self, model, add_mul_node, layerwise=False, scale_computation_type: str = 'maxabs'): - super(GraphActivationEqualization, self).__init__() - self.graph_model = model + self, + model: GraphModule, + add_mul_node: bool = False, + scale_computation_type: str = 'maxabs'): + super(GraphActivationEqualization, self).__init__(model, scale_computation_type) self.float_act_map = {} self.batch_dim_act_map = {} self.hooks = [] - self.layerwise = layerwise - if self.layerwise: - self.add_mul_node = True - else: - self.add_mul_node = add_mul_node - if self.layerwise: - regions = [] - self.find_module(model, regions) - self.regions = regions - else: - self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) + self.add_mul_node = add_mul_node + self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) - self.scale_computation_type = scale_computation_type if self.scale_computation_type == 'maxabs': self.scale_fn = _channel_maxabs elif self.scale_computation_type == 'range': self.scale_fn = _channel_range def setup(self): - name_to_module = dict_name_to_module(self.graph_model, self.regions) + name_to_module = dict_name_to_module(self.model, self.regions) # Select only regions with activation to equalize through. # If a region has multiple scale varying activation, must also be dropped # because we can't propagate scaling factors @@ -835,29 +847,30 @@ def setup(self): _scale_varying_activations) for act_name in region.acts]): regions_to_drop.append(region) - else: - # We assume that the entire region has a unique batch_dim - batch_dim = 0 - region_to_search = region.sinks if len(region.acts) == 0 else region.acts - for name in region.srcs + region.sinks: - module = name_to_module[name] - if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first else 1 - for name in region_to_search: - act_module = name_to_module[name] - use_inp = True if region_to_search == region.sinks else False - hook_fn = partial( - self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) - new_instance = KwargsForwardHook(act_module, hook_fn) - ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model) - self.hooks.append(new_instance) + continue + + # We assume that the entire region has a unique batch_dim + batch_dim = 0 + region_to_search = region.sinks if len(region.acts) == 0 else region.acts + for name in region.srcs + region.sinks: + module = name_to_module[name] + if hasattr(module, 'batch_first'): + batch_dim = 0 if module.batch_first else 1 + for name in region_to_search: + module = name_to_module[name] + use_inp = True if region_to_search == region.sinks else False + hook_fn = partial( + self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) + self.hooks.append(new_instance) self.regions = [x for x in self.regions if x not in regions_to_drop] def apply(self, alpha): scale_factors = [] self.remove_hooks() - name_to_module = dict_name_to_module(self.graph_model, self.regions) + name_to_module = dict_name_to_module(self.model, self.regions) for region in self.regions: region_to_search = region.sinks if len(region.acts) == 0 else region.acts if any([self.float_act_map[name] is None for name in region_to_search]): @@ -877,7 +890,7 @@ def apply(self, alpha): # Even though we iterate, this list will always have a single element by definition list_of_insert_mul_node_fn = [] for act_name in region.acts: - act_node = get_node(self.graph_model, act_name) + act_node = get_node(self.model, act_name) list_of_insert_mul_node_fn.append( partial( self.insert_mul_node, @@ -895,46 +908,9 @@ def apply(self, alpha): return scale_factors - def remove_hooks(self): - for hook in self.hooks: - ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model) - - def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs): - # Check for MHA Cross attention, and if found, skip it - kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1])) - if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs: - if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr(): - self.float_act_map[name] = None - return - - possible_input_kwargs = ['input', 'inp', 'query'] - input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] - if use_inp: - x = kwargs[input_kwarg] - elif not use_inp: - x = args[-1] - - # Extra check for batch_dim - if hasattr(x, 'names') and 'N' in x.names: - batch_dim = x.names.index('N') - - self.batch_dim_act_map[name] = batch_dim - - input_scales = self.scale_fn(x, dim=batch_dim) - if name not in self.float_act_map: - self.float_act_map[name] = input_scales - else: - self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) - def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): - broadcastable_shape = [1] * len(shape) - broadcastable_shape[axis] = shape[axis] - # Add Batch Dim - broadcastable_shape.insert(batch_dim, 1) - mul_factor = ScaleBias( - num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape) - mul_factor.weight.data = scale + mul_factor = self.create_mul_node(scale, shape, axis, batch_dim) mul_factor_name = act_node.name + 'act_eq_mul' - self.graph_model.add_module(mul_factor_name, mul_factor) + self.model.add_module(mul_factor_name, mul_factor) rewriter = InsertModuleCallAfter(mul_factor_name, act_node) - rewriter.apply(self.graph_model) + rewriter.apply(self.model)