Skip to content

Commit 907413a

Browse files
committed
Fixes and tests
1 parent c1a0d48 commit 907413a

File tree

3 files changed

+125
-16
lines changed

3 files changed

+125
-16
lines changed

src/common/transformations/src/transformations/common_optimizations/compress_float_constants.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "transformations/rt_info/decompression.hpp"
2424
#include "transformations/rt_info/disable_fp16_compression.hpp"
2525
#include "transformations/rt_info/old_api_map_element_type_attribute.hpp"
26+
#include "openvino/core/type.hpp"
2627

2728
namespace {
2829
template <ov::element::Type_t PREC_FROM>
@@ -197,36 +198,49 @@ ov::pass::CompressFloatConstantsImpl::CompressFloatConstantsImpl(bool postponed)
197198
}
198199
auto target_inputs_to_replace = const_node->get_output_target_inputs(0);
199200

201+
// Check if the next node is a postponed constant. It will be constant_folded later during serialization.
200202
auto postponed_constant_node = [&]() -> std::shared_ptr<ov::Node> {
201203
if (target_inputs_to_replace.size() == 1 &&
202204
target_inputs_to_replace.begin()->get_node()->get_rt_info().count("postponed_constant")) {
203205
return target_inputs_to_replace.begin()->get_node()->shared_from_this();
204206
}
205207
return nullptr;
206208
}();
207-
// postponed_constant_node will be constant_folded later during serialization.
208-
// If f16 conversion is also postponed, we need to insert Convert node after the postponed_constant_node
209209

210-
postponed_constant_node = nullptr;
211210
std::shared_ptr<ov::Node> convert;
212211
if (postponed_constant_node && postponed) {
213-
target_inputs_to_replace = postponed_constant_node->get_output_target_inputs(0);
214-
postponed_constant_node->set_friendly_name(const_node->get_friendly_name() + "_compressedblablabla");
215-
convert = std::make_shared<ov::op::v0::Convert>(postponed_constant_node, const_node->get_element_type());
212+
// If f16 conversion is also postponed, we need to insert Convert after the postponed_constant_node
213+
if (is_fp16_compression_postponed(postponed_constant_node->get_rt_info())) {
214+
// Convert was already added after postponed_constant_node. Get it and just update rt info
215+
OPENVINO_ASSERT(postponed_constant_node->get_output_target_inputs(0).size() == 1);
216+
auto next_node = postponed_constant_node->get_output_target_inputs(0).begin()->get_node();
217+
OPENVINO_ASSERT(ov::as_type<ov::op::v0::Convert>(next_node));
218+
ov::copy_runtime_info(const_node, next_node->shared_from_this());
219+
} else {
220+
target_inputs_to_replace = postponed_constant_node->get_output_target_inputs(0);
221+
postponed_constant_node->set_friendly_name(const_node->get_friendly_name() + "_compressed");
222+
convert = std::make_shared<ov::op::v0::Convert>(postponed_constant_node, const_node->get_element_type());
223+
postpone_fp16_compression(postponed_constant_node->get_rt_info());
224+
postpone_fp16_compression(postponed_constant_node->get_output_tensor(0).get_rt_info());
225+
}
216226
} else {
217227
convert = std::make_shared<ov::op::v0::Convert>(new_const, const_node->get_element_type());
218228
}
219229

220-
convert->set_friendly_name(const_node->get_friendly_name());
221-
new_const->set_friendly_name(const_node->get_friendly_name() + "_compressed");
222-
ov::copy_runtime_info(const_node, convert);
223-
ov::mark_as_decompression(convert);
230+
if (convert) {
231+
convert->set_friendly_name(const_node->get_friendly_name());
232+
new_const->set_friendly_name(const_node->get_friendly_name() + "_compressed");
233+
ov::copy_runtime_info(const_node, convert);
234+
ov::mark_as_decompression(convert);
235+
}
224236
if (postponed) {
225237
postpone_fp16_compression(new_const->get_rt_info());
226238
postpone_fp16_compression(new_const->get_output_tensor(0).get_rt_info());
227239

228-
for (const auto& target_input : target_inputs_to_replace) {
229-
target_input.replace_source_output(convert);
240+
if (convert) {
241+
for (const auto& target_input : target_inputs_to_replace) {
242+
target_input.replace_source_output(convert);
243+
}
230244
}
231245
} else {
232246
ov::replace_node(const_node, convert);

src/core/src/xml_util/xml_serialize_util.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class PostponedConstantReplacer {
6666
m_constant = outputs[0].get_node_shared_ptr();
6767
m_node = m_constant.get();
6868
m_node->set_friendly_name(node->get_friendly_name());
69+
if (ov::is_fp16_compression_postponed(node->get_rt_info())) {
70+
// postpone_fp16_compression is not copied by constant_fold
71+
ov::postpone_fp16_compression(m_node->get_rt_info());
72+
}
6973
}
7074
}
7175
};

