Skip to content

Commit

Permalink
fix tests and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 31, 2024
1 parent f4e0ed8 commit 1426481
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn

from brevitas import torch_version
import brevitas.config as config
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand All @@ -21,6 +22,8 @@
from brevitas.nn.quant_mha import QuantMultiheadAttention
from brevitas.nn.quant_rnn import QuantLSTM
from brevitas.nn.quant_rnn import QuantRNN
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
Expand All @@ -36,7 +39,6 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor

SEED = 123456
OUT_CH = 16
Expand All @@ -61,12 +63,14 @@
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_mx': MXInt8Weight,
'quant_float': Fp8e4m3WeightPerTensorFloat,
**A2Q_WBIOL_WEIGHT_QUANTIZER}

WBIOL_IO_QUANTIZER = {
'None': None,
'batch_quant': (Int8ActPerTensorFloatBatchQuant1d, Int8ActPerTensorFloatBatchQuant2d),
'quant_mx': MXInt8Act,
'quant_float': Fp8e4m3ActPerTensorFloat,
'quant_sym': Int8ActPerTensorFloat,
'quant_asym': ShiftedUint8ActPerTensorFloat}

Expand Down Expand Up @@ -130,6 +134,8 @@ def build_case_model(
io_quantizer == MXInt8Act):
pytest.skip("MX requires input and weights quantization to be aligned")
elif weight_quantizer == MXInt8Weight:
if config.JIT_ENABLED:
pytest.skip("Dynamic act quant is not compatible with JIT")
bias_quantizer = None

impl = module.__name__
Expand Down

0 comments on commit 1426481

Please sign in to comment.