Skip to content

Commit 74a802f

Browse files
committed
Fix
1 parent 06bcaae commit 74a802f

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

src/brevitas/graph/equalize.py

+2
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
922922
if hasattr(x, 'names') and 'N' in x.names:
923923
batch_dim = x.names.index('N')
924924

925+
self.batch_dim_act_map[name] = batch_dim
926+
925927
input_scales = self.scale_fn(x, dim=batch_dim)
926928
if name not in self.float_act_map:
927929
self.float_act_map[name] = input_scales

src/brevitas_examples/llm/llm_quant/equalize.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
3131
def trace_and_standardize(model, ref_kwargs):
3232
graph_model = value_trace(model, value_args=ref_kwargs)
3333
graph_model = TorchFunctionalToModule().apply(graph_model)
34-
graph_model = DuplicateSharedStatelessModule().apply(graph_model)
3534
return graph_model
3635

3736

0 commit comments

Comments
 (0)