Skip to content

Commit

Permalink
[CPU] Add Clamp before Convert when bool is replaced with u8 (openvin…
Browse files Browse the repository at this point in the history
…otoolkit#25253)

### Details:
CPU plugin doesn't natively support `boolean` data type, thus it's
replaced with `u8` during the convert precision transformation pass.
However, casting numerical types to `bool` implies clamping the
numerical values with the interval [0; 1] (either true or false), so to
mimic this behavior a Clamp operation should be inserted before the
modified convert.

### Tickets:
 - CVS-145166
  • Loading branch information
maxnick authored Jul 4, 2024
1 parent 82af474 commit c89db83
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,37 +235,37 @@ bool Transformations::fuse_type_to_convert(const std::shared_ptr<ov::Node>& node
return false;
const auto& to = it->second;

// For Convert node, converting precision from floating point to boolean will lead to mathematical
// error, because here the output precision boolean is replaced by u8:
// - floating point value 0.01 is converted to be 1 for boolean, but 0 for u8 - need to insert Ceil.
// - floating point value 256 is converted to be 1 for boolean, but 0 for u8 - need to insert Min(x, UINT8_MAX)
// - floating point value -256 is converted to be 1 for boolean, but 0 for u8 - need to insert Abs before Min.
// Thus an Abs, Ceil and Min nodes should be added before the Convert node for this scenario.
if (convert->input(0).get_element_type().is_real() &&
convert->get_convert_element_type() == ov::element::boolean && to.is_integral_number()) {
if (convert->get_convert_element_type() == ov::element::boolean && to.is_integral_number()) {
// For Convert node, converting precision from numerical data types to boolean will lead to mathematical
// error, because here the output precision boolean is replaced by u8:
// - floating point value 0.01 is converted to be 1 for boolean, but 0 for u8 - need to insert Ceil.
// - either float or int values should be clipped with the interval [0; 1] to mimic bool cast behavior, i.e.
// 0 - is false, 1 - is true
// - to perform clamping correctly an Abs op should be inserted before Clamp
// Thus an Abs, Ceil and Clamp nodes should be added before the Convert node for this scenario.
ov::pass::NodeRegistry reg;
const auto& in_prec = convert->get_input_element_type(0);
auto data = convert->input_value(0).get_node_shared_ptr();
auto parent_node = convert->input_value(0).get_node_shared_ptr();
auto item = precisions.find(in_prec);
if (item != precisions.end()) {
// Add convert node for unsupported precision, such as FP64
data = reg.make<ov::opset10::Convert>(data, item->second);
// Add convert node for unsupported precision, such as FP64 or INT64
parent_node = reg.make<ov::opset10::Convert>(parent_node, item->second);
}
const auto abs = reg.make<ov::opset10::Abs>(data);
const auto to_max_value = reg.make<ov::opset10::Constant>(ov::util::make_tensor_of_max_value(to));
const auto to_max_convert = reg.make<ov::opset10::Convert>(to_max_value, abs->get_output_element_type(0));
const auto min = reg.make<ov::opset10::Minimum>(abs, to_max_convert);
const auto ceil = reg.make<ov::opset10::Ceiling>(min);
const auto new_convert = reg.make<ov::opset10::Convert>(ceil, to);
if (in_prec.is_signed()) {
parent_node = reg.make<ov::opset10::Abs>(parent_node);
}
if (in_prec.is_real()) {
parent_node = reg.make<ov::opset10::Ceiling>(parent_node);
}
parent_node = reg.make<ov::opset10::Clamp>(parent_node, 0, 1);
const auto new_convert = reg.make<ov::opset10::Convert>(parent_node, to);
new_convert->set_friendly_name(convert->get_friendly_name());
ov::copy_runtime_info(convert, reg.get());
ov::replace_node(convert, new_convert);
return true;
} else {
convert->set_convert_element_type(to);
return true;
}
return false;
convert->set_convert_element_type(to);
return true;
}

void Transformations::UpToLpt() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "shared_test_classes/base/ov_subgraph.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"

namespace ov {
namespace test {
// ┌────────┐
// │ Param │
// └───┬────┘
// │ f32
//
// ┌───┴────┐
// │Convert │
// └───┬────┘
// │ bool
//
// ┌───┴────┐
// │Reshape │
// └───┬────┘
// │ bool
//
// ┌───┴────┐ ┌────────┐
// │Convert │ │ Param │
// └───┬────┘ └───┬────┘
// │ f32 │ f32
// │ │
// │ ┌────────┐ │
// └─────┤ Add ├───┘
// └───┬────┘
// │ f32
//
// ┌───┴────┐
// │Reshape │
// └────────┘

class ConvertBoolMathTest : public SubgraphBaseStaticTest {
public:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;

ov::ParameterVector inputParams{std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{24, 7}),
std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{3, 8, 7})};

auto inputConvert = std::make_shared<ov::opset10::Convert>(inputParams.front(), ov::element::boolean);

auto reshapeConst = ov::opset10::Constant::create<int32_t>(ov::element::i32, ov::Shape{3}, {3, 8, 7});
auto reshape = std::make_shared<ov::opset10::Reshape>(inputConvert, reshapeConst, false);

auto secondConvert = std::make_shared<ov::opset10::Convert>(reshape, ov::element::f32);
auto add = std::make_shared<ov::opset10::Add>(secondConvert, inputParams.back());

ov::ResultVector results{std::make_shared<ov::opset10::Result>(add)};
function = std::make_shared<ov::Model>(results, inputParams, "ConvertBoolMath");
}
};

TEST_F(ConvertBoolMathTest, smoke_CompareWithRefs) {
run();
}

} // namespace test
} // namespace ov

0 comments on commit c89db83

Please sign in to comment.