|
11 | 11 | #include "openvino/op/add.hpp" |
12 | 12 | #include "openvino/op/concat.hpp" |
13 | 13 | #include "openvino/op/constant.hpp" |
| 14 | +#include "openvino/op/convert.hpp" |
14 | 15 | #include "openvino/op/multiply.hpp" |
15 | 16 | #include "openvino/pass/constant_folding.hpp" |
16 | 17 | #include "openvino/pass/manager.hpp" |
17 | 18 | #include "openvino/pass/serialize.hpp" |
18 | 19 | #include "openvino/runtime/core.hpp" |
| 20 | +#include "transformations/common_optimizations/compress_float_constants.hpp" |
19 | 21 |
|
20 | 22 | class CustomOpsSerializationTest : public ::testing::Test { |
21 | 23 | protected: |
@@ -186,7 +188,7 @@ TEST(PostponedConstantTest, ConcatWithPostponedConstant) { |
186 | 188 |
|
187 | 189 | auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); |
188 | 190 |
|
189 | | - ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); |
| 191 | + ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model)); |
190 | 192 | } |
191 | 193 | ov::Core core; |
192 | 194 |
|
@@ -230,7 +232,7 @@ TEST(PostponedConstantTest, SubgraphExclusion) { |
230 | 232 | auto model = |
231 | 233 | std::make_shared<ov::Model>(final_add->outputs(), ov::ParameterVector{param}, "SubgraphExclusionModel"); |
232 | 234 |
|
233 | | - ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); |
| 235 | + ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model)); |
234 | 236 | } |
235 | 237 | ov::Core core; |
236 | 238 |
|
@@ -274,7 +276,7 @@ TEST(PostponedConstantTest, NodeWithMultipleConsumers) { |
274 | 276 |
|
275 | 277 | concat->get_rt_info()["postponed_constant"] = true; |
276 | 278 |
|
277 | | - ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); |
| 279 | + ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model)); |
278 | 280 | } |
279 | 281 | ov::Core core; |
280 | 282 |
|
@@ -330,7 +332,7 @@ TEST(PostponedConstantTest, ModelIsUnchangedAfterSerialization) { |
330 | 332 | ov::pass::disable_constant_folding(concat); |
331 | 333 |
|
332 | 334 | auto model_copy = model->clone(); |
333 | | - ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); |
| 335 | + ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model)); |
334 | 336 |
|
335 | 337 | const auto& [success, message] = compare_functions(model_copy, model, true, true, true, true, true); |
336 | 338 | ASSERT_TRUE(success) << message; |
@@ -358,3 +360,91 @@ TEST(PostponedConstantTest, ModelIsUnchangedAfterSerialization) { |
358 | 360 | ASSERT_TRUE(success) << message; |
359 | 361 | } |
360 | 362 | } |
| 363 | + |
| 364 | +TEST(PostponedConstantTest, F16Compression2Inputs) { |
| 365 | + std::stringstream serialized_xml, serialized_bin; |
| 366 | + { |
| 367 | + auto const1 = |
| 368 | + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); |
| 369 | + auto const2 = |
| 370 | + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); |
| 371 | + auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0); |
| 372 | + concat->get_rt_info()["postponed_constant"] = true; |
| 373 | + |
| 374 | + auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); |
| 375 | + auto add = std::make_shared<ov::op::v1::Add>(concat, param); |
| 376 | + |
| 377 | + auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); |
| 378 | + |
| 379 | + // in case of postponed_constant + postponed f16 compression, f16 -> f32 convert should be added after postponed |
| 380 | + // constant |
| 381 | + bool postponed = true; |
| 382 | + ov::pass::compress_model_to_f16(model, postponed); |
| 383 | + |
| 384 | + ASSERT_NO_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model)); |
| 385 | + } |
| 386 | + ov::Core core; |
| 387 | + |
| 388 | + auto weights = serialized_bin.str(); |
| 389 | + ov::Tensor weights_tensor(ov::element::u8, ov::Shape{weights.size()}, weights.data()); |
| 390 | + |
| 391 | + auto deserialized_model = core.read_model(serialized_xml.str(), weights_tensor); |
| 392 | + |
| 393 | + { |
| 394 | + auto constant = std::make_shared<ov::op::v0::Constant>(ov::element::f16, |
| 395 | + ov::Shape{4, 2}, |
| 396 | + std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8}); |
| 397 | + auto convert = std::make_shared<ov::op::v0::Convert>(constant, ov::element::f32); |
| 398 | + auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); |
| 399 | + auto add = std::make_shared<ov::op::v1::Add>(convert, param); |
| 400 | + |
| 401 | + auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); |
| 402 | + |
| 403 | + const auto& [success, message] = |
| 404 | + compare_functions(deserialized_model, expected, true, false, false, true, true); |
| 405 | + ASSERT_TRUE(success) << message; |
| 406 | + } |
| 407 | +} |
| 408 | + |
| 409 | +TEST(PostponedConstantTest, F16CompressionNotPostponned) { |
| 410 | + std::stringstream serialized_xml, serialized_bin; |
| 411 | + auto check_model = [](const std::shared_ptr<ov::Model>& model) { |
| 412 | + auto const1 = |
| 413 | + std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); |
| 414 | + auto convert1 = std::make_shared<ov::op::v0::Convert>(const1, ov::element::f32); |
| 415 | + auto const2 = |
| 416 | + std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); |
| 417 | + auto convert2 = std::make_shared<ov::op::v0::Convert>(const2, ov::element::f32); |
| 418 | + auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{convert1, convert2}, 0); |
| 419 | + concat->get_rt_info()["postponed_constant"] = true; |
| 420 | + |
| 421 | + auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); |
| 422 | + auto add = std::make_shared<ov::op::v1::Add>(concat, param); |
| 423 | + |
| 424 | + auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); |
| 425 | + |
| 426 | + const auto& [success, message] = compare_functions(model, expected, true, false, false, true, true); |
| 427 | + ASSERT_TRUE(success) << message; |
| 428 | + }; |
| 429 | + |
| 430 | + { |
| 431 | + auto const1 = |
| 432 | + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); |
| 433 | + auto const2 = |
| 434 | + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); |
| 435 | + auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0); |
| 436 | + concat->get_rt_info()["postponed_constant"] = true; |
| 437 | + |
| 438 | + auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); |
| 439 | + auto add = std::make_shared<ov::op::v1::Add>(concat, param); |
| 440 | + |
| 441 | + auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); |
| 442 | + |
| 443 | + bool postponed = false; |
| 444 | + ov::pass::compress_model_to_f16(model, postponed); |
| 445 | + |
| 446 | + check_model(model); |
| 447 | + |
| 448 | + ASSERT_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model), ov::Exception); |
| 449 | + } |
| 450 | +} |
0 commit comments