diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index da3b06d24..4e3401057 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause from enum import Enum +from packaging import version import pytest import torch -from brevitas.inject.enum import QuantType +from brevitas import torch_version from brevitas.nn import QuantIdentity from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat from brevitas.quant_tensor import FloatQuantTensor from brevitas.quant_tensor import IntQuantTensor -from brevitas.quant_tensor import QuantTensor class Operator(Enum): @@ -49,6 +49,10 @@ def test_quant_tensor_init(): 'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL]) @pytest.mark.parametrize('quant_fn', [to_quant_tensor, to_float_quant_tensor]) def test_quant_tensor_operators(op, quant_fn): + + if quant_fn == to_float_quant_tensor and torch_version < version.parse('1.13'): + pytest.skip("Torch 1.13 is required for JIT to be compatible with FloatQuantTensor") + # Avoid 0 values x = 1 + torch.rand(4, 4)