Skip to content

Commit

Permalink
Fix (core): deprecate bitwidth-less bias
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 9, 2024
1 parent 30238b0 commit 63b9945
Show file tree
Hide file tree
Showing 15 changed files with 38 additions and 76 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def quant_bias(self):
def quant_bias_scale(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).scale
else:
if self._cached_bias is None:
Expand All @@ -197,7 +197,7 @@ def quant_bias_scale(self):
def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).zero_point
else:
if self._cached_bias is None:
Expand All @@ -211,7 +211,7 @@ def quant_bias_zero_point(self):
def quant_bias_bit_width(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).bit_width
else:
if self._cached_bias is None:
Expand Down
5 changes: 2 additions & 3 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
quant_weight = self.quant_weight(quant_input)

if (self.return_quant_tensor or
(self.is_bias_quant_enabled and
(self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))):
(self.is_bias_quant_enabled and self.bias_quant.requires_input_scale)):
if quant_input.bit_width is not None and quant_weight.bit_width is not None:
output_bit_width = self.max_acc_bit_width(
quant_input.bit_width, quant_weight.bit_width)
Expand All @@ -323,7 +322,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
output_signed = inp.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
quant_bias = self.bias_quant(self.bias, output_scale)
if not self.training and self.cache_inference_quant_bias:
self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)

Expand Down
25 changes: 3 additions & 22 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def forward(self, x: torch.Tensor) -> QuantTensor:

@runtime_checkable
class BiasQuantProxyProtocol(QuantProxyProtocol, Protocol):
requires_input_bit_width: bool
requires_input_scale: bool

def forward(
Expand Down Expand Up @@ -162,13 +161,6 @@ class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxy
def tracked_parameter_list(self):
return [m.bias for m in self.tracked_module_list if m.bias is not None]

@property
def requires_input_bit_width(self) -> bool:
if self.is_quant_enabled:
return self.quant_injector.requires_input_bit_width
else:
return False

@property
def requires_input_scale(self) -> bool:
if self.is_quant_enabled:
Expand All @@ -189,30 +181,19 @@ def zero_point(self):
return zero_point

def bit_width(self):
if self.requires_input_bit_width:
return None
zhs = self._zero_hw_sentinel()
bit_width = self.__call__(self.tracked_parameter_list[0], zhs, zhs).bit_width
return bit_width

def forward(
self,
x: Tensor,
input_scale: Optional[Tensor] = None,
input_bit_width: Optional[Tensor] = None) -> QuantTensor:
def forward(self, x: Tensor, input_scale: Optional[Tensor] = None) -> QuantTensor:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
raise RuntimeError("Input scale required")
if self.requires_input_bit_width and input_bit_width is None:
raise RuntimeError("Input bit-width required")
if self.requires_input_scale and self.requires_input_bit_width:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale, input_bit_width)
elif self.requires_input_scale and not self.requires_input_bit_width:
if self.requires_input_scale:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
elif not self.requires_input_scale and not self.requires_input_bit_width:
elif not self.requires_input_scale:
out, out_scale, out_zp, out_bit_width = impl(x)
else:
raise RuntimeError("Internally defined bit-width required")
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/quant/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class Int8BiasPerTensorFixedPointInternalScaling(IntQuant,
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFixedPointInternalScaling)
"""
requires_input_scale = False
requires_input_bit_width = False


class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Param):
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/quant/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class NoneBiasQuant(BiasQuantSolver):
"""
quant_type = QuantType.FP
requires_input_scale = False
requires_input_bit_width = False


class NoneTruncQuant(TruncQuantSolver):
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class IntBias(IntQuant, BiasQuantSolver):
"""
tensor_clamp_impl = TensorClamp
requires_input_scale = True
requires_input_bit_width = True


class Int8Bias(IntBias):
Expand All @@ -98,7 +97,6 @@ class Int8Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8Bias)
"""
bit_width = 8
requires_input_bit_width = False


class Int16Bias(IntBias):
Expand All @@ -111,7 +109,6 @@ class Int16Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 16
requires_input_bit_width = False


class Int24Bias(IntBias):
Expand All @@ -124,7 +121,6 @@ class Int24Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 24
requires_input_bit_width = False


class Int32Bias(IntBias):
Expand All @@ -137,7 +133,6 @@ class Int32Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 32
requires_input_bit_width = False


