Skip to content

Commit

Permalink
Add Phi-3.5-vision-instruct and Phi-3-vision-128k-instruct (#1609)
Browse files Browse the repository at this point in the history
Ticket 156662

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
Wovchena and ilya-lavrenov authored Jan 23, 2025
1 parent af41f9c commit aa63ad9
Show file tree
Hide file tree
Showing 14 changed files with 703 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ concurrency:

env:
PYTHON_VERSION: '3.10'
OV_BRANCH: 'master'
OV_BRANCH: 7f56fcd4658c6a427111ac835e809ddd87f0cad2
OV_TARBALL: ''

jobs:
Expand Down
21 changes: 21 additions & 0 deletions SUPPORTED_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
<th>Models</th>
<th>LoRA support</th>
<th>Example HuggingFace Models</th>
<th>Notes</th>
</tr>
<tr>
<td><code>InternVL2</code></td>
Expand All @@ -329,6 +330,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
<li><a href="https://huggingface.co/OpenGVLab/InternVL2_5-8B"><code>OpenGVLab/InternVL2_5-8B</code></a></li>
</ul>
</td>
<td></td>
</tr>
<tr>
<td><code>LLaVA</code></td>
Expand All @@ -339,6 +341,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
<li><a href="https://huggingface.co/llava-hf/llava-1.5-7b-hf"><code>llava-hf/llava-1.5-7b-hf</code></a></li>
</ul>
</td>
<td></td>
</tr>
<tr>
<td><code>LLaVA-NeXT</code></td>
Expand All @@ -351,6 +354,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
<li><a href="https://huggingface.co/llava-hf/llama3-llava-next-8b-hf"><code>llava-hf/llama3-llava-next-8b-hf</code></a></li>
</ul>
</td>
<td></td>
</tr>
<tr>
<td><code>MiniCPMV</code></td>
Expand All @@ -361,6 +365,22 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
<li><a href="https://huggingface.co/openbmb/MiniCPM-V-2_6"><code>openbmb/MiniCPM-V-2_6</code></a></li>
</ul>
</td>
<td></td>
</tr>
<tr>
<td><code>Phi3VForCausalLM</code></td>
<td>phi3_v</td>
<td>Not supported</td>
<td>
<ul>
<li><a href="https://huggingface.co/microsoft/Phi-3-vision-128k-instruct"><code>microsoft/Phi-3-vision-128k-instruct</code></a></li>
<li><a href="https://huggingface.co/microsoft/Phi-3.5-vision-instruct"><code>microsoft/Phi-3.5-vision-instruct</code></a></li>
</ul>
</td>
<td>
<li>GPU isn't supported</li>
<li>These models' configs aren't consistent. It's required to override the default <code>eos_token_id</code> with the one from a tokenizer: <code>generation_config.set_eos_token_id(pipe.get_tokenizer().get_eos_token_id())</code>.</li>
</td>
</tr>
<tr>
<td><code>Qwen2-VL</code></td>
Expand All @@ -372,6 +392,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
<li><a href="https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct"><code>Qwen/Qwen2-VL-7B-Instruct</code></a></li>
</ul>
</td>
<td></td>
</tr>
</tbody>
</table>
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/visual_language/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ static float clip_lerp(float s, float e, float t) {
}

// Bilinear resize function
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
dst.nx = target_width;
dst.ny = target_height;
dst.buf.resize(3 * target_width * target_height);
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/visual_language/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct clip_image_f32 {
};

void bicubic_resize(const clip_image_u8& img, clip_image_u8& dst, int target_width, int target_height);
void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height);

/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
clip_image_f32 clip_image_preprocess(struct clip_ctx& ctx, const clip_image_u8& img);
Expand Down
459 changes: 432 additions & 27 deletions src/cpp/src/visual_language/inputs_embedder.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/cpp/src/visual_language/inputs_embedder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class InputsEmbedder {
friend class InputsEmbedderLLaVA;
friend class InputsEmbedderLLaVANext;
friend class InputsEmbedderInternVLChat;
friend class InputsEmbedderPhi3V;
friend class InputsEmbedderQwen2VL;
};

Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/visual_language/processor_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa
if (parsed.contains("image_grid_pinpoints")) {
image_grid_pinpoints = parsed.at("image_grid_pinpoints").get<std::vector<std::pair<int, int>>>();
}
read_json_param(parsed, "num_crops", phi3_v.num_crops);
if (parsed.contains("img_processor")) {
phi3_v.num_img_tokens = parsed.at("img_processor").at("num_img_tokens");
}

