From 45b5de5bbf36468f8606ebe0beecf28fb7bdf072 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 3 Nov 2023 19:37:57 +0000 Subject: [PATCH] Restore tests --- src/brevitas/nn/quant_layer.py | 4 +- tests/brevitas/nn/test_nn_quantizers.py | 332 ++++++++++++------------ 2 files changed, 170 insertions(+), 166 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 000384f32..a71583f22 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -355,7 +355,9 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe (quant_bias_scale is None or (quant_bias_scale is not None and quant_bias_scale.data_ptr() != output_scale.data_ptr()))): - output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1) + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) output_zero_point = -quant_bias_value.view( output_scale_broadcast_shape) / output_scale diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 2113b3350..27f400308 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -30,171 +30,173 @@ def parse_args(args): return kwargs -# @pytest_cases.parametrize_with_cases('model_input', cases=case_model) -# def test_quant_wbiol(model_input, current_cases): -# model, input = model_input - -# cases_generator_func = current_cases['model_input'][1] -# case_id = get_case_id(cases_generator_func) -# args = case_id.split('-')[1:] # Exclude first argument -# kwargs = parse_args(args) - -# is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] - -# if (not is_input_quanttensor or kwargs['weight_quant'] is None) and kwargs['return_quant_tensor']: -# with pytest.raises(RuntimeError, -# match='QuantLayer is not correctly configured'): -# output = model(input) -# return -# elif (not is_input_quanttensor or -# kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': -# with pytest.raises(RuntimeError, match='Input scale required'): -# output = model(input) -# return -# elif kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor'] and kwargs['io_quant'] is None \ -# and kwargs['input_quantized']: -# with pytest.raises(RuntimeError, -# match='Computing zero point of output accumulator not supported yet.'): -# output = model(input) -# return -# else: -# output = model(input) - -# if kwargs['return_quant_tensor']: -# assert isinstance(output, QuantTensor) -# # Empty QuantTensor -# if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ -# kwargs['io_quant'] is None: -# assert output.scale is None -# assert output.bit_width is None -# else: # "Full" QuantTensor -# assert output.scale is not None -# assert output.bit_width is not None -# else: -# assert isinstance(output, torch.Tensor) - -# @pytest_cases.parametrize_with_cases( -# 'model_input', cases=[case_quant_lstm_full, case_quant_rnn_full]) -# def test_quant_lstm_rnn_full(model_input, current_cases): -# model, input = model_input - -# cases_generator_func = current_cases['model_input'][1] -# case_id = get_case_id(cases_generator_func) -# args = case_id.split('-') -# kwargs = parse_args(args) - -# is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] - -# if (kwargs['bias_quant'] == 'quant_external') and ( \ -# (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ -# (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))): -# with pytest.raises(RuntimeError, match='Input scale required'): -# output = model(input) -# return -# else: -# output = model(input) -# if len(output) == 1: -# output = output[0] -# h, c = None, None -# elif len(output) == 2: -# if 'quant_lstm' in args[0]: -# output, (h, c) = output -# else: -# output, h = output -# c = None -# return_quant_tensor = kwargs['return_quant_tensor'] - -# if return_quant_tensor: -# assert isinstance(output, QuantTensor) -# # Empty QuantTensor -# if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ -# kwargs['io_quant'] is None: -# assert output.scale is None -# assert output.bit_width is None -# else: # "Full" QuantTensor -# assert output.scale is not None -# assert output.bit_width is not None -# else: -# assert isinstance(output, torch.Tensor) - -# if h is not None: -# if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None: -# assert isinstance(h, QuantTensor) -# else: -# assert isinstance(h, torch.Tensor) - -# if c is not None: -# if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']: -# if not kwargs['bidirectional']: -# if not kwargs['return_quant_tensor'] and kwargs['num_layers'] == 1: -# assert isinstance(c, torch.Tensor) -# else: -# if kwargs['num_layers'] == 2 and kwargs['signed_act'] is None: -# assert isinstance(c, torch.Tensor) -# else: -# assert isinstance(c, QuantTensor) -# else: -# if kwargs['num_layers'] == 2 and kwargs['signed_act'] is not None: -# assert isinstance(c, QuantTensor) -# else: -# assert isinstance(c, torch.Tensor) -# else: -# assert isinstance(c, QuantTensor) - -# @pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn]) -# def test_quant_lstm_rnn(model_input, current_cases): -# model, input = model_input - -# cases_generator_func = current_cases['model_input'][1] -# case_id = get_case_id(cases_generator_func) -# args = case_id.split('-') -# kwargs = parse_args(args) - -# is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] - -# if (kwargs['bias_quant'] == 'quant_external') and ( \ -# (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ -# (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))): -# with pytest.raises(RuntimeError, match='Input scale required'): -# output = model(input) -# return -# else: -# output = model(input) -# if len(output) == 1: -# output = output[0] -# h, c = None, None -# elif len(output) == 2: -# if args[0] == 'quant_lstm': -# output, (h, c) = output -# else: -# output, h = output -# c = None -# return_quant_tensor = kwargs['return_quant_tensor'] and kwargs['io_quant'] is not None - -# if return_quant_tensor: -# assert isinstance(output, QuantTensor) -# # Empty QuantTensor -# if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ -# kwargs['io_quant'] is None: -# assert output.scale is None -# assert output.bit_width is None -# else: # "Full" QuantTensor -# assert output.scale is not None -# assert output.bit_width is not None -# else: -# assert isinstance(output, torch.Tensor) - -# if h is not None: -# if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None: -# assert isinstance(h, QuantTensor) -# else: -# assert isinstance(h, torch.Tensor) - -# if c is not None: -# if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None: -# assert isinstance(c, QuantTensor) -# else: -# assert isinstance(c, torch.Tensor) +@pytest_cases.parametrize_with_cases('model_input', cases=case_model) +def test_quant_wbiol(model_input, current_cases): + model, input = model_input + + cases_generator_func = current_cases['model_input'][1] + case_id = get_case_id(cases_generator_func) + args = case_id.split('-')[1:] # Exclude first argument + kwargs = parse_args(args) + + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + + if (not is_input_quanttensor or + kwargs['weight_quant'] is None) and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output = model(input) + return + elif (not is_input_quanttensor or + kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': + with pytest.raises(RuntimeError, match='Input scale required'): + output = model(input) + return + elif kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor'] and kwargs['io_quant'] is None \ + and kwargs['input_quantized']: + with pytest.raises(RuntimeError, + match='Computing zero point of output accumulator not supported yet.'): + output = model(input) + return + else: + output = model(input) + + if kwargs['return_quant_tensor']: + assert isinstance(output, QuantTensor) + # Empty QuantTensor + if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ + kwargs['io_quant'] is None: + assert output.scale is None + assert output.bit_width is None + else: # "Full" QuantTensor + assert output.scale is not None + assert output.bit_width is not None + else: + assert isinstance(output, torch.Tensor) + + +@pytest_cases.parametrize_with_cases( + 'model_input', cases=[case_quant_lstm_full, case_quant_rnn_full]) +def test_quant_lstm_rnn_full(model_input, current_cases): + model, input = model_input + + cases_generator_func = current_cases['model_input'][1] + case_id = get_case_id(cases_generator_func) + args = case_id.split('-') + kwargs = parse_args(args) + + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + + if (kwargs['bias_quant'] == 'quant_external') and ( \ + (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ + (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))): + with pytest.raises(RuntimeError, match='Input scale required'): + output = model(input) + return + else: + output = model(input) + if len(output) == 1: + output = output[0] + h, c = None, None + elif len(output) == 2: + if 'quant_lstm' in args[0]: + output, (h, c) = output + else: + output, h = output + c = None + return_quant_tensor = kwargs['return_quant_tensor'] + + if return_quant_tensor: + assert isinstance(output, QuantTensor) + # Empty QuantTensor + if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ + kwargs['io_quant'] is None: + assert output.scale is None + assert output.bit_width is None + else: # "Full" QuantTensor + assert output.scale is not None + assert output.bit_width is not None + else: + assert isinstance(output, torch.Tensor) + + if h is not None: + if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None: + assert isinstance(h, QuantTensor) + else: + assert isinstance(h, torch.Tensor) + + if c is not None: + if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']: + if not kwargs['bidirectional']: + if not kwargs['return_quant_tensor'] and kwargs['num_layers'] == 1: + assert isinstance(c, torch.Tensor) + else: + if kwargs['num_layers'] == 2 and kwargs['signed_act'] is None: + assert isinstance(c, torch.Tensor) + else: + assert isinstance(c, QuantTensor) + else: + if kwargs['num_layers'] == 2 and kwargs['signed_act'] is not None: + assert isinstance(c, QuantTensor) + else: + assert isinstance(c, torch.Tensor) + else: + assert isinstance(c, QuantTensor) + + +@pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn]) +def test_quant_lstm_rnn(model_input, current_cases): + model, input = model_input + + cases_generator_func = current_cases['model_input'][1] + case_id = get_case_id(cases_generator_func) + args = case_id.split('-') + kwargs = parse_args(args) + + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + + if (kwargs['bias_quant'] == 'quant_external') and ( \ + (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ + (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))): + with pytest.raises(RuntimeError, match='Input scale required'): + output = model(input) + return + else: + output = model(input) + if len(output) == 1: + output = output[0] + h, c = None, None + elif len(output) == 2: + if args[0] == 'quant_lstm': + output, (h, c) = output + else: + output, h = output + c = None + return_quant_tensor = kwargs['return_quant_tensor'] and kwargs['io_quant'] is not None + + if return_quant_tensor: + assert isinstance(output, QuantTensor) + # Empty QuantTensor + if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ + kwargs['io_quant'] is None: + assert output.scale is None + assert output.bit_width is None + else: # "Full" QuantTensor + assert output.scale is not None + assert output.bit_width is not None + else: + assert isinstance(output, torch.Tensor) + + if h is not None: + if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None: + assert isinstance(h, QuantTensor) + else: + assert isinstance(h, torch.Tensor) + + if c is not None: + if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None: + assert isinstance(c, QuantTensor) + else: + assert isinstance(c, torch.Tensor) @pytest_cases.parametrize_with_cases('model_input', cases=case_mha)