diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py new file mode 100644 index 000000000..08b525a71 --- /dev/null +++ b/tests/brevitas/proxy/test_proxy.py @@ -0,0 +1,82 @@ +import pytest + +from brevitas.nn import QuantLinear +from brevitas.nn.quant_activation import QuantReLU +from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatDecoupled +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat +from tests.marker import jit_disabled_for_dynamic_quant_act + + +class TestProxy: + + def test_bias_proxy(self): + model = QuantLinear(10, 5, bias_quant=Int8BiasPerTensorFloatInternalScaling) + assert model.bias_quant.scale() is not None + assert model.bias_quant.zero_point() is not None + assert model.bias_quant.bit_width() is not None + + model.bias_quant.disable_quant = True + assert model.bias_quant.scale() is None + assert model.bias_quant.zero_point() is None + assert model.bias_quant.bit_width() is None + + def test_weight_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat) + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + def test_weight_decoupled_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerChannelFloatDecoupled) + assert model.weight_quant.pre_scale() is not None + assert model.weight_quant.pre_zero_point() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.pre_scale() is None + assert model.weight_quant.pre_zero_point() is None + + def test_weight_decoupled_with_input_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8AccumulatorAwareWeightQuant) + with pytest.raises(NotImplementedError): + model.weight_quant.scale() + with pytest.raises(NotImplementedError): + model.weight_quant.zero_point() + + with pytest.raises(NotImplementedError): + model.weight_quant.pre_scale() + with pytest.raises(NotImplementedError): + model.weight_quant.pre_zero_point() + + def test_act_proxy(self): + model = QuantReLU() + assert model.act_quant.scale() is not None + assert model.act_quant.zero_point() is not None + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.scale() is None + assert model.act_quant.zero_point() is None + assert model.act_quant.bit_width() is None + + @jit_disabled_for_dynamic_quant_act() + def test_dynamic_act_proxy(self): + model = QuantReLU(Int8DynamicActPerTensorFloat) + + with pytest.raises(RuntimeError, match="Scale for Dynamic Act Quant is input-dependant"): + model.act_quant.scale() + with pytest.raises(RuntimeError, + match="Zero point for Dynamic Act Quant is input-dependant"): + model.act_quant.zero_point() + + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.bit_width() is None diff --git a/tests/brevitas/proxy/test_weight_scaling.py b/tests/brevitas/proxy/test_weight_scaling.py index 074ca7c61..49a7f20fe 100644 --- a/tests/brevitas/proxy/test_weight_scaling.py +++ b/tests/brevitas/proxy/test_weight_scaling.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import pytest from torch import nn from brevitas import config diff --git a/tests/marker.py b/tests/marker.py index 59e76a7d2..f11dc7a4a 100644 --- a/tests/marker.py +++ b/tests/marker.py @@ -50,5 +50,14 @@ def skip_wrapper(f): return skip_wrapper +def jit_disabled_for_dynamic_quant_act(): + skip = config.JIT_ENABLED + + def skip_wrapper(f): + return pytest.mark.skipif(skip, reason=f'Dynamic Act Quant requires JIT to be disabled')(f) + + return skip_wrapper + + skip_on_macos_nox = pytest.mark.skipif( platform.system() == "Darwin", reason="Known issue with Nox and MacOS.")