Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
Remove extra Transpose layers from the generated graph
Browse files Browse the repository at this point in the history
Tranpose layers are added to every operator's input and output if layout
is NHWC
Instead it should be added only at the beginning and end of the graph

Signed-off-by: Ritul Jasuja <[email protected]>
  • Loading branch information
Rjasuja committed Feb 17, 2022
1 parent a2a9c70 commit a981368
Show file tree
Hide file tree
Showing 17 changed files with 153 additions and 49 deletions.
1 change: 1 addition & 0 deletions ngraph_creator/operations/include/OperationsBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class OperationsBase {
// override connectOperationToGraph in case Operation has multiple outputs
virtual void connectOperationToGraph();
virtual ~OperationsBase() {}
bool transposed_nchw = false;
};

} // namespace nnhal
Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/AveragePool2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ std::shared_ptr<ngraph::Node> AveragePool2D::createNode() {
}
}

if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -135,7 +138,9 @@ std::shared_ptr<ngraph::Node> AveragePool2D::createNode() {

outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/BatchToSpace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,21 @@ std::shared_ptr<ngraph::Node> BatchToSpace::createNode() {
const auto crop_begin = createConstNode(ngraph::element::i64, {shape.size()}, shape);
const auto crop_end = createConstNode(ngraph::element::i64, {shape.size()}, shape);

if (!useNchw) // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode = std::make_shared<ngraph::opset3::BatchToSpace>(
inputNode, block_shape_node, crop_begin, crop_end);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/Conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,11 @@ std::shared_ptr<ngraph::Node> Conv2d::createNode() {

// OpenVino expects filter in OIHW format
filterNode = transpose(OHWI_OIHW, filterNode);
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -229,7 +232,9 @@ std::shared_ptr<ngraph::Node> Conv2d::createNode() {
convNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY);
outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/DepthToSpace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,23 @@ std::shared_ptr<ngraph::Node> DepthToSpace::createNode() {
input = getInputNode(0);
auto block_size = sModelInfo->ParseOperationInput<uint32_t>(mNnapiOperationIndex, 1);

if (!useNchw) // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode;

outputNode = std::make_shared<ngraph::opset3::DepthToSpace>(
input, ngraph::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, block_size);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/DepthwiseConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,11 @@ std::shared_ptr<ngraph::Node> DepthwiseConv2d::createNode() {

// OpenVino expects filter in OIHW format
filterNode = transpose(IHWO_OIHW, filterNode);
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand Down Expand Up @@ -249,7 +252,9 @@ std::shared_ptr<ngraph::Node> DepthwiseConv2d::createNode() {
groupConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY);
outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/GroupedConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,11 @@ std::shared_ptr<ngraph::Node> GroupedConv2d::createNode() {
// OpenVino expects filter in OIHW format
filterNode = transpose(OHWI_OIHW, filterNode);

if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_width, (size_t)stride_height};
Expand Down Expand Up @@ -218,7 +221,9 @@ std::shared_ptr<ngraph::Node> GroupedConv2d::createNode() {
groupConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY);
outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/InstanceNormalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ std::shared_ptr<ngraph::Node> InstanceNormalization::createNode() {
auto layout = sModelInfo->ParseOperationInput<uint8_t>(mNnapiOperationIndex, 4);
if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

// output[b, h, w, c] = (input[b, h, w, c] - mean[b, c]) * gamma /
// sqrt(var[b, c] + epsilon) + beta
Expand All @@ -71,7 +75,11 @@ std::shared_ptr<ngraph::Node> InstanceNormalization::createNode() {
std::shared_ptr<ngraph::Node> outputNode =
std::make_shared<ngraph::opset3::Add>(mulGamma, betaNode);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}
ALOGV("%s PASSED", __func__);

return outputNode;
Expand Down
13 changes: 9 additions & 4 deletions ngraph_creator/operations/src/L2Pooling2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ std::shared_ptr<ngraph::Node> L2Pooling2D::createNode() {
inputNode = getInputNode(0);
inputSquared = std::make_shared<ngraph::op::v1::Multiply>(inputNode, inputNode);

if (!useNchw) {
ALOGD("%s Forced NCHW conversion at operationIndex %d", __func__, mNnapiOperationIndex);
inputSquared = transpose(NHWC_NCHW, inputSquared);
if (!transposed_nchw) {
if (!useNchw) {
ALOGD("%s Forced NCHW conversion at operationIndex %d", __func__, mNnapiOperationIndex);
inputSquared = transpose(NHWC_NCHW, inputSquared);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -132,7 +135,9 @@ std::shared_ptr<ngraph::Node> L2Pooling2D::createNode() {

auto outputNode = applyActivation(sqrtOutput, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/MaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ std::shared_ptr<ngraph::Node> MaxPool2d::createNode() {
std::shared_ptr<ngraph::Node> inputNode;
inputNode = getInputNode(0);

if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -130,7 +133,9 @@ std::shared_ptr<ngraph::Node> MaxPool2d::createNode() {

auto outputNode = applyActivation(maxpoolNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/ROIAlign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ std::shared_ptr<ngraph::Node> ROIAlign::createNode() {

if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
transposed_nchw = true;
}
}

float spatial_scale = 1.0 / (height_ratio);
int sampling_ratio = sampling_pts_h;
Expand All @@ -85,7 +89,11 @@ std::shared_ptr<ngraph::Node> ROIAlign::createNode() {
feat_maps, rois, batch_indices, output_height, output_width, sampling_ratio, spatial_scale,
"avg");

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

ALOGV("%s PASSED", __func__);

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/ROIPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ std::shared_ptr<ngraph::Node> ROIPooling::createNode() {

if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
transposed_nchw = true;
}
}

auto output_size = ngraph::Shape{(size_t)output_height, (size_t)output_width};
float spatial_scale = 1.0 / (height_ratio);
Expand All @@ -91,7 +95,11 @@ std::shared_ptr<ngraph::Node> ROIPooling::createNode() {
std::shared_ptr<ngraph::Node> outputNode = std::make_shared<ngraph::opset3::ROIPooling>(
feat_maps, roiNode, output_size, spatial_scale);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

ALOGV("%s PASSED", __func__);

Expand Down
12 changes: 9 additions & 3 deletions ngraph_creator/operations/src/ResizeBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ std::shared_ptr<ngraph::Node> ResizeBilinear::createNode() {
input_width = inputDimensions[2];
input_height = inputDimensions[1];
}

if (!useNchw) inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) {
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}
// FLOAT16 type check added for future when VPUX plugin support is added
if (checkInputOperandType(1, (int32_t)OperandType::FLOAT32)) {
// In tensorflow lite, resizing by size is supported. Scaling factors are
Expand Down Expand Up @@ -123,7 +127,9 @@ std::shared_ptr<ngraph::Node> ResizeBilinear::createNode() {

outputNode = std::make_shared<ngraph::op::v4::Interpolate>(inputNode, outputShapeNode,
scaleNode, axesNode, attrs);
if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}
return outputNode;
Expand Down
12 changes: 9 additions & 3 deletions ngraph_creator/operations/src/ResizeNearestNeighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ std::shared_ptr<ngraph::Node> ResizeNearestNeighbor::createNode() {
input_width = inputDimensions[2];
input_height = inputDimensions[1];
}

if (!useNchw) inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) {
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
// mode is passed as "nearest" for Nearest Neighbor interpolation
Expand Down Expand Up @@ -118,7 +122,9 @@ std::shared_ptr<ngraph::Node> ResizeNearestNeighbor::createNode() {

outputNode = std::make_shared<ngraph::op::v4::Interpolate>(inputNode, outputShapeNode,
scaleNode, axesNode, attrs);
if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}
return outputNode;
Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/SpaceToBatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,21 @@ std::shared_ptr<ngraph::Node> SpaceToBatch::createNode() {
if (inputsSize == 4) layout = sModelInfo->ParseOperationInput<uint8_t>(mNnapiOperationIndex, 3);
if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode = std::make_shared<ngraph::opset3::SpaceToBatch>(
inputNode, block_shape_node, pad_begin, pad_end);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/SpaceToDepth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,23 @@ std::shared_ptr<ngraph::Node> SpaceToDepth::createNode() {
input = getInputNode(0);
auto block_size = sModelInfo->ParseOperationInput<uint32_t>(mNnapiOperationIndex, 1);

if (!useNchw) // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode;

outputNode = std::make_shared<ngraph::opset3::SpaceToDepth>(
input, ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, block_size);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
Loading

0 comments on commit a981368

Please sign in to comment.