src/core/tests/pass/serialization/custom_ops.cpp

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
#include "openvino/op/add.hpp"
1212
#include "openvino/op/concat.hpp"
1313
#include "openvino/op/constant.hpp"
14+
#include "openvino/op/convert.hpp"
1415
#include "openvino/op/multiply.hpp"
1516
#include "openvino/pass/constant_folding.hpp"
1617
#include "openvino/pass/manager.hpp"
1718
#include "openvino/pass/serialize.hpp"
1819
#include "openvino/runtime/core.hpp"
20+
#include "transformations/common_optimizations/compress_float_constants.hpp"
1921

2022
class CustomOpsSerializationTest : public ::testing::Test {
2123
protected:
@@ -186,7 +188,7 @@ TEST(PostponedConstantTest, ConcatWithPostponedConstant) {
186188

187189
auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
188190

189-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
191+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
190192
}
191193
ov::Core core;
192194

@@ -230,7 +232,7 @@ TEST(PostponedConstantTest, SubgraphExclusion) {
230232
auto model =
231233
std::make_shared<ov::Model>(final_add->outputs(), ov::ParameterVector{param}, "SubgraphExclusionModel");
232234

233-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
235+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
234236
}
235237
ov::Core core;
236238

@@ -274,7 +276,7 @@ TEST(PostponedConstantTest, NodeWithMultipleConsumers) {
274276

275277
concat->get_rt_info()["postponed_constant"] = true;
276278

277-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
279+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
278280
}
279281
ov::Core core;
280282

@@ -330,7 +332,7 @@ TEST(PostponedConstantTest, ModelIsUnchangedAfterSerialization) {
330332
ov::pass::disable_constant_folding(concat);
331333

332334
auto model_copy = model->clone();
333-
ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model);
335+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
334336

335337
const auto& [success, message] = compare_functions(model_copy, model, true, true, true, true, true);
336338
ASSERT_TRUE(success) << message;
@@ -358,3 +360,92 @@ TEST(PostponedConstantTest, ModelIsUnchangedAfterSerialization) {
358360
ASSERT_TRUE(success) << message;
359361
}
360362
}
363+
364+
TEST(PostponedConstantTest, F16Compression2Inputs) {
365+
std::stringstream serialized_xml, serialized_bin;
366+
{
367+
auto const1 =
368+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4});
369+
auto const2 =
370+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8});
371+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0);
372+
concat->get_rt_info()["postponed_constant"] = true;
373+
374+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
375+
auto add = std::make_shared<ov::op::v1::Add>(concat, param);
376+
377+
auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
378+
379+
//in case of postponed_constant + postponed f16 compression, f16 -> f32 convert should be added after postponed constant
380+
bool postponed = true;
381+
ov::pass::compress_model_to_f16(model, postponed);
382+
383+
ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model));
384+
}
385+
ov::Core core;
386+
387+
auto weights = serialized_bin.str();
388+
ov::Tensor weights_tensor(ov::element::u8, ov::Shape{weights.size()}, weights.data());
389+
390+
auto deserialized_model = core.read_model(serialized_xml.str(), weights_tensor);
391+
392+
{
393+
auto constant = std::make_shared<ov::op::v0::Constant>(ov::element::f16,
394+
ov::Shape{4, 2},
395+
std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
396+
auto convert = std::make_shared<ov::op::v0::Convert>(constant, ov::element::f32);
397+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
398+
auto add = std::make_shared<ov::op::v1::Add>(convert, param);
399+
400+
auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
401+
402+
const auto& [success, message] =
403+
compare_functions(deserialized_model, expected, true, false, false, true, true);
404+
ASSERT_TRUE(success) << message;
405+
}
406+
}
407+
408+
TEST(PostponedConstantTest, F16CompressionNotPostponned) {
409+
std::stringstream serialized_xml, serialized_bin;
410+
auto check_model = [](const std::shared_ptr<ov::Model>& model) {
411+
auto const1 =
412+
std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4});
413+
auto convert1 = std::make_shared<ov::op::v0::Convert>(const1, ov::element::f32);
414+
auto const2 =
415+
std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8});
416+
auto convert2 = std::make_shared<ov::op::v0::Convert>(const2, ov::element::f32);
417+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{convert1, convert2}, 0);
418+
concat->get_rt_info()["postponed_constant"] = true;
419+
420+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
421+
auto add = std::make_shared<ov::op::v1::Add>(concat, param);
422+
423+
auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
424+
425+
const auto& [success, message] =
426+
compare_functions(model, expected, true, false, false, true, true);
427+
ASSERT_TRUE(success) << message;
428+
};
429+
430+
{
431+
auto const1 =
432+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4});
433+
auto const2 =
434+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8});
435+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0);
436+
concat->get_rt_info()["postponed_constant"] = true;
437+
438+
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2});
439+
auto add = std::make_shared<ov::op::v1::Add>(concat, param);
440+
441+
auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel");
442+
443+
bool postponed = false;
444+
ov::pass::compress_model_to_f16(model, postponed);
445+
446+
check_model(model);
447+
448+
ASSERT_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model), ov::Exception);
449+
}
450+
}
451+

0 commit comments

Comments
 (0)