From 8180de607025530c965cfc8de956ccd2e570d77d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 11:33:36 +0100 Subject: [PATCH 01/16] Fix MX --- notebooks/minifloat_mx_tutorial.ipynb | 6 +++--- src/brevitas/core/function_wrapper/shape.py | 20 +++++++++++++++++++ src/brevitas/core/quant/float.py | 4 +++- src/brevitas/core/quant/int.py | 8 +++++++- src/brevitas/core/scaling/runtime.py | 13 +++--------- .../proxy/groupwise_float_parameter_quant.py | 6 ------ .../proxy/groupwise_float_runtime_quant.py | 1 - .../proxy/groupwise_int_parameter_quant.py | 1 - src/brevitas/quant/solver/act.py | 13 +++++++++++- src/brevitas/quant/solver/parameter.py | 10 ++++++++++ src/brevitas/quant/solver/weight.py | 6 +++++- .../groupwise_float_quant_tensor.py | 8 +++++--- .../groupwise_int_quant_tensor.py | 7 ++++--- 13 files changed, 72 insertions(+), 31 deletions(-) diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index e764fd05c..60f00fcd4 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -233,12 +233,12 @@ "import brevitas.nn as qnn\n", "import torch\n", "\n", - "class MXFloat8Weight(MXInt8Weight):\n", + "class MXInt8Weight(MXInt8Weight):\n", " # The group dimension for the weights it is automatically identified based on the layer type\n", " # If a new layer type is used, it can be manually specified\n", " bit_width = 8\n", "\n", - "class MXFloat8Act(MXInt8Act):\n", + "class MXInt8Act(MXInt8Act):\n", " # It is necessary to specify the group dimension for the activation quantization\n", " group_dim = 1\n", " bit_width = 8\n", @@ -246,7 +246,7 @@ "class MXModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", - " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXInt8Weight, input_quant=MXInt8Act)\n", " \n", " def forward(self, x):\n", " return self.conv(x)\n", diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index cdef81b3e..1bb77476e 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -171,6 +171,25 @@ def forward(self, x: torch.Tensor): return y +class DynamicOverSubChannelBlockView(brevitas.jit.ScriptModule): + __constants__ = ['group_size', 'group_dim'] + + def __init__(self, group_size, group_dim) -> None: + super(DynamicOverSubChannelBlockView, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + + @brevitas.jit.script_method + def forward(self, x): + tensor_shape = x.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list.insert(block_dim, self.group_size) + x = x.view(tensor_shape_list) + return x + + class StatsInputViewShapeImpl(object): """ Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. @@ -182,3 +201,4 @@ class StatsInputViewShapeImpl(object): OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView OVER_OUTPUT_FEATURES = OverOutputFeaturesView OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView + DYNAMIC_OVER_SUBCHANNEL_BLOCK = DynamicOverSubChannelBlockView diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 195d42a96..f4fd79f1a 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -24,6 +24,7 @@ def __init__( mantissa_bit_width: int, exponent_bias: int, float_clamp_impl: nn.Module, + input_view_impl: nn.Module, scaling_impl: Optional[nn.Module] = None, float_scaling_impl: Optional[nn.Module] = None, float_to_int_impl: nn.Module = RoundSte(), @@ -52,6 +53,7 @@ def __init__( if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) + self.input_view_impl = input_view_impl # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl @@ -71,7 +73,7 @@ def quantize(self, x: torch.Tensor): float_scaling_impl_value = self.float_scaling_impl( self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) scale = scale / float_scaling_impl_value - + x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..9c29ecdd2 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -135,6 +135,7 @@ class RescalingIntQuant(brevitas.jit.ScriptModule): def __init__( self, int_quant: Module, + input_view_impl: Module, scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, @@ -145,6 +146,7 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -153,6 +155,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) + x = self.input_view_impl(x) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -167,7 +170,8 @@ def __init__( int_scaling_impl: Module, pre_zero_point_impl: Module, zero_point_impl: Module, - bit_width_impl: Module): + bit_width_impl: Module, + input_view_impl: Module): super(DecoupledRescalingIntQuant, self).__init__() self.decoupled_int_quant = decoupled_int_quant self.pre_scaling_impl = pre_scaling_impl @@ -176,6 +180,7 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -184,6 +189,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) + x = self.input_view_impl(x) threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 23707344f..eda69ba10 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -166,6 +166,7 @@ def __init__( self, group_size: int, group_dim: int, + input_view_impl: torch.nn.Module, scaling_stats_impl: torch.nn.Module, scaling_min_val: Optional[float], restrict_scaling_impl: Optional[torch.nn.Module]) -> None: @@ -174,21 +175,13 @@ def __init__( self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl self.scaling_min_val = scaling_min_val + self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) - @brevitas.jit.script_method - def group_scaling_reshape(self, stats_input): - tensor_shape = stats_input.shape - tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 - tensor_shape_list.insert(block_dim, self.group_size) - stats_input = stats_input.view(tensor_shape_list) - return stats_input @brevitas.jit.script_method def forward(self, stats_input) -> torch.Tensor: - stats_input_reshaped = self.group_scaling_reshape(stats_input) + stats_input_reshaped = self.input_view_impl(stats_input) out = self.scaling_stats_impl(stats_input_reshaped) # Scaling min val out = self.restrict_clamp_scaling(out) diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py index cd38d9906..d08033f8e 100644 --- a/src/brevitas/proxy/groupwise_float_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -9,11 +9,6 @@ class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Is this always generated? - self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl - @property def group_dim(self): return self.quant_injector.group_dim @@ -25,7 +20,6 @@ def group_size(self): def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) return GroupwiseFloatQuantTensor( out, diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 4ab182d20..b2aad4729 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -24,7 +24,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat y = x if isinstance(y, QuantTensor): y = y.value - if self.export_mode: y = self.fused_activation_quant_proxy.activation_impl(y) y = self.export_handler(y) diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py index 035ee9729..35892daeb 100644 --- a/src/brevitas/proxy/groupwise_int_parameter_quant.py +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -25,7 +25,6 @@ def group_size(self): def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant - x = self.view_impl(x) out, scale, zero_point, bit_width = impl(x) return GroupwiseIntQuantTensor( out, diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index 3149e75b9..e7c09869f 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl import torch from torch import nn from torch import Tensor @@ -128,6 +130,14 @@ def update_state_dict_impl(scaling_impl_type): return None +class SolveInputViewImpl(ExtendedInjector): + @value + def input_view_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.DYNAMIC_OVER_SUBCHANNEL_BLOCK + else: + return Identity + class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingImplFromEnum, SolveIntScalingImplFromEnum, @@ -140,7 +150,8 @@ class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingShape, SolveScalingStatsInputViewShapeImplFromEnum, SolveActScalingPerOutputChannelShape, - SolveUpdateStateDictImplFromEnum): + SolveUpdateStateDictImplFromEnum, + SolveInputViewImpl): """ Translate enum directives to activation-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index d8c655efa..38d4990df 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -4,6 +4,8 @@ import math from typing import List +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from dependencies import this from dependencies import value import torch @@ -139,3 +141,11 @@ def expanded_scaling_shape(module, group_size=None): def group_dim(module, group_size=None): if group_size is not None: return 1 + +class SolveInputViewImpl(ExtendedInjector): + @value + def input_view_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK + else: + return Identity diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 097f65443..54ac52ee8 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -8,6 +8,7 @@ from brevitas.inject import this from brevitas.inject import value from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * @@ -68,6 +69,8 @@ def scaling_stats_input_concat_dim(scaling_per_output): return 0 elif scaling_per_output == ScalingPerOutputType.CHANNEL: return 1 + else: + raise RuntimeError("Shared groupwise quantization is not supported") @value def permute_dims(module, output_channel_dim): @@ -103,7 +106,8 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum, SolveParameterScalingShape, SolveWeightScalingPerOutputChannelShapeFromModule, SolveWeightTensorQuantFromEnum, - SolveDtypeDeviceFromTrackedParameterList): + SolveDtypeDeviceFromTrackedParameterList, + SolveInputViewImpl): """ Translate enum and shape directives to weight-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 7d73bf7de..e0b7eab0e 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -90,13 +90,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def expand(self): curr_shape = self.value_.shape - new_value = self.value_.flatten(self.group_dim, self.group_dim + 1) + start_dim = self.group_dim if self.group_dim != -1 else -2 + new_value = self.value_.flatten(start_dim, start_dim + 1) + new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_scale = self.scale_ if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_zp = self.zero_point_ diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 976e86130..21068c8c9 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -59,13 +59,14 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def expand(self): curr_shape = self.value_.shape - new_value = self.value_.flatten(self.group_dim, self.group_dim + 1) + start_dim = self.group_dim if self.group_dim != -1 else -2 + new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_scale = self.scale_ if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: new_zp = self.zero_point_ From 615524152ffcade36e4cf1dbc6f8efae592200f2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:30:06 +0100 Subject: [PATCH 02/16] Fix Bias --- src/brevitas/quant/solver/bias.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant/solver/bias.py b/src/brevitas/quant/solver/bias.py index 33a55f55c..eb840541e 100644 --- a/src/brevitas/quant/solver/bias.py +++ b/src/brevitas/quant/solver/bias.py @@ -9,6 +9,7 @@ from brevitas.proxy import BiasQuantProxyFromInjector from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * +from brevitas.quant.solver.parameter import SolveInputViewImpl __all__ = [ 'BiasQuantSolver', @@ -65,7 +66,8 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum, SolveBiasScalingPerOutputChannelShapeFromModule, SolveBiasScalingStatsInputConcatDimFromModule, SolveBiasTensorQuantFromEnum, - SolveDtypeDeviceFromTrackedParameterList): + SolveDtypeDeviceFromTrackedParameterList, + SolveInputViewImpl): """ Translate enum directives to bias-specific quantization core modules. It should be placed last in the list of classes a quantizer inherits from, From 2ee59f4487d22063fc40b1760796796f47c5dce8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:36:57 +0100 Subject: [PATCH 03/16] Moved one level down --- src/brevitas/core/quant/int.py | 5 ----- src/brevitas/core/quant/int_base.py | 6 ++++++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 9c29ecdd2..ff3e16a5b 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -135,7 +135,6 @@ class RescalingIntQuant(brevitas.jit.ScriptModule): def __init__( self, int_quant: Module, - input_view_impl: Module, scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, @@ -146,7 +145,6 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl - self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -155,7 +153,6 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - x = self.input_view_impl(x) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -180,7 +177,6 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl - self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -189,7 +185,6 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - x = self.input_view_impl(x) threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 7a7a0f828..338e5a433 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -51,6 +51,7 @@ def __init__( self, narrow_range: bool, signed: bool, + input_view_impl: Module, float_to_int_impl: Module = RoundSte(), tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): @@ -60,9 +61,11 @@ def __init__( self.signed = signed self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) + self.input_view_impl = input_view_impl @brevitas.jit.script_method def to_int(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: + x = self.input_view_impl(x) y = x / scale y = y + zero_point min_int_val = self.min_int(bit_width) @@ -124,6 +127,7 @@ def __init__( self, narrow_range: bool, signed: bool, + input_view_impl: Module, float_to_int_impl: Module = RoundSte(), tensor_clamp_impl: Module = TensorClamp(), quant_delay_steps: int = 0): @@ -133,11 +137,13 @@ def __init__( self.signed = signed self.narrow_range = narrow_range self.delay_wrapper = DelayWrapper(quant_delay_steps) + self.input_view_impl = input_view_impl @brevitas.jit.script_method def to_int( self, pre_scale: Tensor, pre_zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: + x = self.input_view_impl(x) y = x / pre_scale y = y + pre_zero_point min_int_val = self.min_int(bit_width) From d6019493477e5ebefaa814b687cd037e1c45d70b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:47:58 +0100 Subject: [PATCH 04/16] fix for bias v2 --- src/brevitas/quant/solver/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 61eccc90b..2847275e8 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -195,8 +195,6 @@ def scaling_per_output(scaling_per_output_type=None, scaling_per_output_channel= return ScalingPerOutputType.CHANNEL if scaling_per_output_channel else ScalingPerOutputType.TENSOR elif scaling_per_output_type is not None: return scaling_per_output_type - else: - raise RuntimeError("Specify scaling_per_output_type or scaling_per_output_channel") class SolveScalingStatsInputViewShapeImplFromEnum(ExtendedInjector): From f9578518e9a8343aa0c1cd0bb81a7b41d335dcde Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:51:44 +0100 Subject: [PATCH 05/16] precommit and remove raise --- src/brevitas/core/scaling/runtime.py | 1 - src/brevitas/quant/solver/act.py | 6 ++++-- src/brevitas/quant/solver/parameter.py | 6 ++++-- src/brevitas/quant/solver/weight.py | 4 +--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index eda69ba10..e4333186d 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -178,7 +178,6 @@ def __init__( self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) - @brevitas.jit.script_method def forward(self, stats_input) -> torch.Tensor: stats_input_reshaped = self.input_view_impl(stats_input) diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index e7c09869f..345239089 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -1,12 +1,12 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from brevitas.core.function_wrapper.misc import Identity -from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl import torch from torch import nn from torch import Tensor +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from brevitas.core.quant import ClampedBinaryQuant from brevitas.core.quant import RescalingIntQuant from brevitas.core.quant import TernaryQuant @@ -131,6 +131,7 @@ def update_state_dict_impl(scaling_impl_type): class SolveInputViewImpl(ExtendedInjector): + @value def input_view_impl(scaling_per_output): if scaling_per_output == ScalingPerOutputType.GROUP: @@ -138,6 +139,7 @@ def input_view_impl(scaling_per_output): else: return Identity + class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingImplFromEnum, SolveIntScalingImplFromEnum, diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 38d4990df..97137c567 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -4,8 +4,6 @@ import math from typing import List -from brevitas.core.function_wrapper.misc import Identity -from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from dependencies import this from dependencies import value import torch @@ -14,6 +12,8 @@ from brevitas.core.bit_width import * from brevitas.core.function_wrapper import TensorClamp from brevitas.core.function_wrapper import TensorClampSte +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from brevitas.core.scaling import * from brevitas.core.scaling import ScalingImplType from brevitas.core.scaling import ScalingPerOutputType @@ -142,7 +142,9 @@ def group_dim(module, group_size=None): if group_size is not None: return 1 + class SolveInputViewImpl(ExtendedInjector): + @value def input_view_impl(scaling_per_output): if scaling_per_output == ScalingPerOutputType.GROUP: diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 54ac52ee8..7f63fe17e 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -8,9 +8,9 @@ from brevitas.inject import this from brevitas.inject import value from brevitas.proxy import WeightQuantProxyFromInjector -from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * +from brevitas.quant.solver.parameter import SolveInputViewImpl __all__ = [ 'SolveWeightTensorQuantFromEnum', @@ -69,8 +69,6 @@ def scaling_stats_input_concat_dim(scaling_per_output): return 0 elif scaling_per_output == ScalingPerOutputType.CHANNEL: return 1 - else: - raise RuntimeError("Shared groupwise quantization is not supported") @value def permute_dims(module, output_channel_dim): From bd921b1f261b9b4ac5d4ba7ce870ee323d9928c7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:56:27 +0100 Subject: [PATCH 06/16] fix notebook --- notebooks/Brevitas_TVMCon2021.ipynb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index 20ce30701..81b866269 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -1659,11 +1659,13 @@ "from brevitas.core.bit_width import BitWidthConst\n", "from brevitas.core.quant import IntQuant, RescalingIntQuant\n", "from brevitas.core.zero_point import ZeroZeroPoint\n", + "from brevitas.core.function_wrapper.misc import Identity\n", "\n", "tensor_quant = RescalingIntQuant(\n", " int_quant=IntQuant(\n", " float_to_int_impl=RoundSte(),\n", " tensor_clamp_impl=TensorClamp(),\n", + " input_view_impl=Identity,\n", " signed=False,\n", " narrow_range=False),\n", " zero_point_impl=ZeroZeroPoint(),\n", From 52d6560684d7652f4864b79a945abc7af40215d0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:04:28 +0100 Subject: [PATCH 07/16] fix a2q dep inj --- src/brevitas/quant/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 18351a05b..abcf6d4f0 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -53,6 +53,7 @@ from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum +from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.parameter import SolveParameterScalingShape from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule from brevitas.quant.solver.weight import SolveWeightScalingStatsInputDimsFromModule @@ -333,7 +334,8 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, - SolveParameterScalingShape): + SolveParameterScalingShape, + SolveInputViewImpl): """Experimental narrow per-channel weight normalization-based signed integer quantizer based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. From fd323a4ae8f237c85847f7a0cb27dfb3855afd46 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:09:19 +0100 Subject: [PATCH 08/16] notebook fix --- notebooks/Brevitas_TVMCon2021.ipynb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index 81b866269..0453c7bd6 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -1769,6 +1769,7 @@ "from brevitas.inject import value\n", "from brevitas.proxy import WeightQuantProxyFromInjector\n", "from brevitas.core.scaling import ParameterScaling\n", + "from brevitas.core.function_wrapper.misc import Identity\n", "\n", "class Int8ActPerTensorFloatParameterFromScratch(ExtendedInjector):\n", " \n", @@ -1786,11 +1787,12 @@ " int_scaling_impl = IntScaling\n", " scaling_impl = ParameterScaling\n", " restrict_scaling_impl = FloatRestrictValue\n", + " input_view_impl = Identity\n", " scaling_shape = ()\n", " bit_width = 8\n", " narrow_range = True\n", " signed = True\n", - " \n", + "\n", "quant_linear = QuantLinear(2, 4, weight_quant=Int8ActPerTensorFloatParameterFromScratch, bias=False)" ] }, From 6eafb8aa525f6fcf9a17ba6e2fedb88aa891250f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:10:28 +0100 Subject: [PATCH 09/16] remove unused init param --- src/brevitas/core/quant/int.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index ff3e16a5b..cdb75df74 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -167,8 +167,7 @@ def __init__( int_scaling_impl: Module, pre_zero_point_impl: Module, zero_point_impl: Module, - bit_width_impl: Module, - input_view_impl: Module): + bit_width_impl: Module): super(DecoupledRescalingIntQuant, self).__init__() self.decoupled_int_quant = decoupled_int_quant self.pre_scaling_impl = pre_scaling_impl From 6d48dd633ec1477a99d6ebbff318afa723ba4fff Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:20:52 +0100 Subject: [PATCH 10/16] Qonnx fix --- notebooks/Brevitas_TVMCon2021.ipynb | 2 +- src/brevitas/export/onnx/qonnx/function.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index 0453c7bd6..7f5846e09 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -1940,7 +1940,7 @@ "torch.manual_seed(0)\n", "\n", "from brevitas.export import export_qonnx\n", - "from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Int16Bias\n", + "from brevitas.quant import Int8ActPerTensorFloat, Int16Bias\n", "\n", "float_inp = torch.randn(1, 2, 5)\n", "\n", diff --git a/src/brevitas/export/onnx/qonnx/function.py b/src/brevitas/export/onnx/qonnx/function.py index d410ee31e..3e7faad0e 100644 --- a/src/brevitas/export/onnx/qonnx/function.py +++ b/src/brevitas/export/onnx/qonnx/function.py @@ -7,6 +7,7 @@ from brevitas.core.bit_width import BitWidthConst from brevitas.core.function_wrapper.clamp import TensorClamp +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant import IntQuant from brevitas.core.quant import TruncIntQuant from brevitas.function import binary_sign @@ -51,6 +52,7 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding quant = IntQuant( float_to_int_impl=float_to_int_impl(), tensor_clamp_impl=TensorClamp(), + input_view_impl=Identity(), #TODO: Update this when QONNX support Groupwise export narrow_range=narrow_range, signed=signed) y = quant(scale, zero_point, bit_width, x) From 6365b65402fe0104654e7d80fce723795e771638 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:43:06 +0100 Subject: [PATCH 11/16] last fix group_dim maybe --- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 1 + src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index e0b7eab0e..fa91bdca1 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -106,6 +106,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): + group_dim = group_dim if group_dim != -1 else -2 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 21068c8c9..082ec1234 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -74,6 +74,7 @@ def expand(self): @staticmethod def from_expanded(value, group_size, group_dim, compress=False): + group_dim = group_dim if group_dim != -1 else -2 size = list(value.shape) assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' if compress: From 075eeb44681c219b2ec0ed5f5d72da7cdedbe108 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:45:45 +0100 Subject: [PATCH 12/16] fix last test --- tests/brevitas/core/test_int_quant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/brevitas/core/test_int_quant.py b/tests/brevitas/core/test_int_quant.py index 312795235..5197b5b67 100644 --- a/tests/brevitas/core/test_int_quant.py +++ b/tests/brevitas/core/test_int_quant.py @@ -5,6 +5,7 @@ import mock import torch +from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp from brevitas.core.quant import * @@ -30,6 +31,7 @@ def test_int_quant_to_int_called_with( int_quant = IntQuant( narrow_range=narrow_range, signed=signed, + input_view_impl=Identity, float_to_int_impl=float_to_int_impl, tensor_clamp_impl=tensor_clamp_impl) bit_width = torch.tensor(bit_width_init) From 13437f75e82789d1aeb2510caa1103cb43d0350a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 14:08:39 +0100 Subject: [PATCH 13/16] fix tests --- tests/brevitas/core/test_int_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/core/test_int_quant.py b/tests/brevitas/core/test_int_quant.py index 5197b5b67..b792ea71f 100644 --- a/tests/brevitas/core/test_int_quant.py +++ b/tests/brevitas/core/test_int_quant.py @@ -31,7 +31,7 @@ def test_int_quant_to_int_called_with( int_quant = IntQuant( narrow_range=narrow_range, signed=signed, - input_view_impl=Identity, + input_view_impl=Identity(), float_to_int_impl=float_to_int_impl, tensor_clamp_impl=tensor_clamp_impl) bit_width = torch.tensor(bit_width_init) @@ -53,7 +53,7 @@ def test_int_quant_arange( zero_point_init, bit_width_init, arange_int_tensor): - int_quant = IntQuant(narrow_range=narrow_range, signed=signed) + int_quant = IntQuant(narrow_range=narrow_range, signed=signed, input_view_impl=Identity()) zero_point = torch.tensor(zero_point_init).float() bit_width = torch.tensor(bit_width_init).float() scale = torch.tensor(standalone_scaling_init).float() From 59b2a44c7856cd4e13291e60c05f663ead5acf20 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 14:09:46 +0100 Subject: [PATCH 14/16] fix last last test --- tests/brevitas/core/test_int_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/core/test_int_quant.py b/tests/brevitas/core/test_int_quant.py index b792ea71f..5e106dc4c 100644 --- a/tests/brevitas/core/test_int_quant.py +++ b/tests/brevitas/core/test_int_quant.py @@ -41,7 +41,7 @@ def test_int_quant_to_int_called_with( output, min_val=int_quant.min_int(bit_width), max_val=int_quant.max_int(bit_width)) def test_int_quant_defaults(self, narrow_range, signed): - int_quant = IntQuant(narrow_range=narrow_range, signed=signed) + int_quant = IntQuant(narrow_range=narrow_range, signed=signed, input_view_impl=Identity()) assert isinstance(int_quant.float_to_int_impl, RoundSte) assert isinstance(int_quant.tensor_clamp_impl, TensorClamp) From 8b88ff4a4f4e470b5801be087f1789e4d7051420 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 14:51:10 +0100 Subject: [PATCH 15/16] fix tests minifloat --- tests/brevitas/core/test_float_quant.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 2d4c829f0..16b8a4b5f 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -9,6 +9,7 @@ from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling @@ -32,6 +33,7 @@ def test_float_quant_defaults(minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), float_clamp_impl=None) else: # init FloatClamp @@ -48,6 +50,7 @@ def test_float_quant_defaults(minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) @@ -73,6 +76,7 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=None) else: @@ -90,6 +94,7 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) @@ -115,6 +120,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=None) @@ -132,6 +138,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) @@ -162,6 +169,7 @@ def test_inner_scale(inp, minifloat_format, scale): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=None) @@ -179,6 +187,7 @@ def test_inner_scale(inp, minifloat_format, scale): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) From 417dfcb6f68bd64fab1d8f6b3ae0097298c414c7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Aug 2024 07:47:36 +0100 Subject: [PATCH 16/16] Fix last 2 tests --- src/brevitas/quant/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index abcf6d4f0..7b6fe409e 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -7,6 +7,7 @@ from brevitas.core.bit_width import BitWidthConst from brevitas.core.bit_width import BitWidthStatefulConst +from brevitas.core.function_wrapper import Identity from brevitas.core.function_wrapper import OverOutputChannelView from brevitas.core.function_wrapper import RoundToZeroSte from brevitas.core.function_wrapper import TensorClamp @@ -294,6 +295,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM stats_reduce_dim = SCALING_STATS_REDUCE_DIM restrict_scaling_impl = FloatRestrictValue scaling_shape = SCALAR_SHAPE + scaling_per_output_type = ScalingPerOutputType.TENSOR + input_view_impl = Identity scaling_impl = ParameterFromStatsFromParameterScaling int_scaling_impl = IntScaling zero_point_impl = ZeroZeroPoint @@ -306,7 +309,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, - SolveParameterScalingShape): + SolveParameterScalingShape, + SolveInputViewImpl): """ Experimental narrow per-channel signed int weight quantizer fragment with decoupled Linf normalization and learned scaling.