// Setting qwen2vl config params
read_json_param(parsed, "min_pixels", min_pixels);
Expand Down
7 changes: 6 additions & 1 deletion src/cpp/src/visual_language/processor_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,21 @@ class ProcessorConfig {
/// llava calls it image_std.
std::array<float, 3> norm_std{1.0f, 1.0f, 1.0f};

// llava specific config params
// A renamed version of norm_mean.
std::array<float, 3> image_mean{0.0f, 0.0f, 0.0f};
std::array<float, 3> image_std{1.0f, 1.0f, 1.0f};
// llava specific config params
size_t crop_size_height = 336;
size_t crop_size_width = 336;
size_t size_shortest_edge = 336;

// llava-next specific config params
std::vector<std::pair<int, int>> image_grid_pinpoints{{336, 672}, {672, 336}, {672, 672}, {1008, 336}, {336, 1008}};

struct {
size_t num_crops = 4;
size_t num_img_tokens = 144;
} phi3_v;
// qwen2vl specific params
size_t min_pixels = 3136;
size_t max_pixels = 12845056;
Expand Down
205 changes: 205 additions & 0 deletions src/cpp/src/visual_language/vision_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,202 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon
return output_tensor;
}

namespace phi3_v {
constexpr size_t INPUT_IMAGE_SIZE = 336;

ov::Tensor padding_336(const ov::Tensor& unpadded) {
ov::Shape _1ss3 = unpadded.get_shape();
size_t s1 = _1ss3.at(1), s2 = _1ss3.at(2);
if (s1 < s2) {
size_t tar = size_t(std::ceil(float(s1) / INPUT_IMAGE_SIZE) * INPUT_IMAGE_SIZE);
size_t top_padding = (tar - s1) / 2;
ov::Tensor padded{ov::element::u8, {1, tar, s2, 3}};
uint8_t* padded_data = padded.data<uint8_t>();
std::fill_n(padded_data, padded.get_size(), 255);
std::copy_n(unpadded.data<uint8_t>(), unpadded.get_size(), padded_data + top_padding * s2 * 3);
return padded;
}
size_t tar = size_t(std::ceil(float(s2) / INPUT_IMAGE_SIZE) * INPUT_IMAGE_SIZE);
size_t left_padding = (tar - s2) / 2;
ov::Tensor padded{ov::element::u8, {1, s1, tar, 3}};
uint8_t* padded_data = padded.data<uint8_t>();
std::fill_n(padded_data, padded.get_size(), 255);
uint8_t* unpadded_data = unpadded.data<uint8_t>();
for (size_t row = 0; row < s1; ++row) {
std::copy_n(unpadded_data + row * s2 * 3, s2 * 3, padded_data + row * tar * 3 + left_padding * 3);
}
return padded;
}

ov::Tensor HD_transform(const ov::Tensor& uint8, size_t num_crops) {
ov::Shape _1hwc = uint8.get_shape();
size_t height = _1hwc.at(1), width = _1hwc.at(2);
bool trans = false;
if (width < height) {
std::swap(height, width);
trans = true;
}
float ratio = float(width) / height;
unsigned scale = 1;
while (scale * std::ceil(scale / ratio) <= num_crops) {
++scale;
}
--scale;
size_t new_w = scale * INPUT_IMAGE_SIZE;
size_t new_h = new_w / ratio;
clip_image_u8 src{}, dst{};
uint8_t* uint8_data = uint8.data<uint8_t>();
if (trans) {
src = clip_image_u8{int(height), int(width), {uint8_data, uint8_data + uint8.get_size()}};
bilinear_resize(src, dst, new_h, new_w);
return padding_336(ov::Tensor{ov::element::u8, {1, new_w, new_h, 3}, dst.buf.data()});
}
src = clip_image_u8{int(width), int(height), {uint8_data, uint8_data + uint8.get_size()}};
bilinear_resize(src, dst, new_w, new_h);
return padding_336(ov::Tensor{ov::element::u8, {1, new_h, new_w, 3}, dst.buf.data()});
}

ov::Tensor mean_scale(const ov::Tensor& uint8, const ProcessorConfig& config) {
uint8_t* uint_8_data = uint8.data<uint8_t>();
ov::Tensor float_normalized{ov::element::f32, uint8.get_shape()};
float* float_data = float_normalized.data<float>();
OPENVINO_ASSERT(0 == uint8.get_size() % 3, "RGB");
for (size_t idx = 0; idx < uint8.get_size(); idx += 3) {
float_data[idx] = (float(uint_8_data[idx]) / 255.0f - config.image_mean[0]) / config.image_std[0];
float_data[idx + 1] = (float(uint_8_data[idx + 1]) / 255.0f - config.image_mean[1]) / config.image_std[1];
float_data[idx + 2] = (float(uint_8_data[idx + 2]) / 255.0f - config.image_mean[2]) / config.image_std[2];
}
return float_normalized;
}

ov::Tensor channels_first(const ov::Tensor& _1hw3) {
ov::Shape shape = _1hw3.get_shape();
ov::Tensor _13hw = ov::Tensor{ov::element::f32, {1, 3, shape.at(1), shape.at(2)}};
float* _1hw3_data = _1hw3.data<float>();
float* _13hw_data = _13hw.data<float>();
for (size_t plane = 0; plane < 3; ++plane) {
for (size_t row = 0; row < shape.at(1); ++row) {
for (size_t col = 0; col < shape.at(2); ++col) {
_13hw_data[plane * shape.at(1) * shape.at(2) + row * shape.at(2) + col] = _1hw3_data[row * shape.at(2) * 3 + col * 3 + plane];
}
}
}
return _13hw;
}

// Reimplementation of Python im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336)
ov::Tensor slice_image(const ov::Tensor& image) {
ov::Shape shape = image.get_shape();
size_t N = shape[0];
size_t C = shape[1];
size_t H = shape[2];
size_t W = shape[3];

size_t num_h_slices = H / INPUT_IMAGE_SIZE;
size_t num_w_slices = W / INPUT_IMAGE_SIZE;

// Step 1: Define and populate the reshaped tensor in the correct shape order
ov::Tensor reshaped{ov::element::f32, {N, num_h_slices, num_w_slices, C, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE}};
float* reshaped_data = reshaped.data<float>();
float* image_data = image.data<float>();

// Populate the reshaped tensor
for (size_t n = 0; n < N; ++n) {
for (size_t h = 0; h < num_h_slices; ++h) {
for (size_t w = 0; w < num_w_slices; ++w) {
for (size_t c = 0; c < C; ++c) {
for (size_t i = 0; i < INPUT_IMAGE_SIZE; ++i) {
for (size_t j = 0; j < INPUT_IMAGE_SIZE; ++j) {
size_t src_idx = n * C * H * W + c * H * W + (h * INPUT_IMAGE_SIZE + i) * W + (w * INPUT_IMAGE_SIZE + j);
size_t dst_idx = n * num_h_slices * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
h * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
w * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
i * INPUT_IMAGE_SIZE + j;
reshaped_data[dst_idx] = image_data[src_idx];
}
}
}
}
}
}

// Step 2: Define the permuted tensor in the final shape
ov::Tensor permuted{ov::element::f32, {N * num_h_slices * num_w_slices, C, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE}};
float* permuted_data = permuted.data<float>();

// Perform permutation by flattening N, num_h_slices, and num_w_slices
for (size_t n = 0; n < N; ++n) {
for (size_t h = 0; h < num_h_slices; ++h) {
for (size_t w = 0; w < num_w_slices; ++w) {
for (size_t c = 0; c < C; ++c) {
for (size_t i = 0; i < INPUT_IMAGE_SIZE; ++i) {
for (size_t j = 0; j < INPUT_IMAGE_SIZE; ++j) {
size_t src_idx = n * num_h_slices * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
h * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
w * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
i * INPUT_IMAGE_SIZE + j;
size_t dst_idx = (n * num_h_slices * num_w_slices + h * num_w_slices + w) * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
i * INPUT_IMAGE_SIZE + j;
permuted_data[dst_idx] = reshaped_data[src_idx];
}
}
}
}
}
}

return permuted;
}

ov::Tensor concatenate_batch(const ov::Tensor& float_first, const ov::Tensor& float_second) {
ov::Shape shape_first = float_first.get_shape();
ov::Shape shape_second = float_second.get_shape();
OPENVINO_ASSERT(shape_first.at(1) == shape_second.at(1), "Channels must be the same");
OPENVINO_ASSERT(shape_first.at(2) == shape_second.at(2), "Height must be the same");
OPENVINO_ASSERT(shape_first.at(3) == shape_second.at(3), "Width must be the same");
ov::Tensor concatenated{ov::element::f32, {shape_first.at(0) + shape_second.at(0), shape_first.at(1), shape_first.at(2), shape_first.at(3)}};
float* concatenated_data = concatenated.data<float>();
float* first_data = float_first.data<float>();
float* second_data = float_second.data<float>();
std::copy(first_data, first_data + float_first.get_size(), concatenated_data);
std::copy(second_data, second_data + float_second.get_size(), concatenated_data + float_first.get_size());
return concatenated;
}

ov::Tensor pad_to_max_num_crops_tensor(const ov::Tensor& nchw, size_t max_crops) {
ov::Shape shape = nchw.get_shape();
size_t num_crops = shape[0];
if (num_crops >= max_crops) {
return nchw;
}
ov::Tensor padded{ov::element::f32, {max_crops, shape[1], shape[2], shape[3]}};
float* padded_data = padded.data<float>();
float* nchw_data = nchw.data<float>();
std::copy_n(nchw_data, nchw.get_size(), padded_data);
return padded;
}

std::tuple<ov::Tensor, ImageSize> get_pixel_values_phi3_v(const ov::Tensor& image, const ProcessorConfig& config) {
ov::Tensor hd_image = HD_transform(image, config.phi3_v.num_crops);
ImageSize image_size{hd_image.get_shape().at(2), hd_image.get_shape().at(1)};
clip_image_u8 img{int(hd_image.get_shape().at(2)), int(hd_image.get_shape().at(1)), {hd_image.data<uint8_t>(), hd_image.data<uint8_t>() + hd_image.get_size()}};
clip_image_u8 dst;
bicubic_resize(img, dst, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE);
ov::Tensor global_image{ov::element::u8, {1, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE, 3}, dst.buf.data()};
global_image = mean_scale(global_image, config);
hd_image = mean_scale(hd_image, config);
global_image = channels_first(global_image);
hd_image = channels_first(hd_image);
ov::Tensor slices = slice_image(hd_image);
ov::Tensor concatenated = concatenate_batch(global_image, slices);
ov::Tensor pixel_values = pad_to_max_num_crops_tensor(concatenated, config.phi3_v.num_crops);
return {std::move(pixel_values), image_size};
}
} // namespace phi3_v

ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) {
if (height < factor || width < factor) {
OPENVINO_THROW("Height or width must be larger than factor");
Expand Down Expand Up @@ -832,6 +1028,8 @@ EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfi
return encode_llava_next(image, config);
} else if (model_type == VLMModelType::INTERNVL_CHAT) {
return encode_internvl(image, config);
} else if (model_type == VLMModelType::PHI3_V) {
return encode_phi3_v(image, config);
} else if (model_type == VLMModelType::QWEN2_VL) {
return encode_qwen2vl(image, config);
} else {
Expand Down Expand Up @@ -908,6 +1106,13 @@ EncodedImage VisionEncoder::encode_internvl(const ov::Tensor& image, const Proce
return {std::move(image_features), resized_source_size};
}

EncodedImage VisionEncoder::encode_phi3_v(const ov::Tensor& image, const ProcessorConfig& config) {
const auto& [pixel_values, image_size] = phi3_v::get_pixel_values_phi3_v(image, config);
m_vision_encoder.set_input_tensor(pixel_values);
m_vision_encoder.infer();
return {m_vision_encoder.get_output_tensor(), image_size};
}

EncodedImage VisionEncoder::encode_qwen2vl(const ov::Tensor& image, const ProcessorConfig& config) {
ov::Shape image_shape = image.get_shape();
auto original_height = image_shape.at(1);
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/visual_language/vision_encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ class VisionEncoder {
const ov::Tensor& image, const ProcessorConfig& config
);

EncodedImage encode_phi3_v(
const ov::Tensor& image, const ProcessorConfig& config
);

EncodedImage encode_qwen2vl(
const ov::Tensor& image, const ProcessorConfig& config
);
Expand Down
9 changes: 9 additions & 0 deletions src/cpp/src/visual_language/vlm_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,13 @@ ov::genai::VLMConfig::VLMConfig(const std::filesystem::path& json_path) {

// Setting llava_next specific config params
read_json_param(parsed, "image_newline", image_newline);
// phi3_v
if (parsed.contains("sub_GN")) {
sub_GN = parsed.at("sub_GN").get<std::vector<std::vector<std::vector<std::vector<float>>>>>().at(0).at(0).at(0);
}
OPENVINO_ASSERT(sub_GN.size() == 4096);
if (parsed.contains("glb_GN")) {
glb_GN = parsed.at("glb_GN").get<std::vector<std::vector<std::vector<float>>>>().at(0).at(0);
}
OPENVINO_ASSERT(glb_GN.size() == 4096);
}
3 changes: 3 additions & 0 deletions src/cpp/src/visual_language/vlm_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class VLMConfig {
std::string image_context_token = "<IMG_CONTEXT>";
/// @brief A string token denoting end of image embeddings for InternVL2 model.
std::string image_end_token = "</img>";
/// @brief phi3_v new line token embedding to separate images.
std::vector<float> sub_GN = std::vector(4096, 0.0f);
std::vector<float> glb_GN = std::vector(4096, 0.0f);

/// @brief A string token denoting start of vision embeddings for Qwen2VL model.
std::string vision_start_token = "<|vision_start|>";
Expand Down
Loading

0 comments on commit aa63ad9

Please sign in to comment.