diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 57a2872d..c1f6a01a 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -92,7 +92,7 @@ def get_nodeattr_types(self): "out_dtype": ("s", True, ""), "out_scale": ("f", False, 1.0), "out_bias": ("f", False, 0.0), - "data_layout": ("s", False, "NCHW", {"NCHW", "NHWC"}), + "data_layout": ("s", False, "NCHW", {"NCHW", "NHWC", "NC"}), } def make_shape_compatible_op(self, model): diff --git a/src/qonnx/util/range_analysis.py b/src/qonnx/util/range_analysis.py index 7651b835..4eb0c052 100644 --- a/src/qonnx/util/range_analysis.py +++ b/src/qonnx/util/range_analysis.py @@ -34,6 +34,7 @@ from onnx import ValueInfoProto from warnings import warn +from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_node from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine @@ -250,6 +251,22 @@ def calc_range_outdtype(node, model, range_dict): range_dict[oname].range = (odt.min(), odt.max()) +# Softmax always produces outputs in [0,1] +def calc_softmax_range(node, model, range_dict): + oname = node.output[0] + assert node.op_type == "Softmax" + range_dict[oname].range = (0, 1) + + +# LogSoftmax always produces outputs in [-inf,0], which is the log of the range +# of the Softmax +def calc_logsoftmax_range(node, model, range_dict): + oname = node.output[0] + assert node.op_type == "LogSoftmax" + # Note: Replaces -inf by the smallest representable float 32 value + range_dict[oname].range = (DataType["FLOAT32"].min(), 0) + + # return whether a given tensor is a shape operand def is_shape_operand(tensor_name, model): cons = model.find_consumer(tensor_name) @@ -293,7 +310,7 @@ def calc_range_with_lowering(prep_transforms, lowering_transforms, node, model, for node_inp in node_model.graph.input: node_range_dict[node_inp.name] = range_dict[node_inp.name] # run range analysis on the lowered single-node model - ret_range_dict = range_analysis(node_model, irange=node_range_dict, report_mode=REPORT_MODE_RANGE) + ret_range_dict, _ = range_analysis(node_model, irange=node_range_dict, report_mode=REPORT_MODE_RANGE) # copy results back into original range_dict for node_out in node.output: range_dict[node_out] = ret_range_dict[node_out] @@ -829,7 +846,7 @@ def calc_intrange_eltwise_monotonic_intrangefirst(node, model, range_dict): # strategy: use regular range analysis (which will execute the node on the corners of the input # range) using the integer range as the input, which gives us the output integer range. then figure # out the scale/bias based on the output integer range. - orange_inf = range_dict[node.output[0]] + orange_inf = {out: range_dict[out] for out in node.output} int_range_dict = {} for node_out in node.output: oshape = model.get_tensor_shape(node_out) @@ -847,23 +864,24 @@ def calc_intrange_eltwise_monotonic_intrangefirst(node, model, range_dict): int_range_dict[node_in] = range_dict[node_in] range_calc_fxn = optype_to_range_calc[node.op_type] range_calc_fxn(node, model, int_range_dict) - int_orange_inf = int_range_dict[node.output[0]] - range_dict[node.output[0]].int_range = int_orange_inf.range - # now deduce the output scale factor and bias from all available info - # range_max = S*int_range_max + B - # range_min = S*int_range_min + B - # so S = (range_max - range_min) / (int_range_max - int_range_min) - # and afterwards, B = range_max - S*int_range_max - # TODO scale and bias may contain NaN's when channels are stuck - # how best to deal with this? leave as is? set to 1/0? - # try to recover in some other way? (perturb the actual range before calling range_calc_fxn) - scale = (orange_inf.range[1] - orange_inf.range[0]) / (int_orange_inf.range[1] - int_orange_inf.range[0]) - if not np.isfinite(scale).all(): - warn(f"{node.name} has stuck values, forcing scale to 1.0 for those") - scale = np.nan_to_num(scale, nan=1.0, posinf=1.0, neginf=1.0) - bias = orange_inf.range[1] - scale * int_orange_inf.range[1] - range_dict[node.output[0]].scale = scale - range_dict[node.output[0]].bias = bias + for i, out in enumerate(node.output): + int_orange_inf = int_range_dict[out] + range_dict[out].int_range = int_orange_inf.range + # now deduce the output scale factor and bias from all available info + # range_max = S*int_range_max + B + # range_min = S*int_range_min + B + # so S = (range_max - range_min) / (int_range_max - int_range_min) + # and afterwards, B = range_max - S*int_range_max + # TODO scale and bias may contain NaN's when channels are stuck + # how best to deal with this? leave as is? set to 1/0? + # try to recover in some other way? (perturb the actual range before calling range_calc_fxn) + scale = (orange_inf[out].range[1] - orange_inf[out].range[0]) / (int_orange_inf.range[1] - int_orange_inf.range[0]) + if not np.isfinite(scale).all(): + warn(f"{node.name} has stuck values, forcing scale to 1.0 for those") + scale = np.nan_to_num(scale, nan=1.0, posinf=1.0, neginf=1.0) + bias = orange_inf[out].range[1] - scale * int_orange_inf.range[1] + range_dict[out].scale = scale + range_dict[out].bias = bias # for several types of nodes, we dynamically convert ("lower") the node to something else that we can @@ -883,7 +901,7 @@ def calc_intrange_with_lowering(prep_transforms, lowering_transforms, node, mode for node_inp in node_model.graph.input: node_range_dict[node_inp.name] = range_dict[node_inp.name] # run range analysis on the lowered single-node model - ret_range_dict = range_analysis(node_model, irange=node_range_dict, report_mode=REPORT_MODE_RANGE, scaled_int=True) + ret_range_dict, _ = range_analysis(node_model, irange=node_range_dict, report_mode=REPORT_MODE_RANGE, scaled_int=True) # copy results back into original range_dict for node_out in node.output: range_dict[node_out] = ret_range_dict[node_out] @@ -927,7 +945,6 @@ def calc_intrange_gemm(node, model, range_dict): "Div": calc_monotonic_range, "Add": calc_monotonic_range, "BatchNormalization": calc_monotonic_range, - "Relu": calc_monotonic_range, "Pad": calc_monotonic_range, "AveragePool": calc_monotonic_range, "Trunc": calc_monotonic_range, @@ -937,11 +954,43 @@ def calc_intrange_gemm(node, model, range_dict): "GlobalAveragePool": calc_monotonic_range, "QuantizeLinear": calc_monotonic_range, "DequantizeLinear": calc_monotonic_range, - "Clip": calc_monotonic_range, - "Sigmoid": calc_monotonic_range, "Concat": calc_monotonic_range, "Split": calc_monotonic_range, "Im2Col": calc_monotonic_range, + # Monotonic activation functions: This list is not completer yet, there are + # some not supported/produced by export, so they are not verified and thus + # not added here. + "Identity": calc_monotonic_range, + "Relu": calc_monotonic_range, + "LeakyRelu": calc_monotonic_range, + "Clip": calc_monotonic_range, + "Selu": calc_monotonic_range, + "Celu": calc_monotonic_range, + "Elu": calc_monotonic_range, + "Sigmoid": calc_monotonic_range, + "HardSigmoid": calc_monotonic_range, + "Tanh": calc_monotonic_range, + "Softplus": calc_monotonic_range, + "Exp": calc_monotonic_range, + "Log": calc_monotonic_range, + "Sqrt": calc_monotonic_range, + "Erf": calc_monotonic_range, + "Floor": calc_monotonic_range, + "Ceil": calc_monotonic_range, + "Round": calc_monotonic_range, + "Sign": calc_monotonic_range, + # Softmax has a defined output range of [0,1] while LogSoftmax yields the + # log of this range + "Softmax": calc_softmax_range, + "LogSoftmax": calc_logsoftmax_range, + # Squeeze and Unsqueeze are special cases of Reshape, which ist monotonic + "Squeeze": calc_monotonic_range, + "Unsqueeze": calc_monotonic_range, + # Treat MultiThreshold as monotonic. This might be necessary for iterated + # rounds of activation function to MultiThreshold conversion to absorb + # chains of monotonic activation functions into MultiThreshold + # TODO: Check whether this is actually ok... + "MultiThreshold": calc_monotonic_range, "Conv": calc_conv_range, "Gemm": calc_gemm_range, } @@ -950,15 +999,26 @@ def calc_intrange_gemm(node, model, range_dict): optype_to_intrange_calc = { "MatMul": calc_intrange_matmul, "Conv": calc_intrange_conv, - "Add": calc_intrange_add, "Mul": calc_intrange_mul, "Relu": calc_intrange_relu, "Quant": calc_intrange_quant, "Pad": calc_intrange_eltwise_monotonic, "MaxPool": calc_intrange_eltwise_monotonic, - "Reshape": calc_intrange_eltwise_monotonic, - "Transpose": calc_intrange_eltwise_monotonic, "Im2Col": calc_intrange_eltwise_monotonic, + "Concat": calc_intrange_eltwise_monotonic, + # TODO: Workaround for some weird RA behavior producing NANs, zero scales or + # ranges from -0 to +0. So far only observed in rather complex topology + # involving residual connections, attention and novel activation functions + # and it is unclear how to reproduce this in isolation... + "Add": calc_intrange_eltwise_monotonic_intrangefirst, + "Reshape": calc_intrange_eltwise_monotonic_intrangefirst, + "Transpose": calc_intrange_eltwise_monotonic_intrangefirst, + "Split": calc_intrange_eltwise_monotonic_intrangefirst, + # Treat MultiThreshold as monotonic. This might be necessary for iterated + # rounds of activation function to MultiThreshold conversion to absorb + # chains of monotonic activation functions into MultiThreshold + # TODO: Check whether this is actually ok... + "MultiThreshold": calc_intrange_eltwise_monotonic, "Sub": calc_intrange_sub, "Div": calc_intrange_div, "Gemm": calc_intrange_gemm, @@ -1141,7 +1201,10 @@ def range_analysis( ret = new_ret if prettyprint: ret = pprint.pformat(ret, sort_dicts=False) - return ret + # Return the range information and the transformed model as we might have + # added, removed or renamed some tensors above, and thus we need the new + # model to match tensor names from range information. + return ret, model def main():