Skip to content

Commit

Permalink
precommit fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 28, 2024
1 parent ad54986 commit d45a591
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def scaling_impl(scaling_impl_type):


class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(weight_shape, group_dim, group_size=None, scaling_per_output=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
Expand Down
31 changes: 24 additions & 7 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from packaging import version
import pytest
import pytest_cases
from pytest_cases import fixture_union
import torch
import torch.nn as nn
from torchvision import models
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act, MXInt8Weight, MXFloat8e4m3Act, MXFloat8e4m3Weight

from brevitas import torch_version
from brevitas.graph.equalize import _cross_layer_equalization
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat

SEED = 123456
ATOL = 1e-3
Expand Down Expand Up @@ -383,6 +386,7 @@ def forward(self, x):

input_quant, weight_quant = pytest_cases.param_fixtures("input_quant, weight_quant", [(Int8ActPerTensorFloat, Int8WeightPerTensorFloat), (MXInt8Act, MXInt8Weight), (MXFloat8e4m3Act, MXFloat8e4m3Weight)])


@pytest_cases.fixture
def quant_conv_with_input_quant_model(input_quant, weight_quant):

Expand All @@ -392,7 +396,8 @@ def __init__(self) -> None:
super().__init__()
self.conv_0 = qnn.QuantConv2d(
3, 16, kernel_size=3) # gpxq tests assume no quant on first layer
self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)
self.conv_1 = qnn.QuantConv2d(
16, 32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)

def forward(self, x):
x = self.conv_0(x)
Expand Down Expand Up @@ -430,8 +435,10 @@ class QuantResidualModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant)
self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant)
self.conv = qnn.QuantConv2d(
3, 16, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant)
self.conv_0 = qnn.QuantConv2d(
16, 3, kernel_size=1, input_quant=input_quant, weight_quant=weight_quant)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
Expand All @@ -453,8 +460,18 @@ class QuantConvTransposeModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = qnn.QuantReLU(return_quant_tensor=True)
self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)
self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3, input_quant=input_quant, weight_quant=weight_quant)
self.conv_0 = qnn.QuantConvTranspose2d(
in_channels=3,
out_channels=8,
kernel_size=3,
input_quant=input_quant,
weight_quant=weight_quant)
self.conv_1 = qnn.QuantConvTranspose2d(
in_channels=8,
out_channels=32,
kernel_size=3,
input_quant=input_quant,
weight_quant=weight_quant)

def forward(self, x):
x = self.conv_0(x)
Expand Down

0 comments on commit d45a591

Please sign in to comment.