Skip to content

Commit

Permalink
Fix MX
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 22, 2024
1 parent 6733ba2 commit 8180de6
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 31 deletions.
6 changes: 3 additions & 3 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,20 @@
"import brevitas.nn as qnn\n",
"import torch\n",
"\n",
"class MXFloat8Weight(MXInt8Weight):\n",
"class MXInt8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
"\n",
"class MXFloat8Act(MXInt8Act):\n",
"class MXInt8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXInt8Weight, input_quant=MXInt8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
Expand Down
20 changes: 20 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ def forward(self, x: torch.Tensor):
return y


class DynamicOverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['group_size', 'group_dim']

def __init__(self, group_size, group_dim) -> None:
super(DynamicOverSubChannelBlockView, self).__init__()
self.group_size = group_size
self.group_dim = group_dim

@brevitas.jit.script_method
def forward(self, x):
tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x


class StatsInputViewShapeImpl(object):
"""
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor.
Expand All @@ -182,3 +201,4 @@ class StatsInputViewShapeImpl(object):
OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView
OVER_OUTPUT_FEATURES = OverOutputFeaturesView
OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView
DYNAMIC_OVER_SUBCHANNEL_BLOCK = DynamicOverSubChannelBlockView
4 changes: 3 additions & 1 deletion src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
mantissa_bit_width: int,
exponent_bias: int,
float_clamp_impl: nn.Module,
input_view_impl: nn.Module,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
if scaling_impl is None:
scaling_impl = ConstScaling(1., device=device, dtype=dtype)

self.input_view_impl = input_view_impl
# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
Expand All @@ -71,7 +73,7 @@ def quantize(self, x: torch.Tensor):
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value

x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
Expand Down
8 changes: 7 additions & 1 deletion src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class RescalingIntQuant(brevitas.jit.ScriptModule):
def __init__(
self,
int_quant: Module,
input_view_impl: Module,
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
Expand All @@ -145,6 +146,7 @@ def __init__(
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -153,6 +155,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
x = self.input_view_impl(x)
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width

Expand All @@ -167,7 +170,8 @@ def __init__(
int_scaling_impl: Module,
pre_zero_point_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
input_view_impl: Module):
super(DecoupledRescalingIntQuant, self).__init__()
self.decoupled_int_quant = decoupled_int_quant
self.pre_scaling_impl = pre_scaling_impl
Expand All @@ -176,6 +180,7 @@ def __init__(
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -184,6 +189,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
pre_threshold = self.pre_scaling_impl(x)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
x = self.input_view_impl(x)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
Expand Down
13 changes: 3 additions & 10 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: torch.nn.Module,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
Expand All @@ -174,21 +175,13 @@ def __init__(
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def group_scaling_reshape(self, stats_input):
tensor_shape = stats_input.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
stats_input = stats_input.view(tensor_shape_list)
return stats_input

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
stats_input_reshaped = self.group_scaling_reshape(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
# Scaling min val
out = self.restrict_clamp_scaling(out)
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@

class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Is this always generated?
self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand All @@ -25,7 +20,6 @@ def group_size(self):
def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return GroupwiseFloatQuantTensor(
out,
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def group_size(self):
def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, bit_width = impl(x)
return GroupwiseIntQuantTensor(
out,
Expand Down
13 changes: 12 additions & 1 deletion src/brevitas/quant/solver/act.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
import torch
from torch import nn
from torch import Tensor
Expand Down Expand Up @@ -128,6 +130,14 @@ def update_state_dict_impl(scaling_impl_type):
return None


class SolveInputViewImpl(ExtendedInjector):
@value
def input_view_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.DYNAMIC_OVER_SUBCHANNEL_BLOCK
else:
return Identity

class ActQuantSolver(SolveActTensorQuantFromEnum,
SolveActScalingImplFromEnum,
SolveIntScalingImplFromEnum,
Expand All @@ -140,7 +150,8 @@ class ActQuantSolver(SolveActTensorQuantFromEnum,
SolveActScalingShape,
SolveScalingStatsInputViewShapeImplFromEnum,
SolveActScalingPerOutputChannelShape,
SolveUpdateStateDictImplFromEnum):
SolveUpdateStateDictImplFromEnum,
SolveInputViewImpl):
"""
Translate enum directives to activation-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
10 changes: 10 additions & 0 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import math
from typing import List

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
from dependencies import this
from dependencies import value
import torch
Expand Down Expand Up @@ -139,3 +141,11 @@ def expanded_scaling_shape(module, group_size=None):
def group_dim(module, group_size=None):
if group_size is not None:
return 1

class SolveInputViewImpl(ExtendedInjector):
@value
def input_view_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK
else:
return Identity
6 changes: 5 additions & 1 deletion src/brevitas/quant/solver/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brevitas.inject import this
from brevitas.inject import value
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.quant.solver.parameter import SolveInputViewImpl
from brevitas.quant.solver.common import *
from brevitas.quant.solver.parameter import *

Expand Down Expand Up @@ -68,6 +69,8 @@ def scaling_stats_input_concat_dim(scaling_per_output):
return 0
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return 1
else:
raise RuntimeError("Shared groupwise quantization is not supported")

@value
def permute_dims(module, output_channel_dim):
Expand Down Expand Up @@ -103,7 +106,8 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum,
SolveParameterScalingShape,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveWeightTensorQuantFromEnum,
SolveDtypeDeviceFromTrackedParameterList):
SolveDtypeDeviceFromTrackedParameterList,
SolveInputViewImpl):
"""
Translate enum and shape directives to weight-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
8 changes: 5 additions & 3 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None):

def expand(self):
curr_shape = self.value_.shape
new_value = self.value_.flatten(self.group_dim, self.group_dim + 1)
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

Expand Down
7 changes: 4 additions & 3 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def __torch_function__(self, func, types, args=(), kwargs=None):

def expand(self):
curr_shape = self.value_.shape
new_value = self.value_.flatten(self.group_dim, self.group_dim + 1)
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

Expand Down

0 comments on commit 8180de6

Please sign in to comment.