From f05ef84b1579925bf45e59b151c4ba4a03806f8a Mon Sep 17 00:00:00 2001 From: Ian Colbert <88047104+i-colbert@users.noreply.github.com> Date: Fri, 26 Jan 2024 15:48:58 +0000 Subject: [PATCH] Feat (examples/a2q+): new super resolution models (#811) * Feat (a2q+): adding to super_res example * Updating links to pre-trained checkpoints --- .../super_resolution/README.md | 2 + .../super_resolution/models/__init__.py | 46 +++++++++++++--- .../super_resolution/models/common.py | 10 +++- .../super_resolution/models/espcn.py | 4 +- .../super_resolution/utils/evaluate.py | 55 +++++++++++++++++-- 5 files changed, 98 insertions(+), 19 deletions(-) diff --git a/src/brevitas_examples/super_resolution/README.md b/src/brevitas_examples/super_resolution/README.md index 1c73873e5..5a30efef2 100644 --- a/src/brevitas_examples/super_resolution/README.md +++ b/src/brevitas_examples/super_resolution/README.md @@ -20,10 +20,12 @@ Note that this is a difference from many academic works that train only on the Y | [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth) | x2 | int8 | (u)int8 | 30.96 | | [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth) | x2 | int8 | (u)int8 | 30.79 | | [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth) | x2 | int8 | (u)int8 | 30.56 | +| [quant_espcn_x2_w8a8_a2q_plus_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r2/quant_espcn_x2_w8a8_a2q_plus_16b-0ddf46f1.pth) | x2 | int8 | (u)int8 | 31.24 | || | [quant_espcn_x2_w4a4_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth) | x2 | int4 | (u)int4 | 30.30 | | [quant_espcn_x2_w4a4_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth) | x2 | int4 | (u)int4 | 30.27 | | [quant_espcn_x2_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth) | x2 | int4 | (u)int4 | 30.24 | +| [quant_espcn_x2_w4a4_a2q_plus_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r2/quant_espcn_x2_w4a4_a2q_plus_13b-6e6d55f0.pth) | x2 | int4 | (u)int4 | 30.95 | ## Train diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index 6d533fccb..6af691038 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -7,6 +7,7 @@ from torch import hub import torch.nn as nn +from .common import CommonIntAccumulatorAwareZeroCenterWeightQuant from .espcn import * model_impl = { @@ -43,18 +44,45 @@ upscale_factor=2, weight_bit_width=4, act_bit_width=4, - acc_bit_width=13)} + acc_bit_width=13), + 'quant_espcn_x2_w4a4_a2q_plus_13b': + partial( + quant_espcn, + upscale_factor=2, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=13, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant), + 'quant_espcn_x2_w8a8_a2q_plus_16b': + partial( + quant_espcn, + upscale_factor=2, + weight_bit_width=8, + act_bit_width=8, + acc_bit_width=16, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant)} -root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res_r1' +root_url = 'https://github.com/Xilinx/brevitas/releases/download/' model_url = { - 'float_espcn_x2': f'{root_url}/float_espcn_x2-2f85a454.pth', - 'quant_espcn_x2_w4a4_a2q_13b': f'{root_url}/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth', - 'quant_espcn_x2_w4a4_a2q_32b': f'{root_url}/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth', - 'quant_espcn_x2_w4a4_base': f'{root_url}/quant_espcn_x2_w4a4_base-80658e6d.pth', - 'quant_espcn_x2_w8a8_a2q_16b': f'{root_url}/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth', - 'quant_espcn_x2_w8a8_a2q_32b': f'{root_url}/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth', - 'quant_espcn_x2_w8a8_base': f'{root_url}/quant_espcn_x2_w8a8_base-f761e4a1.pth'} + 'float_espcn_x2': + f'{root_url}/super_res_r1/float_espcn_x2-2f85a454.pth', + 'quant_espcn_x2_w4a4_a2q_13b': + f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth', + 'quant_espcn_x2_w4a4_a2q_32b': + f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth', + 'quant_espcn_x2_w4a4_base': + f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth', + 'quant_espcn_x2_w8a8_a2q_16b': + f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth', + 'quant_espcn_x2_w8a8_a2q_32b': + f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth', + 'quant_espcn_x2_w8a8_base': + f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth', + 'quant_espcn_x2_w4a4_a2q_plus_13b': + f'{root_url}/super_res_r2/quant_espcn_x2_w4a4_a2q_plus_13b-6e6d55f0.pth', + 'quant_espcn_x2_w8a8_a2q_plus_16b': + f'{root_url}/super_res_r2/quant_espcn_x2_w8a8_a2q_plus_16b-0ddf46f1.pth'} def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]: diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index 16ba143c5..b9ac6d3fc 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -12,6 +12,7 @@ import brevitas.nn as qnn from brevitas.nn.quant_layer import WeightQuantType from brevitas.quant import Int8AccumulatorAwareWeightQuant +from brevitas.quant import Int8AccumulatorAwareZeroCenterWeightQuant from brevitas.quant import Int8ActPerTensorFloat from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Uint8ActPerTensorFloat @@ -26,9 +27,14 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + """A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance""" restrict_scaling_impl = FloatRestrictValue # backwards compatibility - pre_scaling_min_val = 1e-10 - scaling_min_val = 1e-10 + bit_width = None + + +class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant): + """A2Q+: Improving Accumulator-Aware Weight Quantization""" + bit_width = None class CommonIntActQuant(Int8ActPerTensorFloat): diff --git a/src/brevitas_examples/super_resolution/models/espcn.py b/src/brevitas_examples/super_resolution/models/espcn.py index f3123fcb9..2ec6226a8 100644 --- a/src/brevitas_examples/super_resolution/models/espcn.py +++ b/src/brevitas_examples/super_resolution/models/espcn.py @@ -164,7 +164,7 @@ def float_espcn(upscale_factor: int, num_channels: int = 3) -> FloatESPCN: def quant_espcn( - upcsale_factor: int, + upscale_factor: int, num_channels: int = 3, weight_bit_width: int = 8, act_bit_width: int = 8, @@ -172,7 +172,7 @@ def quant_espcn( weight_quant: WeightQuantType = CommonIntWeightPerChannelQuant) -> QuantESPCN: """ """ return QuantESPCN( - upscale_factor=upcsale_factor, + upscale_factor=upscale_factor, num_channels=num_channels, act_bit_width=act_bit_width, acc_bit_width=acc_bit_width, diff --git a/src/brevitas_examples/super_resolution/utils/evaluate.py b/src/brevitas_examples/super_resolution/utils/evaluate.py index 2d317e9dc..2eb9e3627 100644 --- a/src/brevitas_examples/super_resolution/utils/evaluate.py +++ b/src/brevitas_examples/super_resolution/utils/evaluate.py @@ -5,12 +5,52 @@ from torch import Tensor import torch.nn as nn +from brevitas.core.scaling import AccumulatorAwareParameterPreScaling +from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL EPS = 1e-10 +def _get_a2q_module(module: nn.Module): + for submod in module.modules(): + if isinstance(submod, AccumulatorAwareParameterPreScaling): + return submod + return None + + +def _calc_a2q_acc_bit_width( + weight_max_l1_norm: Tensor, input_bit_width: Tensor, input_is_signed: bool): + """Using the closed-form bounds on accumulator bit-width as derived in + `A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance`. + This function returns the minimum accumulator bit-width that can be used + without risk of overflow.""" + assert weight_max_l1_norm.numel() == 1 + input_is_signed = float(input_is_signed) + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS) + alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed + phi = lambda x: torch.log2(1. + pow(2., -x)) + min_bit_width = alpha + phi(alpha) + 1. + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width + + +def _calc_a2q_plus_acc_bit_width( + weight_max_l1_norm: Tensor, input_bit_width: Tensor, input_is_signed: bool): + """Using the closed-form bounds on accumulator bit-width as derived in `A2Q+: + Improving Accumulator-Aware Weight Quantization`. This function returns the + minimum accumulator bit-width that can be used without risk of overflow, + assuming that the floating-point weights are zero-centered.""" + input_is_signed = float(input_is_signed) + assert weight_max_l1_norm.numel() == 1 + weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS) + input_range = pow(2., input_bit_width) - 1. # 2^N - 1. + min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.) + min_bit_width = torch.ceil(min_bit_width) + return min_bit_width + + def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: assert isinstance(module, qnn.QuantConv2d), "Error: function only support QuantConv2d." @@ -24,12 +64,15 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3)) # using the closed-form bounds on accumulator bit-width - weight_max_l1_norm = quant_weight_per_channel_l1_norm.max() - weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS) - alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed - phi = lambda x: torch.log2(1. + pow(2., -x)) - min_bit_width = alpha + phi(alpha) + 1. - min_bit_width = torch.ceil(min_bit_width) + min_bit_width = _calc_a2q_acc_bit_width( + quant_weight_per_channel_l1_norm.max(), + input_bit_width=input_bit_width, + input_is_signed=input_is_signed) + if isinstance(_get_a2q_module(module), AccumulatorAwareZeroCenterParameterPreScaling): + min_bit_width = _calc_a2q_plus_acc_bit_width( + quant_weight_per_channel_l1_norm.max(), + input_bit_width=input_bit_width, + input_is_signed=input_is_signed) return min_bit_width