diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index f8532b71a..1d1e5972a 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -109,6 +109,7 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): + @value def scaling_shape(weight_shape, group_dim, group_size=None, scaling_per_output=None): if scaling_per_output == ScalingPerOutputType.TENSOR: diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 8e7f7b097..225321933 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from brevitas.quant.scaled_int import Int8WeightPerTensorFloat from packaging import version import pytest import pytest_cases @@ -9,12 +8,16 @@ import torch import torch.nn as nn from torchvision import models -from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act, MXInt8Weight, MXFloat8e4m3Act, MXFloat8e4m3Weight from brevitas import torch_version from brevitas.graph.equalize import _cross_layer_equalization import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat SEED = 123456 ATOL = 1e-3 @@ -383,6 +386,7 @@ def forward(self, x): input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)]) + @pytest_cases.fixture def quant_conv_with_input_quant_model(input_quant, weight_quant): @@ -392,7 +396,8 @@ def __init__(self) -> None: super().__init__() self.conv_0 = qnn.QuantConv2d( 3, 16, kernel_size=3) # gpxq tests assume no quant on first layer - self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant) + self.conv_1 = qnn.QuantConv2d( + 16, 32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant) def forward(self, x): x = self.conv_0(x) @@ -430,8 +435,10 @@ class QuantResidualModel(nn.Module): def __init__(self) -> None: super().__init__() - self.conv = qnn.QuantConv2d(3, 16, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant) - self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant) + self.conv = qnn.QuantConv2d( + 3, 16, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant) + self.conv_0 = qnn.QuantConv2d( + 16, 3, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant) self.relu = qnn.QuantReLU(return_quant_tensor=True) def forward(self, x): @@ -453,8 +460,18 @@ class QuantConvTransposeModel(nn.Module): def __init__(self) -> None: super().__init__() self.relu = qnn.QuantReLU(return_quant_tensor=True) - self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant) - self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant) + self.conv_0 = qnn.QuantConvTranspose2d( + in_channels=3, + out_channels=8, + kernel_size=3, + input_quant=input_quant, + weight_quant=weight_quant) + self.conv_1 = qnn.QuantConvTranspose2d( + in_channels=8, + out_channels=32, + kernel_size=3, + input_quant=input_quant, + weight_quant=weight_quant) def forward(self, x): x = self.conv_0(x)