From 7d724f622f7d66d2255f2d3821bf94b258858b14 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 13:17:18 +0000 Subject: [PATCH] Fix --- src/brevitas/graph/equalize.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f753eba8c..c28bce57d 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -727,7 +727,7 @@ def setup(self): for region in self.regions: batch_dim = 0 if hasattr(region, 'batch_first'): - batch_dim = 0 if region.batch_first == True else 1 + batch_dim = 0 if region.batch_first else 1 hook_fn = partial( self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True) @@ -844,7 +844,7 @@ def setup(self): for name in region.srcs + region.sinks: module = name_to_module[name] if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first == True else 1 + batch_dim = 0 if module.batch_first else 1 for name in region_to_search: act_module = name_to_module[name] use_inp = True if region_to_search == region.sinks else False @@ -920,14 +920,6 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - self.batch_dim_act_map[name] = batch_dim - if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) - else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], - dim=batch_dim) - self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) - input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: self.float_act_map[name] = input_scales