Skip to content

Commit

Permalink
[PT FE] Improve support for complex data type (#28482)
Browse files Browse the repository at this point in the history
### Details:
 - *Remove transformations for FFT*
 - *Use `ComplexTypeMark` to provide information about a complex type*

### Tickets:
 - *CVS-159375*

---------

Signed-off-by: Maxim Vafin <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
mvafin and rkazants authored Jan 17, 2025
1 parent 049c8ba commit 0848f86
Show file tree
Hide file tree
Showing 18 changed files with 497 additions and 450 deletions.
9 changes: 5 additions & 4 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "transforms/dict_resolver.hpp"
#include "transforms/einsum_list_construct.hpp"
#include "transforms/index_loop_getitem_replacer.hpp"
#include "transforms/irfftn_complex_replacer.hpp"
#include "transforms/listconstruct_replacer.hpp"
#include "transforms/min_max_prim_list_construct_replacer.hpp"
#include "transforms/prim_list_construct_pad.hpp"
Expand All @@ -40,7 +39,6 @@
#include "transforms/quantized_node_remover.hpp"
#include "transforms/remove_packing_ops.hpp"
#include "transforms/reverseprop_resolver.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/softmax_reshape_elimination.hpp"
#include "transforms/string_equality_replacer.hpp"
#include "transforms/torchfx_gptq_pattern_replacer.hpp"
Expand Down Expand Up @@ -69,6 +67,11 @@ std::map<std::string, std::string> get_unconverted_types_from_model(const std::s
if (!unconverted_ops_types.count(op_type_it->second)) {
unconverted_ops_types.emplace(op_type_it->second, std::move(exception_msg));
}
} else if (const auto& fw_node = ov::as_type_ptr<ov::op::util::FrameworkNode>(node)) {
auto op_type = std::string(fw_node->get_type_name());
if (!unconverted_ops_types.count(op_type)) {
unconverted_ops_types.emplace(op_type, "This is OpenVINO internal type.");
}
}
if (const auto& fw_node = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(node)) {
for (size_t i = 0; i < fw_node->get_internal_subgraphs_size(); ++i) {
Expand Down Expand Up @@ -283,8 +286,6 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeUnpackParameters>();
Expand Down
84 changes: 84 additions & 0 deletions src/frontends/pytorch/src/op/complex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_complex(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto real = context.get_input(0);
auto imag = context.get_input(1);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
real = context.mark_node(std::make_shared<v0::Unsqueeze>(real, const_neg_1));
imag = context.mark_node(std::make_shared<v0::Unsqueeze>(imag, const_neg_1));

auto complex = context.mark_node(std::make_shared<v0::Concat>(OutputVector{real, imag}, -1));

return {context.mark_node(std::make_shared<ComplexTypeMark>(complex, complex->get_element_type()))};
};

OutputVector translate_imag(const NodeContext& context) {
num_inputs_check(context, 1, 1, true);
auto complex = context.get_input(0);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(complex.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::imag operation expects complex type tensor on input.");

complex = complex_type_mark->input_value(0);
auto axis = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto imag = context.mark_node(std::make_shared<v1::Split>(complex, axis, 2))->output(1);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
return {context.mark_node(std::make_shared<v0::Squeeze>(imag, const_neg_1))};
};

OutputVector translate_real(const NodeContext& context) {
num_inputs_check(context, 1, 1, true);
auto complex = context.get_input(0);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(complex.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::real operation expects complex type tensor on input.");

complex = complex_type_mark->input_value(0);
auto axis = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto real = context.mark_node(std::make_shared<v1::Split>(complex, axis, 2))->output(0);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
return {context.mark_node(std::make_shared<v0::Squeeze>(real, const_neg_1))};
};

OutputVector translate_view_as_real(const NodeContext& context) {
num_inputs_check(context, 1, 1, true);
auto complex = context.get_input(0);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(complex.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::real operation expects complex type tensor on input.");

return {complex_type_mark->input_value(0)};
};

OutputVector translate_view_as_complex(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto complex = context.get_input(0);

return {context.mark_node(std::make_shared<ComplexTypeMark>(complex, complex.get_element_type()))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
208 changes: 208 additions & 0 deletions src/frontends/pytorch/src/op/fft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/irdft.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/rdft.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_fft_rfftn(const NodeContext& context) {
// aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
num_inputs_check(context, 1, 4);
auto input = context.get_input(0);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));

Output<Node> input_shape;
Output<Node> input_rank_scalar;
std::tie(input_shape, input_rank_scalar) = get_shape_rank(context, input, true);

Output<Node> raw_s;
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
if (!context.input_is_none(1)) {
// s is provided, load from input.
raw_s = get_input_concat_if_list(context, 1);
raw_s = context.mark_node(std::make_shared<v0::Convert>(raw_s, element::i32));
}
Output<Node> dim;
// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
if (!context.input_is_none(2)) {
// dim is provided, load from input.
dim = get_input_concat_if_list(context, 2);
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
} else if (!context.input_is_none(1)) {
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
auto s_len = context.mark_node(std::make_shared<v3::ShapeOf>(raw_s, element::i32));
auto slice_start = context.mark_node(std::make_shared<v1::Subtract>(input_rank_scalar, s_len));
auto slice_start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(slice_start));
dim = context.mark_node(
std::make_shared<v4::Range>(slice_start_scalar, input_rank_scalar, const_1, element::i32));
} else {
// Dim and s are set to default, use all of dimensions.
dim = context.mark_node(std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32));
}

Output<Node> s;
if (context.input_is_none(1)) {
// Value for s was set to default, use full size for all dimensions.
s = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
} else {
// Values for s were provided. Replace -1 values with default full size in given dimension.
auto full_s_cond = context.mark_node(std::make_shared<v1::Equal>(raw_s, const_neg_1));
auto full_s_values = context.mark_node(std::make_shared<v8::Gather>(input_shape, dim, const_0));
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, full_s_values, raw_s));
}

// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm = "backward";
if (!context.input_is_none(3)) {
norm = context.const_input<std::string>(3);
}

auto rdft = context.mark_node(std::make_shared<v9::RDFT>(input, dim, s));

// Apply normalizations
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, rdft));
Output<Node> normalized_rfftn;
if (norm == "forward") {
// Normalize by 1/n
normalized_rfftn = context.mark_node(std::make_shared<v1::Divide>(rdft, n));
} else if (norm == "backward") {
// No normalization
normalized_rfftn = rdft;
} else if (norm == "ortho") {
// Normalize by 1/sqrt(n)
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
normalized_rfftn = context.mark_node(std::make_shared<v1::Divide>(rdft, sqrt_n));
} else {
FRONT_END_THROW(
"aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
}

return {std::make_shared<ComplexTypeMark>(normalized_rfftn, normalized_rfftn.get_element_type())};
}

OutputVector translate_fft_irfftn(const NodeContext& context) {
// aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
num_inputs_check(context, 1, 4, true);
auto input = context.get_input(0);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(complex_type_mark, "aten::fft_irfftn operation expects complex type tensor on input.");
input = complex_type_mark->input_value(0);

auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_scalar_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_scalar_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));

