diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 93aae0434..b863bf618 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -18,8 +18,6 @@ from brevitas.nn import QuantConvTranspose2d from brevitas.nn import QuantConvTranspose3d from brevitas.nn import QuantLinear -from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat -from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint @@ -64,8 +62,7 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': (Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat), - 'fp8_ocp_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat), - 'fp8_fnuz_per_tensor_float': (Fp8e4m3FNUZWeightPerTensorFloat, Fp8e4m3FNUZActPerTensorFloat)} + 'fp8_ocp_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), diff --git a/tests/brevitas_ort/quant_module_cases.py b/tests/brevitas_ort/quant_module_cases.py index 30838f914..88022aaa6 100644 --- a/tests/brevitas_ort/quant_module_cases.py +++ b/tests/brevitas_ort/quant_module_cases.py @@ -27,7 +27,7 @@ def case_quant_wbiol( set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol) weight_quant, io_quant = quantizers - is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat or weight_quant == Fp8e4m3FNUZWeightPerTensorFloat + is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat if is_fp8: if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8: pytest.skip('FP8 export requires total bitwidth equal to 8')