diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index 4b862bdd7554b..b6972ef5b4136 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -463,11 +463,34 @@ struct CustomGraph { } if (!is_prev_input) { - for (const auto& edge : output_edges) { + if (prev.node_ptr->OutputDefs()[0]->Type() != dq_node_ref.OutputDefs()[0]->Type()) { + NodeArg& output = original_graph.GetOrCreateNodeArg(prev.node_name + "_cast_0", dq_node_ref.OutputDefs()[0]->TypeAsProto()); + std::string cast_node_name = prev.node_ptr->OutputDefs()[0]->Name() + "_cast"; + InlinedVector input_args = {const_cast(prev.node_ptr->OutputDefs()[0])}; + InlinedVector output_args = {&output}; + Node& cast_node = original_graph.AddNode(cast_node_name, "Cast", "", input_args, output_args, nullptr, ""); + auto type_str = dq_node_ref.OutputDefs()[0]->Type(); + auto type_cast = type_str->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16; + ORT_ENFORCE((type_cast == onnx::TensorProto_DataType_FLOAT) || (type_str->find("tensor(float16)") != std::string::npos), + "QDQ type misalignment, expected float32 or float16 output"); + cast_node.AddAttribute("to", static_cast(type_cast)); original_graph.AddEdge(prev.node_ptr->Index(), - std::get<0>(edge), + cast_node.Index(), prev_output_index, - std::get<2>(edge)); + 0); + for (const auto& edge : output_edges) { + original_graph.AddEdge(cast_node.Index(), + std::get<0>(edge), + 0, + std::get<2>(edge)); + } + } else { + for (const auto& edge : output_edges) { + original_graph.AddEdge(prev.node_ptr->Index(), + std::get<0>(edge), + prev_output_index, + std::get<2>(edge)); + } } } }