diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index fc8308f46..f270c1942 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -10,6 +10,7 @@ import torch.nn as nn from brevitas import torch_version +import brevitas.config as config from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d @@ -21,6 +22,8 @@ from brevitas.nn.quant_mha import QuantMultiheadAttention from brevitas.nn.quant_rnn import QuantLSTM from brevitas.nn.quant_rnn import QuantRNN +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant @@ -36,7 +39,6 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant_tensor import IntQuantTensor -from brevitas.quant_tensor import QuantTensor SEED = 123456 OUT_CH = 16 @@ -61,12 +63,14 @@ 'quant_asym': ShiftedUint8WeightPerTensorFloat, 'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint, 'quant_mx': MXInt8Weight, + 'quant_float': Fp8e4m3WeightPerTensorFloat, **A2Q_WBIOL_WEIGHT_QUANTIZER} WBIOL_IO_QUANTIZER = { 'None': None, 'batch_quant': (Int8ActPerTensorFloatBatchQuant1d, Int8ActPerTensorFloatBatchQuant2d), 'quant_mx': MXInt8Act, + 'quant_float': Fp8e4m3ActPerTensorFloat, 'quant_sym': Int8ActPerTensorFloat, 'quant_asym': ShiftedUint8ActPerTensorFloat} @@ -130,6 +134,8 @@ def build_case_model( io_quantizer == MXInt8Act): pytest.skip("MX requires input and weights quantization to be aligned") elif weight_quantizer == MXInt8Weight: + if config.JIT_ENABLED: + pytest.skip("Dynamic act quant is not compatible with JIT") bias_quantizer = None impl = module.__name__