Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multi-image input support for Phi3V #917

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 57 additions & 29 deletions aphrodite/modeling/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
from functools import lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Expand All @@ -29,6 +30,7 @@
from aphrodite.attention import AttentionMetadata
from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.common.utils import is_list_of
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
from aphrodite.modeling.layers.sampler import Sampler
Expand All @@ -38,11 +40,11 @@
from aphrodite.modeling.models.llama import LlamaModel
from aphrodite.modeling.sampling_metadata import SamplingMetadata
from aphrodite.multimodal import MULTIMODAL_REGISTRY
from aphrodite.multimodal.utils import cached_get_tokenizer
from aphrodite.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_token)
from aphrodite.quantization.base_config import QuantizationConfig

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings

Expand Down Expand Up @@ -397,55 +399,81 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
if isinstance(image_data, Image.Image):
w, h = image_data.size

image_feature_size = get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
w, h = image.size
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h))

elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

prompt = llm_inputs.get("prompt")
if prompt is None:
image_idx = []
new_prompt = None
else:
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")

elif (num_image_tags := len(image_idx)) > 1:
assert num_image_tags == len(
image_data), "The count of image_placeholder not match image's"
new_prompt = prompt

prompt_token_ids = llm_inputs["prompt_token_ids"]
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
# masked place_holder with image token id
for idx in image_idx:
image_token_ids = _get_image_placeholder_token_ids(model_config,
idx=idx)
for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
prompt_token_ids[i:i + len(image_token_ids)] = [
_IMAGE_TOKEN_ID
] * len(image_token_ids)
break
# merge consecutive tag ids
merged_token_ids: List[int] = []
for is_placeholder, token_ids in itertools.groupby(
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
if is_placeholder:
merged_token_ids.append(_IMAGE_TOKEN_ID)
else:
merged_token_ids.extend(list(token_ids))

# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(_IMAGE_TOKEN_ID)

# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
break
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
new_token_ids.extend(
repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
))
placeholder_idx += 1
else:
new_token_ids.append(prompt_token_ids[i])
new_token_ids.append(token_id)

# NOTE: Create a defensive copy of the original inputs
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)

return input_processor_for_clip(
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
return llm_inputs


@MULTIMODAL_REGISTRY.register_image_input_mapper()
Expand Down
Loading
Loading