Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 24, 2024
1 parent 9a3f25d commit 293dcdc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 31 deletions.
12 changes: 0 additions & 12 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return x
Expand All @@ -116,9 +113,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.power_of_two(x)
Expand All @@ -143,9 +137,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand All @@ -171,9 +162,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def forward(
if threshold is None:
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_scaling_pre(threshold)
threshold = self.restrict_clamp_scaling(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
stats = stats / threshold
return stats


Expand Down
38 changes: 22 additions & 16 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
if self.init_done:
threshold = self.restrict_inplace_preprocess(threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
threshold = self.stats_scaling_impl.restrict_clamp_scaling(
self.restrict_preprocess(threshold))
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
value = value / threshold
return value
else:
stats = self.parameter_list_stats()
Expand All @@ -231,10 +232,11 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor
if self.local_loss_mode:
return self.stats_scaling_impl(stats, threshold)
stats = self.restrict_inplace_preprocess(stats)
threshold = self.restrict_inplace_preprocess(threshold)
threshold = self.stats_scaling_impl.restrict_clamp_scaling(
self.restrict_preprocess(threshold))
inplace_tensor_mul(self.value.detach(), stats)
value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
value = value / threshold
self.init_done = True
return value

Expand Down Expand Up @@ -360,14 +362,16 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
elif self.counter == self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
threshold = self.restrict_scaling(self.restrict_preprocess(threshold))
value = self.clamp_scaling(self.restrict_scaling(self.value))
value = value / threshold
self.counter = self.counter + 1
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
return abs_binary_sign_grad(value)
else:
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
threshold = self.restrict_scaling(self.restrict_preprocess(threshold))
value = self.clamp_scaling(self.restrict_scaling(self.value))
value = value / threshold
return abs_binary_sign_grad(value)

@brevitas.jit.script_method
def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor:
Expand All @@ -378,12 +382,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te
return self.training_forward(stats_input, threshold)
else:
if self.counter <= self.collect_stats_steps:
out = self.buffer / threshold
out = self.buffer
out = self.restrict_preprocess(out)
else:
threshold = self.restrict_preprocess(threshold)
out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold)
out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out)))
out = self.value
threshold = self.restrict_scaling(self.restrict_preprocess(threshold))
out = self.clamp_scaling(self.restrict_scaling(out))
out = out / threshold
out = abs_binary_sign_grad(self.clamp_scaling(out))
return out

def state_dict(self, destination=None, prefix='', keep_vars=False):
Expand Down
3 changes: 1 addition & 2 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def reference_implementation_scale_factors_po2(
return scale


@given(inp=float_tensor_random_size_st())
@given(inp=float_tensor_random_size_st(max_val=1e10, min_val=-1e10))
def test_scale_factors_ptq_calibration_po2(inp):

class TestModel(nn.Module):
Expand All @@ -74,7 +74,6 @@ def forward(self, x):

expected_scale = reference_implementation_scale_factors_po2(inp)
scale = model.act.act_quant.scale()

assert torch.allclose(expected_scale, scale)


Expand Down

0 comments on commit 293dcdc

Please sign in to comment.