Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (core): deprecate bitwidth-less bias #839

Merged
merged 6 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,13 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
@staticmethod
def gate_params_fwd(gate, quant_input):
acc_scale = None
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if isinstance(quant_input, QuantTensor):
acc_bit_width = None # TODO
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor):
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
quant_bias = gate.bias_quant(gate.bias, acc_scale, acc_bit_width)
quant_bias = gate.bias_quant(gate.bias, acc_scale)
return quant_weight_ih, quant_weight_hh, quant_bias

def reset_parameters(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,13 @@ def quant_bias(self):
if self.bias is None:
return None
scale = self.quant_bias_scale()
bit_width = self.quant_bias_bit_width()
quant_bias = self.bias_quant(self.bias, scale, bit_width)
quant_bias = self.bias_quant(self.bias, scale)
return quant_bias

def quant_bias_scale(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).scale
else:
if self._cached_bias is None:
Expand All @@ -197,7 +196,8 @@ def quant_bias_scale(self):
def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:

if not self.bias_quant.requires_input_scale:
bias_quant = self.bias_quant(self.bias)
if isinstance(bias_quant, QuantTensor):
return bias_quant.zero_point
Expand All @@ -215,7 +215,7 @@ def quant_bias_zero_point(self):
def quant_bias_bit_width(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).bit_width
else:
if self._cached_bias is None:
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
output_signed = quant_input.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
quant_bias = self.bias_quant(self.bias, output_scale)
if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias,
QuantTensor):

self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input),
Expand Down
37 changes: 10 additions & 27 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def forward(self, x: torch.Tensor) -> QuantTensor:

@runtime_checkable
class BiasQuantProxyProtocol(QuantProxyProtocol, Protocol):
requires_input_bit_width: bool
requires_input_scale: bool

