diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index bd3397245f5a26..01a1ac3e7b47fa 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -235,37 +235,37 @@ bool Transformations::fuse_type_to_convert(const std::shared_ptr& node return false; const auto& to = it->second; - // For Convert node, converting precision from floating point to boolean will lead to mathematical - // error, because here the output precision boolean is replaced by u8: - // - floating point value 0.01 is converted to be 1 for boolean, but 0 for u8 - need to insert Ceil. - // - floating point value 256 is converted to be 1 for boolean, but 0 for u8 - need to insert Min(x, UINT8_MAX) - // - floating point value -256 is converted to be 1 for boolean, but 0 for u8 - need to insert Abs before Min. - // Thus an Abs, Ceil and Min nodes should be added before the Convert node for this scenario. - if (convert->input(0).get_element_type().is_real() && - convert->get_convert_element_type() == ov::element::boolean && to.is_integral_number()) { + if (convert->get_convert_element_type() == ov::element::boolean && to.is_integral_number()) { + // For Convert node, converting precision from numerical data types to boolean will lead to mathematical + // error, because here the output precision boolean is replaced by u8: + // - floating point value 0.01 is converted to be 1 for boolean, but 0 for u8 - need to insert Ceil. + // - either float or int values should be clipped with the interval [0; 1] to mimic bool cast behavior, i.e. + // 0 - is false, 1 - is true + // - to perform clamping correctly an Abs op should be inserted before Clamp + // Thus an Abs, Ceil and Clamp nodes should be added before the Convert node for this scenario. ov::pass::NodeRegistry reg; const auto& in_prec = convert->get_input_element_type(0); - auto data = convert->input_value(0).get_node_shared_ptr(); + auto parent_node = convert->input_value(0).get_node_shared_ptr(); auto item = precisions.find(in_prec); if (item != precisions.end()) { - // Add convert node for unsupported precision, such as FP64 - data = reg.make(data, item->second); + // Add convert node for unsupported precision, such as FP64 or INT64 + parent_node = reg.make(parent_node, item->second); } - const auto abs = reg.make(data); - const auto to_max_value = reg.make(ov::util::make_tensor_of_max_value(to)); - const auto to_max_convert = reg.make(to_max_value, abs->get_output_element_type(0)); - const auto min = reg.make(abs, to_max_convert); - const auto ceil = reg.make(min); - const auto new_convert = reg.make(ceil, to); + if (in_prec.is_signed()) { + parent_node = reg.make(parent_node); + } + if (in_prec.is_real()) { + parent_node = reg.make(parent_node); + } + parent_node = reg.make(parent_node, 0, 1); + const auto new_convert = reg.make(parent_node, to); new_convert->set_friendly_name(convert->get_friendly_name()); ov::copy_runtime_info(convert, reg.get()); ov::replace_node(convert, new_convert); return true; - } else { - convert->set_convert_element_type(to); - return true; } - return false; + convert->set_convert_element_type(to); + return true; } void Transformations::UpToLpt() { diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/convert_bool_math.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/convert_bool_math.cpp new file mode 100644 index 00000000000000..b3f08e11624f1b --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/convert_bool_math.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "common_test_utils/ov_tensor_utils.hpp" + +namespace ov { +namespace test { +// ┌────────┐ +// │ Param │ +// └───┬────┘ +// │ f32 +// │ +// ┌───┴────┐ +// │Convert │ +// └───┬────┘ +// │ bool +// │ +// ┌───┴────┐ +// │Reshape │ +// └───┬────┘ +// │ bool +// │ +// ┌───┴────┐ ┌────────┐ +// │Convert │ │ Param │ +// └───┬────┘ └───┬────┘ +// │ f32 │ f32 +// │ │ +// │ ┌────────┐ │ +// └─────┤ Add ├───┘ +// └───┬────┘ +// │ f32 +// │ +// ┌───┴────┐ +// │Reshape │ +// └────────┘ + +class ConvertBoolMathTest : public SubgraphBaseStaticTest { +public: + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_CPU; + + ov::ParameterVector inputParams{std::make_shared(ov::element::f32, ov::Shape{24, 7}), + std::make_shared(ov::element::f32, ov::Shape{3, 8, 7})}; + + auto inputConvert = std::make_shared(inputParams.front(), ov::element::boolean); + + auto reshapeConst = ov::opset10::Constant::create(ov::element::i32, ov::Shape{3}, {3, 8, 7}); + auto reshape = std::make_shared(inputConvert, reshapeConst, false); + + auto secondConvert = std::make_shared(reshape, ov::element::f32); + auto add = std::make_shared(secondConvert, inputParams.back()); + + ov::ResultVector results{std::make_shared(add)}; + function = std::make_shared(results, inputParams, "ConvertBoolMath"); + } +}; + +TEST_F(ConvertBoolMathTest, smoke_CompareWithRefs) { + run(); +} + +} // namespace test +} // namespace ov \ No newline at end of file