diff --git a/examples/offline_inference_whisper.py b/examples/offline_inference_whisper.py new file mode 100644 index 0000000000000..35326a435c76e --- /dev/null +++ b/examples/offline_inference_whisper.py @@ -0,0 +1,61 @@ +import time + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset + +dtype = "float" + +# Create a Whisper encoder/decoder model instance +llm = LLM( + model="openai/whisper-large-v3", + max_model_len=448, + max_num_seqs=400, + limit_mm_per_prompt={"audio": 1}, + kv_cache_dtype="fp8", +) + +prompts = [ + { + "prompt": "<|startoftranscript|>", + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + }, + }, + { # Test explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": AudioAsset("winning_call").audio_and_sample_rate, + }, + }, + "decoder_prompt": "<|startoftranscript|>", + } +] * 1024 + +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=200, +) + +start = time.time() + +# Generate output tokens from the prompts. The output is a list of +# RequestOutput objects that contain the prompt, generated +# text, and other information. +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + encoder_prompt = output.encoder_prompt + generated_text = output.outputs[0].text + print(f"Encoder prompt: {encoder_prompt!r}, " + f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + +duration = time.time() - start + +print("Duration:", duration) +print("RPS:", len(prompts) / duration) diff --git a/tests/models/encoder_decoder/audio/__init__.py b/tests/models/encoder_decoder/audio/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/audio/test_whisper.py b/tests/models/encoder_decoder/audio/test_whisper.py new file mode 100644 index 0000000000000..81761d2a2b679 --- /dev/null +++ b/tests/models/encoder_decoder/audio/test_whisper.py @@ -0,0 +1,107 @@ +"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling. + +Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`. +""" +from typing import Optional + +import pytest + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset + +from ....utils import fork_new_process_for_each_test, multi_gpu_test + +PROMPTS = [ + { + "prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + }, + }, + { # Test explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": AudioAsset("winning_call").audio_and_sample_rate, + }, + }, + "decoder_prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + } +] + +EXPECTED = { + "openai/whisper-medium": [ + " The first words I spoke in the original phonograph, a little piece" + " of practical poetry. Mary had a little lamb, its fleece was white as" + " snow, and everywhere that Mary went the lamb would shun it all.", + " And the old one pitch on the way to Edgar Martinez swung on the line" + " down the left field line for Obeysmith. Here comes Joy. Here is" + " Jorgen at third base. They're gonna wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh" + " my." + ], + "openai/whisper-large-v3": [ + " The first words I spoke in the original phonograph. A little piece" + " of practical poetry. Mary had a little lamb, its fleece was white as" + " snow, and everywhere that Mary went, the lamb was sure to go.", + " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line," + " down the left field line for a base hit. Here comes Joy. Here is" + " Junior to third base. They're going to wave him in. The throw to the" + " plate will be late. The Mariners are going to play for the American" + " League Championship. I don't believe it. It just continues. My, oh," + " my." + ] +} + + +def run_test( + model: str, + *, + enforce_eager: bool, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + prompt_list = PROMPTS * 10 + expected_list = EXPECTED[model] * 10 + + llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=enforce_eager, + ) + + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + max_tokens=200, + ) + + outputs = llm.generate(prompt_list, sampling_params) + + for output, expected in zip(outputs, expected_list): + print(output.outputs[0].text) + assert output.outputs[0].text == expected + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model", + ["openai/whisper-medium", "openai/whisper-large-v3"]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_models(model, enforce_eager) -> None: + run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3"]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +def test_models_distributed(model, enforce_eager, + distributed_executor_backend) -> None: + run_test(model, + enforce_eager=enforce_eager, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend) diff --git a/tests/models/registry.py b/tests/models/registry.py index 819ef957a07f3..3c4f3951fb4d7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -126,6 +126,7 @@ class _HfExamplesInfo: # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), + "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index 0e886e18fcd6d..17d021c4f5918 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2252,6 +2252,8 @@ def _get_and_verify_max_len( "seq_length", # Command-R "model_max_length", + # Whisper + "max_target_positions", # Others "max_sequence_length", "max_seq_length", diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3bc6becf0995..b3d396f9cedda 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1579,6 +1579,7 @@ def _preempt_by_recompute( seq.status = SequenceStatus.WAITING self.free_seq(seq) seq.reset_state_for_recompute() + self._free_seq_group_cross_attn_blocks(seq_group) def _preempt_by_swap( self, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3d606817e90aa..9ccae3ab11f75 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -184,10 +184,16 @@ def _tokenize_prompt( corresponding token IDs. """ tokenizer = self.get_tokenizer_group() - + add_special_tokens = None + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False return tokenizer.encode(request_id=request_id, prompt=prompt, - lora_request=lora_request) + lora_request=lora_request, + add_special_tokens=add_special_tokens) async def _tokenize_prompt_async( self, @@ -197,10 +203,17 @@ async def _tokenize_prompt_async( ) -> List[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group() - - return await tokenizer.encode_async(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + add_special_tokens = None + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + add_special_tokens = False + return await tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) def _can_process_multimodal(self) -> bool: model_config = self.model_config @@ -439,8 +452,15 @@ def _build_enc_dec_llm_inputs( assert_never(encoder_inputs) if decoder_inputs is None: - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - None) + if self.model_config.hf_config.model_type == "whisper": + # For Whisper models, the text prompt should go to the decoder. + # If no explicit encoder/decoder inputs, then copy the prompt + # from the encoder to the decoder. The encoder tokens are later + # overridden by the audio features. + dec_token_ids = encoder_inputs["prompt_token_ids"].copy() + else: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) decoder_inputs = token_inputs(dec_token_ids) elif (decoder_inputs["type"] == "token" or decoder_inputs["type"] == "multimodal"): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 04d806c3c7eae..af6e3b4eb244a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -169,6 +169,7 @@ "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } _SPECULATIVE_DECODING_MODELS = { diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py new file mode 100644 index 0000000000000..4207c4249aae5 --- /dev/null +++ b/vllm/model_executor/models/whisper.py @@ -0,0 +1,741 @@ +import math +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) + +import numpy as np +import torch +from torch import nn +from transformers.models.whisper.modeling_whisper import sinusoids + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.audio import resample_audio +from vllm.sequence import SequenceData +from vllm.transformers_utils.processor import cached_get_processor + +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, WeightsMapper, make_layers + +logger = init_logger(__name__) + + +class WhisperAudioInputs(TypedDict): + input_features: NestedTensors + """Shape: `(batch_size, 128, M)`""" + + +class WhisperPositionalEmbedding(nn.Embedding): + + def __init__(self, + num_positions: int, + embedding_dim: int, + padding_idx: Optional[int] = None): + super().__init__(num_positions, embedding_dim) + + def forward(self, position_ids): + return self.weight[position_ids] + + +class WhisperAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + attn_type: AttentionType = AttentionType.DECODER, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = embed_dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.attn_type = attn_type + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) + self.out_proj = RowParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=self.attn_type) + + output, _ = self.out_proj(attn_output) + + return output + + +class WhisperCrossAttention(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + self.q_proj = ColumnParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=0, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + q, _ = self.q_proj(hidden_states) + + # Encoder hidden states are only computed once during prefill phase. + # Afterwards, the keys and values should be available in the kv-cache. + if encoder_hidden_states is not None: + kv, _ = self.kv_proj(encoder_hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + else: + k = v = None + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + + output, _ = self.out_proj(attn_output) + + return output + + +class WhisperMLP(nn.Module): + + def __init__( + self, + embed_dim: int, + ffn_dim: int, + act_fn: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.activation_fn = get_act_fn(act_fn) + self.fc1 = ColumnParallelLinear( + input_size=embed_dim, + output_size=ffn_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size=ffn_dim, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class WhisperEncoderLayer(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.embed_dim = config.d_model + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + attn_type=AttentionType.ENCODER, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.mlp = WhisperMLP( + embed_dim=config.d_model, + ffn_dim=config.encoder_ffn_dim, + act_fn=config.activation_function, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + return hidden_states + + +class WhisperDecoderLayer(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.self_attn = WhisperAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + attn_type=AttentionType.DECODER, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(config.d_model) + self.encoder_attn = WhisperCrossAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder_attn", + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model) + self.mlp = WhisperMLP( + embed_dim=config.d_model, + ffn_dim=config.decoder_ffn_dim, + act_fn=config.activation_function, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.final_layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperEncoder(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = (math.sqrt(embed_dim) + if config.scale_embedding else 1.0) + + self.conv1 = nn.Conv1d(self.num_mel_bins, + embed_dim, + kernel_size=3, + padding=1) + self.conv2 = nn.Conv1d(embed_dim, + embed_dim, + kernel_size=3, + stride=2, + padding=1) + self.embed_positions = nn.Embedding(self.max_source_positions, + embed_dim) + self.start_layer, self.end_layer, self.layers = make_layers( + config.encoder_layers, + lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + with torch.no_grad(): + self.embed_positions.weight.copy_( + sinusoids(*self.embed_positions.weight.shape)) + + def forward( + self, + input_features: Union[torch.Tensor, List[torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ): + hidden_states = [] + for features in input_features: + embeds = nn.functional.gelu(self.conv1(features)) + embeds = nn.functional.gelu(self.conv2(embeds)) + embeds = embeds.permute(1, 0) + embeds = embeds + self.embed_positions.weight[:embeds.size(0), :] + hidden_states.append(embeds) + hidden_states = torch.cat(hidden_states) + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class WhisperDecoder(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.max_source_positions = config.max_source_positions + self.embed_scale = (math.sqrt(config.d_model) + if config.scale_embedding else 1.0) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, + self.padding_idx) + self.embed_positions = WhisperPositionalEmbedding( + self.max_target_positions, config.d_model) + self.start_layer, self.end_layer, self.layers = make_layers( + config.decoder_layers, + lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, + prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_ids, + positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ): + inputs_embeds = self.get_input_embeddings(input_ids) + positions = self.embed_positions(positions) + hidden_states = inputs_embeds + positions + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + + +class WhisperModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = WhisperEncoder(vllm_config=vllm_config, + prefix=f"{prefix}.encoder") + self.decoder = WhisperDecoder(vllm_config=vllm_config, + prefix=f"{prefix}.decoder") + + def forward( + self, + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + encoder_outputs = self.get_encoder_outputs( + input_features, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=encoder_outputs, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return decoder_outputs + + def get_encoder_outputs( + self, + input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> Optional[torch.Tensor]: + if input_features is None: + return None + return self.encoder( + input_features, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + (".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"), + (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def get_max_whisper_audio_tokens(ctx: InputContext) -> int: + return ctx.model_config.hf_config.max_source_positions + + +def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + assert mm_counts["audio"] == 1 + sample_rate = 16000 + num_tokens = ctx.model_config.hf_config.max_source_positions + return DummyData( + SequenceData.from_prompt_token_counts((0, num_tokens)), + {"audio": [(np.zeros(30 * sample_rate), sample_rate)]}, + ) + + +def input_processor_for_whisper(ctx: InputContext, inputs): + multi_modal_data = inputs["encoder"]["multi_modal_data"] + if isinstance(multi_modal_data["audio"], list): + assert len(multi_modal_data["audio"]) == 1 + multi_modal_data["audio"] = multi_modal_data["audio"][0] + # Resample and process audio + audio, orig_sr = multi_modal_data["audio"] + processor = cached_get_processor(ctx.model_config.model) + target_sr = processor.feature_extractor.sampling_rate + audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) + if audio.size > 30 * target_sr: + # Truncate audio to 30 seconds + audio = audio[:30 * target_sr] + multi_modal_data["audio"] = (audio, target_sr) + # Calculate number of tokens after feature extraction and convolutions + num_tokens = (audio.size // 160 - 1) // 2 + 1 + # Pre-allocate placeholder tokens in encoder sequence + inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens + return inputs + + +def input_mapper_for_whisper( + ctx: InputContext, + multi_modal_data: Union[np.ndarray, List[np.ndarray]], +) -> MultiModalKwargs: + if not isinstance(multi_modal_data, list): + multi_modal_data = [multi_modal_data] + + assert len(multi_modal_data) == 1 + + if len(multi_modal_data) == 0: + return MultiModalKwargs() + + processor = cached_get_processor(ctx.model_config.model) + sampling_rate = processor.feature_extractor.sampling_rate + + audios = [audio for audio, _ in multi_modal_data] + + kwargs = processor(audios, + sampling_rate=sampling_rate, + padding=False, + return_tensors="pt") + kwargs["input_features"] = kwargs["input_features"].squeeze(0) + kwargs["input_features"] = kwargs["input_features"].to(torch.float16) + + return MultiModalKwargs(kwargs) + + +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) +@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_whisper_audio_tokens) +class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) + self.unpadded_vocab_size = config.vocab_size + self.proj_out = ParallelLMHead(config.vocab_size, + config.d_model, + quant_config=quant_config) + self.proj_out = self.proj_out.tie_weights( + self.model.decoder.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + audio_input = self._parse_and_validate_audio_input(**kwargs) + decoder_outputs = self.model( + input_features=audio_input["input_features"], + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return decoder_outputs + + def get_multimodal_embeddings( + self, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> Optional[NestedTensors]: + # TODO: This method does not obey the interface for SupportsMultiModal. + # Refactor this once encoder/decoder support is implemented in V1. + audio_input = self._parse_and_validate_audio_input(**kwargs) + return self.model.get_encoder_outputs( + audio_input["input_features"], + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # TODO: This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. Refactor this once + # encoder/decoder support is implemented in V1. + return self.model.decoder.get_input_embeddings(input_ids) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> WhisperAudioInputs: + input_features = kwargs.pop("input_features", None) + + if input_features is not None: + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(input_features)}") + input_features = [feat.to(self.dtype) for feat in input_features] + + return WhisperAudioInputs(input_features=input_features) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.proj_out, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) + loaded_weights = [(name, loaded_weight) + for name, loaded_weight in weights] + mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) + return loader.load_weights(loaded_weights, mapper=mapper) \ No newline at end of file diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6baf19d675d50..f75b835e8464b 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -14,7 +14,7 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from .audio import resample_audio @@ -55,24 +55,6 @@ def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": ) -def _encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: bool = False, -) -> list[int]: - """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=...)`. - """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) - - return tokenizer.encode(text, add_special_tokens=add_special_tokens) - - @lru_cache(maxsize=2048) def _cached_encode( tokenizer: AnyTokenizer, @@ -80,7 +62,9 @@ def _cached_encode( *, add_special_tokens: bool = False, ) -> list[int]: - return _encode(tokenizer, text, add_special_tokens=add_special_tokens) + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) def _decode( @@ -763,7 +747,9 @@ def _apply_prompt_replacements( mm_item_counts, ) - token_ids = _encode(tokenizer, text) + token_ids = encode_tokens(tokenizer, + text, + add_special_tokens=False) matched_repls = [match.prompt_repl for match in text_matches] placeholders = self._find_placeholders(matched_repls, token_ids, diff --git a/vllm/sequence.py b/vllm/sequence.py index cc3d96fc93a79..c4f4032a7bac7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -709,15 +709,27 @@ def token_type_ids(self) -> Optional[List[int]]: @property def multi_modal_data(self) -> MultiModalDataDict: - return self.first_seq.multi_modal_data + if self.first_seq.multi_modal_data: + return self.first_seq.multi_modal_data + elif self.encoder_seq is not None: + return self.encoder_seq.multi_modal_data + return {} @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - return self.first_seq.multi_modal_placeholders + if self.first_seq.multi_modal_data: + return self.first_seq.multi_modal_placeholders + elif self.encoder_seq is not None: + return self.encoder_seq.multi_modal_placeholders + return {} @property def mm_processor_kwargs(self) -> Dict[str, Any]: - return self.first_seq.mm_processor_kwargs + if self.first_seq.multi_modal_data: + return self.first_seq.mm_processor_kwargs + elif self.encoder_seq is not None: + return self.encoder_seq.mm_processor_kwargs + return {} @property def lora_int_id(self) -> int: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e6701f4c4b835..42b2f095bc543 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -21,6 +21,25 @@ MistralTokenizer] +def encode_tokens( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=...)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + elif add_special_tokens is not None: + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + return tokenizer.encode(text) + + def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 8f78ef65bbf1a..e6cc7cd4e2e3a 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -32,7 +32,8 @@ def get_max_input_len( def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @@ -41,7 +42,8 @@ async def encode_async( self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 9a999a0d6067d..3f7627e11ae5e 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -112,7 +112,8 @@ def _finalize_encode(self, actor: ray.ObjectRef, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -132,7 +133,8 @@ def encode(self, ret = ray.get( actor.encode.remote(request_id=request_id, prompt=prompt, - lora_request=lora_request)) + lora_request=lora_request, + add_special_tokens=add_special_tokens)) except ActorDiedError as e: # If the actor is dead, we first try to reinitialize it. logger.warning("%s died with ActorDiedError, reinitializing.", @@ -143,7 +145,8 @@ def encode(self, ret = ray.get( actor.encode.remote(request_id=request_id, prompt=prompt, - lora_request=lora_request)) + lora_request=lora_request, + add_special_tokens=add_special_tokens)) except ActorDiedError as e: logger.error( "%s died for second time in a row, marking " @@ -160,7 +163,8 @@ async def encode_async( self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -177,9 +181,11 @@ async def encode_async( actor_is_alive = True original_actor = actor try: - ret = await actor.encode.remote(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + ret = await actor.encode.remote( + request_id=request_id, + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) except ActorDiedError as e: # If the actor is dead, we first try to reinitialize it. logger.warning("%s died with ActorDiedError, reinitializing.", @@ -187,9 +193,11 @@ async def encode_async( exc_info=e) actor = self._init_actor() try: - ret = await actor.encode.remote(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + ret = await actor.encode.remote( + request_id=request_id, + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens) except ActorDiedError as e: logger.error( "%s died for second time in a row, marking " diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 95a8f7098bbac..6dc2f90561873 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -2,7 +2,7 @@ from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import (AnyTokenizer, +from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_lora_tokenizer, get_lora_tokenizer_async, get_tokenizer) @@ -55,9 +55,12 @@ def _raise_if_input_too_long(self, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - ret = tokenizer.encode(prompt) + ret = encode_tokens(tokenizer, + prompt, + add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret @@ -65,9 +68,12 @@ async def encode_async( self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = tokenizer.encode(prompt) + ret = encode_tokens(tokenizer, + prompt, + add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bff01320d7927..4d5d918087be8 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -287,12 +287,11 @@ def profile_run(self) -> None: seq_len, self.mm_registry, is_encoder_data=False) - encoder_dummy_data \ - = self.input_registry.dummy_data_for_profiling( - self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) + encoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine assert len(