diff --git a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py index 090fc5345..c5bff8c57 100644 --- a/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py @@ -13,7 +13,7 @@ from qonnx.util.basic import gen_finn_dt_tensor import torch -from brevitas.export import FINNManager +from brevitas.export import export_qonnx from brevitas.nn import TruncAvgPool2d from brevitas.quant_tensor import QuantTensor @@ -50,7 +50,7 @@ def test_brevitas_avg_pool_export( # export test_id = request.node.callspec.id export_path = test_id + '_' + export_onnx_path - FINNManager.export(quant_avgpool, export_path=export_path, input_t=input_quant_tensor) + export_qonnx(quant_avgpool, export_path=export_path, input_t=input_quant_tensor) model = ModelWrapper(export_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) diff --git a/tests/brevitas_finn/brevitas/test_wbiol.py b/tests/brevitas_finn/brevitas/test_wbiol.py index 9f41d9255..1b0341bd4 100644 --- a/tests/brevitas_finn/brevitas/test_wbiol.py +++ b/tests/brevitas_finn/brevitas/test_wbiol.py @@ -15,6 +15,7 @@ from brevitas.nn import QuantLinear import brevitas.onnx as bo from brevitas.quant import Int16Bias +from brevitas.quant.scaled_int import Int8ActPerTensorFloat @pytest.mark.parametrize("bias", [True, False]) @@ -25,10 +26,10 @@ @pytest.mark.parametrize("channel_scaling", [True, False]) @pytest.mark.parametrize("i_bits", [2, 4]) def test_quant_linear(bias, bias_quant, out_features, in_features, w_bits, channel_scaling, i_bits): - # required to generated quantized inputs, not part of the exported model to test - quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True) - inp_tensor = quant_inp(torch.randn(1, in_features)) + inp_tensor = torch.randn(1, in_features) linear = QuantLinear( + input_quant=Int8ActPerTensorFloat, + input_bit_width=i_bits, out_features=out_features, in_features=in_features, bias=bias, @@ -36,12 +37,10 @@ def test_quant_linear(bias, bias_quant, out_features, in_features, w_bits, chann weight_bit_width=w_bits, weight_scaling_per_output_channel=channel_scaling) linear.eval() - model = bo.export_qonnx(linear, input_t=inp_tensor, export_path='linear.onnx') + model = bo.export_qonnx(linear, inp_tensor, export_path='linear.onnx') model = ModelWrapper(model) model = model.transform(InferShapes()) - # the quantized input tensor passed to FINN should be in integer form - int_inp_array = inp_tensor.int(float_datatype=True).detach().numpy() - idict = {model.graph.input[0].name: int_inp_array} + idict = {model.graph.input[0].name: inp_tensor.detach().numpy()} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] expected = linear(inp_tensor).detach().numpy() @@ -73,11 +72,11 @@ def test_quant_conv2d( padding, stride, i_bits): - # required to generated quantized inputs, not part of the exported model to test - quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True) - inp_tensor = quant_inp(torch.randn(1, in_channels, in_features, in_features)) + inp_tensor = torch.randn(1, in_channels, in_features, in_features) try: conv = QuantConv2d( + input_quant=Int8ActPerTensorFloat, + input_bit_width=i_bits, in_channels=in_channels, # out_channels=in_channels if dw else out_channels, out_channels= @@ -105,12 +104,11 @@ def test_quant_conv2d( assert False conv.eval() - model = bo.export_qonnx(conv, input_t=inp_tensor) + model = bo.export_qonnx(conv, inp_tensor) model = ModelWrapper(model) model = model.transform(InferShapes()) # the quantized input tensor passed to FINN should be in integer form - int_inp_array = inp_tensor.int(float_datatype=True).detach().numpy() - idict = {model.graph.input[0].name: int_inp_array} + idict = {model.graph.input[0].name: inp_tensor.detach().numpy()} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] expected = conv(inp_tensor).detach().numpy()