diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 57a2872d..bcf0731d 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -92,7 +92,7 @@ def get_nodeattr_types(self): "out_dtype": ("s", True, ""), "out_scale": ("f", False, 1.0), "out_bias": ("f", False, 0.0), - "data_layout": ("s", False, "NCHW", {"NCHW", "NHWC"}), + "data_layout": ("s", False, ""), } def make_shape_compatible_op(self, model): @@ -122,29 +122,28 @@ def execute_node(self, context, graph): # retrieve attributes if output scaling is used out_scale = self.get_nodeattr("out_scale") out_bias = self.get_nodeattr("out_bias") - # transpose input if NHWC data layout is chosen + + # Consider the data layout for transposing the input into the format + # accepted by the multithreshold function above, i.e, the channel + # dimension is along the axis with index 1. data_layout = self.get_nodeattr("data_layout") - if data_layout == "NHWC": - if v.ndim == 4: - # NHWC -> NCHW - v = np.transpose(v, (0, 3, 1, 2)) - elif v.ndim == 2: - # no HW dimension means NHWC and NCHW layouts are equivalent - pass - else: - raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.") - # calculate output + # If there is no layout annotation, guess based on rank of the + # tensor + if not data_layout and len(v.shape) < 5: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"} + # Lookup the layout required by this input shape + data_layout = rank_to_layout[len(v.shape)] + # Lookup the index of the channel dimension in the data layout + # Note: Assumes there is at most one "C" which denotes the channel + # dimension + cdim = data_layout.index("C") if "C" in data_layout else 1 + # Rearrange the input to the expected (N, C, ...) layout + v = v.swapaxes(cdim, 1) + # Now we can use the multithreshold function to calculate output output = multithreshold(v, thresholds, out_scale, out_bias) - # setting context according to output - if data_layout == "NHWC": - if output.ndim == 4: - # NCHW -> NHWC - output = np.transpose(output, (0, 2, 3, 1)) - elif output.ndim == 2: - # no HW dimension means NHWC and NCHW layouts are equivalent - pass - else: - raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.") + # Rearrange the output back to the original layout + output = output.swapaxes(cdim, 1) context[node.output[0]] = output def verify_node(self): diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py index a763e268..8eec7156 100644 --- a/tests/core/test_custom_onnx_exec.py +++ b/tests/core/test_custom_onnx_exec.py @@ -274,7 +274,9 @@ def test_execute_custom_node_multithreshold(): assert (execution_context["out"] == outputs_nhwc).all() # check the set of allowed values op_inst = getCustomOp(node_def) - assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC"} + # TODO: Removed this check to generalize the supported data layouts, but do + # we need some other check to verify the validity of data layouts? + # assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC", "NC", "NWC", "NCW"} # exercise the allowed value checks # try to set attribute to non-allowed value, should raise an exception try: