Skip to content

Commit

Permalink
Fix (graph/equalize): small refactor for act equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2023
1 parent 0dd37f9 commit f7ee7f5
Showing 1 changed file with 94 additions and 118 deletions.
212 changes: 94 additions & 118 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field
from functools import partial
Expand Down Expand Up @@ -147,7 +149,7 @@ def __exit__(self, type, value, traceback):

def dict_name_to_module(model, regions):
name_to_module: Dict[str, torch.nn.Module] = {}
# name_set = {name for region in regions for module_set in region for name in module_set}

name_set = set()
for region in regions:
for name in region.srcs:
Expand Down Expand Up @@ -689,11 +691,67 @@ def apply(self,
return graph_model


class LayerwiseActivationEqualization(GraphTransform):
class ActivationEqualization(GraphTransform, ABC):

def __init__(self, model, scale_computation_type: str = 'maxabs'):
super(LayerwiseActivationEqualization, self).__init__()
def __init__(
self, model: Union[nn.Module, GraphModule], scale_computation_type: str = 'maxabs'):
self.model = model
self.scale_computation_type = scale_computation_type

@abstractmethod
def setup(self):
pass

@abstractmethod
def insert_mul_node(self):
pass

def create_mul_node(self, scale, shape, axis, batch_dim=0):
broadcastable_shape = [1] * len(shape)
broadcastable_shape[axis] = shape[axis]
# Add Batch Dim
broadcastable_shape.insert(batch_dim, 1)
mul_factor = ScaleBias(
num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape)
mul_factor.weight.data = scale
return mul_factor

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1]))
if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs:
if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr():
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
x = args[-1]

# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def remove_hooks(self):
for hook in self.hooks:
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model)


class LayerwiseActivationEqualization(ActivationEqualization):

def __init__(self, model, scale_computation_type: str = 'maxabs'):
super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
Expand All @@ -703,7 +761,6 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'):
self.find_module(model, regions)
self.regions = regions

self.scale_computation_type = scale_computation_type
if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
Expand Down Expand Up @@ -751,79 +808,34 @@ def apply(self, alpha):
alpha=alpha))
return scale_factors

def remove_hooks(self):
for hook in self.hooks:
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model)

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1]))
if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs:
if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr():
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
x = args[-1]

# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
broadcastable_shape = [1] * len(shape)
broadcastable_shape[axis] = shape[axis]
# Add Batch Dim
broadcastable_shape.insert(batch_dim, 1)
mul_factor = ScaleBias(
num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape)
mul_factor.weight.data = scale
mul_factor = self.create_mul_node(scale, shape, axis, batch_dim)
rewriter = ModuleInstanceToModuleInstance(
region, EqualizedModule(scale_module=mul_factor, layer=region))
rewriter.apply(self.model)


class GraphActivationEqualization(GraphTransform):
class GraphActivationEqualization(ActivationEqualization):

def __init__(
self, model, add_mul_node, layerwise=False, scale_computation_type: str = 'maxabs'):
super(GraphActivationEqualization, self).__init__()
self.graph_model = model
self,
model: GraphModule,
add_mul_node: bool = False,
scale_computation_type: str = 'maxabs'):
super(GraphActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
self.layerwise = layerwise
if self.layerwise:
self.add_mul_node = True
else:
self.add_mul_node = add_mul_node
if self.layerwise:
regions = []
self.find_module(model, regions)
self.regions = regions
else:
self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True)
self.add_mul_node = add_mul_node
self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True)

self.scale_computation_type = scale_computation_type
if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
self.scale_fn = _channel_range

def setup(self):
name_to_module = dict_name_to_module(self.graph_model, self.regions)
name_to_module = dict_name_to_module(self.model, self.regions)
# Select only regions with activation to equalize through.
# If a region has multiple scale varying activation, must also be dropped
# because we can't propagate scaling factors
Expand All @@ -835,29 +847,30 @@ def setup(self):
_scale_varying_activations)
for act_name in region.acts]):
regions_to_drop.append(region)
else:
# We assume that the entire region has a unique batch_dim
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
act_module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
hook_fn = partial(
self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp)
new_instance = KwargsForwardHook(act_module, hook_fn)
ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model)
self.hooks.append(new_instance)
continue

# We assume that the entire region has a unique batch_dim
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
hook_fn = partial(
self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp)
new_instance = KwargsForwardHook(module, hook_fn)
ModuleInstanceToModuleInstance(module, new_instance).apply(self.model)
self.hooks.append(new_instance)

self.regions = [x for x in self.regions if x not in regions_to_drop]

def apply(self, alpha):
scale_factors = []
self.remove_hooks()
name_to_module = dict_name_to_module(self.graph_model, self.regions)
name_to_module = dict_name_to_module(self.model, self.regions)
for region in self.regions:
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
if any([self.float_act_map[name] is None for name in region_to_search]):
Expand All @@ -877,7 +890,7 @@ def apply(self, alpha):
# Even though we iterate, this list will always have a single element by definition
list_of_insert_mul_node_fn = []
for act_name in region.acts:
act_node = get_node(self.graph_model, act_name)
act_node = get_node(self.model, act_name)
list_of_insert_mul_node_fn.append(
partial(
self.insert_mul_node,
Expand All @@ -895,46 +908,9 @@ def apply(self, alpha):

return scale_factors

def remove_hooks(self):
for hook in self.hooks:
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model)

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1]))
if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs:
if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr():
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
x = args[-1]

# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0):
broadcastable_shape = [1] * len(shape)
broadcastable_shape[axis] = shape[axis]
# Add Batch Dim
broadcastable_shape.insert(batch_dim, 1)
mul_factor = ScaleBias(
num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape)
mul_factor.weight.data = scale
mul_factor = self.create_mul_node(scale, shape, axis, batch_dim)
mul_factor_name = act_node.name + 'act_eq_mul'
self.graph_model.add_module(mul_factor_name, mul_factor)
self.model.add_module(mul_factor_name, mul_factor)
rewriter = InsertModuleCallAfter(mul_factor_name, act_node)
rewriter.apply(self.graph_model)
rewriter.apply(self.model)

0 comments on commit f7ee7f5

Please sign in to comment.