Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Extend ONNX frontend with SoftmaxCrossEntropyLoss function #20547 #22650

Closed
wants to merge 2 commits into from

Conversation

tanishka321
Copy link

Details:

Extend ONNX frontend with SoftmaxCrossEntropyLoss function

Tickets:

Closes #20547

@tanishka321 tanishka321 requested a review from a team as a code owner February 5, 2024 03:19
@github-actions github-actions bot added the category: ONNX FE OpenVINO ONNX FrontEnd label Feb 5, 2024
@mlukasze mlukasze added the ExternalPR External contributor label Feb 8, 2024
Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @tanishka321, thank you for your contribution.

Decomposition subgraph can contain only extisting OpenVINO operators,
it's not possible to create default_opset::SoftmaxCrossEntropyLoss as such class doesn't exist (so such code can't be compiled). SoftmaxCrossEntropyLoss operation is not available as OpenVINO primitive operator (list of ov ops), that's why the ONNX Function expand mechanism is used.

@gkrivor As ONNX SoftmaxCrossEntropyLoss is training related operator, I would suggest to verify it's applicability for custom decomposition in OpenVINO ONNX FE.

const auto coerced_logits = ov::op::util::flatten(logits, static_cast<int>(axis));
const auto coerced_labels = ov::op::util::flatten(labels, static_cast<int>(axis));

const auto result = std::make_shared<default_opset::SoftmaxCrossEntropyLoss>(coerced_logits, coerced_labels, 1);
Copy link
Contributor

@mitruska mitruska Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not possible to create default_opset::SoftmaxCrossEntropyLoss as such class doesn't exist. SoftmaxCrossEntropyLoss operation is not available as OpenVINO primitive operator (list of ov ops), this is why the ONNX Function expand mechanism is used.
Decomposition should follow or optimize the current subgraph produced by ONNX Frontend.

Comment on lines +33 to +45
namespace set_1 {
OutputVector softmax_cross_entropy_loss(const Node& node) {
const auto logits = node.get_ng_inputs().at(0);
const auto labels = node.get_ng_inputs().at(1);

const auto axis = node.get_attribute_value<int64_t>("axis", 1);

std::shared_ptr<ngraph::Node> result;
result = onnx_softmax_cross_entropy_loss(logits, labels, axis);

return {result};
}
} // namespace set_1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is repeated 3 times for each namespace (set_1, set_11, set_13). It was a good direction to create common helper function onnx_softmax_cross_entropy_loss and reuse it, but more code can be unified. In fact there is no difference between the versions, just bfloat16 support has been added to the ONNX spec, so the import code can be shared between versions.
https://onnx.ai/onnx/operators/text_diff_SoftmaxCrossEntropyLoss_12_13.html

Comment on lines +35 to +36
const auto logits = node.get_ng_inputs().at(0);
const auto labels = node.get_ng_inputs().at(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to call node.get_ng_inputs() once and assign it to a variable, then take required input by index.

auto test_case = ov::test::TestCase(function, s_device);
test_case.add_input<int64_t>(Shape{2}, {1, 2});
test_case.add_input<int64_t>(Shape{2}, {0, 2});
test_case.add_expected_output<bool>(Shape{2}, {false, true});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the type of the expected output values is bool?
According to the ONNX SoftmaxCrossEntropyLoss specification, it should be floating-point precision.

Comment on lines +7 to +18
#include "openvino/core/deprecated.hpp"
OPENVINO_SUPPRESS_DEPRECATED_START

#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {

OutputVector softmax_cross_entropy_loss(const Node& node);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sync with the latest master. ONNX FE has been slightly refactored. The changes (namespaces, headers) need to be aligned with the rest of the frontends/onnx files.

Any of the onnx frontend op file can be taken as an example to follow, like:

/src/frontends/onnx/frontend/src/op/leaky_relu.hpp
/src/frontends/onnx/frontend/src/op/leaky_relu.cpp

Copy link
Contributor

github-actions bot commented Mar 3, 2024

This PR will be closed in a week because of 2 weeks of no activity.

@github-actions github-actions bot added the Stale label Mar 3, 2024
Copy link
Contributor

This PR was closed because it has been stalled for 2 week with no activity.

@github-actions github-actions bot closed this Mar 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: ONNX FE OpenVINO ONNX FrontEnd ExternalPR External contributor Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Good First Issue]: Extend ONNX Frontend with Function SoftmaxCrossEntropyLoss
4 participants