Skip to content

Commit

Permalink
GPTQ fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 28, 2024
1 parent b9ba558 commit e2cd495
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 28 deletions.
17 changes: 13 additions & 4 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,25 @@ def forward(self, x: torch.Tensor):


class OverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape']
__constants__ = ['expanded_scaling_shape', 'group_size', 'group_dim']

def __init__(self, expanded_scaling_shape, padding) -> None:
def __init__(self, expanded_scaling_shape, group_size, group_dim) -> None:
super(OverSubChannelBlockView, self).__init__()
self.expanded_scaling_shape = expanded_scaling_shape
self.padding = padding
self.group_dim = group_dim
self.group_size = group_size

def padding(self, x):
padding = [0, 0] * len(x.shape)
size = x.shape
if size[self.group_dim] % self.group_size != 0:
padding[2 * self.group_dim] = self.group_size - size[self.group_dim] % self.group_size
padding = list(reversed(padding))
return padding

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0)
y = torch.nn.functional.pad(x, self.padding(x), mode='constant', value=0)
y = y.view(self.expanded_scaling_shape)
return y

Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
'get_output_channels',
'get_output_channel_dim']

CONV_TRANSPOSED = [
CONV_TRANSPOSED = (
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d]
qnn.QuantConvTranspose3d)


def module_class_name(m: torch.nn.Module):
Expand Down Expand Up @@ -146,7 +146,7 @@ def matches_module_pattern(pattern: Iterable, node: Node, modules: Dict[str, Any


def is_conv_transposed(module):
return isinstance(module, tuple(CONV_TRANSPOSED))
return isinstance(module, CONV_TRANSPOSED)


def get_output_channel_dim(module):
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@
from brevitas.core.utils import StatelessBuffer
from brevitas.inject import BaseInjector as Injector
from brevitas.utils.quant_utils import float_to_int_impl_to_enum
from brevitas.core.scaling import ScalingPerOutputType

__all__ = [
'QuantProxyProtocol',
'QuantProxyFromInjector',]


def _is_groupwise(quant_injector):
if 'group_size' in quant_injector:
return True
else:
return False
return 'scaling_per_output' in quant_injector and quant_injector.scaling_per_output == ScalingPerOutputType.GROUP


def _is_narrow_range(quant_injector):
Expand Down
26 changes: 17 additions & 9 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,40 @@ def scaling_impl(scaling_impl_type):


class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None):
def scaling_shape(weight_shape, group_dim, group_size=None, scaling_per_output=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return this.scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
assert group_size is not None, "Per Group scaling requires group size"
assert group_dim is not None, "Per Group scaling requires group dim"
size = list(module.weight.shape)
size = list(weight_shape)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, 1)
return size
return tuple(size)

@value
def reshaped_scaling_shape(module):
return module.weight.shape
def reshaped_scaling_shape(expanded_scaling_shape, group_dim, group_size):
new_shape = list(expanded_scaling_shape)
del new_shape[group_dim + 1] # delete the group_size shape
# Expand the group_dim shape, accounting for padding
new_shape[group_dim] = new_shape[group_dim] * group_size
return new_shape

@value
def expanded_scaling_shape(module, group_dim, group_size=None):
def expanded_scaling_shape(weight_shape, group_dim, group_size=None):
assert group_size is not None, "Per Group scaling requires group size"
size = list(module.weight.shape)
size = list(weight_shape)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, group_size)
return size
return tuple(size)

@value
def weight_shape(tracked_parameter_list):
module = tracked_parameter_list[0]
return tuple(module.shape)

@value
def padding(module, group_dim, group_size):
Expand Down
20 changes: 12 additions & 8 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from packaging import version
import pytest
import pytest_cases
from pytest_cases import fixture_union
import torch
import torch.nn as nn
from torchvision import models
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act, MXInt8Weight, MXFloat8e4m3Act, MXFloat8e4m3Weight

from brevitas import torch_version
from brevitas.graph.equalize import _cross_layer_equalization
Expand Down Expand Up @@ -379,16 +381,18 @@ def forward(self, x):
[('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')],]


input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)])

@pytest_cases.fixture
def quant_conv_with_input_quant_model():
def quant_conv_with_input_quant_model(input_quant, weight_quant):

class QuantConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv_0 = qnn.QuantConv2d(
3, 16, kernel_size=3) # gpxq tests assume no quant on first layer
self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat)
self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)

def forward(self, x):
x = self.conv_0(x)
Expand Down Expand Up @@ -420,14 +424,14 @@ def forward(self, x):


@pytest_cases.fixture
def quant_residual_model():
def quant_residual_model(input_quant, weight_quant):

class QuantResidualModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=1)
self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1)
self.conv = qnn.QuantConv2d(3, 16, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant)
self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
Expand All @@ -442,15 +446,15 @@ def forward(self, x):


@pytest_cases.fixture
def quant_convtranspose_model():
def quant_convtranspose_model(input_quant, weight_quant):

class QuantConvTransposeModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = qnn.QuantReLU(return_quant_tensor=True)
self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)
self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)
self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)

def forward(self, x):
x = self.conv_0(x)
Expand Down

0 comments on commit e2cd495

Please sign in to comment.