From aa63ad91665c955f625712bbbcc0a36dac9cb8d3 Mon Sep 17 00:00:00 2001 From: Vladimir Zlobin Date: Thu, 23 Jan 2025 11:17:29 +0400 Subject: [PATCH] Add Phi-3.5-vision-instruct and Phi-3-vision-128k-instruct (#1609) Ticket 156662 --------- Co-authored-by: Ilya Lavrenov --- .github/workflows/mac.yml | 2 +- SUPPORTED_MODELS.md | 21 + src/cpp/src/visual_language/clip.cpp | 2 +- src/cpp/src/visual_language/clip.hpp | 1 + .../src/visual_language/inputs_embedder.cpp | 459 ++++++++++++++++-- .../src/visual_language/inputs_embedder.hpp | 1 + .../src/visual_language/processor_config.cpp | 4 + .../src/visual_language/processor_config.hpp | 7 +- .../src/visual_language/vision_encoder.cpp | 205 ++++++++ .../src/visual_language/vision_encoder.hpp | 4 + src/cpp/src/visual_language/vlm_config.cpp | 9 + src/cpp/src/visual_language/vlm_config.hpp | 3 + .../src/visual_language/vlm_model_type.hpp | 2 + tests/python_tests/test_vlm_pipeline.py | 21 +- 14 files changed, 703 insertions(+), 38 deletions(-) diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 57776be64b..e444443ea7 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -17,7 +17,7 @@ concurrency: env: PYTHON_VERSION: '3.10' - OV_BRANCH: 'master' + OV_BRANCH: 7f56fcd4658c6a427111ac835e809ddd87f0cad2 OV_TARBALL: '' jobs: diff --git a/SUPPORTED_MODELS.md b/SUPPORTED_MODELS.md index f79234489d..3064fb58c1 100644 --- a/SUPPORTED_MODELS.md +++ b/SUPPORTED_MODELS.md @@ -312,6 +312,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize Models LoRA support Example HuggingFace Models + Notes InternVL2 @@ -329,6 +330,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
  • OpenGVLab/InternVL2_5-8B
  • + LLaVA @@ -339,6 +341,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
  • llava-hf/llava-1.5-7b-hf
  • + LLaVA-NeXT @@ -351,6 +354,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
  • llava-hf/llama3-llava-next-8b-hf
  • + MiniCPMV @@ -361,6 +365,22 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
  • openbmb/MiniCPM-V-2_6
  • + + + + Phi3VForCausalLM + phi3_v + Not supported + + + + +
  • GPU isn't supported
  • +
  • These models' configs aren't consistent. It's required to override the default eos_token_id with the one from a tokenizer: generation_config.set_eos_token_id(pipe.get_tokenizer().get_eos_token_id()).
  • + Qwen2-VL @@ -372,6 +392,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
  • Qwen/Qwen2-VL-7B-Instruct
  • + diff --git a/src/cpp/src/visual_language/clip.cpp b/src/cpp/src/visual_language/clip.cpp index 30a6dff5ae..9347f63074 100644 --- a/src/cpp/src/visual_language/clip.cpp +++ b/src/cpp/src/visual_language/clip.cpp @@ -12,7 +12,7 @@ static float clip_lerp(float s, float e, float t) { } // Bilinear resize function -static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { +void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { dst.nx = target_width; dst.ny = target_height; dst.buf.resize(3 * target_width * target_height); diff --git a/src/cpp/src/visual_language/clip.hpp b/src/cpp/src/visual_language/clip.hpp index 4bdb4542d0..e00ac2fc40 100644 --- a/src/cpp/src/visual_language/clip.hpp +++ b/src/cpp/src/visual_language/clip.hpp @@ -31,6 +31,7 @@ struct clip_image_f32 { }; void bicubic_resize(const clip_image_u8& img, clip_image_u8& dst, int target_width, int target_height); +void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height); /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ clip_image_f32 clip_image_preprocess(struct clip_ctx& ctx, const clip_image_u8& img); diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index 4f3812862c..66b17e5804 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -7,15 +7,10 @@ #include "visual_language/clip.hpp" #include "visual_language/vision_encoder.hpp" #include "visual_language/embedding_model.hpp" +#include "openvino/opsets/opset13.hpp" #include "utils.hpp" - - -namespace { - -constexpr size_t BATCH_SIZE = 1; - -} // namespace +#include namespace ov::genai { @@ -155,17 +150,8 @@ class InputsEmbedder::IInputsEmbedder { ), m_tokenizer(tokenizer) { } - ov::Tensor get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics, const std::string& chat_template_fallback = {}) { - ov::Tensor encoded_input_ids; + std::pair apply_chat_template_tokenize(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics, const std::string& chat_template_fallback = {}) { if (m_is_chat_conversation) { - // KV cache in model already contains prompts and answers from previous iterations. - // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns - // token_ids = {, ...}. So if tokenizer applies only to the new prompt, - // will be inserted on every iteration. - // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt - // and takes only the difference between them. - // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but - // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. m_history.push_back({{"role", "user"}, {"content", prompt}}); constexpr bool add_generation_prompt = true; std::string new_templated_chat_history; @@ -177,9 +163,31 @@ class InputsEmbedder::IInputsEmbedder { } auto start_tokenizer_time = std::chrono::steady_clock::now(); ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)).input_ids; - TokenizedInputs prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); + ov::Tensor prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)).input_ids; + auto end_tokenizer_time = std::chrono::steady_clock::now(); + metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); + m_templated_chat_history = std::move(new_templated_chat_history); + return {new_chat_tokens, prev_chat_tokens}; + } else { + auto start_tokenizer_time = std::chrono::steady_clock::now(); + ov::Tensor encoded_input_ids = m_tokenizer.encode(prompt).input_ids; auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); + return {encoded_input_ids, ov::Tensor()}; + } + } + + ov::Tensor update_history(const ov::Tensor& new_chat_tokens, const ov::Tensor& prev_chat_tokens) { + if (m_is_chat_conversation) { + ov::Tensor encoded_input_ids; + // KV cache in model already contains prompts and answers from previous iterations. + // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns + // token_ids = {, ...}. So if tokenizer applies only to the new prompt, + // will be inserted on every iteration. + // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt + // and takes only the difference between them. + // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but + // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. // some symbols combinations can be encoded by the tokenizer in different ways // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history @@ -187,7 +195,7 @@ class InputsEmbedder::IInputsEmbedder { size_t trusted_history_length = 0; if (!m_tokenized_history.empty()) { std::set stop_tokens = {m_tokenizer.get_eos_token_id()}; - trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens); + trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens, m_tokenized_history, stop_tokens); } if (m_tokenized_history.empty()) { @@ -213,25 +221,25 @@ class InputsEmbedder::IInputsEmbedder { new_tensor.copy_to(encoded_input_ids); } else { encoded_input_ids = utils::subtract_chat_tokenized_inputs( - {new_chat_tokens}, prev_chat_tokens + {new_chat_tokens}, {prev_chat_tokens} ).input_ids; if (m_last_disappeared_token.has_value()) encoded_input_ids = ov::genai::utils::push_front_inputs(encoded_input_ids, *m_last_disappeared_token); } - m_templated_chat_history = std::move(new_templated_chat_history); m_tokenized_history.clear(); std::copy_n(new_chat_tokens.data(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history)); + return encoded_input_ids; } else { - auto start_tokenizer_time = std::chrono::steady_clock::now(); - encoded_input_ids = m_tokenizer.encode(prompt).input_ids; - auto end_tokenizer_time = std::chrono::steady_clock::now(); - metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); m_tokenized_history.clear(); - std::copy_n(encoded_input_ids.data(), encoded_input_ids.get_size(), std::back_inserter(m_tokenized_history)); + std::copy_n(new_chat_tokens.data(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history)); + return new_chat_tokens; } + } - return encoded_input_ids; + ov::Tensor get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics, const std::string& chat_template_fallback = "") { + const auto [new_chat_tokens, prev_chat_tokens] = apply_chat_template_tokenize(prompt, metrics, chat_template_fallback); + return update_history(new_chat_tokens, prev_chat_tokens); } /** @@ -687,6 +695,7 @@ class InputsEmbedderLLaVA : public InputsEmbedder::IInputsEmbedder { } size_t merged_seq_length = text_embeds_seq_length + total_image_seq_length - num_image_tokens; + constexpr size_t BATCH_SIZE = 1; ov::Tensor merged_embeds(text_embeds.get_element_type(), {BATCH_SIZE, merged_seq_length, hidden_size}); float* merged_data = merged_embeds.data(); @@ -1163,6 +1172,400 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder { } }; +namespace { +namespace phi3_v { +// Reimplementation of python +// N, L, C = image_features.shape +// assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0 +// num_images = N // (h_crop * w_crop) +// H = int(L**0.5) +// print(L, H) +// image_features_hd = ( +// image_features.reshape(N, H, H, C) # N, 24, 24, 1024 +// .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 +// .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 +// .reshape(N, -1, 4 * C) # N, 144, 4096 +// .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 +// .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 +// .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 +// ) +// Obtained in the following way +// import torch +// import openvino as ov +// import numpy as np +// class Model(torch.nn.Module): +// def forward(self, image_features, h_crop, w_crop): +// """ +// image_features: (num_images*num_crops, 24*24, 1024) +// output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops +// """ +// N, L, C = image_features.shape +// num_images = N // (h_crop * w_crop) +// H = (torch.tensor(L, dtype=torch.float32)**0.5).int() +// image_features_hd = ( +// image_features.reshape(N, H, H, C) # N, 24, 24, 1024 +// .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 +// .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 +// .reshape(N, -1, 4 * C) # N, 144, 4096 +// .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 +// .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 +// .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 +// return {"o": image_features_hd} +// model = Model() +// example_input = {"image_features": torch.rand((4, 576, 1024), dtype=torch.float32), "h_crop": torch.tensor(2, dtype=torch.int32), "w_crop": torch.tensor(2, dtype=torch.int32)} +// ov_model = ov.convert_model(model, example_input=example_input, input=ov.PartialShape([-1, 576, 1024])) +// # ov_model.outputs[0].get_tensor().set_names({"out"}) +// ov.save_model(ov_model, "reshape_hd_patches_2x2merge.xml") +// inp = np.arange(4 * 576 * 1024).reshape([4, 576, 1024]) +// test = ov.Core().compile_model(ov_model, "CPU") +// print(ov_model) +// print(test([inp, 2, 2])["o"].flatten()) +// 2. Run https://github.com/slyalin/openvino_devtools/blob/bcd4a51b1354b24b2316ac3e1c77b2f87ae7a497/openvino_devtools/ov2py.py with the IR. +// 3. Translate the printed Python implementation to C++. +ov::InferRequest create_hd_feature_transformer() { + using namespace ov; + using namespace element; + using namespace opset13; + using namespace std; + auto t0 = make_shared(f32, PartialShape{-1, 576, 1024}); + auto t1 = make_shared(i32, PartialShape{}); + auto t2 = make_shared(i32, PartialShape{}); + auto t3 = make_shared(t0); + auto t4 = make_shared(i64, Shape{}, vector{0}); + auto t5 = make_shared(i64, Shape{}, vector{0}); + auto t6 = make_shared(t3, t4, t5); + auto t7 = make_shared(i64, Shape{1}, vector{1}); + auto t8 = make_shared(t6, t7, false); + auto t9 = make_shared(i64, Shape{}, vector{1}); + auto t10 = make_shared(i64, Shape{}, vector{0}); + auto t11 = make_shared(t3, t9, t10); + auto t12 = make_shared(t11, element::f32); + auto t13 = make_shared(f32, Shape{}, vector{0.5}); + auto t14 = make_shared(t12, t13, "numpy"); + auto t15 = make_shared(t14, element::i32); + auto t16 = make_shared(t15, element::i64); + auto t17 = make_shared(i32, Shape{}, vector{0}); + auto t18 = make_shared(t16, t17); + auto t19 = make_shared(i64, Shape{1}, vector{2}); + auto t20 = make_shared(i64, Shape{}, vector{0}); + auto t21 = make_shared(t3, t19, t20); + auto t22 = make_shared(NodeVector{t8, t18, t18, t21}, 0); + auto t23 = make_shared(t0, t22, false); + auto t24 = make_shared(i64, Shape{}, vector{2}); + auto t25 = make_shared(t16, t24, "numpy"); + auto t26 = make_shared(t25); + auto t27 = make_shared(i32, Shape{}, vector{0}); + auto t28 = make_shared(t26, t27); + auto t29 = make_shared(i64, Shape{1}, vector{2}); + auto t30 = make_shared(i64, Shape{1}, vector{2}); + auto t31 = make_shared(NodeVector{t8, t28, t29, t28, t30, t21}, 0); + auto t32 = make_shared(t23, t31, false); + auto t33 = make_shared(i64, Shape{6}, vector{0, 1, 3, 2, 4, 5}); + auto t34 = make_shared(t32, t33); + auto t35 = make_shared(i64, Shape{1}, vector{-1}); + auto t36 = make_shared(i64, Shape{1}, vector{4}); + auto t37 = make_shared(t21, t36, "numpy"); + auto t38 = make_shared(NodeVector{t8, t35, t37}, 0); + auto t39 = make_shared(t34, t38, false); + auto t40 = make_shared(t1, t2, "numpy"); + auto t41 = make_shared(t40, element::i64); + auto t42 = make_shared(t6, t41, "numpy"); + auto t43 = make_shared(t42); + auto t44 = make_shared(i64, Shape{}, vector{0}); + auto t45 = make_shared(t43, t44); + auto t46 = make_shared(t1, element::i64); + auto t47 = make_shared(t46, t44); + auto t48 = make_shared(t2, element::i64); + auto t49 = make_shared(t48, t44); + auto t50 = make_shared(i64, Shape{1}, vector{-1}); + auto t51 = make_shared(NodeVector{t45, t47, t49, t28, t28, t50}, 0); + auto t52 = make_shared(t39, t51, false); + auto t53 = make_shared(i64, Shape{6}, vector{0, 1, 3, 2, 4, 5}); + auto t54 = make_shared(t52, t53); + auto t55 = make_shared(t1, t15, "numpy"); + auto t56 = make_shared(t55, element::i64); + auto t57 = make_shared(i64, Shape{}, vector{2}); + auto t58 = make_shared(t56, t57, "numpy"); + auto t59 = make_shared(t58); + auto t60 = make_shared(i32, Shape{}, vector{0}); + auto t61 = make_shared(t59, t60); + auto t62 = make_shared(t2, t15, "numpy"); + auto t63 = make_shared(t62, element::i64); + auto t64 = make_shared(i64, Shape{}, vector{2}); + auto t65 = make_shared(t63, t64, "numpy"); + auto t66 = make_shared(t65); + auto t67 = make_shared(t66, t60); + auto t68 = make_shared(NodeVector{t45, t61, t67, t37}, 0); + auto t69 = make_shared(t54, t68, false); + shared_ptr model = make_shared(make_shared(t69), ParameterVector{t0, t1, t2}); + return utils::singleton_core().compile_model( + model, "CPU" + ).create_infer_request(); +} + +ov::Tensor reshape_hd_patches_2x2merge(const ov::Tensor& image_features, size_t h_crop, size_t w_crop, InferRequest& hd_feature_transformer) { + ov::Shape shape = image_features.get_shape(); + OPENVINO_ASSERT(3 == shape.size()); + OPENVINO_ASSERT(24 * 24 == shape.at(1)); + OPENVINO_ASSERT(1024 == shape.at(2)); + hd_feature_transformer.set_input_tensor(0, image_features); + ov::Tensor height{ov::element::i32, {}, &h_crop}; + hd_feature_transformer.set_input_tensor(1, height); + ov::Tensor width{ov::element::i32, {}, &w_crop}; + hd_feature_transformer.set_input_tensor(2, width); + hd_feature_transformer.infer(); + return hd_feature_transformer.get_output_tensor(); +} + +// image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) +// output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) +ov::Tensor add_image_newline(const ov::Tensor& image_features_hd, const std::vector& sub_GN) { + const ov::Shape& nhwc = image_features_hd.get_shape(); // [N, 12*h_crop, 12*w_crop, 4096] + const float* in = image_features_hd.data(); + ov::Tensor image_features_hd_new_line{ov::element::f32, {nhwc.at(0), nhwc.at(1) * (nhwc.at(2) + 1), nhwc.at(3)}}; + float* out = image_features_hd_new_line.data(); + for (size_t batch_id = 0; batch_id < nhwc.at(0); ++batch_id) { + for (size_t row_id = 0; row_id < nhwc.at(1); ++row_id) { + for (size_t col_id = 0; col_id < nhwc.at(2); ++col_id) { + std::copy_n( + in + batch_id * nhwc.at(1) * nhwc.at(2) * nhwc.at(3) + row_id * nhwc.at(2) * nhwc.at(3) + col_id * nhwc.at(3), + nhwc.at(3), + out + batch_id * nhwc.at(1) * (nhwc.at(2) + 1) * nhwc.at(3) + row_id * (nhwc.at(2) + 1) * nhwc.at(3) + col_id * nhwc.at(3) + ); + } + std::copy( + sub_GN.begin(), + sub_GN.end(), + out + batch_id * nhwc.at(1) * (nhwc.at(2) + 1) * nhwc.at(3) + row_id * (nhwc.at(2) + 1) * nhwc.at(3) + nhwc.at(2) * nhwc.at(3) + ); + } + } + return image_features_hd_new_line; +} + +ov::Tensor concatenate_2d(const ov::Tensor& first_1lf, const std::vector& second_f, const ov::Tensor& third_1lf) { + size_t first_l = first_1lf.get_shape().at(1); + constexpr size_t second_l = 1; + size_t third_l = third_1lf.get_shape().at(1); + size_t features = first_1lf.get_shape().at(2); + OPENVINO_ASSERT(second_f.size() == features); + ov::Tensor out_1lf{ov::element::f32, {1, first_l + second_l + third_l, features}}; + float* out = out_1lf.data(); + std::copy_n(first_1lf.data(), first_l * features, out); + std::copy(second_f.begin(), second_f.end(), out + first_l * features); + std::copy_n(third_1lf.data(), third_l * features, out + (first_l + second_l) * features); + return out_1lf; +} + +// image_features.resized_source: (num_crops+1, 24*24, 1024) +ov::Tensor hd_feature_transform(const EncodedImage& image_features, InferRequest& hd_feature_transformer, const std::vector& sub_GN, const std::vector& glb_GN, ov::InferRequest& vision_projection) { + const ov::Shape& image_features_shape = image_features.resized_source.get_shape(); + ov::Tensor global_image_features{ov::element::f32, {1, image_features_shape.at(1), image_features_shape.at(2)}, image_features.resized_source.data()}; + // global feature can be viewed as a special HD case with num_crops 1x1 + ov::Tensor global_image_features_hd = reshape_hd_patches_2x2merge(global_image_features, 1, 1, hd_feature_transformer); + ov::Tensor global_image_features_hd_newline = add_image_newline(global_image_features_hd, sub_GN); // [1,12*(12+1),4096] + constexpr size_t INPUT_IMAGE_SIZE = 336; + size_t h_crop = image_features.resized_source_size.height / INPUT_IMAGE_SIZE; + size_t w_crop = image_features.resized_source_size.width / INPUT_IMAGE_SIZE; + size_t num_crops = h_crop * w_crop; + + // NOTE: real num_crops is padded + // (num_crops, 24*24, 1024) + ov::Tensor sub_image_features{ov::element::f32, { + num_crops, + image_features_shape.at(1), + image_features_shape.at(2) + }, image_features.resized_source.data() + image_features_shape.at(1) * image_features_shape.at(2)}; + ov::Tensor sub_image_features_hd = reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop, hd_feature_transformer); // [1, 24, 24, 4096] + ov::Tensor sub_image_features_hd_newline = add_image_newline(sub_image_features_hd, sub_GN); // [1,h_crop*12*(w_crop*12+1), 4096] + ov::Tensor image_embeddings = concatenate_2d(sub_image_features_hd_newline, glb_GN, global_image_features_hd_newline); // [1,l,4096] + vision_projection.set_input_tensor(image_embeddings); + vision_projection.infer(); + ov::Tensor out = vision_projection.get_output_tensor(); + ov::Tensor res{out.get_element_type(), out.get_shape()}; + out.copy_to(res); + return res; +} + +std::vector split_tokenize(const std::string& text, ov::genai::Tokenizer& tokenizer) { + constexpr int make_suffix_iterator = -1; + std::regex rgx{R"(<\|image_\d+\|>)"}; + std::sregex_token_iterator iter{ + text.begin(), + text.end(), + rgx, + make_suffix_iterator + }; + std::vector tokenized; + for ( ; iter != std::sregex_token_iterator{}; ++iter) { + if (iter->str().empty()) { + continue; + } + std::string substr = *iter; + tokenized.push_back(tokenizer.encode(substr, ov::genai::add_special_tokens(true)).input_ids); + } + return tokenized; +} + +ov::Tensor insert_image_placeholders(const std::vector& chunks, size_t tokens_per_image) { + size_t merged_length = 0; + for (const ov::Tensor& chunk : chunks) { + merged_length += chunk.get_shape().at(1); + } + merged_length += chunks.empty() ? 0 : (chunks.size() - 1) * tokens_per_image; + ov::Tensor merged{ov::element::i64, {1, merged_length}}; + size_t offset = 0; + int64_t image_id = -1; + for (const ov::Tensor& chunk : chunks) { + size_t length = chunk.get_shape().at(1); + std::copy_n( + chunk.data(), + length, + merged.data() + offset + ); + offset += length; + if (offset < merged_length) { + std::fill_n( + merged.data() + offset, + tokens_per_image, + image_id + ); + offset += tokens_per_image; + --image_id; + } + } + return merged; +} + +std::vector drop_image_placeholders(const ov::Tensor& tokens) { + std::vector chunks; + size_t offset = 0; + while (offset < tokens.get_shape().at(1)) { + size_t length = 0; + while (offset + length < tokens.get_shape().at(1) && tokens.data()[offset + length] >= 0) { + ++length; + } + chunks.emplace_back(ov::element::i64, ov::Shape{1, length}, tokens.data() + offset); + offset += length; + while (offset < tokens.get_shape().at(1) && tokens.data()[offset] < 0) { + ++offset; + } + } + return chunks; +} +} // namespace phi3_v +} // anonymous namespace + +class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { +public: + ov::InferRequest m_hd_feature_transformer; + ov::InferRequest m_vision_projection; + // Used to insert <|image_i|>\n per image (not a slice). + size_t m_image_id = 1; + size_t m_tokens_per_image = 0; + + InputsEmbedderPhi3V( + const VLMConfig& vlm_config, + const std::filesystem::path& model_dir, + const std::string& device, + const ov::AnyMap device_config + ): + IInputsEmbedder(vlm_config, model_dir, device, device_config), m_image_id{0}, + m_hd_feature_transformer{phi3_v::create_hd_feature_transformer()}, + m_vision_projection{utils::singleton_core().compile_model(model_dir / "openvino_vision_projection_model.xml", device, {}).create_infer_request()} {} + + ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) override { + OPENVINO_ASSERT(images.empty() || m_history.empty(), "Images can only be provided for initial prompt"); + std::vector images_features_proj; + std::stringstream images_prompt; + for (const ov::Tensor& image : to_single_image_tensors(images)) { + EncodedImage encoded_image = m_vision_encoder.encode(image); + images_features_proj.push_back(phi3_v::hd_feature_transform(encoded_image, m_hd_feature_transformer, m_vlm_config.sub_GN, m_vlm_config.glb_GN, m_vision_projection)); + images_prompt << "<|image_" << m_image_id << "|>\n"; + ++m_image_id; + } + images_prompt << prompt; + std::vector new_chat_tokens; + std::vector prev_chat_tokens; + if (m_is_chat_conversation) { + m_history.push_back({{"role", "user"}, {"content", images_prompt.str()}}); + constexpr bool add_generation_prompt = true; + std::string new_templated_chat_history; + new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + auto start_tokenizer_time = std::chrono::steady_clock::now(); + new_chat_tokens = phi3_v::split_tokenize(new_templated_chat_history, m_tokenizer); + prev_chat_tokens = phi3_v::split_tokenize(m_templated_chat_history, m_tokenizer); + auto end_tokenizer_time = std::chrono::steady_clock::now(); + metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); + m_templated_chat_history = std::move(new_templated_chat_history); + } else { + auto start_tokenizer_time = std::chrono::steady_clock::now(); + new_chat_tokens = phi3_v::split_tokenize(images_prompt.str(), m_tokenizer); + auto end_tokenizer_time = std::chrono::steady_clock::now(); + metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); + } + if (0 == m_tokens_per_image && !images_features_proj.empty()) { + m_tokens_per_image = images_features_proj.at(0).get_shape().at(1); + } + ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_image); + ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_image); + ov::Tensor new_tokens = update_history(new_merged_tokens, prev_merged_tokens); + std::vector tokens = phi3_v::drop_image_placeholders(new_tokens); + OPENVINO_ASSERT(tokens.size() == images_features_proj.size() + 1); + size_t features_length = 0; + for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) { + size_t text_length = tokens.at(im_id).get_shape().at(1); + size_t im_length = images_features_proj.at(im_id).get_shape().at(1); + OPENVINO_ASSERT(im_length == m_tokens_per_image); + features_length += text_length + im_length; + } + features_length += tokens.back().get_shape().at(1); + ov::Tensor inputs_embeds{ov::element::f32, {1, features_length, m_vlm_config.hidden_size}}; + size_t offset = 0; + for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) { + const ov::Tensor& text_embeds = m_embedding.infer(tokens.at(im_id)); + const ov::Tensor& image_embeds = images_features_proj.at(im_id); + size_t text_length = text_embeds.get_shape().at(1); + size_t im_length = image_embeds.get_shape().at(1); + std::copy_n( + text_embeds.data(), + text_embeds.get_size(), + inputs_embeds.data() + offset * m_vlm_config.hidden_size + ); + offset += text_length; + std::copy_n( + image_embeds.data(), + image_embeds.get_size(), + inputs_embeds.data() + offset * m_vlm_config.hidden_size + ); + offset += im_length; + } + const ov::Tensor& text_embeds = m_embedding.infer(tokens.back()); + size_t text_length = text_embeds.get_shape().at(1); + std::copy_n( + text_embeds.data(), + text_embeds.get_size(), + inputs_embeds.data() + offset * m_vlm_config.hidden_size + ); + + if (!m_is_chat_conversation) { + m_image_id = 0; + } + + return inputs_embeds; + } + + virtual void start_chat(const std::string& system_message) override { + IInputsEmbedder::start_chat(system_message); + m_image_id = 0; + } + + virtual void finish_chat() override { + IInputsEmbedder::finish_chat(); + m_image_id = 0; + } +}; + class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { // A model for merging image embeddings (hidden states), rotary_pos_emb and attension_mask. // Inputs: @@ -1577,6 +1980,8 @@ InputsEmbedder::InputsEmbedder(const VLMConfig& vlm_config, m_impl = std::make_shared(vlm_config, model_dir, device, device_config); } else if (vlm_config.model_type == VLMModelType::INTERNVL_CHAT) { m_impl = std::make_shared(vlm_config, model_dir, device, device_config); + } else if (vlm_config.model_type == VLMModelType::PHI3_V) { + m_impl = std::make_shared(vlm_config, model_dir, device, device_config); } else if (vlm_config.model_type == VLMModelType::QWEN2_VL) { m_impl = std::make_shared(vlm_config, model_dir, device, device_config); } else { diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp index 223d090b22..4462c58185 100644 --- a/src/cpp/src/visual_language/inputs_embedder.hpp +++ b/src/cpp/src/visual_language/inputs_embedder.hpp @@ -68,6 +68,7 @@ class InputsEmbedder { friend class InputsEmbedderLLaVA; friend class InputsEmbedderLLaVANext; friend class InputsEmbedderInternVLChat; + friend class InputsEmbedderPhi3V; friend class InputsEmbedderQwen2VL; }; diff --git a/src/cpp/src/visual_language/processor_config.cpp b/src/cpp/src/visual_language/processor_config.cpp index f790c58912..527557061e 100644 --- a/src/cpp/src/visual_language/processor_config.cpp +++ b/src/cpp/src/visual_language/processor_config.cpp @@ -41,6 +41,10 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa if (parsed.contains("image_grid_pinpoints")) { image_grid_pinpoints = parsed.at("image_grid_pinpoints").get>>(); } + read_json_param(parsed, "num_crops", phi3_v.num_crops); + if (parsed.contains("img_processor")) { + phi3_v.num_img_tokens = parsed.at("img_processor").at("num_img_tokens"); + } // Setting qwen2vl config params read_json_param(parsed, "min_pixels", min_pixels); diff --git a/src/cpp/src/visual_language/processor_config.hpp b/src/cpp/src/visual_language/processor_config.hpp index 1d40e091a9..1c4db59fd9 100644 --- a/src/cpp/src/visual_language/processor_config.hpp +++ b/src/cpp/src/visual_language/processor_config.hpp @@ -35,9 +35,10 @@ class ProcessorConfig { /// llava calls it image_std. std::array norm_std{1.0f, 1.0f, 1.0f}; - // llava specific config params + // A renamed version of norm_mean. std::array image_mean{0.0f, 0.0f, 0.0f}; std::array image_std{1.0f, 1.0f, 1.0f}; + // llava specific config params size_t crop_size_height = 336; size_t crop_size_width = 336; size_t size_shortest_edge = 336; @@ -45,6 +46,10 @@ class ProcessorConfig { // llava-next specific config params std::vector> image_grid_pinpoints{{336, 672}, {672, 336}, {672, 672}, {1008, 336}, {336, 1008}}; + struct { + size_t num_crops = 4; + size_t num_img_tokens = 144; + } phi3_v; // qwen2vl specific params size_t min_pixels = 3136; size_t max_pixels = 12845056; diff --git a/src/cpp/src/visual_language/vision_encoder.cpp b/src/cpp/src/visual_language/vision_encoder.cpp index 4a5179fdd0..04ddd63145 100644 --- a/src/cpp/src/visual_language/vision_encoder.cpp +++ b/src/cpp/src/visual_language/vision_encoder.cpp @@ -645,6 +645,202 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon return output_tensor; } +namespace phi3_v { +constexpr size_t INPUT_IMAGE_SIZE = 336; + +ov::Tensor padding_336(const ov::Tensor& unpadded) { + ov::Shape _1ss3 = unpadded.get_shape(); + size_t s1 = _1ss3.at(1), s2 = _1ss3.at(2); + if (s1 < s2) { + size_t tar = size_t(std::ceil(float(s1) / INPUT_IMAGE_SIZE) * INPUT_IMAGE_SIZE); + size_t top_padding = (tar - s1) / 2; + ov::Tensor padded{ov::element::u8, {1, tar, s2, 3}}; + uint8_t* padded_data = padded.data(); + std::fill_n(padded_data, padded.get_size(), 255); + std::copy_n(unpadded.data(), unpadded.get_size(), padded_data + top_padding * s2 * 3); + return padded; + } + size_t tar = size_t(std::ceil(float(s2) / INPUT_IMAGE_SIZE) * INPUT_IMAGE_SIZE); + size_t left_padding = (tar - s2) / 2; + ov::Tensor padded{ov::element::u8, {1, s1, tar, 3}}; + uint8_t* padded_data = padded.data(); + std::fill_n(padded_data, padded.get_size(), 255); + uint8_t* unpadded_data = unpadded.data(); + for (size_t row = 0; row < s1; ++row) { + std::copy_n(unpadded_data + row * s2 * 3, s2 * 3, padded_data + row * tar * 3 + left_padding * 3); + } + return padded; +} + +ov::Tensor HD_transform(const ov::Tensor& uint8, size_t num_crops) { + ov::Shape _1hwc = uint8.get_shape(); + size_t height = _1hwc.at(1), width = _1hwc.at(2); + bool trans = false; + if (width < height) { + std::swap(height, width); + trans = true; + } + float ratio = float(width) / height; + unsigned scale = 1; + while (scale * std::ceil(scale / ratio) <= num_crops) { + ++scale; + } + --scale; + size_t new_w = scale * INPUT_IMAGE_SIZE; + size_t new_h = new_w / ratio; + clip_image_u8 src{}, dst{}; + uint8_t* uint8_data = uint8.data(); + if (trans) { + src = clip_image_u8{int(height), int(width), {uint8_data, uint8_data + uint8.get_size()}}; + bilinear_resize(src, dst, new_h, new_w); + return padding_336(ov::Tensor{ov::element::u8, {1, new_w, new_h, 3}, dst.buf.data()}); + } + src = clip_image_u8{int(width), int(height), {uint8_data, uint8_data + uint8.get_size()}}; + bilinear_resize(src, dst, new_w, new_h); + return padding_336(ov::Tensor{ov::element::u8, {1, new_h, new_w, 3}, dst.buf.data()}); +} + +ov::Tensor mean_scale(const ov::Tensor& uint8, const ProcessorConfig& config) { + uint8_t* uint_8_data = uint8.data(); + ov::Tensor float_normalized{ov::element::f32, uint8.get_shape()}; + float* float_data = float_normalized.data(); + OPENVINO_ASSERT(0 == uint8.get_size() % 3, "RGB"); + for (size_t idx = 0; idx < uint8.get_size(); idx += 3) { + float_data[idx] = (float(uint_8_data[idx]) / 255.0f - config.image_mean[0]) / config.image_std[0]; + float_data[idx + 1] = (float(uint_8_data[idx + 1]) / 255.0f - config.image_mean[1]) / config.image_std[1]; + float_data[idx + 2] = (float(uint_8_data[idx + 2]) / 255.0f - config.image_mean[2]) / config.image_std[2]; + } + return float_normalized; +} + +ov::Tensor channels_first(const ov::Tensor& _1hw3) { + ov::Shape shape = _1hw3.get_shape(); + ov::Tensor _13hw = ov::Tensor{ov::element::f32, {1, 3, shape.at(1), shape.at(2)}}; + float* _1hw3_data = _1hw3.data(); + float* _13hw_data = _13hw.data(); + for (size_t plane = 0; plane < 3; ++plane) { + for (size_t row = 0; row < shape.at(1); ++row) { + for (size_t col = 0; col < shape.at(2); ++col) { + _13hw_data[plane * shape.at(1) * shape.at(2) + row * shape.at(2) + col] = _1hw3_data[row * shape.at(2) * 3 + col * 3 + plane]; + } + } + } + return _13hw; +} + +// Reimplementation of Python im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336) +ov::Tensor slice_image(const ov::Tensor& image) { + ov::Shape shape = image.get_shape(); + size_t N = shape[0]; + size_t C = shape[1]; + size_t H = shape[2]; + size_t W = shape[3]; + + size_t num_h_slices = H / INPUT_IMAGE_SIZE; + size_t num_w_slices = W / INPUT_IMAGE_SIZE; + + // Step 1: Define and populate the reshaped tensor in the correct shape order + ov::Tensor reshaped{ov::element::f32, {N, num_h_slices, num_w_slices, C, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE}}; + float* reshaped_data = reshaped.data(); + float* image_data = image.data(); + + // Populate the reshaped tensor + for (size_t n = 0; n < N; ++n) { + for (size_t h = 0; h < num_h_slices; ++h) { + for (size_t w = 0; w < num_w_slices; ++w) { + for (size_t c = 0; c < C; ++c) { + for (size_t i = 0; i < INPUT_IMAGE_SIZE; ++i) { + for (size_t j = 0; j < INPUT_IMAGE_SIZE; ++j) { + size_t src_idx = n * C * H * W + c * H * W + (h * INPUT_IMAGE_SIZE + i) * W + (w * INPUT_IMAGE_SIZE + j); + size_t dst_idx = n * num_h_slices * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + h * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + w * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + i * INPUT_IMAGE_SIZE + j; + reshaped_data[dst_idx] = image_data[src_idx]; + } + } + } + } + } + } + + // Step 2: Define the permuted tensor in the final shape + ov::Tensor permuted{ov::element::f32, {N * num_h_slices * num_w_slices, C, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE}}; + float* permuted_data = permuted.data(); + + // Perform permutation by flattening N, num_h_slices, and num_w_slices + for (size_t n = 0; n < N; ++n) { + for (size_t h = 0; h < num_h_slices; ++h) { + for (size_t w = 0; w < num_w_slices; ++w) { + for (size_t c = 0; c < C; ++c) { + for (size_t i = 0; i < INPUT_IMAGE_SIZE; ++i) { + for (size_t j = 0; j < INPUT_IMAGE_SIZE; ++j) { + size_t src_idx = n * num_h_slices * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + h * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + w * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + i * INPUT_IMAGE_SIZE + j; + size_t dst_idx = (n * num_h_slices * num_w_slices + h * num_w_slices + w) * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE + + i * INPUT_IMAGE_SIZE + j; + permuted_data[dst_idx] = reshaped_data[src_idx]; + } + } + } + } + } + } + + return permuted; +} + +ov::Tensor concatenate_batch(const ov::Tensor& float_first, const ov::Tensor& float_second) { + ov::Shape shape_first = float_first.get_shape(); + ov::Shape shape_second = float_second.get_shape(); + OPENVINO_ASSERT(shape_first.at(1) == shape_second.at(1), "Channels must be the same"); + OPENVINO_ASSERT(shape_first.at(2) == shape_second.at(2), "Height must be the same"); + OPENVINO_ASSERT(shape_first.at(3) == shape_second.at(3), "Width must be the same"); + ov::Tensor concatenated{ov::element::f32, {shape_first.at(0) + shape_second.at(0), shape_first.at(1), shape_first.at(2), shape_first.at(3)}}; + float* concatenated_data = concatenated.data(); + float* first_data = float_first.data(); + float* second_data = float_second.data(); + std::copy(first_data, first_data + float_first.get_size(), concatenated_data); + std::copy(second_data, second_data + float_second.get_size(), concatenated_data + float_first.get_size()); + return concatenated; +} + +ov::Tensor pad_to_max_num_crops_tensor(const ov::Tensor& nchw, size_t max_crops) { + ov::Shape shape = nchw.get_shape(); + size_t num_crops = shape[0]; + if (num_crops >= max_crops) { + return nchw; + } + ov::Tensor padded{ov::element::f32, {max_crops, shape[1], shape[2], shape[3]}}; + float* padded_data = padded.data(); + float* nchw_data = nchw.data(); + std::copy_n(nchw_data, nchw.get_size(), padded_data); + return padded; +} + +std::tuple get_pixel_values_phi3_v(const ov::Tensor& image, const ProcessorConfig& config) { + ov::Tensor hd_image = HD_transform(image, config.phi3_v.num_crops); + ImageSize image_size{hd_image.get_shape().at(2), hd_image.get_shape().at(1)}; + clip_image_u8 img{int(hd_image.get_shape().at(2)), int(hd_image.get_shape().at(1)), {hd_image.data(), hd_image.data() + hd_image.get_size()}}; + clip_image_u8 dst; + bicubic_resize(img, dst, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE); + ov::Tensor global_image{ov::element::u8, {1, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE, 3}, dst.buf.data()}; + global_image = mean_scale(global_image, config); + hd_image = mean_scale(hd_image, config); + global_image = channels_first(global_image); + hd_image = channels_first(hd_image); + ov::Tensor slices = slice_image(hd_image); + ov::Tensor concatenated = concatenate_batch(global_image, slices); + ov::Tensor pixel_values = pad_to_max_num_crops_tensor(concatenated, config.phi3_v.num_crops); + return {std::move(pixel_values), image_size}; +} +} // namespace phi3_v + ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) { if (height < factor || width < factor) { OPENVINO_THROW("Height or width must be larger than factor"); @@ -832,6 +1028,8 @@ EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfi return encode_llava_next(image, config); } else if (model_type == VLMModelType::INTERNVL_CHAT) { return encode_internvl(image, config); + } else if (model_type == VLMModelType::PHI3_V) { + return encode_phi3_v(image, config); } else if (model_type == VLMModelType::QWEN2_VL) { return encode_qwen2vl(image, config); } else { @@ -908,6 +1106,13 @@ EncodedImage VisionEncoder::encode_internvl(const ov::Tensor& image, const Proce return {std::move(image_features), resized_source_size}; } +EncodedImage VisionEncoder::encode_phi3_v(const ov::Tensor& image, const ProcessorConfig& config) { + const auto& [pixel_values, image_size] = phi3_v::get_pixel_values_phi3_v(image, config); + m_vision_encoder.set_input_tensor(pixel_values); + m_vision_encoder.infer(); + return {m_vision_encoder.get_output_tensor(), image_size}; +} + EncodedImage VisionEncoder::encode_qwen2vl(const ov::Tensor& image, const ProcessorConfig& config) { ov::Shape image_shape = image.get_shape(); auto original_height = image_shape.at(1); diff --git a/src/cpp/src/visual_language/vision_encoder.hpp b/src/cpp/src/visual_language/vision_encoder.hpp index e725c06bf4..8bec971894 100644 --- a/src/cpp/src/visual_language/vision_encoder.hpp +++ b/src/cpp/src/visual_language/vision_encoder.hpp @@ -159,6 +159,10 @@ class VisionEncoder { const ov::Tensor& image, const ProcessorConfig& config ); + EncodedImage encode_phi3_v( + const ov::Tensor& image, const ProcessorConfig& config + ); + EncodedImage encode_qwen2vl( const ov::Tensor& image, const ProcessorConfig& config ); diff --git a/src/cpp/src/visual_language/vlm_config.cpp b/src/cpp/src/visual_language/vlm_config.cpp index 6eab781fc0..5609c886c4 100644 --- a/src/cpp/src/visual_language/vlm_config.cpp +++ b/src/cpp/src/visual_language/vlm_config.cpp @@ -19,4 +19,13 @@ ov::genai::VLMConfig::VLMConfig(const std::filesystem::path& json_path) { // Setting llava_next specific config params read_json_param(parsed, "image_newline", image_newline); + // phi3_v + if (parsed.contains("sub_GN")) { + sub_GN = parsed.at("sub_GN").get>>>>().at(0).at(0).at(0); + } + OPENVINO_ASSERT(sub_GN.size() == 4096); + if (parsed.contains("glb_GN")) { + glb_GN = parsed.at("glb_GN").get>>>().at(0).at(0); + } + OPENVINO_ASSERT(glb_GN.size() == 4096); } diff --git a/src/cpp/src/visual_language/vlm_config.hpp b/src/cpp/src/visual_language/vlm_config.hpp index c70c757707..7a052b8537 100644 --- a/src/cpp/src/visual_language/vlm_config.hpp +++ b/src/cpp/src/visual_language/vlm_config.hpp @@ -54,6 +54,9 @@ class VLMConfig { std::string image_context_token = ""; /// @brief A string token denoting end of image embeddings for InternVL2 model. std::string image_end_token = ""; + /// @brief phi3_v new line token embedding to separate images. + std::vector sub_GN = std::vector(4096, 0.0f); + std::vector glb_GN = std::vector(4096, 0.0f); /// @brief A string token denoting start of vision embeddings for Qwen2VL model. std::string vision_start_token = "<|vision_start|>"; diff --git a/src/cpp/src/visual_language/vlm_model_type.hpp b/src/cpp/src/visual_language/vlm_model_type.hpp index 6f554fbf98..93387cacbc 100644 --- a/src/cpp/src/visual_language/vlm_model_type.hpp +++ b/src/cpp/src/visual_language/vlm_model_type.hpp @@ -16,6 +16,7 @@ enum class VLMModelType { LLAVA, LLAVA_NEXT, INTERNVL_CHAT, + PHI3_V, QWEN2_VL, }; @@ -25,6 +26,7 @@ inline VLMModelType to_vlm_model_type(const std::string& value) { {"llava", VLMModelType::LLAVA}, {"llava_next", VLMModelType::LLAVA_NEXT}, {"internvl_chat", VLMModelType::INTERNVL_CHAT}, + {"phi3_v", VLMModelType::PHI3_V}, {"qwen2_vl", VLMModelType::QWEN2_VL} }; diff --git a/tests/python_tests/test_vlm_pipeline.py b/tests/python_tests/test_vlm_pipeline.py index b413b6cf1d..0f9358b961 100644 --- a/tests/python_tests/test_vlm_pipeline.py +++ b/tests/python_tests/test_vlm_pipeline.py @@ -9,17 +9,17 @@ from openvino_genai import VLMPipeline, GenerationConfig from common import get_image_by_link, get_beam_search, get_multinomial_all_parameters, get_default_properties -def get_ov_model(cache): - model_dir = cache.mkdir("tiny-random-minicpmv-2_6") +def get_ov_model(model_id, cache): + model_dir = cache.mkdir(model_id.split('/')[-1]) if (model_dir / "openvino_language_model.xml").exists(): return model_dir - model_id = "katuni4ka/tiny-random-minicpmv-2_6" processor = transformers.AutoProcessor.from_pretrained(model_id, trust_remote_code=True) processor.tokenizer.save_pretrained(model_dir) ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(processor.tokenizer, with_detokenizer=True) openvino.save_model(ov_tokenizer, model_dir / "openvino_tokenizer.xml") openvino.save_model(ov_detokenizer, model_dir / "openvino_detokenizer.xml") model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False, device="CPU", export=True, load_in_8bit=False, trust_remote_code=True, ov_config=get_default_properties()) + processor.chat_template = processor.tokenizer.chat_template # It seems that tiny-random-phi3-vision is saved incorrectly. That line works this around. processor.save_pretrained(model_dir) model.save_pretrained(model_dir) return model_dir @@ -44,13 +44,17 @@ def get_ov_model(cache): @pytest.mark.precommit @pytest.mark.nightly -def test_vlm_pipeline(cache): +@pytest.mark.parametrize("model_id", [ + "katuni4ka/tiny-random-minicpmv-2_6", + "katuni4ka/tiny-random-phi3-vision", +]) +def test_vlm_pipeline(model_id, cache): def streamer(word: str) -> bool: nonlocal result_from_streamer result_from_streamer.append(word) return False - models_path = get_ov_model(cache) + models_path = get_ov_model(model_id, cache) generation_config = GenerationConfig(max_new_tokens=30) for links in image_links_for_testing: @@ -76,7 +80,7 @@ def streamer(word: str) -> bool: @pytest.mark.precommit @pytest.mark.nightly def test_vlm_get_tokenizer(cache): - models_path = get_ov_model(cache) + models_path = get_ov_model("katuni4ka/tiny-random-minicpmv-2_6", cache) pipe = VLMPipeline(models_path, "CPU") tokenizer = pipe.get_tokenizer() tokenizer.encode("") @@ -89,15 +93,16 @@ def test_vlm_get_tokenizer(cache): get_multinomial_all_parameters(), ]) def test_sampling(config, cache): - models_path = get_ov_model(cache) + models_path = get_ov_model("katuni4ka/tiny-random-minicpmv-2_6", cache) image = get_image_by_link(image_links[0]) pipe = VLMPipeline(models_path, "CPU") pipe.generate(prompts[0], image=image, generation_config=config) @pytest.mark.precommit +@pytest.mark.nightly def test_perf_metrics(cache): import numpy as np - models_path = get_ov_model(cache) + models_path = get_ov_model("katuni4ka/tiny-random-minicpmv-2_6", cache) images = [get_image_by_link(image_links[0])]