Skip to content

Commit

Permalink
Restore tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 3, 2023
1 parent 944e77b commit 45b5de5
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 166 deletions.
4 changes: 3 additions & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
(quant_bias_scale is None or
(quant_bias_scale is not None and
quant_bias_scale.data_ptr() != output_scale.data_ptr()))):
output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1)
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -quant_bias_value.view(
output_scale_broadcast_shape) / output_scale

Expand Down
332 changes: 167 additions & 165 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,171 +30,173 @@ def parse_args(args):
return kwargs


# @pytest_cases.parametrize_with_cases('model_input', cases=case_model)
# def test_quant_wbiol(model_input, current_cases):
# model, input = model_input

# cases_generator_func = current_cases['model_input'][1]
# case_id = get_case_id(cases_generator_func)
# args = case_id.split('-')[1:] # Exclude first argument
# kwargs = parse_args(args)

# is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']

# if (not is_input_quanttensor or kwargs['weight_quant'] is None) and kwargs['return_quant_tensor']:
# with pytest.raises(RuntimeError,
# match='QuantLayer is not correctly configured'):
# output = model(input)
# return
# elif (not is_input_quanttensor or
# kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':
# with pytest.raises(RuntimeError, match='Input scale required'):
# output = model(input)
# return
# elif kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor'] and kwargs['io_quant'] is None \
# and kwargs['input_quantized']:
# with pytest.raises(RuntimeError,
# match='Computing zero point of output accumulator not supported yet.'):
# output = model(input)
# return
# else:
# output = model(input)

# if kwargs['return_quant_tensor']:
# assert isinstance(output, QuantTensor)
# # Empty QuantTensor
# if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
# kwargs['io_quant'] is None:
# assert output.scale is None
# assert output.bit_width is None
# else: # "Full" QuantTensor
# assert output.scale is not None
# assert output.bit_width is not None
# else:
# assert isinstance(output, torch.Tensor)

# @pytest_cases.parametrize_with_cases(
# 'model_input', cases=[case_quant_lstm_full, case_quant_rnn_full])
# def test_quant_lstm_rnn_full(model_input, current_cases):
# model, input = model_input

# cases_generator_func = current_cases['model_input'][1]
# case_id = get_case_id(cases_generator_func)
# args = case_id.split('-')
# kwargs = parse_args(args)

# is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']

# if (kwargs['bias_quant'] == 'quant_external') and ( \
# (not is_input_quanttensor or kwargs['weight_quant'] is None) or \
# (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))):
# with pytest.raises(RuntimeError, match='Input scale required'):
# output = model(input)
# return
# else:
# output = model(input)
# if len(output) == 1:
# output = output[0]
# h, c = None, None
# elif len(output) == 2:
# if 'quant_lstm' in args[0]:
# output, (h, c) = output
# else:
# output, h = output
# c = None
# return_quant_tensor = kwargs['return_quant_tensor']

# if return_quant_tensor:
# assert isinstance(output, QuantTensor)
# # Empty QuantTensor
# if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
# kwargs['io_quant'] is None:
# assert output.scale is None
# assert output.bit_width is None
# else: # "Full" QuantTensor
# assert output.scale is not None
# assert output.bit_width is not None
# else:
# assert isinstance(output, torch.Tensor)

# if h is not None:
# if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
# assert isinstance(h, QuantTensor)
# else:
# assert isinstance(h, torch.Tensor)

# if c is not None:
# if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']:
# if not kwargs['bidirectional']:
# if not kwargs['return_quant_tensor'] and kwargs['num_layers'] == 1:
# assert isinstance(c, torch.Tensor)
# else:
# if kwargs['num_layers'] == 2 and kwargs['signed_act'] is None:
# assert isinstance(c, torch.Tensor)
# else:
# assert isinstance(c, QuantTensor)
# else:
# if kwargs['num_layers'] == 2 and kwargs['signed_act'] is not None:
# assert isinstance(c, QuantTensor)
# else:
# assert isinstance(c, torch.Tensor)
# else:
# assert isinstance(c, QuantTensor)

# @pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn])
# def test_quant_lstm_rnn(model_input, current_cases):
# model, input = model_input

# cases_generator_func = current_cases['model_input'][1]
# case_id = get_case_id(cases_generator_func)
# args = case_id.split('-')
# kwargs = parse_args(args)

# is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']

# if (kwargs['bias_quant'] == 'quant_external') and ( \
# (not is_input_quanttensor or kwargs['weight_quant'] is None) or \
# (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))):
# with pytest.raises(RuntimeError, match='Input scale required'):
# output = model(input)
# return
# else:
# output = model(input)
# if len(output) == 1:
# output = output[0]
# h, c = None, None
# elif len(output) == 2:
# if args[0] == 'quant_lstm':
# output, (h, c) = output
# else:
# output, h = output
# c = None
# return_quant_tensor = kwargs['return_quant_tensor'] and kwargs['io_quant'] is not None