class Int8BiasPerTensorFloatInternalScaling(IntQuant,
Expand All @@ -153,7 +148,6 @@ class Int8BiasPerTensorFloatInternalScaling(IntQuant,
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFloatInternalScaling)
"""
requires_input_scale = False
requires_input_bit_width = False


class Int8WeightPerTensorFloat(NarrowIntQuant,
Expand Down
22 changes: 4 additions & 18 deletions src/brevitas/quant/solver/bias.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import Identity
from brevitas.core.quant import PrescaledRestrictIntQuant
from brevitas.core.quant import PrescaledRestrictIntQuantWithInputBitWidth
from brevitas.core.quant import RescalingIntQuant
from brevitas.inject import ExtendedInjector
from brevitas.inject import value
Expand Down Expand Up @@ -34,29 +32,17 @@ def scaling_per_output_channel_shape(module):
return (module.out_channels,)


class SolveBiasBitWidthImplFromEnum(ExtendedInjector):

@value
def bit_width_impl(bit_width_impl_type, requires_input_bit_width):
if not requires_input_bit_width:
return solve_bit_width_impl_from_enum(bit_width_impl_type)
else:
return Identity


class SolveBiasTensorQuantFromEnum(SolveIntQuantFromEnum):

@value
def tensor_quant(quant_type, requires_input_bit_width, requires_input_scale):
def tensor_quant(quant_type, requires_input_scale):
if quant_type == QuantType.FP:
return None
elif quant_type == QuantType.INT:
if not requires_input_bit_width and requires_input_scale:
if requires_input_scale:
return PrescaledRestrictIntQuant
elif not requires_input_bit_width and not requires_input_scale:
elif not requires_input_scale:
return RescalingIntQuant
else: # requires_input_bit_width == True
return PrescaledRestrictIntQuantWithInputBitWidth
elif quant_type == QuantType.TERNARY:
raise RuntimeError(f'{quant_type} not supported.')
elif quant_type == QuantType.BINARY:
Expand All @@ -75,7 +61,7 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum,
SolveParameterScalingImplFromEnum,
SolveParameterTensorClampImplFromEnum,
SolveParameterScalingInitFromEnum,
SolveBiasBitWidthImplFromEnum,
SolveBitWidthImplFromEnum,
SolveBiasScalingPerOutputChannelShapeFromModule,
SolveBiasScalingStatsInputConcatDimFromModule,
SolveBiasTensorQuantFromEnum,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/bnn_pynq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ These pretrained models and training scripts are courtesy of
| CNV_1W1A | 8 bit | 1 bit | 1 bit | CIFAR10 | 84.22% |
| CNV_1W2A | 8 bit | 1 bit | 2 bit | CIFAR10 | 87.80% |
| CNV_2W2A | 8 bit | 2 bit | 2 bit | CIFAR10 | 89.03% |
| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.60% |
| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.61% |

## Train

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/bnn_pynq/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias
from brevitas.quant import TruncTo8bit
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=IntBias,
last_layer_bias_quant=Int32Bias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias

from .common import CommonIntWeightPerChannelQuant
from .common import CommonIntWeightPerTensorQuant
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
in_channels,
num_classes,
bias=True,
bias_quant=IntBias,
bias_quant=Int32Bias,
weight_quant=last_layer_weight_quant,
weight_bit_width=last_layer_bit_width)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias

from .common import *

Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(
in_features=in_channels,
out_features=num_classes,
bias=True,
bias_quant=IntBias,
bias_quant=Int32Bias,
weight_bit_width=bit_width,
weight_quant=CommonIntWeightPerTensorQuant)

Expand Down
12 changes: 6 additions & 6 deletions src/brevitas_examples/imagenet_classification/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ and by no means a direct mapping to hardware should be assumed.

Below in the table is a list of example pretrained models made available for reference.

| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Top5 | Pretrained model | Retrained from |
|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 71.14 | 90.10 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 73.52 | 91.46 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.42 | 92.04 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 75.01 | 92.33 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Pretrained model | Retrained from |
|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.95 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 72.87 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.39 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 74.94 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |


To evaluate a pretrained quantized model on ImageNet:
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torchvision.transforms as transforms
from tqdm import tqdm

from brevitas.quant_tensor import QuantTensor

SEED = 123456

MEAN = [0.485, 0.456, 0.406]
Expand Down Expand Up @@ -81,6 +83,8 @@ def print_accuracy(top1, prefix=''):
images = images.to(dtype)

output = model(images)
if isinstance(output, QuantTensor):
output = output.value
# measure accuracy
acc1, = accuracy(output, target, stable=stable)
top1.update(acc1[0], images.size(0))
Expand Down
5 changes: 0 additions & 5 deletions tests/brevitas/nn/test_hadamard.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import torch
from torch.nn import Module

from brevitas.nn import HadamardClassifier
from brevitas.quant import IntBias
from brevitas.quant_tensor import QuantTensor

OUTPUT_FEATURES = 10
INPUT_FEATURES = 5
Expand Down
13 changes: 9 additions & 4 deletions tests/brevitas/nn/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch.nn import Module

from brevitas.nn import QuantLinear
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias
from brevitas.quant_tensor import QuantTensor

OUTPUT_FEATURES = 10
Expand All @@ -25,7 +24,10 @@ def test_module_init_bias_fp(self):

def test_module_init_bias_int(self):
mod = QuantLinear(
out_features=OUTPUT_FEATURES, in_features=INPUT_FEATURES, bias=True, bias_quant=IntBias)
out_features=OUTPUT_FEATURES,
in_features=INPUT_FEATURES,
bias=True,
bias_quant=Int32Bias)
assert mod

def test_module_init_scale_impl_type_override(self):
Expand All @@ -51,7 +53,10 @@ def test_forward_bias_fp(self):

def test_forward_bias_int(self):
mod = QuantLinear(
out_features=OUTPUT_FEATURES, in_features=INPUT_FEATURES, bias=True, bias_quant=IntBias)
out_features=OUTPUT_FEATURES,
in_features=INPUT_FEATURES,
bias=True,
bias_quant=Int32Bias)
x = QuantTensor(
torch.rand(size=(3, INPUT_FEATURES)),
torch.tensor(1.0),
Expand Down

0 comments on commit 63b9945

Please sign in to comment.