// Input shape of complex number (excluding dimension created by concatenation of real and imag)
auto complex_input_shape = get_complex_shape(context, input);
auto input_rank = context.mark_node(std::make_shared<v3::ShapeOf>(complex_input_shape, element::i32));
auto input_rank_scalar = context.mark_node(std::make_shared<v0::Squeeze>(input_rank));

Output<Node> raw_s;
// Inputs can be either none or List. Check whether input values should be used or should be set to default values.
if (!context.input_is_none(1)) {
// s is provided, load from input.
raw_s = get_input_concat_if_list(context, 1);
raw_s = context.mark_node(std::make_shared<v0::Convert>(raw_s, element::i32));
}

// Handle dim parameter containing vector of integers indicating dimensions to be transformed.
Output<Node> dim;
if (!context.input_is_none(2)) {
// Dim values is provided, load from input.
dim = get_input_concat_if_list(context, 2);
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
} else if (!context.input_is_none(1)) {
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
auto s_len = context.mark_node(std::make_shared<v3::ShapeOf>(raw_s, element::i32));
auto range_start = context.mark_node(std::make_shared<v1::Subtract>(input_rank, s_len));
auto range_start_scalar = context.mark_node(std::make_shared<v0::Squeeze>(range_start));
dim = context.mark_node(
std::make_shared<v4::Range>(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32));
} else {
// Dim and s are set to default, use all of dimensions.
dim = context.mark_node(
std::make_shared<v4::Range>(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32));
}

