Skip to content

Commit

Permalink
HQO
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 12, 2024
1 parent 10dcee3 commit 6e0c4fc
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 26 deletions.
222 changes: 222 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import brevitas
from brevitas import config
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.quant_tensor import _unpack_quant_tensor
Expand Down Expand Up @@ -544,3 +546,223 @@ def forward(self, x):
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate


class HalfQuadraticOptimizerScale(torch.nn.Module):
# References:
# https://mobiusml.github.io/hqq_blog/
# https://github.com/mobiusml/hqq?tab=readme-ov-file

def __init__(
self,
proxy_module,
hqo_init_op_scale,
keepdim: bool,
inner_stats_input_view_shape_impl: torch.nn.Module,
scaling_min_val: Optional[float] = None,
stats_reduce_dim: Optional[int] = None,
int_scaling_impl=None,
bit_width_impl=None,
hqo_beta_scale: float = 1e5,
hqo_kappa_scale: float = 1.01,
hqo_lp_norm_scale: float = .7,
hqo_iters_scale: int = 1000):
super(HalfQuadraticOptimizerScale, self).__init__()
self.hqo_init_op = hqo_init_op_scale
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.hqo_iters = hqo_iters_scale
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False

self.beta = hqo_beta_scale
self.kappa = hqo_kappa_scale
self.lp_norm = hqo_lp_norm_scale

self.int_scaling_impl = int_scaling_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
else:
self.clamp_min_ste = Identity()
self.keepdim = keepdim

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
best_candidate = candidate
beta = self.beta
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
loss = torch.abs(quant_tensor.value - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm)
zero_point = quant_tensor.zero_point
num = self.input_view_shape_impl(x - W_e).detach()
den = self.input_view_shape_impl(
torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach()
mask = (num != 0.) & (den != 0.)
if self.stats_reduce_dim is None:
candidate = masked_median(num / den, mask)
else:
candidate = masked_median(
num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim)
candidate = candidate.type_as(self.internal_candidate)
candidate = self.clamp_min_ste(candidate)
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
candidate = candidate * int_threshold
candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)]
candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)]
beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.hqo_init_op(x_view).detach()
best_candidate = self.parameter_search(init, x_view)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
torch.cuda.empty_cache()
return best_candidate

def forward(self, x):
if not self.local_loss_mode:
with torch.no_grad():
return self.optimize(x)
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.hqo_init_op(x).detach()
return self.internal_candidate


class HalfQuadraticOptimizerZeroPoint(torch.nn.Module):
# References:
# https://mobiusml.github.io/hqq_blog/
# https://github.com/mobiusml/hqq?tab=readme-ov-file

def __init__(
self,
proxy_module,
keepdim: bool,
hqo_init_op_zp: torch.nn.Module,
inner_stats_input_view_shape_impl: torch.nn.Module,
stats_reduce_dim: Optional[int] = None,
hqo_beta_zp: float = 1e0,
hqo_kappa_zp: float = 1.01,
hqo_lp_norm_zp: float = .5,
hqo_iters_zp: int = 1000):
super(HalfQuadraticOptimizerZeroPoint, self).__init__()
self.hqo_init_op_zp = hqo_init_op_zp
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False
self.beta = hqo_beta_zp
self.kappa = hqo_kappa_zp
self.lp_norm = hqo_lp_norm_zp
self.hqo_iters = hqo_iters_zp
self.keepdim = keepdim

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
best_candidate = candidate
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_int = self.input_view_shape_impl(quant_tensor.int())
loss = torch.abs(qt_value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)

val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)

if self.stats_reduce_dim is None:
candidate = torch.mean(val)
else:
candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim)
self.beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.hqo_init_op_zp(x_view).detach()

best_candidate = self.parameter_search(init, x)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
torch.cuda.empty_cache()
return best_candidate

