diff --git a/src/frontends/onnx/frontend/src/op/softmax_crossentropy_loss.cpp b/src/frontends/onnx/frontend/src/op/softmax_crossentropy_loss.cpp index 4e2573debb187d..dfa97577361fb0 100644 --- a/src/frontends/onnx/frontend/src/op/softmax_crossentropy_loss.cpp +++ b/src/frontends/onnx/frontend/src/op/softmax_crossentropy_loss.cpp @@ -3,75 +3,110 @@ // #include "core/operator_set.hpp" -#include "exceptions.hpp" -#include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/log.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/negative.hpp" +#include "openvino/op/not_equal.hpp" #include "openvino/op/reduce_mean.hpp" #include "openvino/op/reduce_sum.hpp" -#include "openvino/op/select.hpp" #include "openvino/op/softmax.hpp" -#include "utils/common.hpp" #include "softmax_cross_entropy_loss.hpp" namespace ov { namespace frontend { namespace onnx { namespace { - // softmax cross entropy implementation (Shared helper fn) - OutputVector impl_softmax_cross_entropy(const Node& node, int64_t axis_default) { - const auto inputs = node.get_ov_inputs(); +OutputVector impl_softmax_cross_entropy(const Node& node, int64_t axis_default) { + const auto inputs = node.get_ov_inputs(); - const auto scores = inputs[0]; - const auto labels = inputs[1]; + const auto scores = inputs[0]; + const auto labels = inputs[1]; - const auto axis = node.get_attribute_value("axis", axis_default); - const auto reduction = node.get_attribute_value("reduction", "mean"); + bool has_weights = inputs.size() > 2; + std::shared_ptr weights_gather = nullptr; - // Computing softmax - const auto softmax = std::make_shared(scores, axis); - const auto log_softmax = std::make_shared(softmax); + bool has_ignore_index = node.has_attribute("ignore_index"); + int64_t ignore_index_val = 0; + std::shared_ptr mask = nullptr; - const auto axis_const = ov::op::v0::Constant::create(element::i64, {}, {axis}); - const auto gathered = std::make_shared(log_softmax, labels, axis_const); + if (has_ignore_index) { + ignore_index_val = node.get_attribute_value("ignore_index"); + auto ignore_index_node = ov::op::v0::Constant::create(labels.get_element_type(), {}, {ignore_index_val}); + auto neq = std::make_shared(labels, ignore_index_node); + mask = std::make_shared(neq, scores.get_element_type()); + } + if (has_weights) { + const auto weights = inputs[2]; + const auto axis_for_weights = ov::op::v0::Constant::create(element::i64, {}, {0}); + weights_gather = std::make_shared(weights, labels, axis_for_weights); - // Computing loss - std::shared_ptr loss = std::make_shared(gathered); + if (has_ignore_index) { + weights_gather = std::make_shared(weights_gather, mask); + } + } else if (has_ignore_index) { + weights_gather = mask; + } - // applying reduction as mentioned in https://github.com/onnx/onnx/blob/main/docs/Changelog.md#softmaxcrossentropyloss-12 + const auto axis = node.get_attribute_value("axis", axis_default); + const auto reduction = node.get_attribute_value("reduction", "mean"); - if (reduction != "none") { - const auto reduce_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); - - loss = (reduction == "mean") - ? static_cast>( - std::make_shared(loss->output(0), reduce_axis, true)) - : static_cast>( - std::make_shared(loss->output(0), reduce_axis, true)); - } + const auto softmax = std::make_shared(scores, axis); + const auto log_softmax = std::make_shared(softmax); - return {loss}; - } -} -namespace ai_onnx { - namespace opset_12 { - OutputVector ov::frontend::onnx::ai_onnx::opset_12::softmax_cross_entropy_loss(const Node& node) { - return impl_softmax_cross_entropy(node, 1); - } - ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(12), ai_onnx::opset_12::softmax_cross_entropy_loss); - } - namespace opset_13 { - OutputVector ov::frontend::onnx::ai_onnx::opset_13::softmax_cross_entropy_loss(const Node& node) { - return impl_softmax_cross_entropy(node, 1); + const auto axis_const = ov::op::v0::Constant::create(element::i64, {}, {axis}); + const auto gathered = std::make_shared(log_softmax, labels, axis_const); + + std::shared_ptr loss = std::make_shared(gathered); + + if (weights_gather) { + loss = std::make_shared(loss, weights_gather); } - ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_SINCE(13), ai_onnx::opset_13::softmax_cross_entropy_loss); + if (reduction != "none") { + auto loss_shape = loss->get_output_partial_shape(0); + if (loss_shape.rank().is_static()) { + size_t loss_rank = loss_shape.rank().get_length(); + std::vector reduce_axes(loss_rank); + std::iota(reduce_axes.begin(), reduce_axes.end(), 0); + auto reduce_axis = ov::op::v0::Constant::create(ov::element::i64, {reduce_axes.size()}, reduce_axes); + + if (reduction == "mean") { + if (weights_gather) { + auto loss_sum = std::make_shared(loss, reduce_axis, false); + auto weight_sum = std::make_shared(weights_gather, reduce_axis, false); + loss = std::make_shared(loss_sum, weight_sum); + } else { + loss = std::make_shared(loss, reduce_axis, false); + } + } else if (reduction == "sum") { + loss = std::make_shared(loss, reduce_axis, false); + } + } else { + OPENVINO_THROW("Dynamic rank is not supported for SoftmaxCrossEntropyLoss reduction."); + } } -} + + return {loss}; +} +} // namespace +namespace ai_onnx { +namespace opset_12 { +OutputVector softmax_cross_entropy_loss(const Node& node) { + return impl_softmax_cross_entropy(node, 1); } +ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_IN(12), ai_onnx::opset_12::softmax_cross_entropy_loss); +} // namespace opset_12 +namespace opset_13 { +OutputVector softmax_cross_entropy_loss(const Node& node) { + return impl_softmax_cross_entropy(node, 1); } -} \ No newline at end of file +ONNX_OP("SoftmaxCrossEntropyLoss", OPSET_IN(13), ai_onnx::opset_13::softmax_cross_entropy_loss); +} // namespace opset_13 +} // namespace ai_onnx +} // namespace onnx +} // namespace frontend +} // namespace ov \ No newline at end of file