// Calculate default s values. Use full available size except last element, which is set to even value in last
// dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
auto default_s_raw = context.mark_node(std::make_shared<v8::Gather>(complex_input_shape, dim, const_0));
auto last_s = context.mark_node(std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0));
auto last_s_m_1 = context.mark_node(std::make_shared<v1::Subtract>(last_s, const_1));
auto s_upd = context.mark_node(std::make_shared<v1::Multiply>(last_s_m_1, const_2));
auto s_shape = context.mark_node(std::make_shared<v3::ShapeOf>(default_s_raw, element::i32));
auto last_s_idx = context.mark_node(std::make_shared<v1::Subtract>(s_shape, const_1));
auto default_s = context.mark_node(std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0));

// Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
Output<Node> s;
if (!context.input_is_none(1)) {
// Values for s were provided. Replace -1 values with default full size in given dimension.
auto full_s_cond = context.mark_node(std::make_shared<v1::Equal>(raw_s, const_neg_1));
s = context.mark_node(std::make_shared<v1::Select>(full_s_cond, default_s, raw_s));
} else {
// Value for s was set to default.
s = default_s;
}

// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm = "backward";
if (!context.input_is_none(3)) {
norm = context.const_input<std::string>(3);
}

auto irdft = context.mark_node(std::make_shared<v9::IRDFT>(input, dim, s));

// Apply normalizations.
auto n_int = context.mark_node(std::make_shared<v1::ReduceProd>(s, const_0));
auto n = context.mark_node(std::make_shared<v1::ConvertLike>(n_int, irdft));
Output<Node> normalized_irfftn;
if (norm == "forward") {
normalized_irfftn = context.mark_node(std::make_shared<v1::Multiply>(irdft, n));
} else if (norm == "backward") {
normalized_irfftn = irdft;
} else if (norm == "ortho") {
auto sqrt_n = context.mark_node(std::make_shared<v0::Sqrt>(n));
normalized_irfftn = context.mark_node(std::make_shared<v1::Multiply>(irdft, sqrt_n));
} else {
FRONT_END_THROW(
"aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
}
return {normalized_irfftn};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
35 changes: 31 additions & 4 deletions src/frontends/pytorch/src/op/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
//

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "utils.hpp"

Expand All @@ -12,17 +15,41 @@ namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_permute(const NodeContext& context) {
num_inputs_check(context, 2, 2);
num_inputs_check(context, 2, 2, true);
auto data = context.get_input(0);
auto order = get_input_concat_if_list(context, 1);
auto rank = std::get<1>(get_shape_rank(context, data));
auto rank_converted = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(rank, order));

Output<Node> rank;
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(data.get_node_shared_ptr());
if (complex_type_mark) {
data = complex_type_mark->input_value(0);
rank = std::get<1>(get_shape_rank(context, data));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
rank = context.mark_node(std::make_shared<v1::Subtract>(rank, const_1));
} else {
rank = std::get<1>(get_shape_rank(context, data));
}

auto rank_converted = context.mark_node(std::make_shared<v1::ConvertLike>(rank, order));
auto order_normalized = normalize_axis(context, order, rank_converted);

if (complex_type_mark) {
auto to_concat = OutputVector{order_normalized, rank_converted};
order_normalized = context.mark_node(std::make_shared<v0::Concat>(to_concat, 0));
}

if (const auto order_const = ov::util::get_constant_from_source(order_normalized)) {
order_normalized = order_const;
}
return {context.mark_node(std::make_shared<ov::op::v1::Transpose>(data, order_normalized))};
auto permute = context.mark_node(std::make_shared<v1::Transpose>(data, order_normalized));
if (complex_type_mark) {
const auto& complex_dtype = complex_type_mark->get_complex_part_type();
permute = context.mark_node(std::make_shared<ComplexTypeMark>(permute, complex_dtype));
}
return {permute};
}

} // namespace op
Expand Down
Loading

0 comments on commit 0848f86

Please sign in to comment.