Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MultiThreshold] Generalize data layouts for node execution #143

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
43 changes: 21 additions & 22 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_nodeattr_types(self):
"out_dtype": ("s", True, ""),
"out_scale": ("f", False, 1.0),
"out_bias": ("f", False, 0.0),
"data_layout": ("s", False, "NCHW", {"NCHW", "NHWC"}),
"data_layout": ("s", False, ""),
}

def make_shape_compatible_op(self, model):
Expand Down Expand Up @@ -122,29 +122,28 @@ def execute_node(self, context, graph):
# retrieve attributes if output scaling is used
out_scale = self.get_nodeattr("out_scale")
out_bias = self.get_nodeattr("out_bias")
# transpose input if NHWC data layout is chosen

# Consider the data layout for transposing the input into the format
# accepted by the multithreshold function above, i.e, the channel
# dimension is along the axis with index 1.
data_layout = self.get_nodeattr("data_layout")
if data_layout == "NHWC":
if v.ndim == 4:
# NHWC -> NCHW
v = np.transpose(v, (0, 3, 1, 2))
elif v.ndim == 2:
# no HW dimension means NHWC and NCHW layouts are equivalent
pass
else:
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")
# calculate output
# If there is no layout annotation, guess based on rank of the
# tensor
if not data_layout and len(v.shape) < 5:
# Maps tensor rank to layout annotation
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
# Lookup the layout required by this input shape
data_layout = rank_to_layout[len(v.shape)]
# Lookup the index of the channel dimension in the data layout
# Note: Assumes there is at most one "C" which denotes the channel
# dimension
cdim = data_layout.index("C") if "C" in data_layout else 1
# Rearrange the input to the expected (N, C, ...) layout
v = v.swapaxes(cdim, 1)
# Now we can use the multithreshold function to calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# setting context according to output
if data_layout == "NHWC":
if output.ndim == 4:
# NCHW -> NHWC
output = np.transpose(output, (0, 2, 3, 1))
elif output.ndim == 2:
# no HW dimension means NHWC and NCHW layouts are equivalent
pass
else:
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")
# Rearrange the output back to the original layout
output = output.swapaxes(cdim, 1)
context[node.output[0]] = output

def verify_node(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_custom_onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def test_execute_custom_node_multithreshold():
assert (execution_context["out"] == outputs_nhwc).all()
# check the set of allowed values
op_inst = getCustomOp(node_def)
assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC"}
# TODO: Removed this check to generalize the supported data layouts, but do
# we need some other check to verify the validity of data layouts?
# assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC", "NC", "NWC", "NCW"}
# exercise the allowed value checks
# try to set attribute to non-allowed value, should raise an exception
try:
Expand Down
Loading