From 0dfed16827697589b4105c6c2a840c7b2a9d990a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 4 Sep 2024 15:53:00 +0200 Subject: [PATCH] Feat (mx): automatic group_dim in layerwise quant (#1012) --- notebooks/minifloat_mx_tutorial.ipynb | 18 ++++++++--------- src/brevitas/quant/solver/act.py | 29 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index 284a0d4f5..bd43880de 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -163,8 +163,8 @@ " pass\n", "\n", "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", - " group_dim = 1\n", + " # In layerwise quantization, groupdim is automatically determined\n", + " pass\n", "\n", "\n", "class MXModel(nn.Module):\n", @@ -221,9 +221,8 @@ " group_size = 8\n", "\n", "class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", + " # In layerwise quantization, groupdim is automatically determined\n", " group_size = 8\n", - " group_dim = 1\n", "\n", "\n", "class MXModelNoPadding(nn.Module):\n", @@ -277,8 +276,8 @@ " pass\n", "\n", "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", - " group_dim = 1\n", + " # In layerwise quantization, groupdim is automatically determined\n", + " pass\n", "\n", "\n", "class MXModel(nn.Module):\n", @@ -314,12 +313,11 @@ "class MXInt8Weight(MXInt8Weight):\n", " # The group dimension for the weights it is automatically identified based on the layer type\n", " # If a new layer type is used, it can be manually specified\n", - " bit_width = 8\n", + " pass\n", "\n", "class MXInt8Act(MXInt8Act):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", - " group_dim = 1\n", - " bit_width = 8\n", + " # In layerwise quantization, groupdim is automatically determined\n", + " pass\n", "\n", "class MXModel(nn.Module):\n", " def __init__(self):\n", diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index 345239089..35f771c54 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from warnings import warn + import torch from torch import nn from torch import Tensor @@ -111,6 +113,33 @@ def scaling_shape(scaling_per_output): elif scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE + @value + def group_dim(module=None, group_size=None): + # Avoid circular import + from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer + + if group_size is not None and module is not None: + if isinstance(module, QuantWeightBiasInputOutputLayer): + if isinstance(module, nn.Linear): + return -1 + elif isinstance(module, + (nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d)): + warn( + "Group dim is being selected assuming batched input. Using unbatched input will fail and requires manually specification of group_dim" + ) + # We are assuming batched input + return 1 + else: + raise RuntimeError("Cannot determine automatically group_dim. Please specify") + else: + raise RuntimeError( + f"Cannot determine automatically group_dim for {type(module)}. Please specify") + class SolveActScalingPerOutputChannelShape(ExtendedInjector):