Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 18, 2023
1 parent 60e15d0 commit bbeca9f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
38 changes: 19 additions & 19 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class Region:
sinks: Dict = field(default_factory=dict)
acts: Tuple = field(default_factory=tuple)

@property
def srcs_names(self):
return [name.split("$")[0] for name in self.srcs.keys()]

@property
def sinks_names(self):
return [name.split("$")[0] for name in self.sinks.keys()]


@dataclass
class WalkRegionState:
Expand Down Expand Up @@ -176,14 +184,11 @@ def dict_name_to_module(model, regions):

name_set = set()
for region in regions:
for name in region.srcs:
name = name.split("$")[0]
for name in region.srcs_names:
name_set.add(name)
for name in region.sinks:
name = name.split("$")[0]
for name in region.sinks_names:
name_set.add(name)
for name in region.acts:
name = name.split("$")[0]
name_set.add(name)
for name, module in model.named_modules():
if name in name_set:
Expand Down Expand Up @@ -363,7 +368,6 @@ def _cross_layer_equalization(
# If equalization criteria are not met, we return a scalar one to indicate that no equalization
# has been performed
def _no_equalize():
print("No Eq")
return torch.tensor(1., dtype=dtype, device=device)

act_sink_axes = {}
Expand All @@ -378,12 +382,11 @@ def _no_equalize():
max_shape_sinks = 0
for name, (k, v) in sinks.items():
max_shape_sinks = max(max_shape_sinks, v.offset + (v.end - v.start))
# Exit if source and sink have different sizes

# Exit if source and sink have different sizes
if max_shape_srcs != max_shape_sinks and len(srcs) > 0:
return _no_equalize()

# for i, module in enumerate(srcs):
src_axes = {}
for i, (name, (module, indexes)) in enumerate(srcs.items()):
# If module is not supported, do not perform graph equalization
Expand Down Expand Up @@ -577,10 +580,10 @@ def _equalize(
name_to_module: Dict[str, nn.Module] = {}
name_set = set()
for region in regions:
for name in region.srcs.keys():
name_set.add(name.split("$")[0])
for name in region.sinks.keys():
name_set.add(name.split("$")[0])
for name in region.srcs_names:
name_set.add(name)
for name in region.sinks_names:
name_set.add(name)

for name, module in model.named_modules():
if name in name_set:
Expand Down Expand Up @@ -737,7 +740,6 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
module = get_module(graph_model, node.target)
weight = get_weight_source([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
# state.srcs.add((node.target, eq_indexes))
full_source_name = node.target + '$' + str(eq_indexes)
state.srcs[full_source_name] = eq_indexes
# After we found a source, we need to check if it branches into multiple sinks
Expand Down Expand Up @@ -792,12 +794,10 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
# It is not possible to equalize through LayerNorm as sink
if isinstance(module, (nn.LayerNorm,) + _batch_norm):
# state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP))
state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP
else:
full_sink_name = node.target + '$' + str(eq_indexes)
state.sinks[full_sink_name] = eq_indexes
# state.sinks[node.target] = eq_indexes
elif _is_scale_invariant_module(
graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node):
find_sinks(graph_model, node, state)
Expand Down Expand Up @@ -1092,11 +1092,11 @@ def apply(self, alpha):
self.remove_hooks()
name_to_module = dict_name_to_module(self.model, self.regions)
for region in self.regions:
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
if any([self.float_act_map[name.split("$")[0]] is None for name in region_to_search]):
region_names = region.sinks_names if len(region.acts) == 0 else region.acts
if any([self.float_act_map[name] is None for name in region_names]):
continue
act_module = [name_to_module[act_name.split("$")[0]] for act_name in region.acts]
list_of_act_val = [self.float_act_map[name.split("$")[0]] for name in region_to_search]
act_module = [name_to_module[act_name] for act_name in region.acts]
list_of_act_val = [self.float_act_map[name] for name in region_names]

list_of_insert_mul_node_fn = None
if self.add_mul_node and any([
Expand Down
6 changes: 2 additions & 4 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_
name_to_module = {}
name_set = set()
for region in regions:
for name in region.srcs:
name = name.split('$')[0]
for name in region.srcs_names:
name_set.add(name)
for name in region.sinks:
name = name.split('$')[0]
for name in region.sinks_names:
name_set.add(name)
scale_factors_regions = []
for name, module in model.named_modules():
Expand Down
4 changes: 2 additions & 2 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool
srcs = set()
sinks = set()
for r in regions:
srcs.update([x.split("$")[0] for x in list(r.srcs)])
sinks.update([x.split("$")[0] for x in list(r.sinks)])
srcs.update([x for x in list(r.srcs_names)])
sinks.update([x for x in list(r.sinks_names)])

count_region_srcs = 0
count_region_sinks = 0
Expand Down

0 comments on commit bbeca9f

Please sign in to comment.