Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Remove extra Transpose layers from the generated graph #152

Open
wants to merge 1 commit into
base: A11_V1.3_dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ngraph_creator/operations/include/OperationsBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class OperationsBase {
// override connectOperationToGraph in case Operation has multiple outputs
virtual void connectOperationToGraph();
virtual ~OperationsBase() {}
bool transposed_nchw = false;
};

} // namespace nnhal
Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/AveragePool2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ std::shared_ptr<ngraph::Node> AveragePool2D::createNode() {
}
}

if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this code be shared by multiple operations? It seems like its needed in all operations, maybe we can have a function in the parent class?

}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -135,7 +138,9 @@ std::shared_ptr<ngraph::Node> AveragePool2D::createNode() {

outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/BatchToSpace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,21 @@ std::shared_ptr<ngraph::Node> BatchToSpace::createNode() {
const auto crop_begin = createConstNode(ngraph::element::i64, {shape.size()}, shape);
const auto crop_end = createConstNode(ngraph::element::i64, {shape.size()}, shape);

if (!useNchw) // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode = std::make_shared<ngraph::opset3::BatchToSpace>(
inputNode, block_shape_node, crop_begin, crop_end);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/Conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,11 @@ std::shared_ptr<ngraph::Node> Conv2d::createNode() {

// OpenVino expects filter in OIHW format
filterNode = transpose(OHWI_OIHW, filterNode);
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -229,7 +232,9 @@ std::shared_ptr<ngraph::Node> Conv2d::createNode() {
convNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY);
outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/DepthToSpace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,23 @@ std::shared_ptr<ngraph::Node> DepthToSpace::createNode() {
input = getInputNode(0);
auto block_size = sModelInfo->ParseOperationInput<uint32_t>(mNnapiOperationIndex, 1);

if (!useNchw) // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode;

outputNode = std::make_shared<ngraph::opset3::DepthToSpace>(
input, ngraph::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, block_size);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/DepthwiseConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,11 @@ std::shared_ptr<ngraph::Node> DepthwiseConv2d::createNode() {

// OpenVino expects filter in OIHW format
filterNode = transpose(IHWO_OIHW, filterNode);
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand Down Expand Up @@ -249,7 +252,9 @@ std::shared_ptr<ngraph::Node> DepthwiseConv2d::createNode() {
groupConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY);
outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/GroupedConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,11 @@ std::shared_ptr<ngraph::Node> GroupedConv2d::createNode() {
// OpenVino expects filter in OIHW format
filterNode = transpose(OHWI_OIHW, filterNode);

if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_width, (size_t)stride_height};
Expand Down Expand Up @@ -218,7 +221,9 @@ std::shared_ptr<ngraph::Node> GroupedConv2d::createNode() {
groupConvNode, biasNode, ngraph::op::AutoBroadcastType::NUMPY);
outputNode = applyActivation(outputNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/InstanceNormalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ std::shared_ptr<ngraph::Node> InstanceNormalization::createNode() {
auto layout = sModelInfo->ParseOperationInput<uint8_t>(mNnapiOperationIndex, 4);
if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

// output[b, h, w, c] = (input[b, h, w, c] - mean[b, c]) * gamma /
// sqrt(var[b, c] + epsilon) + beta
Expand All @@ -71,7 +75,11 @@ std::shared_ptr<ngraph::Node> InstanceNormalization::createNode() {
std::shared_ptr<ngraph::Node> outputNode =
std::make_shared<ngraph::opset3::Add>(mulGamma, betaNode);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}
ALOGV("%s PASSED", __func__);

return outputNode;
Expand Down
13 changes: 9 additions & 4 deletions ngraph_creator/operations/src/L2Pooling2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ std::shared_ptr<ngraph::Node> L2Pooling2D::createNode() {
inputNode = getInputNode(0);
inputSquared = std::make_shared<ngraph::op::v1::Multiply>(inputNode, inputNode);

if (!useNchw) {
ALOGD("%s Forced NCHW conversion at operationIndex %d", __func__, mNnapiOperationIndex);
inputSquared = transpose(NHWC_NCHW, inputSquared);
if (!transposed_nchw) {
if (!useNchw) {
ALOGD("%s Forced NCHW conversion at operationIndex %d", __func__, mNnapiOperationIndex);
inputSquared = transpose(NHWC_NCHW, inputSquared);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -132,7 +135,9 @@ std::shared_ptr<ngraph::Node> L2Pooling2D::createNode() {

auto outputNode = applyActivation(sqrtOutput, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
11 changes: 8 additions & 3 deletions ngraph_creator/operations/src/MaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ std::shared_ptr<ngraph::Node> MaxPool2d::createNode() {
std::shared_ptr<ngraph::Node> inputNode;
inputNode = getInputNode(0);

if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

strides = {(size_t)stride_height, (size_t)stride_width};
Expand All @@ -130,7 +133,9 @@ std::shared_ptr<ngraph::Node> MaxPool2d::createNode() {

auto outputNode = applyActivation(maxpoolNode, activationFn);

if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/ROIAlign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ std::shared_ptr<ngraph::Node> ROIAlign::createNode() {

if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
transposed_nchw = true;
}
}

float spatial_scale = 1.0 / (height_ratio);
int sampling_ratio = sampling_pts_h;
Expand All @@ -85,7 +89,11 @@ std::shared_ptr<ngraph::Node> ROIAlign::createNode() {
feat_maps, rois, batch_indices, output_height, output_width, sampling_ratio, spatial_scale,
"avg");

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

ALOGV("%s PASSED", __func__);

Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/ROIPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ std::shared_ptr<ngraph::Node> ROIPooling::createNode() {

if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
feat_maps = transpose(NHWC_NCHW, feat_maps);
transposed_nchw = true;
}
}

auto output_size = ngraph::Shape{(size_t)output_height, (size_t)output_width};
float spatial_scale = 1.0 / (height_ratio);
Expand All @@ -91,7 +95,11 @@ std::shared_ptr<ngraph::Node> ROIPooling::createNode() {
std::shared_ptr<ngraph::Node> outputNode = std::make_shared<ngraph::opset3::ROIPooling>(
feat_maps, roiNode, output_size, spatial_scale);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

ALOGV("%s PASSED", __func__);

Expand Down
12 changes: 9 additions & 3 deletions ngraph_creator/operations/src/ResizeBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ std::shared_ptr<ngraph::Node> ResizeBilinear::createNode() {
input_width = inputDimensions[2];
input_height = inputDimensions[1];
}

if (!useNchw) inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) {
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}
// FLOAT16 type check added for future when VPUX plugin support is added
if (checkInputOperandType(1, (int32_t)OperandType::FLOAT32)) {
// In tensorflow lite, resizing by size is supported. Scaling factors are
Expand Down Expand Up @@ -123,7 +127,9 @@ std::shared_ptr<ngraph::Node> ResizeBilinear::createNode() {

outputNode = std::make_shared<ngraph::op::v4::Interpolate>(inputNode, outputShapeNode,
scaleNode, axesNode, attrs);
if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}
return outputNode;
Expand Down
12 changes: 9 additions & 3 deletions ngraph_creator/operations/src/ResizeNearestNeighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ std::shared_ptr<ngraph::Node> ResizeNearestNeighbor::createNode() {
input_width = inputDimensions[2];
input_height = inputDimensions[1];
}

if (!useNchw) inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) {
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
// mode is passed as "nearest" for Nearest Neighbor interpolation
Expand Down Expand Up @@ -118,7 +122,9 @@ std::shared_ptr<ngraph::Node> ResizeNearestNeighbor::createNode() {

outputNode = std::make_shared<ngraph::op::v4::Interpolate>(inputNode, outputShapeNode,
scaleNode, axesNode, attrs);
if (!useNchw) {
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}
return outputNode;
Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/SpaceToBatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,21 @@ std::shared_ptr<ngraph::Node> SpaceToBatch::createNode() {
if (inputsSize == 4) layout = sModelInfo->ParseOperationInput<uint8_t>(mNnapiOperationIndex, 3);
if (layout) useNchw = true;

if (!useNchw) // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
inputNode = transpose(NHWC_NCHW, inputNode);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode = std::make_shared<ngraph::opset3::SpaceToBatch>(
inputNode, block_shape_node, pad_begin, pad_end);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
14 changes: 11 additions & 3 deletions ngraph_creator/operations/src/SpaceToDepth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,23 @@ std::shared_ptr<ngraph::Node> SpaceToDepth::createNode() {
input = getInputNode(0);
auto block_size = sModelInfo->ParseOperationInput<uint32_t>(mNnapiOperationIndex, 1);

if (!useNchw) // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
if (!transposed_nchw) {
if (!useNchw) { // No conversion needed if useNchw set
input = transpose(NHWC_NCHW, input);
transposed_nchw = true;
}
}

std::shared_ptr<ngraph::Node> outputNode;

outputNode = std::make_shared<ngraph::opset3::SpaceToDepth>(
input, ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, block_size);

if (!useNchw) outputNode = transpose(NCHW_NHWC, outputNode);
auto outputIndex = sModelInfo->getOperationOutput(mNnapiOperationIndex, 0);
const auto outputOp = sModelInfo->getOperand(outputIndex);
if (!useNchw && (outputOp.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT)) {
outputNode = transpose(NCHW_NHWC, outputNode);
}

return outputNode;
}
Expand Down
Loading