Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 30, 2024
1 parent d23b2d4 commit 630f32e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps:

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x, torch.tensor(1.).type_as(x))
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x, torch.tensor(1.).type_as(x))
scale = self.scaling_impl(x)
y = self.tensor_clamp_impl(x, -scale, scale)
y = binary_sign_ste(y) * scale
y = self.delay_wrapper(x, y)
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def quantize(self, x: torch.Tensor):
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
else:
float_scaling_impl_value = torch.tensor(1.).type_as(x)
float_scaling_impl_value = None
scale = self.scaling_impl(x, float_scaling_impl_value)
x = self.input_view_impl(x)
scaled_x = x / scale
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/core/quant/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x, torch.tensor(1.).type_as(x))
scale = self.scaling_impl(x)
mask = x.abs().gt(self.threshold * scale)
y = mask.float() * ternary_sign_ste(x)
y = y * scale
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, weights: Tensor) -> Tensor:
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s
s = self.scaling_impl(weights) # s
value = (s * d_w) / g
return value

Expand Down Expand Up @@ -184,7 +184,7 @@ def calc_max_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Te
def inner_forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool):
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights, torch.tensor(1.).type_as(weights)) # s
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
Expand Down
19 changes: 15 additions & 4 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
stats = self.parameter_list_stats()
return self.stats_scaling_impl(stats, threshold)

Expand Down Expand Up @@ -80,7 +83,10 @@ def __init__(
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(self, stats: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
stats = self.restrict_scaling_pre(stats / threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
Expand Down Expand Up @@ -120,7 +126,7 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
stats = self.runtime_stats(x)
return self.stats_scaling_impl(stats, threshold)

Expand Down Expand Up @@ -179,7 +185,12 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, stats_input: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
def forward(
self,
stats_input: torch.Tensor,
threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
# Scaling min val
Expand Down
43 changes: 35 additions & 8 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __init__(
self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor:
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
value = self.value() / threshold
restricted_value = self.restrict_clamp_scaling(value)
return restricted_value
Expand Down Expand Up @@ -149,7 +151,9 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor:
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold)
return value

Expand Down Expand Up @@ -197,14 +201,23 @@ def __init__(
self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool)
if restrict_scaling_impl is not None:
self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module()
self.restrict_preprocess = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_inplace_preprocess = Identity()
self.restrict_preprocess = Identity()

self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
def forward(
self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
if self.init_done:
value = self.restrict_inplace_preprocess(self.value / threshold)
value = self.restrict_preprocess(self.value / threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
return value
else:
Expand All @@ -214,7 +227,7 @@ def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tenso
if self.local_loss_mode:
return self.stats_scaling_impl(stats, threshold)
inplace_tensor_mul(self.value.detach(), stats)
value = self.restrict_inplace_preprocess(self.value / threshold)
value = self.restrict_preprocess(self.value / threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
self.init_done = True
return value
Expand All @@ -231,6 +244,10 @@ def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
value_key = prefix + 'value'

# Before, the parameter would be stored after restrict_preprocess (e.g., Log2)
# When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2)
# Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2)
if config._RETROCOMPATIBLE_SCALING:
if not isinstance(self.restrict_scaling_impl, Identity):
state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op(
Expand Down Expand Up @@ -325,7 +342,11 @@ def __init__(
self.restrict_preprocess = Identity()

@brevitas.jit.script_method
def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor:
def training_forward(
self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
# Threshold division must happen after we update self.value, but before we apply restrict_preproces
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
if self.counter < self.collect_stats_steps:
stats_input = self.stats_input_view_shape_impl(stats_input)
stats = self.stats(stats_input)
Expand All @@ -335,6 +356,7 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens
new_counter = self.counter + 1
# Whenever we are in local loss mode, we don't update the counter nor the buffer
if self.local_loss_mode:
# Local loss mode, we early exit and divide by threshold
return abs_binary_sign_grad(clamped_stats / threshold)
if self.counter == 0:
inplace_tensor_mul(self.buffer, clamped_stats.detach())
Expand All @@ -346,16 +368,18 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens
elif self.counter == self.collect_stats_steps:
inplace_tensor_mul(self.value.detach(), self.buffer)
value = self.restrict_preprocess(self.value / threshold)
# self.restrict_inplace_preprocess(self.value / threshold)
self.counter = self.counter + 1
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
else:
value = self.restrict_preprocess(self.value / threshold)
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))

@brevitas.jit.script_method
def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor:
def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
if self.training:
# Threshold division handled inside the training_forward
return self.training_forward(stats_input, threshold)
else:
if self.counter <= self.collect_stats_steps:
Expand Down Expand Up @@ -388,6 +412,9 @@ def _load_from_state_dict(
if retrocomp_value_key in state_dict:
state_dict[value_key] = state_dict.pop(retrocomp_value_key)

# Before, the parameter would be stored after restrict_preprocess (e.g., Log2)
# When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2)
# Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2)
if config._RETROCOMPATIBLE_SCALING:
if not isinstance(self.restrict_scaling_impl, Identity):
state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op(
Expand Down

0 comments on commit 630f32e

Please sign in to comment.