diff --git a/src/frontends/onnx/frontend/src/utils/dft.cpp b/src/frontends/onnx/frontend/src/utils/dft.cpp index 4b9702f54ba0b4..0977f062d55d39 100644 --- a/src/frontends/onnx/frontend/src/utils/dft.cpp +++ b/src/frontends/onnx/frontend/src/utils/dft.cpp @@ -12,6 +12,7 @@ #include "openvino/op/idft.hpp" #include "openvino/op/irdft.hpp" #include "openvino/op/rdft.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/unsqueeze.hpp" @@ -55,27 +56,31 @@ ov::Output make_dft(const ov::Output& signal, conversion_to_complex_applied = try_convert_real_to_complex(processed_signal); } - bool dft_length_provided = !ov::op::util::is_null(length); + const bool dft_length_provided = !ov::op::util::is_null(length); + const auto& signal_size = + dft_length_provided + ? std::make_shared(length, v0::Constant::create(ov::element::i32, {1}, {1}), false)->output(0) + : length; ov::Output result; if (is_inversed) { if (is_onesided) { - result = dft_length_provided ? std::make_shared(processed_signal, axis_const, length) + result = dft_length_provided ? std::make_shared(processed_signal, axis_const, signal_size) : std::make_shared(processed_signal, axis_const); if (conversion_to_complex_applied) { // align the output shape with a real numbers representation - const auto unsqueeze_axis = v0::Constant::create(ov::element::i64, {}, {-1}); + const auto unsqueeze_axis = v0::Constant::create(ov::element::i32, {}, {-1}); result = std::make_shared(result, unsqueeze_axis); } } else { - result = dft_length_provided ? std::make_shared(processed_signal, axis_const, length) + result = dft_length_provided ? std::make_shared(processed_signal, axis_const, signal_size) : std::make_shared(processed_signal, axis_const); } } else { if (is_onesided) { - result = dft_length_provided ? std::make_shared(processed_signal, axis_const, length) + result = dft_length_provided ? std::make_shared(processed_signal, axis_const, signal_size) : std::make_shared(processed_signal, axis_const); } else { - result = dft_length_provided ? std::make_shared(processed_signal, axis_const, length) + result = dft_length_provided ? std::make_shared(processed_signal, axis_const, signal_size) : std::make_shared(processed_signal, axis_const); } } diff --git a/src/frontends/onnx/tests/models/dft_scalar_length_provided.prototxt b/src/frontends/onnx/tests/models/dft_scalar_length_provided.prototxt new file mode 100644 index 00000000000000..bbd7accd246588 --- /dev/null +++ b/src/frontends/onnx/tests/models/dft_scalar_length_provided.prototxt @@ -0,0 +1,71 @@ +ir_version: 7 +graph { + node { + output: "dft_length" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 0 + data_type: 7 + int64_data: 1 + name: "const_tensor" + } + type: TENSOR + } + } + node { + input: "data" + input: "dft_length" + output: "out" + op_type: "DFT" + attribute { + name: "inverse" + i: 0 + type: INT + } + attribute { + name: "onesided" + i: 0 + type: INT + } + attribute { + name: "axis" + i: 0 + type: INT + } + } + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 5 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "out" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + domain: "" + version: 13 +} diff --git a/src/frontends/onnx/tests/onnx_import_signal.in.cpp b/src/frontends/onnx/tests/onnx_import_signal.in.cpp index 14cf03b1dfb886..665598fe6dbff4 100644 --- a/src/frontends/onnx/tests/onnx_import_signal.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_signal.in.cpp @@ -124,6 +124,19 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_dft_length_provided) { {0.000000f, 0.000000f, 1.000000f, 0.000000f, 2.000000f, 0.000000f, 3.000000f, 0.000000f, 4.000000f, 0.000000f}); } +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_dft_scalar_length_provided) { + auto model = convert_model("dft_scalar_length_provided.onnx"); + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(Shape{3, 5, 2}, {0.000000f, 0.000000f, 1.000000f, 0.000000f, 2.000000f, 0.000000f, + 3.000000f, 0.000000f, 4.000000f, 0.000000f, 5.000000f, 0.000000f, + 6.000000f, 0.000000f, 7.000000f, 0.000000f, 8.000000f, 0.000000f, + 9.000000f, 0.000000f, 10.000000f, 0.000000f, 11.000000f, 0.000000f, + 12.000000f, 0.000000f, 13.000000f, 0.000000f, 14.000000f, 0.000000f}); + test_case.add_expected_output( + Shape{1, 5, 2}, + {0.000000f, 0.000000f, 1.000000f, 0.000000f, 2.000000f, 0.000000f, 3.000000f, 0.000000f, 4.000000f, 0.000000f}); +} + OPENVINO_TEST(${BACKEND_NAME}, onnx_model_dft_length_provided_onesided) { auto model = convert_model("dft_lenght_provided_onesided.onnx"); auto test_case = ov::test::TestCase(model, s_device);