Skip to content

Commit 70409ad

Browse files
committed
Fixes and tests
1 parent c1a0d48 commit 70409ad

File tree

3 files changed

+125
-16
lines changed

3 files changed

+125
-16
lines changed

src/common/transformations/src/transformations/common_optimizations/compress_float_constants.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "itt.hpp"
88
#include "openvino/core/graph_util.hpp"
99
#include "openvino/core/rt_info.hpp"
10+
#include "openvino/core/type.hpp"
1011
#include "openvino/op/constant.hpp"
1112
#include "openvino/op/convert.hpp"
1213
#include "openvino/op/fake_convert.hpp"
@@ -197,36 +198,50 @@ ov::pass::CompressFloatConstantsImpl::CompressFloatConstantsImpl(bool postponed)
197198
}
198199
auto target_inputs_to_replace = const_node->get_output_target_inputs(0);
199200

201+
// Check if the next node is a postponed constant. It will be constant_folded later during serialization.
200202
auto postponed_constant_node = [&]() -> std::shared_ptr<ov::Node> {
201203
if (target_inputs_to_replace.size() == 1 &&
202204
target_inputs_to_replace.begin()->get_node()->get_rt_info().count("postponed_constant")) {
203205
return target_inputs_to_replace.begin()->get_node()->shared_from_this();
204206
}
205207
return nullptr;
206208
}();
207-
// postponed_constant_node will be constant_folded later during serialization.
208-
// If f16 conversion is also postponed, we need to insert Convert node after the postponed_constant_node
209209

210-
postponed_constant_node = nullptr;
211210
std::shared_ptr<ov::Node> convert;
212211
if (postponed_constant_node && postponed) {
213-
target_inputs_to_replace = postponed_constant_node->get_output_target_inputs(0);
214-
postponed_constant_node->set_friendly_name(const_node->get_friendly_name() + "_compressedblablabla");
215-
convert = std::make_shared<ov::op::v0::Convert>(postponed_constant_node, const_node->get_element_type());
212+
// If f16 conversion is also postponed, we need to insert Convert after the postponed_constant_node
213+
if (is_fp16_compression_postponed(postponed_constant_node->get_rt_info())) {
214+
// Convert was already added after postponed_constant_node. Get it and just update rt info
215+
OPENVINO_ASSERT(postponed_constant_node->get_output_target_inputs(0).size() == 1);
216+
auto next_node = postponed_constant_node->get_output_target_inputs(0).begin()->get_node();
217+
OPENVINO_ASSERT(ov::as_type<ov::op::v0::Convert>(next_node));
218+
ov::copy_runtime_info(const_node, next_node->shared_from_this());
219+
} else {
220+
target_inputs_to_replace = postponed_constant_node->get_output_target_inputs(0);
221+
postponed_constant_node->set_friendly_name(const_node->get_friendly_name() + "_compressed");
222+
convert =
223+
std::make_shared<ov::op::v0::Convert>(postponed_constant_node, const_node->get_element_type());
224+
postpone_fp16_compression(postponed_constant_node->get_rt_info());
225+
postpone_fp16_compression(postponed_constant_node->get_output_tensor(0).get_rt_info());
226+
}
216227
} else {
217228
convert = std::make_shared<ov::op::v0::Convert>(new_const, const_node->get_element_type());
218229
}
219230

220-
convert->set_friendly_name(const_node->get_friendly_name());
221-
new_const->set_friendly_name(const_node->get_friendly_name() + "_compressed");
222-
ov::copy_runtime_info(const_node, convert);
223-
ov::mark_as_decompression(convert);
231+
if (convert) {
232+
convert->set_friendly_name(const_node->get_friendly_name());
233+
new_const->set_friendly_name(const_node->get_friendly_name() + "_compressed");
234+
ov::copy_runtime_info(const_node, convert);
235+
ov::mark_as_decompression(convert);
236+
}
224237
if (postponed) {
225238
postpone_fp16_compression(new_const->get_rt_info());
226239
postpone_fp16_compression(new_const->get_output_tensor(0).get_rt_info());
227240

228-
for (const auto& target_input : target_inputs_to_replace) {
229-
target_input.replace_source_output(convert);
241+
if (convert) {
242+
for (const auto& target_input : target_inputs_to_replace) {
243+
target_input.replace_source_output(convert);
244+
}
230245
}
231246
} else {
232247
ov::replace_node(const_node, convert);

src/core/src/xml_util/xml_serialize_util.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class PostponedConstantReplacer {
6666
m_constant = outputs[0].get_node_shared_ptr();
6767
m_node = m_constant.get();
6868
m_node->set_friendly_name(node->get_friendly_name());
69+
if (ov::is_fp16_compression_postponed(node->get_rt_info())) {
70+
// postpone_fp16_compression is not copied by constant_fold
71+
ov::postpone_fp16_compression(m_node->get_rt_info());
72+
}
6973
}
7074
}
7175
};

