Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions src/frontends/onnx/frontend/src/op/gather_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,79 @@
#include "openvino/op/gather_nd.hpp"

#include "core/operator_set.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace onnx {
namespace ai_onnx {
namespace opset_1 {

namespace {
// Helper function to extract a dimension from a shape tensor at given index
ov::Output<ov::Node> get_dimension(const ov::Output<ov::Node>& shape, int64_t index) {
auto axis = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto start = v0::Constant::create(ov::element::i64, ov::Shape{1}, {index});
auto stop = v0::Constant::create(ov::element::i64, ov::Shape{1}, {index + 1});
auto step = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});
return std::make_shared<v8::Slice>(shape, start, stop, step, axis);
}
} // namespace

ov::OutputVector gather_nd(const ov::frontend::onnx::Node& node) {
const ov::OutputVector ng_inputs{node.get_ov_inputs()};
const auto data = ng_inputs.at(0);
const auto indices = ng_inputs.at(1);
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
const auto batch_dims = node.get_attribute_value<int64_t>("batch_dims", 0);

// If batch_dims > 0, we need to handle broadcasting for batch dimensions
// This is a workaround for ONNXRuntime's non-standard behavior that allows
// dimension 1 to broadcast to any size N in batch dimensions
if (batch_dims > 0) {
auto data_shape = std::make_shared<v3::ShapeOf>(data, ov::element::i64);
auto indices_shape = std::make_shared<v3::ShapeOf>(indices, ov::element::i64);

// Compute target batch shape as max(data_batch_shape, indices_batch_shape)
ov::OutputVector batch_dims_vec;
for (int64_t i = 0; i < batch_dims; ++i) {
auto data_dim = get_dimension(data_shape, i);
auto indices_dim = get_dimension(indices_shape, i);
auto max_dim = std::make_shared<v1::Maximum>(data_dim, indices_dim);
batch_dims_vec.push_back(max_dim);
}

// Get remaining dimensions
auto zero_const = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto batch_dims_const = v0::Constant::create(ov::element::i64, ov::Shape{1}, {batch_dims});
auto one_step = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});

auto data_rank_node = std::make_shared<v3::ShapeOf>(data_shape, ov::element::i64);
auto indices_rank_node = std::make_shared<v3::ShapeOf>(indices_shape, ov::element::i64);

auto data_remaining =
std::make_shared<v8::Slice>(data_shape, batch_dims_const, data_rank_node, one_step, zero_const);
auto indices_remaining =
std::make_shared<v8::Slice>(indices_shape, batch_dims_const, indices_rank_node, one_step, zero_const);

// Construct target shapes
auto target_batch_shape = std::make_shared<v0::Concat>(batch_dims_vec, 0);
auto target_data_shape = std::make_shared<v0::Concat>(ov::OutputVector{target_batch_shape, data_remaining}, 0);
auto target_indices_shape =
std::make_shared<v0::Concat>(ov::OutputVector{target_batch_shape, indices_remaining}, 0);

// Broadcast data and indices to target shapes
data = std::make_shared<v3::Broadcast>(data, target_data_shape);
indices = std::make_shared<v3::Broadcast>(indices, target_indices_shape);
}

return {std::make_shared<v8::GatherND>(data, indices, batch_dims)};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
ir_version: 3
producer_name: "OpenVINO ONNX Frontend"
graph {
node {
input: "data"
input: "indices"
output: "y"
op_type: "GatherND"
attribute {
name: "batch_dims"
i: 1
type: INT
}
}
name: "test_gatherND_batch_dims_1_broadcast"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 13
}
54 changes: 54 additions & 0 deletions src/frontends/onnx/tests/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3565,6 +3565,60 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gatherND_float) {
test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gatherND_batch_dims_1_broadcast) {
// Test GatherND with batch_dims=1 and broadcasting
// data shape: [1, 3, 4, 4] -> broadcasts to [2, 3, 4, 4]
// indices shape: [2, 2, 2]
// output shape: [2, 2, 4]
const auto model = convert_model("gatherND_batch_dims_1_broadcast.onnx");
auto test_case = ov::test::TestCase(model, s_device);

// data: [1, 3, 4, 4] - random values with seed=42, validated against ONNXRuntime
test_case.add_input<float>({
0.496714f, -0.138264f, 0.647689f, 1.523030f, -0.234153f, -0.234137f, 1.579213f, 0.767435f,
-0.469474f, 0.542560f, -0.463418f, -0.465730f, 0.241962f, -1.913280f, -1.724918f, -0.562288f,
-1.012831f, 0.314247f, -0.908024f, -1.412304f, 1.465649f, -0.225776f, 0.067528f, -1.424748f,
-0.544383f, 0.110923f, -1.150994f, 0.375698f, -0.600639f, -0.291694f, -0.601707f, 1.852278f,
-0.013497f, -1.057711f, 0.822545f, -1.220844f, 0.208864f, -1.959670f, -1.328186f, 0.196861f,
0.738467f, 0.171368f, -0.115648f, -0.301104f, -1.478522f, -0.719844f, -0.460639f, 1.057122f,
});

// indices: [2, 2, 2] - batch_dims=1 means first dim is batch
test_case.add_input<int64_t>({
0,
0, // batch 0: index (0, 0) -> data[0, 0, 0, :]
1,
1, // batch 0: index (1, 1) -> data[0, 1, 1, :]
0,
1, // batch 1: index (0, 1) -> data[0, 0, 1, :] (broadcasted)
2,
3, // batch 1: index (2, 3) -> data[0, 2, 3, :]
});

// Expected output: [2, 2, 4] - reference from ONNXRuntime
test_case.add_expected_output<float>(Shape{2, 2, 4},
{
0.496714f,
-0.138264f,
0.647689f,
1.523030f, // batch 0: data[0, 0, 0, :]
1.465649f,
-0.225776f,
0.067528f,
-1.424748f, // batch 0: data[0, 1, 1, :]
-0.234153f,
-0.234137f,
1.579213f,
0.767435f, // batch 1: data[0, 0, 1, :]
-1.478522f,
-0.719844f,
-0.460639f,
1.057122f, // batch 1: data[0, 2, 3, :]
});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_pad_constant) {
const auto model = convert_model("pad_constant.onnx");
auto test_case = ov::test::TestCase(model, s_device);
Expand Down
Loading