Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Feat (stats): extending sigma quant #1202

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/brevitas/core/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .stats_op import MSE
from .stats_op import NegativeMinOrZero
from .stats_op import NegativePercentileOrZero
from .stats_op import NegativeSigmaOrZero
from .stats_op import PercentileInterval
from .stats_op import SigmaStdInterval
from .stats_wrapper import _ParameterListStats
from .stats_wrapper import _RuntimeStats
from .stats_wrapper import _Stats
Expand Down
69 changes: 69 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,75 @@ def forward(self, x: Tensor) -> Tensor:
return result


class SigmaStdInterval(brevitas.jit.ScriptModule):

def __init__(
self,
sigma: float,
stats_reduce_dim: Optional[int] = None,
std_dev_epsilon: float = DEFAULT_STD_DEV_EPSILON,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
keepdim: bool = False) -> None:
super(SigmaStdInterval, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.epsilon = std_dev_epsilon
self.sigma = StatelessBuffer(torch.tensor(sigma, dtype=dtype, device=device))
self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device))
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor):
sigma = self.sigma()
if self.stats_reduce_dim is None:
max_val = torch.max(x)
min_val = torch.min(x)
std_val = torch.sqrt(torch.var(x) + self.epsilon)
else:
max_val = torch.max(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
std_val = torch.sqrt(torch.var(x, dim=self.stats_reduce_dim) + self.epsilon)
std_val = std_val.view(-1)
min_val = torch.clamp(min_val, max=self.zero())
max_range = torch.abs(max_val - min_val)
std_range = 2. * sigma * std_val
val_range = torch.where(std_range < max_range, std_range, max_range)
return val_range


class NegativeSigmaOrZero(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim', 'keepdim']

def __init__(
self,
sigma: float,
std_dev_epsilon: float = DEFAULT_STD_DEV_EPSILON,
stats_reduce_dim: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
keepdim: bool = False) -> None:
super(NegativeSigmaOrZero, self).__init__()
self.keepdim = keepdim
self.stats_reduce_dim = stats_reduce_dim
self.epsilon = std_dev_epsilon
self.sigma = StatelessBuffer(torch.tensor(sigma, dtype=dtype, device=device))
self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
sigma = self.sigma()
if self.stats_reduce_dim is None:
min_val = torch.min(x)
std_val = torch.sqrt(torch.var(x) + self.epsilon)
else:
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
std_val = torch.sqrt(torch.var(x, dim=self.stats_reduce_dim) + self.epsilon)
std_val = std_val.view(-1)
val = torch.where(-sigma * std_val > min_val, -sigma * std_val, min_val)
val = torch.clamp(val, max=self.zero())
return val


class PercentileInterval(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim', 'low_q', 'high_q', 'keepdim']

Expand Down
14 changes: 14 additions & 0 deletions src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from dependencies import value

from brevitas.core.stats import NegativeSigmaOrZero
from brevitas.core.stats import SigmaStdInterval
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.base import *
Expand Down Expand Up @@ -201,3 +205,13 @@ class ShiftedUint8WeightGroupQuantFloat(ShiftedUint8WeightPerChannelFloat):
"""
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP


class GaussianUint8WeightPerChannelFloat(ShiftedUint8WeightPerChannelFloat):

@value
def sigma(bit_width):
return (1. + bit_width) / 2.

scaling_stats_impl = SigmaStdInterval
zero_point_stats_impl = NegativeSigmaOrZero
2 changes: 2 additions & 0 deletions src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def scaling_stats_impl(scaling_stats_op):
return AbsMinMax
elif scaling_stats_op == StatsOp.PERCENTILE_INTERVAL:
return PercentileInterval
elif scaling_stats_op == StatsOp.SIGMA_STD_INTERVAL:
return SigmaStdInterval
else:
raise RuntimeError(f"{scaling_stats_op} not recognized.")

Expand Down
10 changes: 7 additions & 3 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerTensorFloatHQO
from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE
from brevitas.quant.shifted_scaled_int import GaussianUint8WeightPerChannelFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightGroupQuantFloat
Expand Down Expand Up @@ -83,7 +84,9 @@
'per_tensor': {
'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat},
'per_channel': {
'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat},
'sym': Int8WeightPerChannelFloat,
'asym': ShiftedUint8WeightPerChannelFloat,
'gauss': GaussianUint8WeightPerChannelFloat},
'per_group': {
'sym': IntWeightSymmetricGroupQuant,
'asym': ShiftedUint8WeightGroupQuantFloat}},
Expand Down Expand Up @@ -255,7 +258,8 @@ def generate_quantizers(
weight_kwargs=None,
input_kwargs=None,
quant_attn_mode=None,
scaling_min_val=1e-4):
scaling_min_val=1e-4,
weight_narrow_range=False):
"""
Replace float layers with quant layers in the target model
"""
Expand Down Expand Up @@ -340,7 +344,7 @@ def generate_quantizers(
weight_quant = weight_quant.let(
**{
'bit_width': weight_bit_width,
'narrow_range': False,
'narrow_range': weight_narrow_range,
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)

Expand Down