diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py new file mode 100644 index 000000000..4adb9a99b --- /dev/null +++ b/tests/brevitas/proxy/test_proxy.py @@ -0,0 +1,65 @@ +import pytest + +from brevitas.nn import QuantLinear +from brevitas.nn.quant_activation import QuantReLU +from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatDecoupled +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat + + +class TestProxy: + + def test_bias_proxy(self): + model = QuantLinear(10, 5, bias_quant=Int8BiasPerTensorFloatInternalScaling) + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + def test_weight_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat) + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + def test_weight_decoupled_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerChannelFloatDecoupled) + assert model.weight_quant.pre_scale() is not None + assert model.weight_quant.pre_zero_point() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.pre_scale() is None + assert model.weight_quant.pre_zero_point() is None + + def test_weight_decoupled_with_input_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8AccumulatorAwareWeightQuant) + with pytest.raises(NotImplementedError): + model.weight_quant.scale() + with pytest.raises(NotImplementedError): + model.weight_quant.zero_point() + + with pytest.raises(NotImplementedError): + model.weight_quant.pre_scale() + with pytest.raises(NotImplementedError): + model.weight_quant.pre_zero_point() + + def test_act_proxy(self): + model = QuantReLU() + assert model.act_quant.scale() is not None + assert model.act_quant.zero_point() is not None + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.scale() is None + assert model.act_quant.zero_point() is None + assert model.act_quant.bit_width() is None diff --git a/tests/brevitas/proxy/test_weight_scaling.py b/tests/brevitas/proxy/test_weight_scaling.py index 074ca7c61..49a7f20fe 100644 --- a/tests/brevitas/proxy/test_weight_scaling.py +++ b/tests/brevitas/proxy/test_weight_scaling.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import pytest from torch import nn from brevitas import config