From 6624a1d30d53d2ab39c29b1948d4f7bf0ae0d299 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 11:33:36 +0100 Subject: [PATCH 1/9] Fix MX --- src/brevitas/core/quant/int.py | 8 +++++++- src/brevitas/core/scaling/runtime.py | 4 ++++ src/brevitas/quant/solver/act.py | 4 ++-- src/brevitas/quant/solver/parameter.py | 2 ++ src/brevitas/quant/solver/weight.py | 3 +++ 5 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..9c29ecdd2 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -135,6 +135,7 @@ class RescalingIntQuant(brevitas.jit.ScriptModule): def __init__( self, int_quant: Module, + input_view_impl: Module, scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, @@ -145,6 +146,7 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -153,6 +155,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) + x = self.input_view_impl(x) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -167,7 +170,8 @@ def __init__( int_scaling_impl: Module, pre_zero_point_impl: Module, zero_point_impl: Module, - bit_width_impl: Module): + bit_width_impl: Module, + input_view_impl: Module): super(DecoupledRescalingIntQuant, self).__init__() self.decoupled_int_quant = decoupled_int_quant self.pre_scaling_impl = pre_scaling_impl @@ -176,6 +180,7 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl + self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -184,6 +189,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) + x = self.input_view_impl(x) threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index e4333186d..5728dc5ec 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -178,6 +178,10 @@ def __init__( self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) +<<<<<<< HEAD +======= + +>>>>>>> Fix MX @brevitas.jit.script_method def forward(self, stats_input) -> torch.Tensor: stats_input_reshaped = self.input_view_impl(stats_input) diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index 345239089..b12299f67 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl import torch from torch import nn from torch import Tensor @@ -131,7 +133,6 @@ def update_state_dict_impl(scaling_impl_type): class SolveInputViewImpl(ExtendedInjector): - @value def input_view_impl(scaling_per_output): if scaling_per_output == ScalingPerOutputType.GROUP: @@ -139,7 +140,6 @@ def input_view_impl(scaling_per_output): else: return Identity - class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingImplFromEnum, SolveIntScalingImplFromEnum, diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 97137c567..7b5ecf2af 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -4,6 +4,8 @@ import math from typing import List +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from dependencies import this from dependencies import value import torch diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 7f63fe17e..1e37b0c12 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -8,6 +8,7 @@ from brevitas.inject import this from brevitas.inject import value from brevitas.proxy import WeightQuantProxyFromInjector +from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * from brevitas.quant.solver.parameter import SolveInputViewImpl @@ -69,6 +70,8 @@ def scaling_stats_input_concat_dim(scaling_per_output): return 0 elif scaling_per_output == ScalingPerOutputType.CHANNEL: return 1 + else: + raise RuntimeError("Shared groupwise quantization is not supported") @value def permute_dims(module, output_channel_dim): From cbb4fa0df8c3a5f9e78834e6031d075d01fe66e3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:36:57 +0100 Subject: [PATCH 2/9] Moved one level down --- src/brevitas/core/quant/int.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 9c29ecdd2..ff3e16a5b 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -135,7 +135,6 @@ class RescalingIntQuant(brevitas.jit.ScriptModule): def __init__( self, int_quant: Module, - input_view_impl: Module, scaling_impl: Module, int_scaling_impl: Module, zero_point_impl: Module, @@ -146,7 +145,6 @@ def __init__( self.int_scaling_impl = int_scaling_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl - self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -155,7 +153,6 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: int_threshold = self.int_scaling_impl(bit_width) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) - x = self.input_view_impl(x) y = self.int_quant(scale, zero_point, bit_width, x) return y, scale, zero_point, bit_width @@ -180,7 +177,6 @@ def __init__( self.pre_zero_point_impl = pre_zero_point_impl self.zero_point_impl = zero_point_impl self.msb_clamp_bit_width_impl = bit_width_impl - self.input_view_impl = input_view_impl @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -189,7 +185,6 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - x = self.input_view_impl(x) threshold = self.scaling_impl(x) scale = threshold / int_threshold zero_point = self.zero_point_impl(x, scale, bit_width) From 69a0589cba6fb6945d146dbf0a995c30f437c066 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 12:51:44 +0100 Subject: [PATCH 3/9] precommit and remove raise --- src/brevitas/core/scaling/runtime.py | 4 ---- src/brevitas/quant/solver/act.py | 4 ++-- src/brevitas/quant/solver/parameter.py | 2 -- src/brevitas/quant/solver/weight.py | 3 --- 4 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 5728dc5ec..e4333186d 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -178,10 +178,6 @@ def __init__( self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) -<<<<<<< HEAD -======= - ->>>>>>> Fix MX @brevitas.jit.script_method def forward(self, stats_input) -> torch.Tensor: stats_input_reshaped = self.input_view_impl(stats_input) diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index b12299f67..345239089 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -1,8 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from brevitas.core.function_wrapper.misc import Identity -from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl import torch from torch import nn from torch import Tensor @@ -133,6 +131,7 @@ def update_state_dict_impl(scaling_impl_type): class SolveInputViewImpl(ExtendedInjector): + @value def input_view_impl(scaling_per_output): if scaling_per_output == ScalingPerOutputType.GROUP: @@ -140,6 +139,7 @@ def input_view_impl(scaling_per_output): else: return Identity + class ActQuantSolver(SolveActTensorQuantFromEnum, SolveActScalingImplFromEnum, SolveIntScalingImplFromEnum, diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 7b5ecf2af..97137c567 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -4,8 +4,6 @@ import math from typing import List -from brevitas.core.function_wrapper.misc import Identity -from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from dependencies import this from dependencies import value import torch diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 1e37b0c12..7f63fe17e 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -8,7 +8,6 @@ from brevitas.inject import this from brevitas.inject import value from brevitas.proxy import WeightQuantProxyFromInjector -from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.common import * from brevitas.quant.solver.parameter import * from brevitas.quant.solver.parameter import SolveInputViewImpl @@ -70,8 +69,6 @@ def scaling_stats_input_concat_dim(scaling_per_output): return 0 elif scaling_per_output == ScalingPerOutputType.CHANNEL: return 1 - else: - raise RuntimeError("Shared groupwise quantization is not supported") @value def permute_dims(module, output_channel_dim): From c3bea163b8562a99a3251c8face8343684d0b952 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Aug 2024 13:10:28 +0100 Subject: [PATCH 4/9] remove unused init param --- src/brevitas/core/quant/int.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index ff3e16a5b..cdb75df74 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -167,8 +167,7 @@ def __init__( int_scaling_impl: Module, pre_zero_point_impl: Module, zero_point_impl: Module, - bit_width_impl: Module, - input_view_impl: Module): + bit_width_impl: Module): super(DecoupledRescalingIntQuant, self).__init__() self.decoupled_int_quant = decoupled_int_quant self.pre_scaling_impl = pre_scaling_impl From 8ef9010a4f4d38a0fd0914cfaa523ad64ed89fe1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Aug 2024 11:13:33 +0100 Subject: [PATCH 5/9] Feat (mx): adding padding and transposed support --- notebooks/minifloat_mx_tutorial.ipynb | 82 ++++++++++++++++++++- src/brevitas/core/function_wrapper/shape.py | 19 +++-- src/brevitas/quant/solver/common.py | 4 +- src/brevitas/quant/solver/parameter.py | 29 +++++--- 4 files changed, 115 insertions(+), 19 deletions(-) diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index 60f00fcd4..284a0d4f5 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -104,7 +104,8 @@ "o = ocp_fp8_model(x)\n", "\n", "intermediate_input = ocp_fp8_model.conv.input_quant(x)\n", - "assert isinstance(intermediate_input, FloatQuantTensor)" + "assert isinstance(intermediate_input, FloatQuantTensor)\n", + "assert isinstance(ocp_fp8_model.conv.quant_weight(), FloatQuantTensor)" ] }, { @@ -180,7 +181,84 @@ "o = mx_model(x)\n", "\n", "intermediate_input = mx_model.conv.input_quant(x)\n", - "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)" + "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)\n", + "assert isinstance(mx_model.conv.quant_weight(), GroupwiseFloatQuantTensor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the input channel dimension is not divisible by group size, padding will be applied." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Non padding weights shape torch.Size([64, 8, 3, 3])\n", + "Padded weights shape torch.Size([64, 32, 3, 3])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " return F.conv2d(input, weight, bias, self.stride,\n" + ] + } + ], + "source": [ + "class MXFloat8WeightNoPadding(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n", + " # The group dimension for the weights it is automatically identified based on the layer type\n", + " # If a new layer type is used, it can be manually specified\n", + " group_size = 8\n", + "\n", + "class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", + " # It is necessary to specify the group dimension for the activation quantization\n", + " group_size = 8\n", + " group_dim = 1\n", + "\n", + "\n", + "class MXModelNoPadding(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8WeightNoPadding, input_quant=MXFloat8ActNoPadding)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "class MXModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "mx_model_no_padding = MXModelNoPadding()\n", + "mx_model = MXModel()\n", + "# Make sure that the modules are the same\n", + "mx_model_no_padding.load_state_dict(mx_model.state_dict())\n", + "\n", + "x = torch.randn(1, 8, 8, 8)\n", + "mx_model.eval()\n", + "mx_model_no_padding.eval()\n", + "o_no_padding = mx_model_no_padding(x)\n", + "o = mx_model(x)\n", + "\n", + "# The quant weight of the padded model is different from the non padding one\n", + "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n", + "print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n", + "\n", + "# However, results are still the same \n", + "assert torch.allclose(o, o_no_padding)" ] }, { diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index 1bb77476e..84ee9f355 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -156,17 +156,14 @@ def forward(self, x: torch.Tensor): class OverSubChannelBlockView(brevitas.jit.ScriptModule): __constants__ = ['expanded_scaling_shape'] - def __init__(self, expanded_scaling_shape, permute_dims: Optional[Tuple[int, ...]]) -> None: + def __init__(self, expanded_scaling_shape, padding) -> None: super(OverSubChannelBlockView, self).__init__() self.expanded_scaling_shape = expanded_scaling_shape - if permute_dims is not None: - self.permute_impl = PermuteDims(permute_dims) - else: - self.permute_impl = torch.nn.Identity() + self.padding = padding @brevitas.jit.script_method def forward(self, x: torch.Tensor): - y = self.permute_impl(x) + y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0) y = y.view(self.expanded_scaling_shape) return y @@ -181,6 +178,16 @@ def __init__(self, group_size, group_dim) -> None: @brevitas.jit.script_method def forward(self, x): + + tensor_shape = x.shape + tensor_shape_list = list(tensor_shape) + padding = [0, 0] * len(tensor_shape_list) + if tensor_shape_list[self.group_dim] % self.group_size != 0: + padding[2 * self.group_dim] = self.group_size - tensor_shape_list[ + self.group_dim] % self.group_size + padding = list(reversed(padding)) + x = torch.nn.functional.pad(x, padding, mode='constant', value=0) + tensor_shape = x.shape tensor_shape_list = list(tensor_shape) tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 2847275e8..4d46cc704 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -172,13 +172,13 @@ def int_scaling_impl(restrict_scaling_type): class SolveStatsReduceDimFromEnum(ExtendedInjector): @value - def stats_reduce_dim(scaling_stats_op, scaling_per_output): + def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): if scaling_per_output == ScalingPerOutputType.CHANNEL or scaling_stats_op == StatsOp.MAX_AVE: return SCALING_STATS_REDUCE_DIM elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - return SCALING_STATS_REDUCE_DIM + 1 + return group_dim + 1 @value def keepdim(scaling_per_output): diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 97137c567..198505ec8 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -111,7 +111,7 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(module, group_size=None, scaling_per_output=None): + def scaling_shape(module, input_channel_dim, group_size=None, scaling_per_output=None): if scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE elif scaling_per_output == ScalingPerOutputType.CHANNEL: @@ -119,9 +119,8 @@ def scaling_shape(module, group_size=None, scaling_per_output=None): elif scaling_per_output == ScalingPerOutputType.GROUP: assert group_size is not None, "Per Group scaling requires group size" size = list(module.weight.shape) - assert size[1] % group_size == 0, 'Input channel is not divisible by group size' - size[1] = size[1] // group_size - size.insert(2, 1) + size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size + size.insert(input_channel_dim + 1, 1) return size @value @@ -129,18 +128,30 @@ def reshaped_scaling_shape(module): return module.weight.shape @value - def expanded_scaling_shape(module, group_size=None): + def expanded_scaling_shape(module, input_channel_dim, group_size=None): assert group_size is not None, "Per Group scaling requires group size" size = list(module.weight.shape) - assert size[1] % group_size == 0, 'Input channel is not divisible by group size' - size[1] = size[1] // group_size - size.insert(2, group_size) + size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size + size.insert(input_channel_dim + 1, group_size) return size + @value + def input_channel_dim(module): + return 1 if not module.transposed else 0 + + @value + def padding(module, input_channel_dim, group_size): + padding = [0, 0] * len(module.weight.shape) + size = list(module.weight.shape) + if size[input_channel_dim] % group_size != 0: + padding[2 * input_channel_dim] = group_size - size[input_channel_dim] % group_size + padding = list(reversed(padding)) + return padding + @value def group_dim(module, group_size=None): if group_size is not None: - return 1 + return 1 if not module.transposed else 0 class SolveInputViewImpl(ExtendedInjector): From b4c8cf5ff84260fece4da298db1e4e3172b25916 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Aug 2024 11:35:00 +0100 Subject: [PATCH 6/9] fix transposed --- src/brevitas/quant/solver/parameter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 198505ec8..67c3a56c1 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -137,7 +137,7 @@ def expanded_scaling_shape(module, input_channel_dim, group_size=None): @value def input_channel_dim(module): - return 1 if not module.transposed else 0 + return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 @value def padding(module, input_channel_dim, group_size): @@ -151,7 +151,7 @@ def padding(module, input_channel_dim, group_size): @value def group_dim(module, group_size=None): if group_size is not None: - return 1 if not module.transposed else 0 + return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 class SolveInputViewImpl(ExtendedInjector): From 7ec08ed647ec8c86a8ac7b53f3bc5b047206428e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Aug 2024 11:48:49 +0100 Subject: [PATCH 7/9] simplify --- src/brevitas/quant/solver/parameter.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 67c3a56c1..cc36cc72d 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -128,23 +128,19 @@ def reshaped_scaling_shape(module): return module.weight.shape @value - def expanded_scaling_shape(module, input_channel_dim, group_size=None): + def expanded_scaling_shape(module, group_dim, group_size=None): assert group_size is not None, "Per Group scaling requires group size" size = list(module.weight.shape) - size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size - size.insert(input_channel_dim + 1, group_size) + size[group_dim] = (size[group_dim] + group_size - 1) // group_size + size.insert(group_dim + 1, group_size) return size @value - def input_channel_dim(module): - return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 - - @value - def padding(module, input_channel_dim, group_size): + def padding(module, group_dim, group_size): padding = [0, 0] * len(module.weight.shape) size = list(module.weight.shape) - if size[input_channel_dim] % group_size != 0: - padding[2 * input_channel_dim] = group_size - size[input_channel_dim] % group_size + if size[group_dim] % group_size != 0: + padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding From a85c733a925ecb09a8071d23a6a7f061744b7ef6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Aug 2024 11:53:19 +0100 Subject: [PATCH 8/9] fix simplify --- src/brevitas/quant/solver/parameter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index cc36cc72d..d4befc1c9 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -111,16 +111,17 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(module, input_channel_dim, group_size=None, scaling_per_output=None): + def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None): if scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE elif scaling_per_output == ScalingPerOutputType.CHANNEL: return this.scaling_per_output_channel_shape elif scaling_per_output == ScalingPerOutputType.GROUP: assert group_size is not None, "Per Group scaling requires group size" + assert group_dim is not None, "Per Group scaling requires group dim" size = list(module.weight.shape) - size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size - size.insert(input_channel_dim + 1, 1) + size[group_dim] = (size[group_dim] + group_size - 1) // group_size + size.insert(group_dim + 1, 1) return size @value From 47a66dbc9c63e63137bf84eaf5390ffbf6dfd525 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 26 Aug 2024 18:14:32 +0200 Subject: [PATCH 9/9] Update parameter.py --- src/brevitas/quant/solver/parameter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index d4befc1c9..76d3f6f3e 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -141,7 +141,10 @@ def padding(module, group_dim, group_size): padding = [0, 0] * len(module.weight.shape) size = list(module.weight.shape) if size[group_dim] % group_size != 0: + # Padding is done on the left side padding[2 * group_dim] = group_size - size[group_dim] % group_size + # Padding takes a list of 2 values per dim in reverse order (N_DIM, N_DIM-1,...,0) + # so we need to reverse the order padding = list(reversed(padding)) return padding