Skip to content

Commit

Permalink
Rework RoundAndClipThresholds to avoid range and type promotion issues
Browse files Browse the repository at this point in the history
  • Loading branch information
iksnagreb committed Apr 17, 2024
1 parent e632328 commit 8dd85f4
Showing 1 changed file with 76 additions and 29 deletions.
105 changes: 76 additions & 29 deletions src/finn/transformation/streamline/round_thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,90 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Need numpy for modifying the onnx graph tensors, which are numpy style arrays
import numpy as np

# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper

# QONNX graph transformation base class
from qonnx.transformation.base import Transformation

# Transformation running qonnx datatype inference
from qonnx.transformation.infer_datatypes import InferDataTypes


# Rounds and clips thresholds to integer values if the node inputs are integer,
# respecting range, representability and data type (promotion) of the container
# data type
class RoundAndClipThresholds(Transformation):
"""For MultiThreshold nodes operating on integer inputs, round up
thresholds values to the nearest integer. Additionally, if the input
is unsigned, sets negative thresholds to zero."""
is unsigned, sets negative thresholds to zero. Type-casts thresholds (back)
to the float32 container type (this is separate from the quantization
annotation). Runs InferDataTypes() afterward to propagate any changes to the
quantization data types."""

def apply(self, model):
# Applies the transform to a whole model graph
def apply(self, model: ModelWrapper): # noqa
# Get the model graph out of the model wrapper object
graph = model.graph
# Keep track of whether the graph has been modified
graph_modified = False
for n in graph.node:
if n.op_type == "MultiThreshold":
idtype = model.get_tensor_datatype(n.input[0])
T = model.get_initializer(n.input[1])
Tnew = np.ceil(T)
if idtype.is_integer() and (T != Tnew).any():
# round up the thresholds to nearest integer
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
if idtype.is_integer() and not idtype.signed() and (Tnew < 0).any():
# clip any negative thresholds if input is unsigned
Tnew = np.clip(Tnew, 0, None)
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
if idtype.is_integer() and (
(Tnew < (idtype.min())).any() or (Tnew > (idtype.max())).any()
):
# clip any large thresholds to input range + 1
Tnew = np.clip(Tnew, idtype.min(), idtype.max())
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# Applies to initializer tensors of MultiThreshold operations
if node.op_type == "MultiThreshold":
# Try to get the thresholds initializer tensor
thresholds = model.get_initializer(node.input[1])
# There might be no constant thresholds stored as initializer
# tensor inside the model
if thresholds is None:
# Nothing we can do, skip to the next node
continue
# Get the data type of the inputs to this operation
dtype = model.get_tensor_datatype(node.input[0])
# This transformation only applies to thresholding operations
# operating on integer inputs
if not dtype.is_integer():
# Nothing we can do, skip to the next node
continue
# Round thresholds up to nearest integer and clip thresholds
# outside the input range
# Note: This might promote the thresholds to float64 and
# introduce extra inaccuracies due to large integers not being
# exactly representable in floating-point representation.
# See for example: np.ceil(np.float32(16777217)) == 16777216
# fmt: off
new_thresholds = np.clip(
np.ceil(thresholds), dtype.min(), dtype.max()
)
# fmt: on
# Convert back to the preferred float32 container type
# Note: np.clip might have promoted the thresholds to float64
# TODO: Maybe consider an int64 container type for thresholds
# rounded to integer? Need to check all other transformations
# and code generation through the whole FINN and QONNX stack
# first, as these probably assume a float32 container type.
new_thresholds = new_thresholds.astype(np.float32)
# Insert the rounded and clipped thresholds back into the model
model.set_initializer(node.input[1], new_thresholds)
# The rounded and clipped thresholds now fit into the input data
# type
model.set_tensor_datatype(node.input[1], dtype)
# Test whether the new thresholds actually differ from the old
# ones
if np.any(new_thresholds != thresholds):
# Track the graph has been modified to inform the transform
# container to exhaustively repeat this transformation until
# no changes are possible
graph_modified = True
return (model, graph_modified)
# Immediately exit here to propagate the data type changes
# before considering the next node
break
# Some data types might have changed, do one pass of data type inference
# to propagate these changes through the graph
model = model.transform(InferDataTypes())
# Return the transformed model and indicate whether the graph actually
# has been transformed to exhaustively apply this transformation again.
return model, graph_modified

0 comments on commit 8dd85f4

Please sign in to comment.