def forward(self, x):
if not self.local_loss_mode:
with torch.no_grad():
return self.optimize(x)
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.hqo_init_op_zp(x).detach()
return self.internal_candidate


def masked_median(x, mask, dim=None, keepdim=False):
"""Compute the median of tensor x along dim, ignoring values where mask is False.
x and mask need to be broadcastable.
Args:
x (Tensor): Tensor to compute median of.
mask (BoolTensor): Same shape as x with True where x is valid and False
where x should be masked. Mask should not be all False in any column of
dimension dim to avoid NaNs from zero division.
dim (int, optional): Dimension to take median of. Defaults to 0.
Returns:
Tensor: Same shape as x, except dimension dim reduced.
"""
# uncomment this assert for safety but might impact performance
# assert (
# mask.sum(dim=dim).ne(0).all()
# ), "mask should not be all False in any column, causes zero division"
x_nan = x.float().masked_fill(~mask, float("nan"))
if dim is None:
x_median = x_nan.nanmedian()
else:
x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)
return x_median


# Shrinking operator
def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor:
if lp_norm == 1:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
else:
return torch.sign(x) * torch.nn.functional.relu(
torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1))
2 changes: 1 addition & 1 deletion src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
# Avoid saving the init value
if not self.init_done:
if not self.init_done and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
return output_dict

Expand Down
43 changes: 41 additions & 2 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from brevitas.core.stats import MSE
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerScale
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
from brevitas.core.utils import SingleArgStatelessBuffer
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
Expand Down Expand Up @@ -458,7 +460,7 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
type = (this << 1).type
dtype = (this << 1).dtype


class MSEZeroPointSubInjector(MSESubInjectorBase):
Expand All @@ -470,7 +472,7 @@ class MSEZeroPointSubInjector(MSESubInjectorBase):
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
type = (this << 1).type
dtype = (this << 1).dtype


class MSEAsymmetricScale(ExtendedInjector):
Expand Down Expand Up @@ -520,3 +522,40 @@ class MSEWeightZeroPoint(MSEZeroPoint):

class MSEActZeroPoint(MSEZeroPoint):
zero_point_impl = ParameterFromRuntimeZeroPoint


class HQOZeroPoint(ExtendedInjector):

hqo_init_op_zp = NegativeMinOrZero
inner_stats_input_view_shape_impl = this.zero_point_stats_input_view_shape_impl
stats_impl_zp = HalfQuadraticOptimizerZeroPoint

@value
def zero_point_stats_impl():
return this.stats_impl_zp


class HQOScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
inner_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_impl_scale = HalfQuadraticOptimizerScale

@value
def scaling_stats_impl():
return this.stats_impl_scale


class HQOAsymmetricScale(HQOScale):
hqo_init_op_scale = AbsMinMax


class HQOSymmetricScale(HQOScale):
hqo_init_op_scale = AbsMax


class HQOActZeroPoint(HQOZeroPoint):
zero_point_impl = ParameterFromRuntimeZeroPoint


class HQOWeightZeroPoint(HQOZeroPoint):
zero_point_impl = ParameterFromStatsFromParameterZeroPoint
25 changes: 25 additions & 0 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from brevitas.core.function_wrapper import TensorClamp
from brevitas.quant.base import *
from brevitas.quant.base import HQOSymmetricScale
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
from brevitas.quant.solver.trunc import TruncQuantSolver
Expand Down Expand Up @@ -443,3 +444,27 @@ class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeight
>>> conv.quant_weight()
"""
bit_width = 8


class Int8WeightPerTensorFloatHQO(HQOSymmetricScale, Int8WeightPerTensorFloat):
"""
8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed
from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerTensorFloatHQO)
"""
pass


class Int8WeightPerChannelFloatHQO(HQOSymmetricScale, Int8WeightPerChannelFloat):
"""
8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed
from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerChannelFloatHQO)
"""
pass
Loading

0 comments on commit 6e0c4fc

Please sign in to comment.