# if return_quant_tensor:
# assert isinstance(output, QuantTensor)
# # Empty QuantTensor
# if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
# kwargs['io_quant'] is None:
# assert output.scale is None
# assert output.bit_width is None
# else: # "Full" QuantTensor
# assert output.scale is not None
# assert output.bit_width is not None
# else:
# assert isinstance(output, torch.Tensor)

# if h is not None:
# if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
# assert isinstance(h, QuantTensor)
# else:
# assert isinstance(h, torch.Tensor)

# if c is not None:
# if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
# assert isinstance(c, QuantTensor)
# else:
# assert isinstance(c, torch.Tensor)
@pytest_cases.parametrize_with_cases('model_input', cases=case_model)
def test_quant_wbiol(model_input, current_cases):
model, input = model_input

cases_generator_func = current_cases['model_input'][1]
case_id = get_case_id(cases_generator_func)
args = case_id.split('-')[1:] # Exclude first argument
kwargs = parse_args(args)

is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']

if (not is_input_quanttensor or
kwargs['weight_quant'] is None) and kwargs['return_quant_tensor']:
with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'):
output = model(input)
return
elif (not is_input_quanttensor or
kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':
with pytest.raises(RuntimeError, match='Input scale required'):
output = model(input)
return
elif kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor'] and kwargs['io_quant'] is None \
and kwargs['input_quantized']:
with pytest.raises(RuntimeError,
match='Computing zero point of output accumulator not supported yet.'):
output = model(input)
return
else:
output = model(input)

if kwargs['return_quant_tensor']:
assert isinstance(output, QuantTensor)
# Empty QuantTensor
if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
kwargs['io_quant'] is None:
assert output.scale is None
assert output.bit_width is None
else: # "Full" QuantTensor
assert output.scale is not None
assert output.bit_width is not None
else:
assert isinstance(output, torch.Tensor)


@pytest_cases.parametrize_with_cases(
'model_input', cases=[case_quant_lstm_full, case_quant_rnn_full])
def test_quant_lstm_rnn_full(model_input, current_cases):
model, input = model_input

cases_generator_func = current_cases['model_input'][1]
case_id = get_case_id(cases_generator_func)
args = case_id.split('-')
kwargs = parse_args(args)

is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']

if (kwargs['bias_quant'] == 'quant_external') and ( \
(not is_input_quanttensor or kwargs['weight_quant'] is None) or \
(kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))):
with pytest.raises(RuntimeError, match='Input scale required'):
output = model(input)
return
else:
output = model(input)
if len(output) == 1:
output = output[0]
h, c = None, None
elif len(output) == 2:
if 'quant_lstm' in args[0]:
output, (h, c) = output
else:
output, h = output
c = None
return_quant_tensor = kwargs['return_quant_tensor']

if return_quant_tensor:
assert isinstance(output, QuantTensor)
# Empty QuantTensor
if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
kwargs['io_quant'] is None:
assert output.scale is None
assert output.bit_width is None
else: # "Full" QuantTensor
assert output.scale is not None
assert output.bit_width is not None
else:
assert isinstance(output, torch.Tensor)

if h is not None:
if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
assert isinstance(h, QuantTensor)
else:
assert isinstance(h, torch.Tensor)

if c is not None:
if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']:
if not kwargs['bidirectional']:
if not kwargs['return_quant_tensor'] and kwargs['num_layers'] == 1:
assert isinstance(c, torch.Tensor)
else:
if kwargs['num_layers'] == 2 and kwargs['signed_act'] is None:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, QuantTensor)
else:
if kwargs['num_layers'] == 2 and kwargs['signed_act'] is not None:
assert isinstance(c, QuantTensor)
else:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, QuantTensor)


@pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn])
def test_quant_lstm_rnn(model_input, current_cases):
model, input = model_input

cases_generator_func = current_cases['model_input'][1]
case_id = get_case_id(cases_generator_func)
args = case_id.split('-')
kwargs = parse_args(args)

is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized']

if (kwargs['bias_quant'] == 'quant_external') and ( \
(not is_input_quanttensor or kwargs['weight_quant'] is None) or \
(kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))):
with pytest.raises(RuntimeError, match='Input scale required'):
output = model(input)
return
else:
output = model(input)
if len(output) == 1:
output = output[0]
h, c = None, None
elif len(output) == 2:
if args[0] == 'quant_lstm':
output, (h, c) = output
else:
output, h = output
c = None
return_quant_tensor = kwargs['return_quant_tensor'] and kwargs['io_quant'] is not None

if return_quant_tensor:
assert isinstance(output, QuantTensor)
# Empty QuantTensor
if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
kwargs['io_quant'] is None:
assert output.scale is None
assert output.bit_width is None
else: # "Full" QuantTensor
assert output.scale is not None
assert output.bit_width is not None
else:
assert isinstance(output, torch.Tensor)

if h is not None:
if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
assert isinstance(h, QuantTensor)
else:
assert isinstance(h, torch.Tensor)

if c is not None:
if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
assert isinstance(c, QuantTensor)
else:
assert isinstance(c, torch.Tensor)


@pytest_cases.parametrize_with_cases('model_input', cases=case_mha)
Expand Down

0 comments on commit 45b5de5

Please sign in to comment.