Skip to content

Commit

Permalink
init phi3v merged input processor
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py committed Dec 7, 2024
1 parent b26b4cd commit a40f59a
Showing 1 changed file with 69 additions and 221 deletions.
290 changes: 69 additions & 221 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
from functools import cached_property, lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand All @@ -36,12 +31,16 @@
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict,
MultiModalProcessingMetadata,
MultiModalProcessor, PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .clip import dummy_image_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
Expand Down Expand Up @@ -303,231 +302,80 @@ def add_image_newline(self, image_features_hd):
return image_features_hd_newline


# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
padded_width = width
padded_height = height + top_padding + bottom_padding
return padded_width, padded_height


# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
transposed = False
if width < height:
width, height = height, width
transposed = True

ratio = width / height
scale = 1
while scale * np.ceil(scale / ratio) <= hd_num:
scale += 1
scale -= 1

new_width = int(scale * 336)
new_height = int(new_width / ratio)

padded_width, padded_height = _calc_padded_size(width=new_width,
height=new_height)

if transposed:
padded_width, padded_height = padded_height, padded_width

return padded_width, padded_height


# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: Dict[str, Any],
*,
input_height: int,
input_width: int,
num_crops: int,
) -> int:
if num_crops is None:
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)

return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
+ (new_height // 336 + 1) * 12


def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):

return get_phi3v_image_feature_size(
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
num_crops=num_crops,
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
if num_crops is not None:
image_processor.num_crops = num_crops
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return num_tokens


def dummy_data_for_phi3v(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops: Optional[int] = None):
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]

image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)

seq_data, ranges = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
num_images,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)

return DummyData(seq_data, mm_data, ranges)


@lru_cache
def _get_image_placeholder_token_id_candidates(
model_config: ModelConfig,
idx: int,
) -> List[List[int]]:
assert idx > 0

tokenizer = cached_get_tokenizer(model_config.tokenizer)

# This is used when the image token is at the start of the string
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
add_special_tokens=False)

# This is used when the image token is in the middle of the string
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
add_special_tokens=False)
assert a_token_id == a_token_id_

return [start_candidate, middle_candidate]


def input_processor_for_phi3v(ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
num_crops: Optional[int] = None):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs

model_config = ctx.model_config
hf_config = ctx.get_hf_image_processor_config()

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
w, h = image_data.size
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h,
num_crops=num_crops)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
w, h = image.size
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
image_feature_size = [image_data.shape[0]]
image_data = [image_data]
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[0] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

prompt = inputs.get("prompt")
if prompt is None:
# for async server request, we assume prompt and its token_ids is always
# in correct format. And num_image_tags == len(image_data) always True.
image_idx = range(1, len(image_data) + 1)
new_prompt = None
else:
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif (num_image_tags := len(image_idx)) > 1:
assert num_image_tags == len(
image_data), "The count of image_placeholder not match image's"
new_prompt = prompt

prompt_token_ids = inputs["prompt_token_ids"].copy()

# masked placeholder with image token id
for idx in image_idx:
candidates = _get_image_placeholder_token_id_candidates(model_config,
idx=idx)

for candidate in candidates:
for i in range(len(prompt_token_ids) - len(candidate) + 1):
if prompt_token_ids[i:i + len(candidate)] == candidate:
prompt_token_ids[i:i +
len(candidate)] = ([_IMAGE_TOKEN_ID] *
len(candidate))
break

# merge consecutive tag ids
merged_token_ids: List[int] = []
for is_placeholder, token_ids in itertools.groupby(
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
if is_placeholder:
merged_token_ids.append(_IMAGE_TOKEN_ID)
else:
merged_token_ids.extend(list(token_ids))

# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
replacement_ids = repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_idx += 1
else:
new_token_ids.append(token_id)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
return MultiModalKwargs(**hf_inputs, )


def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}


class Phi3VProcessor(MultiModalProcessor):

def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._apply_hf_processor(
prompt, mm_data, mm_processor_kwargs)
# Phi3v processor has inserted -1 as placeholder in the prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids == -1] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
return processed_outputs

def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: Phi3VProcessor(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
))
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down

0 comments on commit a40f59a

Please sign in to comment.