Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RoundAndClipThresholds Transformation #1030

Merged
merged 8 commits into from
Oct 7, 2024
2 changes: 2 additions & 0 deletions src/finn/builder/build_dataflow_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
)
from finn.transformation.streamline import Streamline
from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
from finn.util.basic import (
get_rtlsim_trace_depth,
pyverilate_get_liveness_threshold_cycles,
Expand Down Expand Up @@ -503,6 +504,7 @@ def step_minimize_bit_width(model: ModelWrapper, cfg: DataflowBuildConfig):
if cfg.minimize_bit_width:
model = model.transform(MinimizeWeightBitWidth())
model = model.transform(MinimizeAccumulatorWidth())
model = model.transform(RoundAndClipThresholds())
# make sure the changed datatypes are propagated through the network
model = model.transform(InferDataTypes())
return model
Expand Down
86 changes: 56 additions & 30 deletions src/finn/transformation/streamline/round_thresholds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020, Xilinx
# Copyright (c) 2020-2022, Xilinx
# Copyright (C) 2022-2024, Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -27,42 +28,67 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_datatypes import InferDataTypes


class RoundAndClipThresholds(Transformation):
"""For MultiThreshold nodes operating on integer inputs, round up
thresholds values to the nearest integer. Additionally, if the input
is unsigned, sets negative thresholds to zero."""
is unsigned, sets negative thresholds to zero. Type-casts thresholds (back)
to the float32 container type (this is separate from the quantization
annotation). Runs InferDataTypes() afterward to propagate any changes to the
quantization data types."""

def apply(self, model):
def apply(self, model: ModelWrapper): # noqa
graph = model.graph
graph_modified = False
for n in graph.node:
if n.op_type == "MultiThreshold":
idtype = model.get_tensor_datatype(n.input[0])
T = model.get_initializer(n.input[1])
Tnew = np.ceil(T)
if idtype.is_integer() and (T != Tnew).any():
# round up the thresholds to nearest integer
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
for index, node in enumerate(graph.node):
op_type = node.op_type
if op_type == "MultiThreshold" or op_type.startswith("Thresholding"):
thresholds = model.get_initializer(node.input[1])
if thresholds is None:
continue
dtype = model.get_tensor_datatype(node.input[0])
# This transformation only applies to thresholding operations
# operating on integer inputs
if not dtype.is_integer():
continue
# Round thresholds up to nearest integer and clip thresholds
# outside the input range
# Note: This might promote the thresholds to float64 and
# introduce extra inaccuracies due to large integers not being
# exactly representable in floating-point representation.
# See for example: np.ceil(np.float32(16777217)) == 16777216
new_thresholds = np.clip(np.ceil(thresholds), dtype.min(), dtype.max() + 1)
# Convert back to the preferred float32 container type
new_thresholds = new_thresholds.astype(np.float32)
# Insert the rounded and clipped thresholds back into the model
model.set_initializer(node.input[1], new_thresholds)
# The rounded and clipped thresholds now fit into a data type
# that is one bit bigger than the input datatype
# Determine new max_value
max_val = dtype.max() + 1
if not dtype.signed():
tdt = DataType.get_smallest_possible(max_val)
else:
tdt = DataType.get_smallest_possible(-(max_val) - 1)
model.set_tensor_datatype(node.input[1], tdt)
# If hw op we need to set the weight data type attribute as well
if op_type.startswith("Thresholding"):
inst = getCustomOp(node)
inst.set_nodeattr("weightDataType", tdt.name)
# ones
if np.any(new_thresholds != thresholds):
# Track the graph has been modified to inform the transform
# container to exhaustively repeat this transformation until
# no changes are possible
graph_modified = True
if idtype.is_integer() and not idtype.signed() and (Tnew < 0).any():
# clip any negative thresholds if input is unsigned
Tnew = np.clip(Tnew, 0, None)
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
if idtype.is_integer() and (
(Tnew < (idtype.min() - 1)).any() or (Tnew > (idtype.max() + 1)).any()
):
# clip any large thresholds to input range + 1
Tnew = np.clip(Tnew, idtype.min() - 1, idtype.max() + 1)
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
return (model, graph_modified)
# Immediately exit here to propagate the data type changes
# before considering the next node
break
model = model.transform(InferDataTypes())
return model, graph_modified
2 changes: 2 additions & 0 deletions tests/end2end/test_end2end_bnn_pynq.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
MakeMaxPoolNHWC,
MoveScalarLinearPastInvariants,
)
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
from finn.util.basic import get_finn_root, make_build_dir, test_board_map
from finn.util.pytorch import ToTensor
from finn.util.test import (
Expand Down Expand Up @@ -672,6 +673,7 @@ def test_minimize_bit_width(self, topology, wbits, abits, board):
model = load_test_checkpoint_or_skip(prev_chkpt_name)
model = model.transform(MinimizeAccumulatorWidth())
model = model.transform(MinimizeWeightBitWidth())
model = model.transform(RoundAndClipThresholds())
curr_chkpt_name = get_checkpoint_name(topology, wbits, abits, "minimize_bit_width")
model.save(curr_chkpt_name)

Expand Down
1 change: 1 addition & 0 deletions tests/end2end/test_end2end_mobilenet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def test_end2end_mobilenet_minimize_bit_width():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_mobilenet_folded.onnx")
model = model.transform(MinimizeAccumulatorWidth())
model = model.transform(MinimizeWeightBitWidth())
model = model.transform(RoundAndClipThresholds())
model.save(build_dir + "/end2end_mobilenet_minimize_bitwidth.onnx")


Expand Down
11 changes: 7 additions & 4 deletions tests/fpgadataflow/test_fpgadataflow_thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds

test_fpga_part = "xczu3eg-sbva484-1-e"
target_clk_ns = 5
Expand Down Expand Up @@ -133,10 +134,8 @@ def make_single_multithresholding_modelwrapper(
@pytest.mark.parametrize(
"idt_tdt_cfg",
[
(DataType["INT8"], DataType["INT8"]),
(DataType["INT8"], DataType["INT9"]),
(DataType["UINT5"], DataType["UINT5"]),
(DataType["UINT5"], DataType["UINT6"]),
(DataType["INT8"], DataType["INT25"]),
(DataType["UINT5"], DataType["UINT8"]),
],
)
@pytest.mark.parametrize("fold", [-1, 1, 2])
Expand All @@ -145,6 +144,7 @@ def make_single_multithresholding_modelwrapper(
@pytest.mark.parametrize("impl_style", ["hls", "rtl"])
@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
@pytest.mark.parametrize("mem_mode", ["internal_embedded", "internal_decoupled"])
@pytest.mark.parametrize("round_thresh", [True, False])
@pytest.mark.fpgadataflow
@pytest.mark.vivado
@pytest.mark.slow
Expand All @@ -159,6 +159,7 @@ def test_fpgadataflow_thresholding(
impl_style,
exec_mode,
mem_mode,
round_thresh,
):
# the mem_mode parameter can only be used for the hls thresholding
# so the test will only be executed once for impl_style=rtl and once skipped
Expand Down Expand Up @@ -234,6 +235,8 @@ def test_fpgadataflow_thresholding(
node = model.get_nodes_by_op_type(model.graph.node[0].op_type)[0]
inst = getCustomOp(node)
inst.set_nodeattr("PE", pe)
if round_thresh is True:
model = model.transform(RoundAndClipThresholds())
model = model.transform(GiveUniqueNodeNames())

if impl_style == "hls":
Expand Down
Loading
Loading