Skip to content

Commit

Permalink
Fix test_wbiol
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 13, 2023
1 parent 9b30b18 commit baccb2c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
4 changes: 2 additions & 2 deletions tests/brevitas_finn/brevitas/test_brevitas_avg_pool_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from qonnx.util.basic import gen_finn_dt_tensor
import torch

from brevitas.export import FINNManager
from brevitas.export import export_qonnx
from brevitas.nn import TruncAvgPool2d
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -50,7 +50,7 @@ def test_brevitas_avg_pool_export(
# export
test_id = request.node.callspec.id
export_path = test_id + '_' + export_onnx_path
FINNManager.export(quant_avgpool, export_path=export_path, input_t=input_quant_tensor)
export_qonnx(quant_avgpool, export_path=export_path, input_t=input_quant_tensor)
model = ModelWrapper(export_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
Expand Down
24 changes: 11 additions & 13 deletions tests/brevitas_finn/brevitas/test_wbiol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.nn import QuantLinear
import brevitas.onnx as bo
from brevitas.quant import Int16Bias
from brevitas.quant.scaled_int import Int8ActPerTensorFloat


@pytest.mark.parametrize("bias", [True, False])
Expand All @@ -25,23 +26,21 @@
@pytest.mark.parametrize("channel_scaling", [True, False])
@pytest.mark.parametrize("i_bits", [2, 4])
def test_quant_linear(bias, bias_quant, out_features, in_features, w_bits, channel_scaling, i_bits):
# required to generated quantized inputs, not part of the exported model to test
quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True)
inp_tensor = quant_inp(torch.randn(1, in_features))
inp_tensor = torch.randn(1, in_features)
linear = QuantLinear(
input_quant=Int8ActPerTensorFloat,
input_bit_width=i_bits,
out_features=out_features,
in_features=in_features,
bias=bias,
bias_quant=bias_quant,
weight_bit_width=w_bits,
weight_scaling_per_output_channel=channel_scaling)
linear.eval()
model = bo.export_qonnx(linear, input_t=inp_tensor, export_path='linear.onnx')
model = bo.export_qonnx(linear, inp_tensor, export_path='linear.onnx')
model = ModelWrapper(model)
model = model.transform(InferShapes())
# the quantized input tensor passed to FINN should be in integer form
int_inp_array = inp_tensor.int(float_datatype=True).detach().numpy()
idict = {model.graph.input[0].name: int_inp_array}
idict = {model.graph.input[0].name: inp_tensor.detach().numpy()}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
expected = linear(inp_tensor).detach().numpy()
Expand Down Expand Up @@ -73,11 +72,11 @@ def test_quant_conv2d(
padding,
stride,
i_bits):
# required to generated quantized inputs, not part of the exported model to test
quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True)
inp_tensor = quant_inp(torch.randn(1, in_channels, in_features, in_features))
inp_tensor = torch.randn(1, in_channels, in_features, in_features)
try:
conv = QuantConv2d(
input_quant=Int8ActPerTensorFloat,
input_bit_width=i_bits,
in_channels=in_channels,
# out_channels=in_channels if dw else out_channels,
out_channels=
Expand Down Expand Up @@ -105,12 +104,11 @@ def test_quant_conv2d(
assert False

conv.eval()
model = bo.export_qonnx(conv, input_t=inp_tensor)
model = bo.export_qonnx(conv, inp_tensor)
model = ModelWrapper(model)
model = model.transform(InferShapes())
# the quantized input tensor passed to FINN should be in integer form
int_inp_array = inp_tensor.int(float_datatype=True).detach().numpy()
idict = {model.graph.input[0].name: int_inp_array}
idict = {model.graph.input[0].name: inp_tensor.detach().numpy()}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
expected = conv(inp_tensor).detach().numpy()
Expand Down

0 comments on commit baccb2c

Please sign in to comment.