diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3bf430235462b..e70bbc9c49076 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -183,61 +183,74 @@ def repeat_and_pad_token( return replacement -def repeat_and_pad_placeholder_tokens( +def repeat_and_pad_image_tokens( tokenizer: AnyTokenizer, prompt: Optional[str], prompt_token_ids: List[int], *, - placeholder_token_id: int, - repeat_count: int = 1, + image_token_id: int, + repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, ) -> Tuple[Optional[str], List[int]]: + if isinstance(repeat_count, int): + repeat_count = [repeat_count] + if prompt is None: new_prompt = None else: - placeholder_token_str = tokenizer.decode(placeholder_token_id) + image_token_str = tokenizer.decode(image_token_id) pad_token_str_left = (None if pad_token_left is None else tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - placeholder_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) - - placeholder_token_count = prompt.count(placeholder_token_str) + + image_token_count = prompt.count(image_token_str) # This is an arbitrary number to distinguish between the two cases - if placeholder_token_count > 16: + if image_token_count > 16: logger.warning( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " - "repeating %s tokens.", placeholder_token_str) - elif placeholder_token_count > 1: - logger.warning("Multiple multi-modal input is not supported yet, " - "so any extra placeholder tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + "repeating %s tokens.", image_token_str) + if image_token_count < len(repeat_count): + logger.warning( + "The number of image tokens in the prompt is less than " + "the number of image inputs. Extra image tokens will be " + "treated as plain text") + repeat_count = repeat_count[:image_token_count] + + prompt_parts = prompt.split(image_token_str, + maxsplit=len(repeat_count)) + new_prompt = "" + for i, repeat_count_item in enumerate(repeat_count): + replacement_str = "".join( + repeat_and_pad_token( + image_token_str, + repeat_count=repeat_count_item, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace + new_prompt += prompt_parts[i] + replacement_str + new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + image_token_idx = 0 for i, token in enumerate(prompt_token_ids): - if token == placeholder_token_id: + if token == image_token_id: replacement_ids = repeat_and_pad_token( - placeholder_token_id, - repeat_count=repeat_count, + image_token_id, + repeat_count=repeat_count[image_token_idx], pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) + image_token_idx += 1 - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break + # No need to further scan the list since we replaced all tokens + if image_token_idx >= len(repeat_count): + new_token_ids.extend(prompt_token_ids[i + 1:]) + break else: new_token_ids.append(token)