From 1a8ea549c43a9a0a21fa0bef0f9ae0823d2917eb Mon Sep 17 00:00:00 2001 From: Po-V Date: Sun, 29 Dec 2024 03:28:41 +0000 Subject: [PATCH 1/6] add rot90 --- src/frontends/pytorch/src/op/rot90.cpp | 67 ++++++++++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + 2 files changed, 69 insertions(+) create mode 100644 src/frontends/pytorch/src/op/rot90.cpp diff --git a/src/frontends/pytorch/src/op/rot90.cpp b/src/frontends/pytorch/src/op/rot90.cpp new file mode 100644 index 00000000000000..db14e455cb4e61 --- /dev/null +++ b/src/frontends/pytorch/src/op/rot90.cpp @@ -0,0 +1,67 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/transpose.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_rot90(const NodeContext& context) { + num_inputs_check(context, 1, 3); + auto input = context.get_input(0); + int k = context.input_is_none(1) ? 1 : context.const_input(1); + std::vector dims = context.input_is_none(2) ? std::vector{0, 1} + : context.const_input>(2); + const auto& partial_shape = input.get_partial_shape(); + const auto ndims = partial_shape.rank().get_length(); + + PYTORCH_OP_CONVERSION_CHECK(dims.size() == 2, + "Expected total rotation dims == 2, but got dims = ", + dims.size()); + PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, + "Expected total dims >= 2, but got total dims = ", + ndims); + PYTORCH_OP_CONVERSION_CHECK(dims[0] != dims[1], + "Rotation dimensions must be different, but got dim0 = " + + std::to_string(dims[0]) + " and dim1 = " + std::to_string(dims[1])); + + for (auto& dim : dims) { + dim = (dim + ndims) % ndims; + } + + k = k % 4; + Output rotated; + + if (k == 1 || k == 3) { + int64_t flip_dim = (k == 1) ? dims[1] : dims[0]; + auto flip_dims = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {flip_dim})); + auto flipped = create_flip(input, flip_dims); + std::vector perm_values(ndims); + std::iota(perm_values.begin(), perm_values.end(), 0); + std::swap(perm_values[dims[0]], perm_values[dims[1]]); + auto perm = context.mark_node( + v0::Constant::create(element::i32, Shape{static_cast(ndims)}, perm_values)); + rotated = context.mark_node(std::make_shared(flipped, perm)); + } else if (k == 2) { + size_t dims_size = dims.size(); + auto flip_dims = context.mark_node(v0::Constant::create(element::i32, Shape{dims_size}, dims)); + rotated = create_flip(input, flip_dims); + } else { + rotated = input; + } + + return {rotated}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a73c13814d7663..2ccaabe186d0e7 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -200,6 +200,7 @@ OP_CONVERTER(translate_reshape_as); OP_CONVERTER(translate_rnn); OP_CONVERTER(translate_roi_align); OP_CONVERTER(translate_roll); +OP_CONVERTER(translate_rot90); OP_CONVERTER(translate_round); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_rsub); @@ -624,6 +625,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::rnn_relu", op::translate_rnn}, {"aten::rnn_tanh", op::translate_rnn}, {"aten::roll", op::translate_roll}, + {"aten::rot90", op::translate_rot90}, {"aten::round", op::translate_round}, {"aten::rsqrt", op::optional_out}, {"aten::rsqrt_", op::inplace_op}, From a748733ccdc2f3462d52f2ce496c8cc3540e7c58 Mon Sep 17 00:00:00 2001 From: Po-V Date: Sun, 29 Dec 2024 03:29:38 +0000 Subject: [PATCH 2/6] add rot90 tests --- tests/layer_tests/pytorch_tests/test_rot90.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_rot90.py diff --git a/tests/layer_tests/pytorch_tests/test_rot90.py b/tests/layer_tests/pytorch_tests/test_rot90.py new file mode 100644 index 00000000000000..ac369353262746 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_rot90.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRot90(PytorchLayerTest): + def _prepare_input(self): + + x = np.arange(24).reshape(2, 3, 4).astype(np.float32) + return (x,) + + def create_model(self, k, dims): + import torch + + class aten_rot90(torch.nn.Module): + def __init__(self, k=1, dims=(0, 1)): + super(aten_rot90, self).__init__() + self.k = k + self.dims = dims + + def forward(self, x): + return torch.rot90(x, self.k, self.dims) + + ref_net = None + return aten_rot90(k, dims), ref_net, "aten::rot90" + + @pytest.mark.parametrize("k", [1, 2, 3, 4, 5]) + @pytest.mark.parametrize("dims", [(0, 1), (0, 2), (1, 2)]) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + def test_rot90(self, k, dims, ie_device, precision, ir_version): + self._test(*self.create_model(k, dims), ie_device, precision, ir_version, + trace_model=True,dynamic_shapes=False) \ No newline at end of file From 0940ab044fbac3ed1938d4ef1a92c0a97199d7cb Mon Sep 17 00:00:00 2001 From: Po-V Date: Sat, 4 Jan 2025 11:01:23 +0000 Subject: [PATCH 3/6] Address review comments for rot90 PR --- src/frontends/pytorch/src/op/rot90.cpp | 70 +++++++++++++++++--------- src/frontends/pytorch/src/utils.cpp | 9 ++++ src/frontends/pytorch/src/utils.hpp | 2 + 3 files changed, 58 insertions(+), 23 deletions(-) diff --git a/src/frontends/pytorch/src/op/rot90.cpp b/src/frontends/pytorch/src/op/rot90.cpp index db14e455cb4e61..4790d28600b539 100644 --- a/src/frontends/pytorch/src/op/rot90.cpp +++ b/src/frontends/pytorch/src/op/rot90.cpp @@ -5,6 +5,12 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/core/validation_util.hpp" +#include "openvino/op/shape_of.hpp" #include "utils.hpp" namespace ov { @@ -17,43 +23,61 @@ using namespace ov::op; OutputVector translate_rot90(const NodeContext& context) { num_inputs_check(context, 1, 3); auto input = context.get_input(0); - int k = context.input_is_none(1) ? 1 : context.const_input(1); - std::vector dims = context.input_is_none(2) ? std::vector{0, 1} - : context.const_input>(2); + int k = context.input_is_none(1) ? 1 : context.const_input(1); + auto dims = context.input_is_none(2) + ? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0,1})) + : get_input_as_i32(context, 2); const auto& partial_shape = input.get_partial_shape(); const auto ndims = partial_shape.rank().get_length(); - PYTORCH_OP_CONVERSION_CHECK(dims.size() == 2, + std::shared_ptr rank = std::make_shared( + ov::element::i32, ov::Shape{}, std::vector{static_cast(ndims)}); + auto dims_norm = normalize_axis(context, dims, rank); + auto dims_const = std::dynamic_pointer_cast(dims_norm.get_node_shared_ptr()); + auto dims_values = dims_const->cast_vector(); + + auto start = v0::Constant::create(element::i32, {}, {0}); + auto step = v0::Constant::create(element::i32, {}, {1}); + auto range = std::make_shared(start, rank, step, element::i32); + + auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto dim0_node = std::make_shared( + v0::Constant::create(element::i32, {}, {dims_values[0]}), axis_0); + auto dim1_node = std::make_shared( + v0::Constant::create(element::i32, {}, {dims_values[1]}), axis_0); + + auto indices = std::make_shared(OutputVector{dim0_node, dim1_node}, 0); + auto updates = std::make_shared( + OutputVector{dim1_node, dim0_node}, 0); + + Output scatter = std::make_shared( + range, indices, updates, axis_0); + if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) { + scatter = context.mark_node(scatter_const); + } else { + context.mark_nodes( + {start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()}); + } + + PYTORCH_OP_CONVERSION_CHECK(dims_values.size() == 2, "Expected total rotation dims == 2, but got dims = ", - dims.size()); + dims_values.size()); PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, "Expected total dims >= 2, but got total dims = ", ndims); - PYTORCH_OP_CONVERSION_CHECK(dims[0] != dims[1], + PYTORCH_OP_CONVERSION_CHECK(dims_values[0] != dims_values[1], "Rotation dimensions must be different, but got dim0 = " + - std::to_string(dims[0]) + " and dim1 = " + std::to_string(dims[1])); - - for (auto& dim : dims) { - dim = (dim + ndims) % ndims; - } + std::to_string(dims_values[0]) + " and dim1 = " + std::to_string(dims_values[1])); k = k % 4; Output rotated; if (k == 1 || k == 3) { - int64_t flip_dim = (k == 1) ? dims[1] : dims[0]; - auto flip_dims = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {flip_dim})); - auto flipped = create_flip(input, flip_dims); - std::vector perm_values(ndims); - std::iota(perm_values.begin(), perm_values.end(), 0); - std::swap(perm_values[dims[0]], perm_values[dims[1]]); - auto perm = context.mark_node( - v0::Constant::create(element::i32, Shape{static_cast(ndims)}, perm_values)); - rotated = context.mark_node(std::make_shared(flipped, perm)); + Output flip_dims = (k ==1) ? dim1_node : dim0_node; + auto flipped = create_flip(input, flip_dims); + rotated = context.mark_node(std::make_shared(flipped, scatter)); } else if (k == 2) { - size_t dims_size = dims.size(); - auto flip_dims = context.mark_node(v0::Constant::create(element::i32, Shape{dims_size}, dims)); - rotated = create_flip(input, flip_dims); + rotated = create_flip(input, dims_norm); } else { rotated = input; } diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 5cc7ec21f30911..a48fce79bb73e3 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -167,6 +167,15 @@ Output normalize_axis(const NodeContext& context, const Output& axis } } +Output create_flip(const Output& x, const Output& axis) { + auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1}); + auto minimum_int = v0::Constant::create(element::i32, Shape{}, {std::numeric_limits::min()}); + auto axis_shape = std::make_shared(axis, element::i32); + auto start = std::make_shared(minus_one, axis_shape); + auto stop = std::make_shared(minimum_int, axis_shape); + return std::make_shared(x, start, stop, start, axis); +}; + std::shared_ptr numel(const NodeContext& context, const Output& x, element::Type output_type) { auto input_shape = context.mark_node(std::make_shared(x, output_type)); auto axes = context.mark_node(v0::Constant::create(output_type, Shape({1}), {0})); diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 9346b9e18b94a3..024e9158e9cdcd 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -59,6 +59,8 @@ std::shared_ptr get_node_axes_range(const NodeContext& context, const Outp Output normalize_axis(const NodeContext& context, const Output& axis, const Output& input_node); +Output create_flip(const Output& x, const Output& axis); + std::shared_ptr numel(const NodeContext& context, const Output& x, element::Type output_type = element::i32); From 06924c9e2340372563d441289f25cfe2ca16f0f4 Mon Sep 17 00:00:00 2001 From: Po-V Date: Wed, 8 Jan 2025 14:57:25 +0000 Subject: [PATCH 4/6] refactor dims handling in translate_rot90 and scatter update --- src/frontends/pytorch/src/op/rot90.cpp | 81 ++++++++++++-------------- 1 file changed, 38 insertions(+), 43 deletions(-) diff --git a/src/frontends/pytorch/src/op/rot90.cpp b/src/frontends/pytorch/src/op/rot90.cpp index 4790d28600b539..2f87b242ffdd49 100644 --- a/src/frontends/pytorch/src/op/rot90.cpp +++ b/src/frontends/pytorch/src/op/rot90.cpp @@ -2,15 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" -#include "openvino/op/transpose.hpp" -#include "openvino/op/unsqueeze.hpp" #include "openvino/op/range.hpp" -#include "openvino/op/concat.hpp" #include "openvino/op/scatter_elements_update.hpp" -#include "openvino/core/validation_util.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" #include "utils.hpp" namespace ov { @@ -24,57 +25,51 @@ OutputVector translate_rot90(const NodeContext& context) { num_inputs_check(context, 1, 3); auto input = context.get_input(0); int k = context.input_is_none(1) ? 1 : context.const_input(1); - auto dims = context.input_is_none(2) - ? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0,1})) - : get_input_as_i32(context, 2); - const auto& partial_shape = input.get_partial_shape(); - const auto ndims = partial_shape.rank().get_length(); - - std::shared_ptr rank = std::make_shared( - ov::element::i32, ov::Shape{}, std::vector{static_cast(ndims)}); - auto dims_norm = normalize_axis(context, dims, rank); - auto dims_const = std::dynamic_pointer_cast(dims_norm.get_node_shared_ptr()); - auto dims_values = dims_const->cast_vector(); - auto start = v0::Constant::create(element::i32, {}, {0}); - auto step = v0::Constant::create(element::i32, {}, {1}); - auto range = std::make_shared(start, rank, step, element::i32); + auto dims = context.input_is_none(2) ? context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 1})) + : get_input_as_i32(context, 2); - auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); - auto dim0_node = std::make_shared( - v0::Constant::create(element::i32, {}, {dims_values[0]}), axis_0); - auto dim1_node = std::make_shared( - v0::Constant::create(element::i32, {}, {dims_values[1]}), axis_0); + auto ndims = 0; + if (input.get_partial_shape().rank().is_static()) { + ndims = input.get_partial_shape().rank().get_length(); + } - auto indices = std::make_shared(OutputVector{dim0_node, dim1_node}, 0); - auto updates = std::make_shared( - OutputVector{dim1_node, dim0_node}, 0); + std::shared_ptr rank = + std::make_shared(ov::element::i32, + ov::Shape{}, + std::vector{static_cast(ndims)}); + + auto dims_norm = normalize_axis(context, dims, rank); + auto dims_const = std::dynamic_pointer_cast(dims_norm.get_node_shared_ptr()); + auto dims_values = dims_const->cast_vector(); - Output scatter = std::make_shared( - range, indices, updates, axis_0); - if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) { - scatter = context.mark_node(scatter_const); - } else { - context.mark_nodes( - {start, step, range, axis_0, dim0_node, dim1_node, indices, updates, scatter.get_node_shared_ptr()}); - } - PYTORCH_OP_CONVERSION_CHECK(dims_values.size() == 2, "Expected total rotation dims == 2, but got dims = ", dims_values.size()); - PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, - "Expected total dims >= 2, but got total dims = ", - ndims); + + PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, "Expected total dims >= 2, but got total dims = ", ndims); + PYTORCH_OP_CONVERSION_CHECK(dims_values[0] != dims_values[1], "Rotation dimensions must be different, but got dim0 = " + std::to_string(dims_values[0]) + " and dim1 = " + std::to_string(dims_values[1])); + auto start = v0::Constant::create(element::i32, {}, {0}); + auto step = v0::Constant::create(element::i32, {}, {1}); + auto range = std::make_shared(start, rank, step, element::i32); + auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto split = std::make_shared(dims_norm, axis_0, 2); + auto dim0_node = std::make_shared(split->output(0), axis_0); + auto dim1_node = std::make_shared(split->output(1), axis_0); + auto indices = std::make_shared(OutputVector{dim0_node, dim1_node}, 0); + auto updates = std::make_shared(OutputVector{dim1_node, dim0_node}, 0); + + Output scatter = std::make_shared(range, indices, updates, axis_0); + k = k % 4; Output rotated; - if (k == 1 || k == 3) { - Output flip_dims = (k ==1) ? dim1_node : dim0_node; - auto flipped = create_flip(input, flip_dims); + Output flip_dims = (k == 1) ? dim1_node : dim0_node; + auto flipped = create_flip(input, flip_dims); rotated = context.mark_node(std::make_shared(flipped, scatter)); } else if (k == 2) { rotated = create_flip(input, dims_norm); @@ -83,9 +78,9 @@ OutputVector translate_rot90(const NodeContext& context) { } return {rotated}; -}; +} } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov +} // namespace ov \ No newline at end of file From 6d9e19dc9f560a8f446bf71df75ff6b02ce0afe6 Mon Sep 17 00:00:00 2001 From: Po-V Date: Wed, 8 Jan 2025 17:23:30 +0000 Subject: [PATCH 5/6] removed dims_values from translate_rot90 --- src/frontends/pytorch/src/op/rot90.cpp | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/frontends/pytorch/src/op/rot90.cpp b/src/frontends/pytorch/src/op/rot90.cpp index 2f87b242ffdd49..cba13f1f2edfec 100644 --- a/src/frontends/pytorch/src/op/rot90.cpp +++ b/src/frontends/pytorch/src/op/rot90.cpp @@ -34,24 +34,14 @@ OutputVector translate_rot90(const NodeContext& context) { ndims = input.get_partial_shape().rank().get_length(); } + PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, "Expected total dims >= 2, but got total dims = ", ndims); + std::shared_ptr rank = std::make_shared(ov::element::i32, ov::Shape{}, std::vector{static_cast(ndims)}); auto dims_norm = normalize_axis(context, dims, rank); - auto dims_const = std::dynamic_pointer_cast(dims_norm.get_node_shared_ptr()); - auto dims_values = dims_const->cast_vector(); - - PYTORCH_OP_CONVERSION_CHECK(dims_values.size() == 2, - "Expected total rotation dims == 2, but got dims = ", - dims_values.size()); - - PYTORCH_OP_CONVERSION_CHECK(ndims >= 2, "Expected total dims >= 2, but got total dims = ", ndims); - - PYTORCH_OP_CONVERSION_CHECK(dims_values[0] != dims_values[1], - "Rotation dimensions must be different, but got dim0 = " + - std::to_string(dims_values[0]) + " and dim1 = " + std::to_string(dims_values[1])); auto start = v0::Constant::create(element::i32, {}, {0}); auto step = v0::Constant::create(element::i32, {}, {1}); From 0626fb30b1d9bb072847ad826e07bec4326a6bc3 Mon Sep 17 00:00:00 2001 From: Po-V Date: Thu, 16 Jan 2025 15:59:43 +0000 Subject: [PATCH 6/6] Fix conversion failure for aten::rot90 operation in OpenVINO frontend --- src/frontends/pytorch/src/op/rot90.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/rot90.cpp b/src/frontends/pytorch/src/op/rot90.cpp index cba13f1f2edfec..bd5b9ac095438d 100644 --- a/src/frontends/pytorch/src/op/rot90.cpp +++ b/src/frontends/pytorch/src/op/rot90.cpp @@ -52,8 +52,9 @@ OutputVector translate_rot90(const NodeContext& context) { auto dim1_node = std::make_shared(split->output(1), axis_0); auto indices = std::make_shared(OutputVector{dim0_node, dim1_node}, 0); auto updates = std::make_shared(OutputVector{dim1_node, dim0_node}, 0); + auto expanded_range = std::make_shared(range, axis_0); - Output scatter = std::make_shared(range, indices, updates, axis_0); + Output scatter = std::make_shared(expanded_range, indices, updates, axis_0); k = k % 4; Output rotated;