def forward(
Expand Down Expand Up @@ -161,13 +160,6 @@ class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxy
def tracked_parameter_list(self):
return [m.bias for m in self.tracked_module_list if m.bias is not None]

@property
def requires_input_bit_width(self) -> bool:
if self.is_quant_enabled:
return self.quant_injector.requires_input_bit_width
else:
return False

@property
def requires_input_scale(self) -> bool:
if self.is_quant_enabled:
Expand All @@ -179,42 +171,33 @@ def scale(self):
if self.requires_input_scale:
return None
zhs = self._zero_hw_sentinel()
scale = self.__call__(self.tracked_parameter_list[0], zhs, zhs).scale
scale = self.__call__(self.tracked_parameter_list[0], zhs).scale
return scale

def zero_point(self):
zhs = self._zero_hw_sentinel()
zero_point = self.__call__(self.tracked_parameter_list[0], zhs, zhs).zero_point
zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point
return zero_point

def bit_width(self):
if self.requires_input_bit_width:
return None
zhs = self._zero_hw_sentinel()
bit_width = self.__call__(self.tracked_parameter_list[0], zhs, zhs).bit_width
bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width
return bit_width

def forward(
self,
x: Tensor,
input_scale: Optional[Tensor] = None,
input_bit_width: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
def forward(self,
x: Tensor,
input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
raise RuntimeError("Input scale required")
if self.requires_input_bit_width and input_bit_width is None:
raise RuntimeError("Input bit-width required")
if self.requires_input_scale and self.requires_input_bit_width:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale, input_bit_width)
elif self.requires_input_scale and not self.requires_input_bit_width:

if self.requires_input_scale:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
elif not self.requires_input_scale and not self.requires_input_bit_width:
out, out_scale, out_zp, out_bit_width = impl(x)
else:
raise RuntimeError("Internally defined bit-width required")
out, out_scale, out_zp, out_bit_width = impl(x)

return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
else:
return x
1 change: 0 additions & 1 deletion src/brevitas/quant/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class Int8BiasPerTensorFixedPointInternalScaling(IntQuant,
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFixedPointInternalScaling)
"""
requires_input_scale = False
requires_input_bit_width = False


class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Param):
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/quant/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class NoneBiasQuant(BiasQuantSolver):
"""
quant_type = QuantType.FP
requires_input_scale = False
requires_input_bit_width = False


class NoneTruncQuant(TruncQuantSolver):
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class IntBias(IntQuant, BiasQuantSolver):
"""
tensor_clamp_impl = TensorClamp
requires_input_scale = True
requires_input_bit_width = True


class Int8Bias(IntBias):
Expand All @@ -98,7 +97,6 @@ class Int8Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8Bias)
"""
bit_width = 8
requires_input_bit_width = False


class Int16Bias(IntBias):
Expand All @@ -111,7 +109,6 @@ class Int16Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 16
requires_input_bit_width = False


class Int24Bias(IntBias):
Expand All @@ -124,7 +121,6 @@ class Int24Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 24
requires_input_bit_width = False


class Int32Bias(IntBias):
Expand All @@ -137,7 +133,6 @@ class Int32Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 32
requires_input_bit_width = False


class Int8BiasPerTensorFloatInternalScaling(IntQuant,
Expand All @@ -153,7 +148,6 @@ class Int8BiasPerTensorFloatInternalScaling(IntQuant,
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFloatInternalScaling)
"""
requires_input_scale = False
requires_input_bit_width = False


class Int8WeightPerTensorFloat(NarrowIntQuant,
Expand Down
22 changes: 4 additions & 18 deletions src/brevitas/quant/solver/bias.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import Identity
from brevitas.core.quant import PrescaledRestrictIntQuant
from brevitas.core.quant import PrescaledRestrictIntQuantWithInputBitWidth
from brevitas.core.quant import RescalingIntQuant
from brevitas.inject import ExtendedInjector
from brevitas.inject import value
Expand Down Expand Up @@ -34,29 +32,17 @@ def scaling_per_output_channel_shape(module):
return (module.out_channels,)


class SolveBiasBitWidthImplFromEnum(ExtendedInjector):

@value
def bit_width_impl(bit_width_impl_type, requires_input_bit_width):
if not requires_input_bit_width:
return solve_bit_width_impl_from_enum(bit_width_impl_type)
else:
return Identity


class SolveBiasTensorQuantFromEnum(SolveIntQuantFromEnum):

@value
def tensor_quant(quant_type, requires_input_bit_width, requires_input_scale):
def tensor_quant(quant_type, requires_input_scale):
if quant_type == QuantType.FP:
return None
elif quant_type == QuantType.INT:
if not requires_input_bit_width and requires_input_scale:
if requires_input_scale:
return PrescaledRestrictIntQuant
elif not requires_input_bit_width and not requires_input_scale:
else:
return RescalingIntQuant
else: # requires_input_bit_width == True
return PrescaledRestrictIntQuantWithInputBitWidth
elif quant_type == QuantType.TERNARY:
raise RuntimeError(f'{quant_type} not supported.')
elif quant_type == QuantType.BINARY:
Expand All @@ -75,7 +61,7 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum,
SolveParameterScalingImplFromEnum,
SolveParameterTensorClampImplFromEnum,
SolveParameterScalingInitFromEnum,
SolveBiasBitWidthImplFromEnum,
SolveBitWidthImplFromEnum,
SolveBiasScalingPerOutputChannelShapeFromModule,
SolveBiasScalingStatsInputConcatDimFromModule,
SolveBiasTensorQuantFromEnum,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/bnn_pynq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ These pretrained models and training scripts are courtesy of
| CNV_1W1A | 8 bit | 1 bit | 1 bit | CIFAR10 | 84.22% |
| CNV_1W2A | 8 bit | 1 bit | 2 bit | CIFAR10 | 87.80% |
| CNV_2W2A | 8 bit | 2 bit | 2 bit | CIFAR10 | 89.03% |
| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.60% |
| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.61% |

## Train

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/bnn_pynq/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias
from brevitas.quant import TruncTo8bit
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=IntBias,
last_layer_bias_quant=Int32Bias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias

from .common import CommonIntWeightPerChannelQuant
from .common import CommonIntWeightPerTensorQuant
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
in_channels,
num_classes,
bias=True,
bias_quant=IntBias,
bias_quant=Int32Bias,
weight_quant=last_layer_weight_quant,
weight_bit_width=last_layer_bit_width)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias

from .common import *

Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(
in_features=in_channels,
out_features=num_classes,
bias=True,
bias_quant=IntBias,
bias_quant=Int32Bias,
weight_bit_width=bit_width,
weight_quant=CommonIntWeightPerTensorQuant)

Expand Down
12 changes: 6 additions & 6 deletions src/brevitas_examples/imagenet_classification/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ and by no means a direct mapping to hardware should be assumed.

Below in the table is a list of example pretrained models made available for reference.

| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Top5 | Pretrained model | Retrained from |
nickfraser marked this conversation as resolved.
Show resolved Hide resolved
|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 71.14 | 90.10 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 73.52 | 91.46 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.42 | 92.04 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 75.01 | 92.33 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Pretrained model | Retrained from |
|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.95 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 72.87 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.39 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 74.94 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |


To evaluate a pretrained quantized model on ImageNet:
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torchvision.transforms as transforms
from tqdm import tqdm

from brevitas.quant_tensor import QuantTensor

SEED = 123456

MEAN = [0.485, 0.456, 0.406]
Expand Down Expand Up @@ -81,6 +83,8 @@ def print_accuracy(top1, prefix=''):
images = images.to(dtype)

output = model(images)
if isinstance(output, QuantTensor):
output = output.value
# measure accuracy
acc1, = accuracy(output, target, stable=stable)
top1.update(acc1[0], images.size(0))
Expand Down
5 changes: 0 additions & 5 deletions tests/brevitas/nn/test_hadamard.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import torch
from torch.nn import Module

from brevitas.nn import HadamardClassifier
from brevitas.quant import IntBias
from brevitas.quant_tensor import QuantTensor

OUTPUT_FEATURES = 10
INPUT_FEATURES = 5
Expand Down
Loading
Loading