diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 59b3fe8ec..7eb9845f9 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -90,9 +90,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: return x @@ -116,9 +113,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.power_of_two(x) @@ -143,9 +137,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) @@ -171,9 +162,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 2dc4cea1c..fee4175bc 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -90,10 +90,11 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats) threshold = self.restrict_scaling_pre(threshold) + threshold = self.restrict_clamp_scaling(threshold) stats = self.restrict_scaling_pre(stats) - stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) + stats = stats / threshold return stats diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 4917b859a..e43fd577a 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -220,9 +220,10 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.restrict_inplace_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + threshold = self.stats_scaling_impl.restrict_clamp_scaling( + self.restrict_preprocess(threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold return value else: stats = self.parameter_list_stats() @@ -231,10 +232,11 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) stats = self.restrict_inplace_preprocess(stats) - threshold = self.restrict_inplace_preprocess(threshold) + threshold = self.stats_scaling_impl.restrict_clamp_scaling( + self.restrict_preprocess(threshold)) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold self.init_done = True return value @@ -360,14 +362,16 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + return abs_binary_sign_grad(value) else: - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold + return abs_binary_sign_grad(value) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -378,12 +382,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer / threshold + out = self.buffer out = self.restrict_preprocess(out) else: - threshold = self.restrict_preprocess(threshold) - out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) + out = self.value + threshold = self.restrict_scaling(self.restrict_preprocess(threshold)) + out = self.clamp_scaling(self.restrict_scaling(out)) + out = out / threshold + out = abs_binary_sign_grad(self.clamp_scaling(out)) return out def state_dict(self, destination=None, prefix='', keep_vars=False): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 10d8f7e7c..58600b6c8 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -54,7 +54,7 @@ def reference_implementation_scale_factors_po2( return scale -@given(inp=float_tensor_random_size_st()) +@given(inp=float_tensor_random_size_st(max_val=1e10, min_val=-1e10)) def test_scale_factors_ptq_calibration_po2(inp): class TestModel(nn.Module): @@ -74,7 +74,6 @@ def forward(self, x): expected_scale = reference_implementation_scale_factors_po2(inp) scale = model.act.act_quant.scale() - assert torch.allclose(expected_scale, scale)