Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2023
1 parent dabb8c5 commit 0f67b6b
Showing 1 changed file with 10 additions and 23 deletions.
33 changes: 10 additions & 23 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class WalkRegionState:
history: set = field(default_factory=set)
name_to_module: Dict = field(default_factory=dict)

add_mul_node: bool = False
cat_encoutered: bool = False
offset: int = 0
update_offset: bool = False

Expand Down Expand Up @@ -471,7 +471,6 @@ def _no_equalize():
# Use the offset and the range to update the correct range in the sinks
sinks_range[indexes.offset:indexes.offset + channel_range] = torch.max(
sinks_range[indexes.offset:indexes.offset + channel_range], weight_range)
sinks_range = torch.clamp(sinks_range, EPSILON)

# Determine the srcs_range based on where we are performing activation equalization or
# weight equalization
Expand Down Expand Up @@ -589,7 +588,7 @@ def _equalize(
"""
for i in range(iterations):
scale_factor_max = None
for ii, region in enumerate(regions):
for region in regions:
scale_factors_region = _cross_layer_equalization(
region,
merge_bias=merge_bias,
Expand Down Expand Up @@ -699,7 +698,8 @@ def cat_handler(graph_model: GraphModule, starting_node: Node, state: WalkRegion
state.srcs.clear()
state.sinks.clear()
state.history.clear()
state.srcs[starting_node.target] = _UNSUPPORTED_OP
# Keep track that concatenation has been encoutered once
state.cat_encoutered = True
state.update_offset = True
state.offset = 0
find_srcs(graph_model, starting_node, state)
Expand All @@ -712,13 +712,6 @@ def _is_cat(node):
return node.target in (torch.cat,)


def _is_cat_in_srcs(srcs):
out = False
for src in srcs:
out = out or src in (torch.cat,)
return out


def _is_add(node):
return (
node.op == 'call_method' and node.target in _residual_methods or
Expand All @@ -733,7 +726,6 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
# we keep a history of how the graph has been walked already, invariant to the direction,
# to avoid getting stuck in a loop
path = (node, starting_node)
module = None
if path not in state.history:
state.history.add(path)
else:
Expand All @@ -759,11 +751,10 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
find_srcs(graph_model, node, state)
state.update_offset = update_offset_state
elif _is_cat(node):
# We have never encoutered cat
if not _is_cat_in_srcs(state.srcs):
# The first time we encoutered a cat differes from all subsequent ones
if not state.cat_encoutered:
# We restart the region search starting from the cat
cat_handler(graph_model, node, state)
# We have encoutered cat already
else:
state.update_offset = False
find_sinks(graph_model, node, state)
Expand Down Expand Up @@ -810,11 +801,10 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
find_srcs(graph_model, node, state)
state.update_offset = update_offset_state
elif _is_cat(node):
# We have never encoutered cat
if not _is_cat_in_srcs(state.srcs):
# The first time we encoutered a cat differes from all subsequent ones
if not state.cat_encoutered:
# We restart the region search starting from the cat
cat_handler(graph_model, node, state)
# We have encoutered cat already
else:
# In this case we define all our sinks, and isolate only the channels we want
# to equalize (start, end).
Expand Down Expand Up @@ -854,7 +844,7 @@ def _extract_regions(
if _is_supported_module(graph_model,
node) or (add_mul_node and
_is_scale_varying_activation(graph_model, node)):
state = WalkRegionState(add_mul_node=add_mul_node)
state = WalkRegionState()
if _is_scale_varying_activation(graph_model, node):
module = get_module(graph_model, node.target)
state.add_acts(node.target, module)
Expand All @@ -864,10 +854,7 @@ def _extract_regions(
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0)
state.add_srcs(node.target, module, eq_indexes)
find_sinks(graph_model, node, state)
if state.sinks and _UNSUPPORTED_OP not in state.sinks.keys(
) and _UNSUPPORTED_OP not in state.srcs.keys():
# Drop cat from the srcs
state.srcs = {k: v for k, v in state.srcs.items() if k is not torch.cat}
if len(state.sinks) > 0 and _UNSUPPORTED_OP not in state.sinks.keys():
sorted_srcs = dict(sorted(state.srcs.items()))
sorted_sinks = dict(sorted(state.sinks.items()))
sorted_acts = tuple(sorted(state.acts))
Expand Down

0 comments on commit 0f67b6b

Please sign in to comment.