Skip to content

Commit

Permalink
Merge pull request Xilinx#866 from mmrahorovic/fix/minimize_bitwidth_…
Browse files Browse the repository at this point in the history
…thresholds

Fix to threshold's width optimization
  • Loading branch information
auphelia authored Aug 4, 2023
2 parents 99f61fc + f52871d commit fd72d48
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/finn/custom_op/fpgadataflow/matrixvectoractivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ def minimize_accumulator_width(self, model):
# for no-activation nodes, output dt = acc dt
self.set_nodeattr("outputDataType", adt.name)
self.set_nodeattr("accDataType", adt.name)

return DataType[self.get_nodeattr("accDataType")]

def minimize_weight_bit_width(self, model):
Expand Down
2 changes: 2 additions & 0 deletions src/finn/custom_op/fpgadataflow/thresholding_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def minimize_accumulator_width(self, model):
threshold_tensor
).all(), "Thresholds can't be expressed with type %s" % str(tdt)
self.set_nodeattr("weightDataType", tdt.name)
# Update QONNX DataType of tensor for consistency
model.set_tensor_datatype(self.onnx_node.input[1], tdt)
return DataType[self.get_nodeattr("weightDataType")]

def get_instream_width(self, ind=0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_datatypes import InferDataTypes

from finn.util.fpgadataflow import is_fpgadataflow_node

Expand All @@ -41,9 +42,15 @@ def __init__(self):
super().__init__()

def apply(self, model):
for node in model.graph.node:
for node_id in range(len(model.graph.node)):
# Since InferDataTypes potentially changes node attributes in each loop iterations,
# the for-loop cannot loop over a list of a snapshot of the graph's node protos
node = model.graph.node[node_id]
if is_fpgadataflow_node(node) is True:
inst = getCustomOp(node)
if hasattr(inst, "minimize_accumulator_width"):
inst.minimize_accumulator_width(model)
# Since this transformation is applied iteratively, we have to ensure that
# we propagate the new datatype to other layers
model = model.transform(InferDataTypes())
return (model, False)

0 comments on commit fd72d48

Please sign in to comment.