src/core/tests/pass/serialization/custom_ops.cpp

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
#include "openvino/op/add.hpp"
1212
#include "openvino/op/concat.hpp"
1313
#include "openvino/op/constant.hpp"
14+
#include "openvino/op/convert.hpp"
1415
#include "openvino/op/multiply.hpp"
1516
#include "openvino/pass/constant_folding.hpp"
1617
#include "openvino/pass/manager.hpp"
1718
#include "openvino/pass/serialize.hpp"
1819
#include "openvino/runtime/core.hpp"
20+
#include "transformations/common_optimizations/compress_float_constants.hpp"
1921

2022
class CustomOpsSerializationTest : public ::testing::Test {
2123
protected:
@@ -186,7 +188,7 @@ TEST(PostponedConstantTest, ConcatWithPostponedConstant) {
186188

187189
auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
188190

189-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
191+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
190192
}
191193
ov::Core core;
192194

@@ -230,7 +232,7 @@ TEST(PostponedConstantTest, SubgraphExclusion) {
230232
auto model =
231233
std::make_shared<ov::Model>(final_add->outputs(), ov::ParameterVector{param}, "SubgraphExclusionModel");
232234

233-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
235+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
234236
}
235237
ov::Core core;
236238

@@ -274,7 +276,7 @@ TEST(PostponedConstantTest, NodeWithMultipleConsumers) {
274276

275277
concat->get_rt_info()["postponed_constant"] = true;
276278

277-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
279+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
278280
}
279281
ov::Core core;
280282

@@ -330,7 +332,7 @@ TEST(PostponedConstantTest, ModelIsUnchangedAfterSerialization) {
330332
ov::pass::disable_constant_folding(concat);
331333

332334
auto model_copy = model->clone();
333-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
335+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
334336

335337
const auto& [success, message] = compare_functions(model_copy, model, true, true, true, true, true);
336338
ASSERT_TRUE(success) << message;
@@ -358,3 +360,91 @@ TEST(PostponedConstantTest, ModelIsUnchangedAfterSerialization) {
358360
ASSERT_TRUE(success) << message;
359361
}
360362
}
363+
364+
TEST(PostponedConstantTest, F16Compression2Inputs) {
365+
std::stringstream serialized_xml, serialized_bin;
366+
{
367+
auto const1 =
368+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4});
369+
auto const2 =
370+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8});
371+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0);
372+
concat->get_rt_info()["postponed_constant"] = true;
373+
374+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
375+
auto add = std::make_shared<ov::op::v1::Add>(concat, param);
376+
377+
auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
378+
379+
// in case of postponed_constant + postponed f16 compression, f16 -> f32 convert should be added after postponed
380+
// constant
381+
bool postponed = true;
382+
ov::pass::compress_model_to_f16(model, postponed);
383+
384+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
385+
}
386+
ov::Core core;
387+
388+
auto weights = serialized_bin.str();
389+
ov::Tensor weights_tensor(ov::element::u8, ov::Shape{weights.size()}, weights.data());
390+
391+
auto deserialized_model = core.read_model(serialized_xml.str(), weights_tensor);
392+
393+
{
394+
auto constant = std::make_shared<ov::op::v0::Constant>(ov::element::f16,
395+
ov::Shape{4, 2},
396+
std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
397+
auto convert = std::make_shared<ov::op::v0::Convert>(constant, ov::element::f32);
398+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
399+
auto add = std::make_shared<ov::op::v1::Add>(convert, param);
400+
401+
auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
402+
403+
const auto& [success, message] =
404+
compare_functions(deserialized_model, expected, true, false, false, true, true);
405+
ASSERT_TRUE(success) << message;
406+
}
407+
}
408+
409+
TEST(PostponedConstantTest, F16CompressionNotPostponned) {
410+
std::stringstream serialized_xml, serialized_bin;
411+
auto check_model = [](const std::shared_ptr<ov::Model>& model) {
412+
auto const1 =
413+
std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4});
414+
auto convert1 = std::make_shared<ov::op::v0::Convert>(const1, ov::element::f32);
415+
auto const2 =
416+
std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8});
417+
auto convert2 = std::make_shared<ov::op::v0::Convert>(const2, ov::element::f32);
418+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{convert1, convert2}, 0);
419+
concat->get_rt_info()["postponed_constant"] = true;
420+
421+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
422+
auto add = std::make_shared<ov::op::v1::Add>(concat, param);
423+
424+
auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
425+
426+
const auto& [success, message] = compare_functions(model, expected, true, false, false, true, true);
427+
ASSERT_TRUE(success) << message;
428+
};
429+
430+
{
431+
auto const1 =
432+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4});
433+
auto const2 =
434+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8});
435+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0);
436+
concat->get_rt_info()["postponed_constant"] = true;
437+
438+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
439+
auto add = std::make_shared<ov::op::v1::Add>(concat, param);
440+
441+
auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
442+
443+
bool postponed = false;
444+
ov::pass::compress_model_to_f16(model, postponed);
445+
446+
check_model(model);
447+
448+
ASSERT_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model), ov::Exception);
449+
}
450+
}

0 commit comments

Comments
 (0)