diff --git a/src/frontends/onnx/frontend/src/op/gather_nd.cpp b/src/frontends/onnx/frontend/src/op/gather_nd.cpp index 7f9c3fff303383..b141030901ea26 100644 --- a/src/frontends/onnx/frontend/src/op/gather_nd.cpp +++ b/src/frontends/onnx/frontend/src/op/gather_nd.cpp @@ -8,6 +8,14 @@ #include "openvino/op/gather_nd.hpp" #include "core/operator_set.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" + using namespace ov::op; namespace ov { @@ -15,12 +23,88 @@ namespace frontend { namespace onnx { namespace ai_onnx { namespace opset_1 { + +namespace { +// Helper function to extract a dimension from a shape tensor at given index +ov::Output get_dimension(const ov::Output& shape, int64_t index) { + auto axis = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto start = v0::Constant::create(ov::element::i64, ov::Shape{1}, {index}); + auto stop = v0::Constant::create(ov::element::i64, ov::Shape{1}, {index + 1}); + auto step = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + return std::make_shared(shape, start, stop, step, axis); +} +} // namespace + ov::OutputVector gather_nd(const ov::frontend::onnx::Node& node) { const ov::OutputVector ng_inputs{node.get_ov_inputs()}; - const auto data = ng_inputs.at(0); - const auto indices = ng_inputs.at(1); + auto data = ng_inputs.at(0); + auto indices = ng_inputs.at(1); const auto batch_dims = node.get_attribute_value("batch_dims", 0); + // If batch_dims > 0, we need to handle broadcasting for batch dimensions + // This is a workaround for ONNXRuntime's non-standard behavior that allows + // dimension 1 to broadcast to any size N in batch dimensions + if (batch_dims > 0) { + // Check if we can determine statically that broadcasting is not needed + bool need_broadcast = false; + bool shapes_are_static = data.get_partial_shape().is_static() && indices.get_partial_shape().is_static(); + + if (shapes_are_static) { + // Compare batch dimensions statically + auto data_shape_static = data.get_shape(); + auto indices_shape_static = indices.get_shape(); + + for (int64_t i = 0; i < batch_dims; ++i) { + if (data_shape_static[i] != indices_shape_static[i]) { + need_broadcast = true; + break; + } + } + } else { + // Dynamic shapes - conservatively assume broadcast may be needed + need_broadcast = true; + } + + // Only add Broadcast operations if needed + if (need_broadcast) { + auto data_shape = std::make_shared(data, ov::element::i64); + auto indices_shape = std::make_shared(indices, ov::element::i64); + + // Compute target batch shape as max(data_batch_shape, indices_batch_shape) + ov::OutputVector batch_dims_vec; + for (int64_t i = 0; i < batch_dims; ++i) { + auto data_dim = get_dimension(data_shape, i); + auto indices_dim = get_dimension(indices_shape, i); + auto max_dim = std::make_shared(data_dim, indices_dim); + batch_dims_vec.push_back(max_dim); + } + + // Get remaining dimensions + auto zero_const = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto batch_dims_const = v0::Constant::create(ov::element::i64, ov::Shape{1}, {batch_dims}); + auto one_step = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + + auto data_rank_node = std::make_shared(data_shape, ov::element::i64); + auto indices_rank_node = std::make_shared(indices_shape, ov::element::i64); + + auto data_remaining = + std::make_shared(data_shape, batch_dims_const, data_rank_node, one_step, zero_const); + auto indices_remaining = + std::make_shared(indices_shape, batch_dims_const, indices_rank_node, one_step, zero_const); + + // Construct target shapes + auto target_batch_shape = std::make_shared(batch_dims_vec, 0); + auto target_data_shape = + std::make_shared(ov::OutputVector{target_batch_shape, data_remaining}, 0); + auto target_indices_shape = + std::make_shared(ov::OutputVector{target_batch_shape, indices_remaining}, 0); + + // Broadcast data and indices to target shapes + data = std::make_shared(data, target_data_shape); + indices = std::make_shared(indices, target_indices_shape); + } + } + return {std::make_shared(data, indices, batch_dims)}; } diff --git a/src/frontends/onnx/tests/models/gatherND_batch_dims_1_broadcast.prototxt b/src/frontends/onnx/tests/models/gatherND_batch_dims_1_broadcast.prototxt new file mode 100644 index 00000000000000..a0259419516fb6 --- /dev/null +++ b/src/frontends/onnx/tests/models/gatherND_batch_dims_1_broadcast.prototxt @@ -0,0 +1,79 @@ +ir_version: 3 +producer_name: "OpenVINO ONNX Frontend" +graph { + node { + input: "data" + input: "indices" + output: "y" + op_type: "GatherND" + attribute { + name: "batch_dims" + i: 1 + type: INT + } + } + name: "test_gatherND_batch_dims_1_broadcast" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "indices" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/src/frontends/onnx/tests/models/gatherND_batch_dims_1_no_broadcast.prototxt b/src/frontends/onnx/tests/models/gatherND_batch_dims_1_no_broadcast.prototxt new file mode 100644 index 00000000000000..47ef1f8d1b670a --- /dev/null +++ b/src/frontends/onnx/tests/models/gatherND_batch_dims_1_no_broadcast.prototxt @@ -0,0 +1,79 @@ +ir_version: 3 +producer_name: "OpenVINO ONNX Frontend" +graph { + node { + input: "data" + input: "indices" + output: "y" + op_type: "GatherND" + attribute { + name: "batch_dims" + i: 1 + type: INT + } + } + name: "test_gatherND_batch_dims_1_no_broadcast" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "indices" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/src/frontends/onnx/tests/onnx_import.in.cpp b/src/frontends/onnx/tests/onnx_import.in.cpp index 7df87aac9fc732..42376f8df9e2ab 100644 --- a/src/frontends/onnx/tests/onnx_import.in.cpp +++ b/src/frontends/onnx/tests/onnx_import.in.cpp @@ -3565,6 +3565,120 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gatherND_float) { test_case.run(); } +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gatherND_batch_dims_1_broadcast) { + // Test GatherND with batch_dims=1 and broadcasting + // data shape: [1, 3, 4, 4] -> broadcasts to [2, 3, 4, 4] + // indices shape: [2, 2, 2] + // output shape: [2, 2, 4] + const auto model = convert_model("gatherND_batch_dims_1_broadcast.onnx"); + auto test_case = ov::test::TestCase(model, s_device); + + // data: [1, 3, 4, 4] - random values with seed=42, validated against ONNXRuntime + test_case.add_input({ + 0.496714f, -0.138264f, 0.647689f, 1.523030f, -0.234153f, -0.234137f, 1.579213f, 0.767435f, + -0.469474f, 0.542560f, -0.463418f, -0.465730f, 0.241962f, -1.913280f, -1.724918f, -0.562288f, + -1.012831f, 0.314247f, -0.908024f, -1.412304f, 1.465649f, -0.225776f, 0.067528f, -1.424748f, + -0.544383f, 0.110923f, -1.150994f, 0.375698f, -0.600639f, -0.291694f, -0.601707f, 1.852278f, + -0.013497f, -1.057711f, 0.822545f, -1.220844f, 0.208864f, -1.959670f, -1.328186f, 0.196861f, + 0.738467f, 0.171368f, -0.115648f, -0.301104f, -1.478522f, -0.719844f, -0.460639f, 1.057122f, + }); + + // indices: [2, 2, 2] - batch_dims=1 means first dim is batch + test_case.add_input({ + 0, + 0, // batch 0: index (0, 0) -> data[0, 0, 0, :] + 1, + 1, // batch 0: index (1, 1) -> data[0, 1, 1, :] + 0, + 1, // batch 1: index (0, 1) -> data[0, 0, 1, :] (broadcasted) + 2, + 3, // batch 1: index (2, 3) -> data[0, 2, 3, :] + }); + + // Expected output: [2, 2, 4] - reference from ONNXRuntime + test_case.add_expected_output(Shape{2, 2, 4}, + { + 0.496714f, + -0.138264f, + 0.647689f, + 1.523030f, // batch 0: data[0, 0, 0, :] + 1.465649f, + -0.225776f, + 0.067528f, + -1.424748f, // batch 0: data[0, 1, 1, :] + -0.234153f, + -0.234137f, + 1.579213f, + 0.767435f, // batch 1: data[0, 0, 1, :] + -1.478522f, + -0.719844f, + -0.460639f, + 1.057122f, // batch 1: data[0, 2, 3, :] + }); + + test_case.run(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gatherND_batch_dims_1_no_broadcast) { + // Test GatherND with batch_dims=1 and NO broadcasting needed + // data shape: [2, 3, 4, 4] - batch dimension is 2 + // indices shape: [2, 2, 2] - batch dimension is also 2 + // No broadcast needed - dimensions are equal + // Reference output generated with ONNXRuntime + const auto model = convert_model("gatherND_batch_dims_1_no_broadcast.onnx"); + auto test_case = ov::test::TestCase(model, s_device); + + // data: [2, 3, 4, 4] - random values with seed=123 + test_case.add_input({ + -1.085631f, 0.997345f, 0.282979f, -1.506295f, -0.578600f, 1.651437f, -2.426679f, -0.428913f, 1.265936f, + -0.866740f, -0.678886f, -0.094709f, 1.491390f, -0.638902f, -0.443982f, -0.434351f, 2.205930f, 2.186786f, + 1.004054f, 0.386186f, 0.737369f, 1.490732f, -0.935834f, 1.175829f, -1.253881f, -0.637752f, 0.907105f, + -1.428681f, -0.140069f, -0.861755f, -0.255619f, -2.798589f, -1.771533f, -0.699877f, 0.927462f, -0.173636f, + 0.002846f, 0.688223f, -0.879536f, 0.283627f, -0.805367f, -1.727669f, -0.390900f, 0.573806f, 0.338589f, + -0.011830f, 2.392365f, 0.412912f, 0.978736f, 2.238143f, -1.294085f, -1.038788f, 1.743712f, -0.798063f, + 0.029683f, 1.069316f, 0.890706f, 1.754886f, 1.495644f, 1.069393f, -0.772709f, 0.794863f, 0.314272f, + -1.326265f, 1.417299f, 0.807237f, 0.045490f, -0.233092f, -1.198301f, 0.199524f, 0.468439f, -0.831155f, + 1.162204f, -1.097203f, -2.123100f, 1.039727f, -0.403366f, -0.126030f, -0.837517f, -1.605963f, 1.255237f, + -0.688869f, 1.660952f, 0.807308f, -0.314758f, -1.085902f, -0.732462f, -1.212523f, 2.087113f, 0.164441f, + 1.150205f, -1.267352f, 0.181035f, 1.177862f, -0.335011f, 1.031114f, + }); + + // indices: [2, 2, 2] - batch_dims=1, equal batch dimensions (no broadcast) + test_case.add_input({ + 0, + 0, // batch 0: index (0, 0) -> data[0, 0, 0, :] + 1, + 2, // batch 0: index (1, 2) -> data[0, 1, 2, :] + 2, + 3, // batch 1: index (2, 3) -> data[1, 2, 3, :] + 0, + 1, // batch 1: index (0, 1) -> data[1, 0, 1, :] + }); + + // Expected output: [2, 2, 4] - reference from ONNXRuntime + test_case.add_expected_output(Shape{2, 2, 4}, + { + -1.085631f, + 0.997345f, + 0.282979f, + -1.506295f, // batch 0: data[0, 0, 0, :] + -1.253881f, + -0.637752f, + 0.907105f, + -1.428681f, // batch 0: data[0, 1, 2, :] + 0.181035f, + 1.177862f, + -0.335011f, + 1.031114f, // batch 1: data[1, 2, 3, :] + 1.743712f, + -0.798063f, + 0.029683f, + 1.069316f, // batch 1: data[1, 0, 1, :] + }); + + test_case.run(); +} + OPENVINO_TEST(${BACKEND_NAME}, onnx_model_pad_constant) { const auto model = convert_model("pad_constant.onnx"); auto test_case = ov::test::TestCase(model, s_device);