Skip to content

Commit

Permalink
Update resnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Oct 11, 2023
1 parent b051309 commit aff5214
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/brevitas_examples/bnn_pynq/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import TruncTo8bit
from brevitas.quant import IntBias
from brevitas.quant_tensor import QuantTensor


Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=IntBias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
Expand Down Expand Up @@ -163,6 +165,7 @@ def __init__(
num_classes,
weight_bit_width=8,
bias=True,
bias_quant=last_layer_bias_quant,
weight_quant=last_layer_weight_quant)

for m in self.modules():
Expand Down Expand Up @@ -224,8 +227,9 @@ def quant_resnet18(cfg) -> QuantResNet:
act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH')
num_classes = cfg.getint('MODEL', 'NUM_CLASSES')
model = QuantResNet(
QuantBasicBlock, [2, 2, 2, 2],
block_impl=QuantBasicBlock,
num_blocks=[2, 2, 2, 2],
num_classes=num_classes,
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width)
return model
return model

0 comments on commit aff5214

Please sign in to comment.