diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index e54ad1ecc..55a5c8150 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -242,16 +242,13 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): @staticmethod def gate_params_fwd(gate, quant_input): acc_scale = None - acc_bit_width = None quant_weight_ih = gate.input_weight() quant_weight_hh = gate.hidden_weight() - if isinstance(quant_input, QuantTensor): - acc_bit_width = None # TODO if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor): acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) acc_scale = quant_weight_ih.scale.view(acc_scale_shape) acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) - quant_bias = gate.bias_quant(gate.bias, acc_scale, acc_bit_width) + quant_bias = gate.bias_quant(gate.bias, acc_scale) return quant_weight_ih, quant_weight_hh, quant_bias def reset_parameters(self) -> None: diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index 3acfe7c95..a752c35ec 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -176,14 +176,13 @@ def quant_bias(self): if self.bias is None: return None scale = self.quant_bias_scale() - bit_width = self.quant_bias_bit_width() - quant_bias = self.bias_quant(self.bias, scale, bit_width) + quant_bias = self.bias_quant(self.bias, scale) return quant_bias def quant_bias_scale(self): if self.bias is None or not self.is_bias_quant_enabled: return None - if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width: + if not self.bias_quant.requires_input_scale: return self.bias_quant(self.bias).scale else: if self._cached_bias is None: @@ -197,7 +196,8 @@ def quant_bias_scale(self): def quant_bias_zero_point(self): if self.bias is None: return None - if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width: + + if not self.bias_quant.requires_input_scale: bias_quant = self.bias_quant(self.bias) if isinstance(bias_quant, QuantTensor): return bias_quant.zero_point @@ -215,7 +215,7 @@ def quant_bias_zero_point(self): def quant_bias_bit_width(self): if self.bias is None or not self.is_bias_quant_enabled: return None - if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width: + if not self.bias_quant.requires_input_scale: return self.bias_quant(self.bias).bit_width else: if self._cached_bias is None: diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index f56ddd160..5cf29602f 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -323,9 +323,10 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe output_signed = quant_input.signed or quant_weight.signed if self.bias is not None: - quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) + quant_bias = self.bias_quant(self.bias, output_scale) if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, QuantTensor): + self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 1f6adf549..2927b1662 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -33,7 +33,6 @@ def forward(self, x: torch.Tensor) -> QuantTensor: @runtime_checkable class BiasQuantProxyProtocol(QuantProxyProtocol, Protocol): - requires_input_bit_width: bool requires_input_scale: bool def forward( @@ -161,13 +160,6 @@ class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxy def tracked_parameter_list(self): return [m.bias for m in self.tracked_module_list if m.bias is not None] - @property - def requires_input_bit_width(self) -> bool: - if self.is_quant_enabled: - return self.quant_injector.requires_input_bit_width - else: - return False - @property def requires_input_scale(self) -> bool: if self.is_quant_enabled: @@ -179,42 +171,33 @@ def scale(self): if self.requires_input_scale: return None zhs = self._zero_hw_sentinel() - scale = self.__call__(self.tracked_parameter_list[0], zhs, zhs).scale + scale = self.__call__(self.tracked_parameter_list[0], zhs).scale return scale def zero_point(self): zhs = self._zero_hw_sentinel() - zero_point = self.__call__(self.tracked_parameter_list[0], zhs, zhs).zero_point + zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point return zero_point def bit_width(self): - if self.requires_input_bit_width: - return None zhs = self._zero_hw_sentinel() - bit_width = self.__call__(self.tracked_parameter_list[0], zhs, zhs).bit_width + bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width return bit_width - def forward( - self, - x: Tensor, - input_scale: Optional[Tensor] = None, - input_bit_width: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: + def forward(self, + x: Tensor, + input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None: raise RuntimeError("Input scale required") - if self.requires_input_bit_width and input_bit_width is None: - raise RuntimeError("Input bit-width required") - if self.requires_input_scale and self.requires_input_bit_width: - input_scale = input_scale.view(-1) - out, out_scale, out_zp, out_bit_width = impl(x, input_scale, input_bit_width) - elif self.requires_input_scale and not self.requires_input_bit_width: + + if self.requires_input_scale: input_scale = input_scale.view(-1) out, out_scale, out_zp, out_bit_width = impl(x, input_scale) - elif not self.requires_input_scale and not self.requires_input_bit_width: - out, out_scale, out_zp, out_bit_width = impl(x) else: - raise RuntimeError("Internally defined bit-width required") + out, out_scale, out_zp, out_bit_width = impl(x) + return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: return x diff --git a/src/brevitas/quant/fixed_point.py b/src/brevitas/quant/fixed_point.py index 41c3a403e..1765dfe83 100644 --- a/src/brevitas/quant/fixed_point.py +++ b/src/brevitas/quant/fixed_point.py @@ -178,7 +178,6 @@ class Int8BiasPerTensorFixedPointInternalScaling(IntQuant, >>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFixedPointInternalScaling) """ requires_input_scale = False - requires_input_bit_width = False class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Param): diff --git a/src/brevitas/quant/none.py b/src/brevitas/quant/none.py index 21720932a..57c9b31a5 100644 --- a/src/brevitas/quant/none.py +++ b/src/brevitas/quant/none.py @@ -31,7 +31,6 @@ class NoneBiasQuant(BiasQuantSolver): """ quant_type = QuantType.FP requires_input_scale = False - requires_input_bit_width = False class NoneTruncQuant(TruncQuantSolver): diff --git a/src/brevitas/quant/scaled_int.py b/src/brevitas/quant/scaled_int.py index 5b7f16ad2..0f67300c3 100644 --- a/src/brevitas/quant/scaled_int.py +++ b/src/brevitas/quant/scaled_int.py @@ -85,7 +85,6 @@ class IntBias(IntQuant, BiasQuantSolver): """ tensor_clamp_impl = TensorClamp requires_input_scale = True - requires_input_bit_width = True class Int8Bias(IntBias): @@ -98,7 +97,6 @@ class Int8Bias(IntBias): >>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8Bias) """ bit_width = 8 - requires_input_bit_width = False class Int16Bias(IntBias): @@ -111,7 +109,6 @@ class Int16Bias(IntBias): >>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias) """ bit_width = 16 - requires_input_bit_width = False class Int24Bias(IntBias): @@ -124,7 +121,6 @@ class Int24Bias(IntBias): >>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias) """ bit_width = 24 - requires_input_bit_width = False class Int32Bias(IntBias): @@ -137,7 +133,6 @@ class Int32Bias(IntBias): >>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias) """ bit_width = 32 - requires_input_bit_width = False class Int8BiasPerTensorFloatInternalScaling(IntQuant, @@ -153,7 +148,6 @@ class Int8BiasPerTensorFloatInternalScaling(IntQuant, >>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFloatInternalScaling) """ requires_input_scale = False - requires_input_bit_width = False class Int8WeightPerTensorFloat(NarrowIntQuant, diff --git a/src/brevitas/quant/solver/bias.py b/src/brevitas/quant/solver/bias.py index 6e2f25440..33a55f55c 100644 --- a/src/brevitas/quant/solver/bias.py +++ b/src/brevitas/quant/solver/bias.py @@ -1,9 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from brevitas.core.function_wrapper import Identity from brevitas.core.quant import PrescaledRestrictIntQuant -from brevitas.core.quant import PrescaledRestrictIntQuantWithInputBitWidth from brevitas.core.quant import RescalingIntQuant from brevitas.inject import ExtendedInjector from brevitas.inject import value @@ -34,29 +32,17 @@ def scaling_per_output_channel_shape(module): return (module.out_channels,) -class SolveBiasBitWidthImplFromEnum(ExtendedInjector): - - @value - def bit_width_impl(bit_width_impl_type, requires_input_bit_width): - if not requires_input_bit_width: - return solve_bit_width_impl_from_enum(bit_width_impl_type) - else: - return Identity - - class SolveBiasTensorQuantFromEnum(SolveIntQuantFromEnum): @value - def tensor_quant(quant_type, requires_input_bit_width, requires_input_scale): + def tensor_quant(quant_type, requires_input_scale): if quant_type == QuantType.FP: return None elif quant_type == QuantType.INT: - if not requires_input_bit_width and requires_input_scale: + if requires_input_scale: return PrescaledRestrictIntQuant - elif not requires_input_bit_width and not requires_input_scale: + else: return RescalingIntQuant - else: # requires_input_bit_width == True - return PrescaledRestrictIntQuantWithInputBitWidth elif quant_type == QuantType.TERNARY: raise RuntimeError(f'{quant_type} not supported.') elif quant_type == QuantType.BINARY: @@ -75,7 +61,7 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum, SolveParameterScalingImplFromEnum, SolveParameterTensorClampImplFromEnum, SolveParameterScalingInitFromEnum, - SolveBiasBitWidthImplFromEnum, + SolveBitWidthImplFromEnum, SolveBiasScalingPerOutputChannelShapeFromModule, SolveBiasScalingStatsInputConcatDimFromModule, SolveBiasTensorQuantFromEnum, diff --git a/src/brevitas_examples/bnn_pynq/README.md b/src/brevitas_examples/bnn_pynq/README.md index f27300345..9ebb69d6d 100644 --- a/src/brevitas_examples/bnn_pynq/README.md +++ b/src/brevitas_examples/bnn_pynq/README.md @@ -20,7 +20,7 @@ These pretrained models and training scripts are courtesy of | CNV_1W1A | 8 bit | 1 bit | 1 bit | CIFAR10 | 84.22% | | CNV_1W2A | 8 bit | 1 bit | 2 bit | CIFAR10 | 87.80% | | CNV_2W2A | 8 bit | 2 bit | 2 bit | CIFAR10 | 89.03% | -| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.60% | +| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.61% | ## Train diff --git a/src/brevitas_examples/bnn_pynq/models/resnet.py b/src/brevitas_examples/bnn_pynq/models/resnet.py index 24995c6b6..14efdf498 100644 --- a/src/brevitas_examples/bnn_pynq/models/resnet.py +++ b/src/brevitas_examples/bnn_pynq/models/resnet.py @@ -9,7 +9,7 @@ import brevitas.nn as qnn from brevitas.quant import Int8WeightPerChannelFloat from brevitas.quant import Int8WeightPerTensorFloat -from brevitas.quant import IntBias +from brevitas.quant import Int32Bias from brevitas.quant import TruncTo8bit from brevitas.quant_tensor import QuantTensor @@ -121,7 +121,7 @@ def __init__( act_bit_width=8, weight_bit_width=8, round_average_pool=False, - last_layer_bias_quant=IntBias, + last_layer_bias_quant=Int32Bias, weight_quant=Int8WeightPerChannelFloat, first_layer_weight_quant=Int8WeightPerChannelFloat, last_layer_weight_quant=Int8WeightPerTensorFloat): diff --git a/src/brevitas_examples/imagenet_classification/models/mobilenetv1.py b/src/brevitas_examples/imagenet_classification/models/mobilenetv1.py index fd085f361..47a6ec8e1 100644 --- a/src/brevitas_examples/imagenet_classification/models/mobilenetv1.py +++ b/src/brevitas_examples/imagenet_classification/models/mobilenetv1.py @@ -34,7 +34,7 @@ from brevitas.nn import QuantLinear from brevitas.nn import QuantReLU from brevitas.nn import TruncAvgPool2d -from brevitas.quant import IntBias +from brevitas.quant import Int32Bias from .common import CommonIntWeightPerChannelQuant from .common import CommonIntWeightPerTensorQuant @@ -181,7 +181,7 @@ def __init__( in_channels, num_classes, bias=True, - bias_quant=IntBias, + bias_quant=Int32Bias, weight_quant=last_layer_weight_quant, weight_bit_width=last_layer_bit_width) diff --git a/src/brevitas_examples/imagenet_classification/models/proxylessnas.py b/src/brevitas_examples/imagenet_classification/models/proxylessnas.py index 5c8cc5539..a26132228 100644 --- a/src/brevitas_examples/imagenet_classification/models/proxylessnas.py +++ b/src/brevitas_examples/imagenet_classification/models/proxylessnas.py @@ -30,7 +30,7 @@ from brevitas.nn import QuantLinear from brevitas.nn import QuantReLU from brevitas.nn import TruncAvgPool2d -from brevitas.quant import IntBias +from brevitas.quant import Int32Bias from .common import * @@ -300,7 +300,7 @@ def __init__( in_features=in_channels, out_features=num_classes, bias=True, - bias_quant=IntBias, + bias_quant=Int32Bias, weight_bit_width=bit_width, weight_quant=CommonIntWeightPerTensorQuant) diff --git a/src/brevitas_examples/imagenet_classification/qat/README.md b/src/brevitas_examples/imagenet_classification/qat/README.md index b0fce235e..92e318c05 100644 --- a/src/brevitas_examples/imagenet_classification/qat/README.md +++ b/src/brevitas_examples/imagenet_classification/qat/README.md @@ -5,12 +5,12 @@ and by no means a direct mapping to hardware should be assumed. Below in the table is a list of example pretrained models made available for reference. -| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Top5 | Pretrained model | Retrained from | -|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------| -| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 71.14 | 90.10 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | -| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 73.52 | 91.46 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | -| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.42 | 92.04 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | -| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 75.01 | 92.33 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | +| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Pretrained model | Retrained from | +|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------| +| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.95 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | +| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 72.87 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | +| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.39 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | +| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 74.94 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) | To evaluate a pretrained quantized model on ImageNet: diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index f614f287c..d506b8a61 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -5,6 +5,8 @@ import torchvision.transforms as transforms from tqdm import tqdm +from brevitas.quant_tensor import QuantTensor + SEED = 123456 MEAN = [0.485, 0.456, 0.406] @@ -81,6 +83,8 @@ def print_accuracy(top1, prefix=''): images = images.to(dtype) output = model(images) + if isinstance(output, QuantTensor): + output = output.value # measure accuracy acc1, = accuracy(output, target, stable=stable) top1.update(acc1[0], images.size(0)) diff --git a/tests/brevitas/nn/test_hadamard.py b/tests/brevitas/nn/test_hadamard.py index ba552a3e2..9853b549b 100644 --- a/tests/brevitas/nn/test_hadamard.py +++ b/tests/brevitas/nn/test_hadamard.py @@ -1,9 +1,4 @@ -import torch -from torch.nn import Module - from brevitas.nn import HadamardClassifier -from brevitas.quant import IntBias -from brevitas.quant_tensor import QuantTensor OUTPUT_FEATURES = 10 INPUT_FEATURES = 5 diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index b9690ce63..8fd2d6e04 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -2,10 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import torch -from torch.nn import Module from brevitas.nn import QuantLinear -from brevitas.quant import IntBias +from brevitas.quant import Int32Bias from brevitas.quant_tensor import QuantTensor OUTPUT_FEATURES = 10 @@ -25,7 +24,10 @@ def test_module_init_bias_fp(self): def test_module_init_bias_int(self): mod = QuantLinear( - out_features=OUTPUT_FEATURES, in_features=INPUT_FEATURES, bias=True, bias_quant=IntBias) + out_features=OUTPUT_FEATURES, + in_features=INPUT_FEATURES, + bias=True, + bias_quant=Int32Bias) assert mod def test_module_init_scale_impl_type_override(self): @@ -51,7 +53,10 @@ def test_forward_bias_fp(self): def test_forward_bias_int(self): mod = QuantLinear( - out_features=OUTPUT_FEATURES, in_features=INPUT_FEATURES, bias=True, bias_quant=IntBias) + out_features=OUTPUT_FEATURES, + in_features=INPUT_FEATURES, + bias=True, + bias_quant=Int32Bias) x = QuantTensor( torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0),