diff --git a/src/brevitas_examples/bnn_pynq/models/resnet.py b/src/brevitas_examples/bnn_pynq/models/resnet.py index f5e23d479..309431d2f 100644 --- a/src/brevitas_examples/bnn_pynq/models/resnet.py +++ b/src/brevitas_examples/bnn_pynq/models/resnet.py @@ -10,6 +10,7 @@ from brevitas.quant import Int8WeightPerChannelFloat from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import TruncTo8bit +from brevitas.quant import IntBias from brevitas.quant_tensor import QuantTensor @@ -120,6 +121,7 @@ def __init__( act_bit_width=8, weight_bit_width=8, round_average_pool=False, + last_layer_bias_quant=IntBias, weight_quant=Int8WeightPerChannelFloat, first_layer_weight_quant=Int8WeightPerChannelFloat, last_layer_weight_quant=Int8WeightPerTensorFloat): @@ -163,6 +165,7 @@ def __init__( num_classes, weight_bit_width=8, bias=True, + bias_quant=last_layer_bias_quant, weight_quant=last_layer_weight_quant) for m in self.modules(): @@ -224,8 +227,9 @@ def quant_resnet18(cfg) -> QuantResNet: act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH') num_classes = cfg.getint('MODEL', 'NUM_CLASSES') model = QuantResNet( - QuantBasicBlock, [2, 2, 2, 2], + block_impl=QuantBasicBlock, + num_blocks=[2, 2, 2, 2], num_classes=num_classes, weight_bit_width=weight_bit_width, act_bit_width=act_bit_width) - return model + return model \ No newline at end of file