From a11f3265282c712d1d9fa75368e2a8c40019fbb7 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 8 Dec 2024 04:50:51 -0800 Subject: [PATCH] [V1] Initial support of multimodal models for V1 re-arch (#10699) Signed-off-by: Roger Wang --- vllm/engine/arg_utils.py | 16 +-- vllm/model_executor/models/interfaces.py | 5 + vllm/model_executor/models/internvl.py | 68 ++++++++++--- vllm/model_executor/models/molmo.py | 72 ++++++++++++-- vllm/model_executor/models/pixtral.py | 121 +++++++++++++++++------ vllm/model_executor/models/utils.py | 28 +++++- vllm/multimodal/inputs.py | 3 +- vllm/multimodal/utils.py | 10 +- vllm/v1/core/scheduler.py | 4 +- vllm/v1/engine/llm_engine.py | 24 ++++- vllm/v1/engine/mm_input_mapper.py | 2 +- 11 files changed, 284 insertions(+), 69 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 96c11ec2b4f9e..3db069ec64ee4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1050,9 +1050,12 @@ def create_engine_config(self, # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # Chunked prefill is currently disabled for multimodal models by - # default. - if use_long_context and not model_config.is_multimodal_model: + # For multimodal models, chunked prefill is disabled by default in + # V0, but enabled by design in V1 + if model_config.is_multimodal_model: + self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) + + elif use_long_context: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) @@ -1241,12 +1244,9 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - # TODO (ywang96): Enable APC by default when VLM supports it. if engine_config.model_config.is_multimodal_model: - logger.warning( - "Prefix caching is currently not supported for multimodal " - "models and has been disabled.") - engine_config.cache_config.enable_prefix_caching = False + # TODO (ywang96): Enable APC by default when VLM supports it. + assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 01a381381ccec..c3979eab905db 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -36,6 +36,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: """ Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. + + The output embeddings must be one of the following formats: + - A list or tuple of 2D tensors, where each tensor corresponds to + each input image. + - A single 3D tensor, with the batch dimension grouping the 2D tensors. """ ... diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d5a7781fecfc3..42c769f79e202 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict): Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + data: NestedTensors + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -349,10 +355,32 @@ def input_processor( new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) + img_context_token_id = tokenizer.encode(self.img_context_token, + add_special_tokens=False) + assert len(img_context_token_id) == 1, \ + (f"Invalid image token '{self.img_context_token}': A valid image " + f"token encodes to a single token ID, got {img_context_token_id}.") + img_context_token_id = img_context_token_id[0] + + # Get precise tracking of placeholder positions + token_idx = image_idx = 0 + placeholder_ranges = [] + while token_idx < len(new_prompt_token_ids): + if new_prompt_token_ids[token_idx] == img_context_token_id: + curr_image_featue_size = image_feature_sizes[image_idx] + placeholder_ranges.append( + PlaceholderRange(offset=token_idx, + length=curr_image_featue_size)) + image_idx += 1 + token_idx += curr_image_featue_size + else: + token_idx += 1 - return token_inputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) def input_mapper( self, @@ -614,26 +642,46 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + patches_per_image = [] + for request_pixel_values in pixel_values: + for image_pixel_values in request_pixel_values: + patches_per_image.append(image_pixel_values.shape[0]) # We need to flatten (B, N, P) to (B*N*P), # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( flatten_bn(flatten_bn(pixel_values), concat=True)), - ) + patches_per_image=patches_per_image) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: InternVLImageInputs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor]: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None + image_embeds = self.extract_feature(image_input["data"]) + patches_per_image = image_input["patches_per_image"] + if len(patches_per_image) == 1: + image_embeds = image_embeds.unsqueeze(0) + return image_embeds + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in patches_per_image + ] + image_embeds = image_embeds.split(image_feature_sizes) return image_embeds def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -696,13 +744,11 @@ def forward( "inputs_embeds": inputs_embeds, } + # Only required if the model is mono-architecture if self.visual_token_mask is not None: - # overwrite visual_token_mask and img_context_token_id back to None, - # so that this doesn't need to depend on encoder output forward_kwargs.update( {"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None - self.img_context_token_id = None hidden_states = self.language_model.model(**forward_kwargs) return hidden_states diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index d1fcbd167c199..a328b5a2aeea7 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -37,7 +37,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -46,12 +46,16 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 +DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 +DEFAULT_IM_START_TOKEN_ID = 152067 +DEFAULT_IM_END_TOKEN_ID = 152064 +DEFAULT_IM_COL_TOKEN_ID = 152065 class MolmoImageInputs(TypedDict): @@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict): `(batch_size, num_crops, num_patch)` """ + image_start_end: Tuple[int, int] + """Starting and ending index of placeholder + tokens + """ + @dataclass class VisionBackboneConfig: @@ -918,6 +927,8 @@ def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): + if isinstance(data, list): + data = data[0] return MultiModalKwargs(data) @@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, if "image_masks" in out: dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) - return DummyData(dummy_seqdata, {"image": dummy_imgdata}) + size = 0 + offset = -1 + for i in range(len(token_ids)): + if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + dummy_imgdata["image_start_end"] = (offset, offset + size) + return DummyData(seq_data=dummy_seqdata, + multi_modal_data={"image": dummy_imgdata}, + multi_modal_placeholders={ + "image": + [PlaceholderRange(offset=offset, length=size)] + }) def pad_images( @@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): if image_masks is not None: image_data["image_masks"] = image_masks - image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + new_prompt_token_ids = out["input_ids"].tolist() + image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), dtype=torch.long) multi_modal_data = dict(image=image_data) + size = 0 + offset = -1 + for i in range(len(new_prompt_token_ids)): + if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + image_data["image_start_end"] = (offset, offset + size) prompt = inputs.get("prompt") if prompt is None: - prompt = tokenizer.decode(out["input_ids"]) + prompt = tokenizer.decode(new_prompt_token_ids) return token_inputs( - prompt_token_ids=out["input_ids"], + prompt_token_ids=new_prompt_token_ids, prompt=prompt, multi_modal_data=multi_modal_data, + multi_modal_placeholders={ + "image": [PlaceholderRange(offset=offset, length=size)] + }, ) @@ -1113,6 +1154,7 @@ def _parse_and_validate_image_input( ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) + image_start_end = kwargs.pop("image_start_end", None) if images is None: return None @@ -1130,6 +1172,7 @@ def _parse_and_validate_image_input( image_input_idx=image_input_idx, seq_len=seq_len, image_masks=image_masks, + image_start_end=image_start_end, ) def _process_image_input( @@ -1178,9 +1221,16 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # Note: In this original implementation from AI2, the final # vision_embeddings will be always be the same length - # of input embedddings, which is not very efficient. - # TODO(ywang96): see if this can be optimized. + # of input embeddings. vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) + + # Split by the sizes of the input sequences. For each full embedding, + # extract the actual vision embeddings to be merged. + vision_embeddings = list(vision_embeddings.split(seq_len.tolist())) + for i in range(len(vision_embeddings)): + start, end = image_input['image_start_end'][i] + vision_embeddings[i] = vision_embeddings[i][start:end] + return vision_embeddings def get_input_embeddings( @@ -1190,7 +1240,11 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - inputs_embeds = inputs_embeds + multimodal_embeddings + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID + ]) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 215727cadd954..c6786c363ab4a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -48,6 +48,9 @@ except ImportError: USE_XFORMERS_OPS = False +PIXTRAL_IMAGE_BREAK_ID = 12 +PIXTRAL_IMAGE_END_ID = 13 + def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( @@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - patch_size = mm_encoder.mm_config.image_patch_size image_token_id = mm_encoder.special_ids.img mm_config = ctx.model_config.multimodal_config @@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, size = 256 image = Image.new("RGB", (size, size), color=0) - image_feature_size = (size**2) // (patch_size**2) - + encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image)) + image_feature_size = len(encoding.tokens) num_image_tokens = image_feature_size * num_images seq_data = SequenceData.from_prompt_token_counts( (image_token_id, num_image_tokens), @@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext, Args: ctx: Context of the loaded model. - data: data potentially containing image/image embeddings to be mapped - to pixel_values in .forward() for a visual QWenLMHeadModel model. + data: data potentially containing PIL images to be processed + and mapped to `images`. Returns: MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ - # Early exit if we have provided an image to a language only Qwen model model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) @@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext, data_list = data if isinstance(data, list) else [data] images = [] + image_tokens_list = [] for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) image = torch.from_numpy(encoding.image).to(device="cuda", dtype=torch.float16) images.append(image) + image_tokens_list.append(encoding.tokens) - return MultiModalKwargs({"images": images}) + image_tokens = torch.tensor([ + token_id for image_tokens in image_tokens_list + for token_id in image_tokens + ]) + return MultiModalKwargs({"images": images, "image_tokens": image_tokens}) def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is not None and "image" in multi_modal_data: - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) - - mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - image_token_id = mm_encoder.special_ids.img + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs - if image_token_id not in inputs['prompt_token_ids']: - raise ValueError( - f"You've passed {inputs=} without {image_token_id=}" - " Make sure to process your input via mistral_common's" - " tokenizer or pass a chat completion request. For more" - " For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + prompt_token_ids = inputs.get("prompt_token_ids") + prompt = inputs.get("prompt") + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - return inputs + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img + image_break_id = mm_encoder.special_ids.img_break + image_end_id = mm_encoder.special_ids.img_end + + if image_token_id not in inputs['prompt_token_ids']: + raise ValueError( + f"You've passed {inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + # Get precise tracking of placeholder positions + placeholder_ranges = [] + curr_offset = -1 + curr_length = 0 + for i in range(len(prompt_token_ids)): + if prompt_token_ids[i] in (image_token_id, image_break_id): + if curr_offset < 0: + curr_offset = i + curr_length += 1 + elif prompt_token_ids[i] == image_end_id: + curr_length += 1 + placeholder_ranges.append( + PlaceholderRange(offset=curr_offset, length=curr_length)) + curr_offset = -1 + curr_length = 0 + else: + pass + return token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @@ -192,11 +225,29 @@ def sampler(self): return get_sampler() def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - image_input = self._parse_and_validate_image_input(**kwargs) + image_input, image_tokens = self._parse_and_validate_image_input( + **kwargs) if image_input is None: return None + vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + + # NOTE: We patch the outputs of the vision encoder with embeddings + # from `[IMG_BREAK]` and `[IMG_END]` tokens. + image_embeds = self.language_model.get_input_embeddings(image_tokens) + image_token_mask = image_tokens == self.vision_args.image_token_id + image_embeds[image_token_mask] = vision_embeddings + + # NOTE: Image embeddings are split into separate tensors for each image + # by the indices of `[IMG_END]` token. + split_indices = torch.where( + image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1 + if len(split_indices) <= 1: + # Do not split, return as tensor of shape [1, fs, hs] + return image_embeds.unsqueeze(0) + + image_embeds = image_embeds.tensor_split(split_indices.cpu()) + return image_embeds def get_input_embeddings( self, @@ -206,8 +257,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.vision_args.image_token_id) + input_ids, inputs_embeds, multimodal_embeddings, [ + self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID, + PIXTRAL_IMAGE_BREAK_ID + ]) return inputs_embeds def forward( @@ -245,10 +298,11 @@ def forward( def _parse_and_validate_image_input( self, images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], - torch.Tensor]] = None + torch.Tensor]] = None, + image_tokens: Optional[torch.Tensor] = None, ) -> Optional[List[torch.Tensor]]: if images is None: - return None + return None, None if isinstance(images, torch.Tensor): # if passed as batch take all images @@ -267,7 +321,16 @@ def _parse_and_validate_image_input( images = flatten_images - return images + if isinstance(image_tokens, torch.Tensor): + # image_tokens are batched + image_tokens = image_tokens.flatten() + elif isinstance(image_tokens, list): + # image_tokens are of different lengths thus passed as a list + image_tokens = torch.cat(image_tokens) + + assert image_tokens.dim() == 1 + + return images, image_tokens def _process_image_input(self, image_input: List[torch.Tensor]) -> torch.Tensor: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5ec44955dbd80..269b66806adf4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -409,16 +409,42 @@ def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_token_id: int, + placeholder_token_id: Union[int, List[int]], ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to + slice-merge instead of individually scattering. + + For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where + - T is text token + - S is image start token + - I is image embedding token + - B is image break token + - E is image end token. + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of + input_ids for a correct embedding merge. Note: This updates ``inputs_embeds`` in place. """ + if isinstance(placeholder_token_id, list): + placeholder_token_id = torch.tensor(placeholder_token_id, + device=input_ids.device) + return _merge_multimodal_embeddings( + inputs_embeds, + torch.isin(input_ids, placeholder_token_id), + multimodal_embeddings, + ) + return _merge_multimodal_embeddings( inputs_embeds, (input_ids == placeholder_token_id), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 640c7c04b8817..229a8fbdf5831 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, + Tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d4333b7519b47..c898ca4e6573e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens( return new_prompt, new_token_ids, placeholder_ranges -def consecutive_placeholder_ranges(num_items: int, - item_size: int) -> List[PlaceholderRange]: +def consecutive_placeholder_ranges( + num_items: int, + item_size: int, + initial_offset: int = 0) -> List[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ - PlaceholderRange(offset=i * item_size, length=item_size) - for i in range(num_items) + PlaceholderRange(offset=initial_offset + i * item_size, + length=item_size) for i in range(num_items) ] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f1f26f4e8d443..1203d35fc985f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -73,12 +73,12 @@ def __init__( # has the Transformer architecture (e.g., ViT). # FIXME(woosuk): Below are placeholder values. We need to calculate the # actual values from the configurations. - self.max_num_encoder_input_tokens = 2048 + self.max_num_encoder_input_tokens = 16384 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) + self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 312c0242a45dd..994e68669108e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,5 +1,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union +from typing_extensions import TypeVar + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase @@ -12,7 +14,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer @@ -21,6 +24,8 @@ logger = init_logger(__name__) +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) + class LLMEngine: """Legacy LLMEngine for backwards compatibility.""" @@ -169,5 +174,18 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) - def get_tokenizer_group(self, group_type): - pass + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 45882f8f076d4..7ad6882b04520 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -33,7 +33,7 @@ def process_inputs( num_images = len(image_inputs) for i in range(num_images): mm_input = self.multi_modal_input_mapper( - {"image": [image_inputs[i]]}, + {"image": image_inputs[i]}, mm_processor_kwargs=mm_processor_kwargs, ) mm_inputs.append(mm_input)