Skip to content

Commit d23b2d4

Browse files
committed
fix
1 parent 149bc32 commit d23b2d4

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

src/brevitas/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ def env_to_bool(name, default):
2525
_FULL_STATE_DICT = False
2626
_IS_INSIDE_QUANT_LAYER = None
2727
_ONGOING_EXPORT = None
28-
_RETROCOMPATIBLE_SCALING = False
28+
_RETROCOMPATIBLE_SCALING = False

src/brevitas/core/restrict_val.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def restrict_init_module(self):
142142

143143
def restrict_init_inplace_module(self):
144144
return Identity()
145-
145+
146146
def retrocompatibility_op(self, x):
147147
return x
148148

@@ -170,7 +170,7 @@ def restrict_init_module(self):
170170

171171
def restrict_init_inplace_module(self):
172172
return InplaceLogTwo()
173-
173+
174174
def retrocompatibility_op(self, x):
175175
return self.power_of_two(x)
176176

src/brevitas/core/scaling/standalone.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150

151151
@brevitas.jit.script_method
152152
def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor:
153-
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value / threshold))
153+
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold)
154154
return value
155155

156156
def _load_from_state_dict(
@@ -214,7 +214,6 @@ def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tenso
214214
if self.local_loss_mode:
215215
return self.stats_scaling_impl(stats, threshold)
216216
inplace_tensor_mul(self.value.detach(), stats)
217-
print(self.restrict_inplace_preprocess)
218217
value = self.restrict_inplace_preprocess(self.value / threshold)
219218
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
220219
self.init_done = True

0 commit comments

Comments
 (0)