From a98136865e9a890a2923123b190e9fa76aceb9b0 Mon Sep 17 00:00:00 2001 From: Ritul Jasuja Date: Thu, 17 Feb 2022 14:28:33 +0530 Subject: [PATCH] Remove extra Transpose layers from the generated graph Tranpose layers are added to every operator's input and output if layout is NHWC Instead it should be added only at the beginning and end of the graph Signed-off-by: Ritul Jasuja --- .../operations/include/OperationsBase.hpp | 1 + ngraph_creator/operations/src/AveragePool2D.cpp | 11 ++++++++--- ngraph_creator/operations/src/BatchToSpace.cpp | 14 +++++++++++--- ngraph_creator/operations/src/Conv2d.cpp | 11 ++++++++--- ngraph_creator/operations/src/DepthToSpace.cpp | 14 +++++++++++--- ngraph_creator/operations/src/DepthwiseConv2d.cpp | 11 ++++++++--- ngraph_creator/operations/src/GroupedConv2d.cpp | 11 ++++++++--- .../operations/src/InstanceNormalization.cpp | 14 +++++++++++--- ngraph_creator/operations/src/L2Pooling2D.cpp | 13 +++++++++---- ngraph_creator/operations/src/MaxPool2d.cpp | 11 ++++++++--- ngraph_creator/operations/src/ROIAlign.cpp | 14 +++++++++++--- ngraph_creator/operations/src/ROIPooling.cpp | 14 +++++++++++--- ngraph_creator/operations/src/ResizeBilinear.cpp | 12 +++++++++--- .../operations/src/ResizeNearestNeighbor.cpp | 12 +++++++++--- ngraph_creator/operations/src/SpaceToBatch.cpp | 14 +++++++++++--- ngraph_creator/operations/src/SpaceToDepth.cpp | 14 +++++++++++--- ngraph_creator/operations/src/TransposeConv2D.cpp | 11 ++++++++--- 17 files changed, 153 insertions(+), 49 deletions(-) diff --git a/ngraph_creator/operations/include/OperationsBase.hpp b/ngraph_creator/operations/include/OperationsBase.hpp index 5f2f06a49..f28b0eb97 100644 --- a/ngraph_creator/operations/include/OperationsBase.hpp +++ b/ngraph_creator/operations/include/OperationsBase.hpp @@ -196,6 +196,7 @@ class OperationsBase { // override connectOperationToGraph in case Operation has multiple outputs virtual void connectOperationToGraph(); virtual ~OperationsBase() {} + bool transposed_nchw = false; }; } // namespace nnhal diff --git a/ngraph_creator/operations/src/AveragePool2D.cpp b/ngraph_creator/operations/src/AveragePool2D.cpp index f72c71486..d0d28246a 100644 --- a/ngraph_creator/operations/src/AveragePool2D.cpp +++ b/ngraph_creator/operations/src/AveragePool2D.cpp @@ -120,8 +120,11 @@ std::shared_ptr AveragePool2D::createNode() { } } - if (!useNchw) { // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } } strides = {(size_t)stride_height, (size_t)stride_width}; @@ -135,7 +138,9 @@ std::shared_ptr AveragePool2D::createNode() { outputNode = applyActivation(outputNode, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } diff --git a/ngraph_creator/operations/src/BatchToSpace.cpp b/ngraph_creator/operations/src/BatchToSpace.cpp index fdc78483a..c9fdefafd 100644 --- a/ngraph_creator/operations/src/BatchToSpace.cpp +++ b/ngraph_creator/operations/src/BatchToSpace.cpp @@ -54,13 +54,21 @@ std::shared_ptr BatchToSpace::createNode() { const auto crop_begin = createConstNode(ngraph::element::i64, {shape.size()}, shape); const auto crop_end = createConstNode(ngraph::element::i64, {shape.size()}, shape); - if (!useNchw) // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } + } std::shared_ptr outputNode = std::make_shared( inputNode, block_shape_node, crop_begin, crop_end); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } return outputNode; } diff --git a/ngraph_creator/operations/src/Conv2d.cpp b/ngraph_creator/operations/src/Conv2d.cpp index 4f210399f..8a4c53481 100644 --- a/ngraph_creator/operations/src/Conv2d.cpp +++ b/ngraph_creator/operations/src/Conv2d.cpp @@ -205,8 +205,11 @@ std::shared_ptr Conv2d::createNode() { // OpenVino expects filter in OIHW format filterNode = transpose(OHWI_OIHW, filterNode); - if (!useNchw) { // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } } strides = {(size_t)stride_height, (size_t)stride_width}; @@ -229,7 +232,9 @@ std::shared_ptr Conv2d::createNode() { convNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY); outputNode = applyActivation(outputNode, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } diff --git a/ngraph_creator/operations/src/DepthToSpace.cpp b/ngraph_creator/operations/src/DepthToSpace.cpp index dbf30743f..508aaf025 100644 --- a/ngraph_creator/operations/src/DepthToSpace.cpp +++ b/ngraph_creator/operations/src/DepthToSpace.cpp @@ -25,15 +25,23 @@ std::shared_ptr DepthToSpace::createNode() { input = getInputNode(0); auto block_size = sModelInfo->ParseOperationInput(mNnapiOperationIndex, 1); - if (!useNchw) // No conversion needed if useNchw set - input = transpose(NHWC_NCHW, input); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + input = transpose(NHWC_NCHW, input); + transposed_nchw = true; + } + } std::shared_ptr outputNode; outputNode = std::make_shared( input, ngraph::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, block_size); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } return outputNode; } diff --git a/ngraph_creator/operations/src/DepthwiseConv2d.cpp b/ngraph_creator/operations/src/DepthwiseConv2d.cpp index 3507c67c4..29483e7b4 100644 --- a/ngraph_creator/operations/src/DepthwiseConv2d.cpp +++ b/ngraph_creator/operations/src/DepthwiseConv2d.cpp @@ -214,8 +214,11 @@ std::shared_ptr DepthwiseConv2d::createNode() { // OpenVino expects filter in OIHW format filterNode = transpose(IHWO_OIHW, filterNode); - if (!useNchw) { // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } } strides = {(size_t)stride_height, (size_t)stride_width}; @@ -249,7 +252,9 @@ std::shared_ptr DepthwiseConv2d::createNode() { groupConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY); outputNode = applyActivation(outputNode, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } diff --git a/ngraph_creator/operations/src/GroupedConv2d.cpp b/ngraph_creator/operations/src/GroupedConv2d.cpp index 7c09a2fd6..eb0481aa2 100644 --- a/ngraph_creator/operations/src/GroupedConv2d.cpp +++ b/ngraph_creator/operations/src/GroupedConv2d.cpp @@ -178,8 +178,11 @@ std::shared_ptr GroupedConv2d::createNode() { // OpenVino expects filter in OIHW format filterNode = transpose(OHWI_OIHW, filterNode); - if (!useNchw) { // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } } strides = {(size_t)stride_width, (size_t)stride_height}; @@ -218,7 +221,9 @@ std::shared_ptr GroupedConv2d::createNode() { groupConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY); outputNode = applyActivation(outputNode, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } diff --git a/ngraph_creator/operations/src/InstanceNormalization.cpp b/ngraph_creator/operations/src/InstanceNormalization.cpp index fdf3d0f35..fa068aa17 100644 --- a/ngraph_creator/operations/src/InstanceNormalization.cpp +++ b/ngraph_creator/operations/src/InstanceNormalization.cpp @@ -50,8 +50,12 @@ std::shared_ptr InstanceNormalization::createNode() { auto layout = sModelInfo->ParseOperationInput(mNnapiOperationIndex, 4); if (layout) useNchw = true; - if (!useNchw) // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } + } // output[b, h, w, c] = (input[b, h, w, c] - mean[b, c]) * gamma / // sqrt(var[b, c] + epsilon) + beta @@ -71,7 +75,11 @@ std::shared_ptr InstanceNormalization::createNode() { std::shared_ptr outputNode = std::make_shared(mulGamma, betaNode); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } ALOGV("%s PASSED", __func__); return outputNode; diff --git a/ngraph_creator/operations/src/L2Pooling2D.cpp b/ngraph_creator/operations/src/L2Pooling2D.cpp index d8520b132..a7463724b 100644 --- a/ngraph_creator/operations/src/L2Pooling2D.cpp +++ b/ngraph_creator/operations/src/L2Pooling2D.cpp @@ -114,9 +114,12 @@ std::shared_ptr L2Pooling2D::createNode() { inputNode = getInputNode(0); inputSquared = std::make_shared(inputNode, inputNode); - if (!useNchw) { - ALOGD("%s Forced NCHW conversion at operationIndex %d", __func__, mNnapiOperationIndex); - inputSquared = transpose(NHWC_NCHW, inputSquared); + if (!transposed_nchw) { + if (!useNchw) { + ALOGD("%s Forced NCHW conversion at operationIndex %d", __func__, mNnapiOperationIndex); + inputSquared = transpose(NHWC_NCHW, inputSquared); + transposed_nchw = true; + } } strides = {(size_t)stride_height, (size_t)stride_width}; @@ -132,7 +135,9 @@ std::shared_ptr L2Pooling2D::createNode() { auto outputNode = applyActivation(sqrtOutput, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } diff --git a/ngraph_creator/operations/src/MaxPool2d.cpp b/ngraph_creator/operations/src/MaxPool2d.cpp index bf320b897..ee492a5e4 100644 --- a/ngraph_creator/operations/src/MaxPool2d.cpp +++ b/ngraph_creator/operations/src/MaxPool2d.cpp @@ -115,8 +115,11 @@ std::shared_ptr MaxPool2d::createNode() { std::shared_ptr inputNode; inputNode = getInputNode(0); - if (!useNchw) { // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } } strides = {(size_t)stride_height, (size_t)stride_width}; @@ -130,7 +133,9 @@ std::shared_ptr MaxPool2d::createNode() { auto outputNode = applyActivation(maxpoolNode, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } diff --git a/ngraph_creator/operations/src/ROIAlign.cpp b/ngraph_creator/operations/src/ROIAlign.cpp index bb5a7ff86..ac093262b 100644 --- a/ngraph_creator/operations/src/ROIAlign.cpp +++ b/ngraph_creator/operations/src/ROIAlign.cpp @@ -75,8 +75,12 @@ std::shared_ptr ROIAlign::createNode() { if (layout) useNchw = true; - if (!useNchw) // No conversion needed if useNchw set - feat_maps = transpose(NHWC_NCHW, feat_maps); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + feat_maps = transpose(NHWC_NCHW, feat_maps); + transposed_nchw = true; + } + } float spatial_scale = 1.0 / (height_ratio); int sampling_ratio = sampling_pts_h; @@ -85,7 +89,11 @@ std::shared_ptr ROIAlign::createNode() { feat_maps, rois, batch_indices, output_height, output_width, sampling_ratio, spatial_scale, "avg"); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } ALOGV("%s PASSED", __func__); diff --git a/ngraph_creator/operations/src/ROIPooling.cpp b/ngraph_creator/operations/src/ROIPooling.cpp index 3e561bf87..8b189d4e9 100755 --- a/ngraph_creator/operations/src/ROIPooling.cpp +++ b/ngraph_creator/operations/src/ROIPooling.cpp @@ -64,8 +64,12 @@ std::shared_ptr ROIPooling::createNode() { if (layout) useNchw = true; - if (!useNchw) // No conversion needed if useNchw set - feat_maps = transpose(NHWC_NCHW, feat_maps); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + feat_maps = transpose(NHWC_NCHW, feat_maps); + transposed_nchw = true; + } + } auto output_size = ngraph::Shape{(size_t)output_height, (size_t)output_width}; float spatial_scale = 1.0 / (height_ratio); @@ -91,7 +95,11 @@ std::shared_ptr ROIPooling::createNode() { std::shared_ptr outputNode = std::make_shared( feat_maps, roiNode, output_size, spatial_scale); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } ALOGV("%s PASSED", __func__); diff --git a/ngraph_creator/operations/src/ResizeBilinear.cpp b/ngraph_creator/operations/src/ResizeBilinear.cpp index 97e61ce71..c01fc0a34 100644 --- a/ngraph_creator/operations/src/ResizeBilinear.cpp +++ b/ngraph_creator/operations/src/ResizeBilinear.cpp @@ -59,8 +59,12 @@ std::shared_ptr ResizeBilinear::createNode() { input_width = inputDimensions[2]; input_height = inputDimensions[1]; } - - if (!useNchw) inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } + } // FLOAT16 type check added for future when VPUX plugin support is added if (checkInputOperandType(1, (int32_t)OperandType::FLOAT32)) { // In tensorflow lite, resizing by size is supported. Scaling factors are @@ -123,7 +127,9 @@ std::shared_ptr ResizeBilinear::createNode() { outputNode = std::make_shared(inputNode, outputShapeNode, scaleNode, axesNode, attrs); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } return outputNode; diff --git a/ngraph_creator/operations/src/ResizeNearestNeighbor.cpp b/ngraph_creator/operations/src/ResizeNearestNeighbor.cpp index ba163a729..6712507c7 100644 --- a/ngraph_creator/operations/src/ResizeNearestNeighbor.cpp +++ b/ngraph_creator/operations/src/ResizeNearestNeighbor.cpp @@ -59,8 +59,12 @@ std::shared_ptr ResizeNearestNeighbor::createNode() { input_width = inputDimensions[2]; input_height = inputDimensions[1]; } - - if (!useNchw) inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } + } attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes; // mode is passed as "nearest" for Nearest Neighbor interpolation @@ -118,7 +122,9 @@ std::shared_ptr ResizeNearestNeighbor::createNode() { outputNode = std::make_shared(inputNode, outputShapeNode, scaleNode, axesNode, attrs); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); } return outputNode; diff --git a/ngraph_creator/operations/src/SpaceToBatch.cpp b/ngraph_creator/operations/src/SpaceToBatch.cpp index 976c16e7d..b9452abff 100644 --- a/ngraph_creator/operations/src/SpaceToBatch.cpp +++ b/ngraph_creator/operations/src/SpaceToBatch.cpp @@ -78,13 +78,21 @@ std::shared_ptr SpaceToBatch::createNode() { if (inputsSize == 4) layout = sModelInfo->ParseOperationInput(mNnapiOperationIndex, 3); if (layout) useNchw = true; - if (!useNchw) // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } + } std::shared_ptr outputNode = std::make_shared( inputNode, block_shape_node, pad_begin, pad_end); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } return outputNode; } diff --git a/ngraph_creator/operations/src/SpaceToDepth.cpp b/ngraph_creator/operations/src/SpaceToDepth.cpp index 83eb8fbb2..2d4ecdf27 100644 --- a/ngraph_creator/operations/src/SpaceToDepth.cpp +++ b/ngraph_creator/operations/src/SpaceToDepth.cpp @@ -26,15 +26,23 @@ std::shared_ptr SpaceToDepth::createNode() { input = getInputNode(0); auto block_size = sModelInfo->ParseOperationInput(mNnapiOperationIndex, 1); - if (!useNchw) // No conversion needed if useNchw set - input = transpose(NHWC_NCHW, input); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + input = transpose(NHWC_NCHW, input); + transposed_nchw = true; + } + } std::shared_ptr outputNode; outputNode = std::make_shared( input, ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, block_size); - if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode); + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { + outputNode = transpose(NCHW_NHWC, outputNode); + } return outputNode; } diff --git a/ngraph_creator/operations/src/TransposeConv2D.cpp b/ngraph_creator/operations/src/TransposeConv2D.cpp index 5b9d6999a..31b4b8d6b 100644 --- a/ngraph_creator/operations/src/TransposeConv2D.cpp +++ b/ngraph_creator/operations/src/TransposeConv2D.cpp @@ -186,8 +186,11 @@ std::shared_ptr TransposeConv2D::createNode() { // OpenVino expects filter in OIHW format filterNode = transpose(IHWO_OIHW, filterNode); - if (!useNchw) { // No conversion needed if useNchw set - inputNode = transpose(NHWC_NCHW, inputNode); + if (!transposed_nchw) { + if (!useNchw) { // No conversion needed if useNchw set + inputNode = transpose(NHWC_NCHW, inputNode); + transposed_nchw = true; + } } strides = {(size_t)stride_height, (size_t)stride_width}; @@ -218,7 +221,9 @@ std::shared_ptr TransposeConv2D::createNode() { transposeConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY); outputNode = applyActivation(outputNode, activationFn); - if (!useNchw) { + auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0); + const auto outputOp = sModelInfo->getOperand(outputIndex); + if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) { outputNode = transpose(NCHW_NHWC, outputNode); }