-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PT FE] Add aten::rot90 #28224
Open
Po-V
wants to merge
8
commits into
openvinotoolkit:master
Choose a base branch
from
Po-V:add-pytorch-rot90
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+127
−0
Open
[PT FE] Add aten::rot90 #28224
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
1a8ea54
add rot90
Po-V a748733
add rot90 tests
Po-V 0940ab0
Address review comments for rot90 PR
Po-V 5bb26cf
Merge remote-tracking branch 'upstream/master' into add-pytorch-rot90
Po-V 74f0f21
Merge remote-tracking branch 'upstream/master' into add-pytorch-rot90
Po-V 06924c9
refactor dims handling in translate_rot90 and scatter update
Po-V f4511b1
Merge remote-tracking branch 'upstream/master' into add-pytorch-rot90
Po-V 6d9e19d
removed dims_values from translate_rot90
Po-V File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/core/validation_util.hpp" | ||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/range.hpp" | ||
#include "openvino/op/scatter_elements_update.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "openvino/op/split.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/unsqueeze.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_rot90(const NodeContext& context) { | ||
num_inputs_check(context, 1, 3); | ||
auto input = context.get_input(0); | ||
int k = context.input_is_none(1) ? 1 : context.const_input<int32_t>(1); | ||
|
||
auto dims = context.input_is_none(2) ? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 1})) | ||
: get_input_as_i32(context, 2); | ||
|
||
auto ndims = 0; | ||
if (input.get_partial_shape().rank().is_static()) { | ||
ndims = input.get_partial_shape().rank().get_length(); | ||
} | ||
|
||
std::shared_ptr<ov::Node> rank = | ||
std::make_shared<ov::op::v0::Constant>(ov::element::i32, | ||
ov::Shape{}, | ||
std::vector<int32_t>{static_cast<int32_t>(ndims)}); | ||
|
||
auto dims_norm = normalize_axis(context, dims, rank); | ||
auto dims_const = std::dynamic_pointer_cast<v0::Constant>(dims_norm.get_node_shared_ptr()); | ||
auto dims_values = dims_const->cast_vector<int32_t>(); | ||
|
||
PYTORCH_OP_CONVERSION_CHECK(dims_values.size() == 2, | ||
"Expected total rotation dims == 2, but got dims = ", | ||
dims_values.size()); | ||
|
||
PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, "Expected total dims >= 2, but got total dims = ", ndims); | ||
|
||
PYTORCH_OP_CONVERSION_CHECK(dims_values[0] != dims_values[1], | ||
"Rotation dimensions must be different, but got dim0 = " + | ||
std::to_string(dims_values[0]) + " and dim1 = " + std::to_string(dims_values[1])); | ||
|
||
auto start = v0::Constant::create(element::i32, {}, {0}); | ||
auto step = v0::Constant::create(element::i32, {}, {1}); | ||
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32); | ||
auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); | ||
auto split = std::make_shared<v1::Split>(dims_norm, axis_0, 2); | ||
auto dim0_node = std::make_shared<v0::Unsqueeze>(split->output(0), axis_0); | ||
auto dim1_node = std::make_shared<v0::Unsqueeze>(split->output(1), axis_0); | ||
auto indices = std::make_shared<v0::Concat>(OutputVector{dim0_node, dim1_node}, 0); | ||
auto updates = std::make_shared<v0::Concat>(OutputVector{dim1_node, dim0_node}, 0); | ||
|
||
Output<Node> scatter = std::make_shared<v3::ScatterElementsUpdate>(range, indices, updates, axis_0); | ||
|
||
k = k % 4; | ||
Output<Node> rotated; | ||
if (k == 1 || k == 3) { | ||
Output<Node> flip_dims = (k == 1) ? dim1_node : dim0_node; | ||
auto flipped = create_flip(input, flip_dims); | ||
rotated = context.mark_node(std::make_shared<v1::Transpose>(flipped, scatter)); | ||
} else if (k == 2) { | ||
rotated = create_flip(input, dims_norm); | ||
} else { | ||
rotated = input; | ||
} | ||
|
||
return {rotated}; | ||
} | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestRot90(PytorchLayerTest): | ||
def _prepare_input(self): | ||
|
||
x = np.arange(24).reshape(2, 3, 4).astype(np.float32) | ||
return (x,) | ||
|
||
def create_model(self, k, dims): | ||
import torch | ||
|
||
class aten_rot90(torch.nn.Module): | ||
def __init__(self, k=1, dims=(0, 1)): | ||
super(aten_rot90, self).__init__() | ||
self.k = k | ||
self.dims = dims | ||
|
||
def forward(self, x): | ||
return torch.rot90(x, self.k, self.dims) | ||
|
||
ref_net = None | ||
return aten_rot90(k, dims), ref_net, "aten::rot90" | ||
|
||
@pytest.mark.parametrize("k", [1, 2, 3, 4, 5]) | ||
@pytest.mark.parametrize("dims", [(0, 1), (0, 2), (1, 2)]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.precommit_torch_export | ||
def test_rot90(self, k, dims, ie_device, precision, ir_version): | ||
self._test(*self.create_model(k, dims), ie_device, precision, ir_version, | ||
trace_model=True,dynamic_shapes=False) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you only use
dims_values
only to validate that inputs are expected. That is nice to have, but not required and much better if we allow non-constant inputs then make these checks. So, please remove them. The only check that you can do is to verify that shape of dims is[2]
ordynamic
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the dims_values but I think there is some build errors in the pr request. Do i need to add anything if its dynamic ?