From cc2df563d70648996db880c75bf30b6bc9a05638 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 20 Feb 2025 13:22:09 -0800 Subject: [PATCH 01/25] update ultravox to accept more than 30s audio Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 194 ++++++++++++++++--------- 1 file changed, 122 insertions(+), 72 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e24b4aeb8ae84..f59c233883ce2 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" -import math from functools import cached_property from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -47,8 +46,18 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: NestedTensors + data: torch.Tensor """Shape: `(batch_size, num_audios, 80, M)`""" + lens: torch.Tensor + """ + Length of the audio frames. Used for attention mask in WhisperEncoder. + Shape: `(batch_size)` + """ + token_len: torch.Tensor + """ + Length of the audio tokens. Used for flattening the audio features. + Shape: `(batch_size)` + """ class UltravoxAudioEmbeddingInputs(TypedDict): @@ -76,7 +85,9 @@ def get_hf_processor( # placeholder that will cause confusion with the actual end of turn # token, thus we override placeholder with a reserved special # token. - hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE + hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN + hf_processor.audio_replacement = _AUDIO_PLACEHOLDER_OVERRIDE + # TODO: the old code currently does not work with the new processor return hf_processor def get_feature_extractor( @@ -99,11 +110,7 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - feature_extractor = self.get_feature_extractor() - max_audio_tokens = math.ceil(feature_extractor.chunk_length * - _AUDIO_TOKENS_PER_SECOND) - - return {"audio": max_audio_tokens} + return {} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] @@ -158,32 +165,19 @@ def _call_hf_processor( mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, + include_audio_num_chunks=True, ) - # Ultravox processor doesn't support multiple inputs, - # therefore we need to input text and audio one by one - audio_features, audio_token_len = [], [] - shared_outputs = {} - for audio in audios: - # NOTE: Ultravox processor accepts "audio" instead of "audios" - item_processor_data = dict(**mm_data, audio=audio) - - item_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=item_processor_data, - mm_kwargs=mm_kwargs, - ) + item_processor_data = dict(**mm_data, audios=audios) - audio_features.append(item_outputs.pop("audio_values")[0]) - audio_token_len.append(item_outputs.pop("audio_token_len").item()) - shared_outputs = item_outputs - - combined_outputs = dict( - **shared_outputs, - audio_features=audio_features, - audio_token_len=audio_token_len, + output = super()._call_hf_processor( + prompt=prompt, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, ) - return BatchFeature(combined_outputs) + output['audio_features'] = output.pop('audio_values') + + return output def _apply_hf_processor_tokens_only( self, @@ -201,8 +195,14 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - audio_features=MultiModalFieldConfig.batched("audio"), - audio_token_len=MultiModalFieldConfig.batched("audio"), + # to handle longer than 30s audio, each audio might be split + # into multiple chunks as such, their batch dimension can be + # higher than the number of audio samples + audio_features=MultiModalFieldConfig.batched("audio_chunked"), + audio_token_len=MultiModalFieldConfig.batched("audio_chunked"), + audio_lens=MultiModalFieldConfig.batched("audio_chunked"), + # num_chunks can convert audio_chunked to audio batch dimension + audio_num_chunks=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), ) @@ -213,14 +213,22 @@ def _get_prompt_replacements( out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - vocab = tokenizer.get_vocab() - replacement_id = vocab[ - hf_processor.audio_token_replacement] # type: ignore + replacement_id = hf_processor.audio_replacement_token_id # type: ignore + + # Each audio can be split into multiple chunks. + # chunks_start_idx[i] indicates the start index of the chunks + # belonging to the i-th audio. + chunks_start_idx: torch.Tensor = torch.cumsum( + out_mm_kwargs["audio_num_chunks"], dim=0, dtype=torch.int32) + chunks_start_idx = torch.cat( + [torch.tensor([0], dtype=torch.int32), chunks_start_idx]) + out_mm_kwargs.pop("audio_num_chunks", None) def get_replacement_ultravox(item_idx: int): - audio_token_len = out_mm_kwargs["audio_token_len"][item_idx] + start = chunks_start_idx[item_idx] + end = chunks_start_idx[item_idx + 1] + audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() return [replacement_id] * int(audio_token_len) # type: ignore return [ @@ -312,12 +320,49 @@ class ModifiedWhisperEncoder(WhisperEncoder): base_model_prefix = "model.encoder" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config.is_decoder = False + + @property + def max_context_length(self): + return (self.config.max_source_positions * self.conv1.stride[0] * + self.conv2.stride[0]) + + def get_attention_mask_by_audio_len(self, + audio_lens: Optional[torch.Tensor], + hidden_states: torch.Tensor): + """ + Create attention mask based on audio lengths to mask out padding tokens + For each sample in batch: + - Convert raw audio length to feature length after convolutions + - Create bool mask: True for valid positions and False for padding + - Convert to attention mask format expected by transformer layers + (1.0 for positions to attend to, large negative for positions to ignore) + This masking ensures consistent behavior between training and inference + by preventing the model from attending to padding tokens in both cases + """ + if audio_lens is None: + return None + + audio_feature_len = self._get_feat_extract_output_lengths(audio_lens) + max_seq_len = hidden_states.shape[1] + attention_mask = torch.arange(max_seq_len, + device=hidden_states.device)[None, :].lt( + audio_feature_len.view(-1, 1)) + attention_mask = self.get_extended_attention_mask( + attention_mask, + None, + dtype=hidden_states.dtype, + ) + return attention_mask + def forward( self, - input_features, + input_features: torch.Tensor, + audio_lens: Optional[torch.Tensor] = None, ): - expected_seq_length = (self.config.max_source_positions * - self.conv1.stride[0] * self.conv2.stride[0]) + expected_seq_length = self.max_context_length if input_features.shape[-1] > expected_seq_length: raise ValueError( f"Whisper expects the mel input features to be of length " @@ -336,10 +381,13 @@ def forward( p=self.dropout, training=self.training) + attention_mask = self.get_attention_mask_by_audio_len( + audio_lens, hidden_states) + for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, - None, + attention_mask, layer_head_mask=None, ) @@ -425,9 +473,10 @@ def get_mm_mapping(self) -> MultiModelKeys: ) def _audio_features_to_embeddings( - self, input_features: torch.Tensor) -> torch.Tensor: + self, input_features: torch.Tensor, + audio_lens: Optional[torch.Tensor]) -> torch.Tensor: audio_input = input_features.to(self.audio_tower.dtype) - audio_features = self.audio_tower(audio_input) + audio_features = self.audio_tower(audio_input, audio_lens) audio_features = audio_features.to(self.audio_tower.dtype) audio_embeddings = self.multi_modal_projector(audio_features) return audio_embeddings @@ -436,6 +485,8 @@ def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[UltravoxAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) + audio_lens = kwargs.pop("audio_lens", None) + audio_token_len = kwargs.pop("audio_token_len", None) if audio_features is None and audio_embeds is None: return None @@ -446,7 +497,9 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_features)}") return UltravoxAudioFeatureInputs(type="audio_features", - data=audio_features) + data=audio_features, + lens=audio_lens, + token_len=audio_token_len) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): @@ -463,34 +516,27 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - audio_features = audio_input["data"] - if isinstance(audio_features, torch.Tensor): - # Combine the B and N dimensions for the encoder/projector - flattened = flatten_bn(audio_features) - flattened_embeddings = self._audio_features_to_embeddings( - flattened) - - # Restore the original dimensions - embeddings = flattened_embeddings.unflatten( - 0, audio_features.shape[:2]) - return embeddings - - result = [] - # TODO: Batch heterogeneous tensors through the encoder/projector - for audio_features_item in audio_features: - if isinstance(audio_features_item, torch.Tensor): - result.append( - self._audio_features_to_embeddings(audio_features_item)) - else: - embeddings = [ - # Add a batch dimension to embed it, then remove it. - self._audio_features_to_embeddings(tensor.unsqueeze(0) - ).squeeze(0) - for tensor in audio_features_item - ] - result.append(embeddings) + # remove unneeded extra dimension added to all elements of mm_kwargs + audio_features = flatten_bn(audio_input["data"]) + audio_lens = flatten_bn(audio_input["lens"]) + audio_token_len = flatten_bn(audio_input["token_len"]) + + embeddings = self._audio_features_to_embeddings( + audio_features, audio_lens) - return result + # We should flatten and concatenate embeddings based on token lengths + # For example, with token_len = [4, 2, 3], flattened_embeddings will be + # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3]) + + # Create a mask of valid indices based on token lengths + max_len = embeddings.shape[1] + indices = torch.arange(max_len, device=embeddings.device).expand( + embeddings.shape[0], -1) + mask = indices < audio_token_len[:, None] + # Apply mask and flatten + flattened_embeddings = embeddings[mask] + + return flattened_embeddings def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: audio_input = self._parse_and_validate_audio_input(**kwargs) @@ -537,7 +583,11 @@ def forward(self, with the `input_ids`. Args: - audio_features: A batch of audio inputs [B, N, 80, M]. + audio_features: A batch of audio input chunks [B, N, 80, M]. + audio_lens: Length of audio frames for each audio chunk [B]. + audio_token_len: Length of audio tokens for each audio chunk [B']. + Note: batch dim is different from batch dim in audio chunks. + """ if intermediate_tensors is not None: From c7e0329ba84c1d6f02e8b3754bd5993d41734b06 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 20 Feb 2025 13:22:47 -0800 Subject: [PATCH 02/25] temporarily use model with updated processor for tests Signed-off-by: Farzad Abdolhosseini --- tests/models/decoder_only/audio_language/test_ultravox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index d1f643a8fdb73..50d48838b317f 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" +MODEL_NAME = "fixie-ai/ultravox-v0_3-llama-3_2-1b" AudioTuple = Tuple[np.ndarray, int] From 0c5363eebbb0f9ca6304944c811ba0557fd9faa6 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 21 Feb 2025 13:01:40 -0800 Subject: [PATCH 03/25] fix collation Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 38 ++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f59c233883ce2..ab9a74767335f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,6 +2,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" +import math from functools import cached_property from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -46,14 +47,14 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: torch.Tensor + data: NestedTensors """Shape: `(batch_size, num_audios, 80, M)`""" - lens: torch.Tensor + lens: NestedTensors """ Length of the audio frames. Used for attention mask in WhisperEncoder. Shape: `(batch_size)` """ - token_len: torch.Tensor + token_len: NestedTensors """ Length of the audio tokens. Used for flattening the audio features. Shape: `(batch_size)` @@ -110,7 +111,11 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - return {} + feature_extractor = self.get_feature_extractor() + max_audio_tokens = math.ceil(feature_extractor.chunk_length * + _AUDIO_TOKENS_PER_SECOND) + + return {"audio": max_audio_tokens} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] @@ -422,6 +427,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + # Due to the batching of audio chunks, the preprocessor cache cannot + # do the right thing so disable it. + vllm_config.model_config.disable_mm_preprocessor_cache = True multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config @@ -516,10 +524,24 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - # remove unneeded extra dimension added to all elements of mm_kwargs - audio_features = flatten_bn(audio_input["data"]) - audio_lens = flatten_bn(audio_input["lens"]) - audio_token_len = flatten_bn(audio_input["token_len"]) + audio_features = audio_input["data"] + if isinstance(audio_features, list): + max_len = max(x.shape[-1] for x in audio_features) + # Pad and concatenate: + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] + audio_features = torch.cat( + [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features]) + else: + # Flatten [B, N, 80, M] -> [B * N, 80, M] + audio_features = flatten_bn(audio_features) + + if isinstance(audio_input['lens'], list): + # [B1, B2] -> [B1+B2] + audio_lens = torch.cat(audio_input['lens']) + audio_token_len = torch.cat(audio_input['token_len']) + else: + audio_lens = flatten_bn(audio_input['lens']) + audio_token_len = flatten_bn(audio_input['token_len']) embeddings = self._audio_features_to_embeddings( audio_features, audio_lens) From 189f5cc23b3529472503839ae8548901692fd20b Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 24 Feb 2025 18:03:08 -0800 Subject: [PATCH 04/25] revert audio_replacement -> audio_token_replacement Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ab9a74767335f..a4bff36ed7cfd 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -48,16 +48,16 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: NestedTensors - """Shape: `(batch_size, num_audios, 80, M)`""" + """Shape: `(batch_size, num_chunks, 80, M)`""" lens: NestedTensors """ Length of the audio frames. Used for attention mask in WhisperEncoder. - Shape: `(batch_size)` + Shape: `(batch_size, num_chunks)` """ token_len: NestedTensors """ Length of the audio tokens. Used for flattening the audio features. - Shape: `(batch_size)` + Shape: `(batch_size, num_chunks)` """ @@ -86,9 +86,10 @@ def get_hf_processor( # placeholder that will cause confusion with the actual end of turn # token, thus we override placeholder with a reserved special # token. - hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN - hf_processor.audio_replacement = _AUDIO_PLACEHOLDER_OVERRIDE - # TODO: the old code currently does not work with the new processor + hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE + vocab = hf_processor.tokenizer.get_vocab() + hf_processor.audio_replacement_token_id = vocab[ + _AUDIO_PLACEHOLDER_OVERRIDE] return hf_processor def get_feature_extractor( From bc3ba8c882a3aeeaaea2b364df60490f492c028d Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 24 Feb 2025 23:25:22 -0800 Subject: [PATCH 05/25] increase max mm tokens Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index a4bff36ed7cfd..b52da75c9f96c 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -116,7 +116,7 @@ def get_mm_max_tokens_per_item( max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) - return {"audio": max_audio_tokens} + return {"audio": max_audio_tokens * 20} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] From 618e7527c1772b20d3de40209c83fd0686dad93c Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 24 Feb 2025 23:42:38 -0800 Subject: [PATCH 06/25] increase max mm tokens Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b52da75c9f96c..316b5ebaa086d 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -43,6 +43,7 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 +_MAX_AUDIO_CHUNKS = 20 class UltravoxAudioFeatureInputs(TypedDict): @@ -116,7 +117,7 @@ def get_mm_max_tokens_per_item( max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) - return {"audio": max_audio_tokens * 20} + return {"audio": max_audio_tokens * _MAX_AUDIO_CHUNKS} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] @@ -130,7 +131,8 @@ def get_dummy_processor_inputs( feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate - audio_len = feature_extractor.chunk_length * sampling_rate + audio_len = (feature_extractor.chunk_length * sampling_rate * + _MAX_AUDIO_CHUNKS) num_audios = mm_counts.get("audio", 0) mm_data = { From 0e629457e2bf6a380361bf6f439adfad602e34b0 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 08:12:07 -0800 Subject: [PATCH 07/25] reduce max mm tokens Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 316b5ebaa086d..47bcbdd26a09b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -43,7 +43,7 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 -_MAX_AUDIO_CHUNKS = 20 +_MAX_AUDIO_CHUNKS = 10 class UltravoxAudioFeatureInputs(TypedDict): From 69278e22b66477ce0aef2fe71aa5b10d32dddd15 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 09:17:29 -0800 Subject: [PATCH 08/25] revert increasing max mm tokens Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 47bcbdd26a09b..f4e5899fb1222 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -43,7 +43,7 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 -_MAX_AUDIO_CHUNKS = 10 +_MAX_AUDIO_CHUNKS = 1 class UltravoxAudioFeatureInputs(TypedDict): From 75c138b52bf3c53ff6bd343a266102a93de87bb8 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 15:29:56 -0800 Subject: [PATCH 09/25] fix <|begin_of_text|> not being included Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 2bbd4be2d34e2..11a4b6c6c3612 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -191,11 +191,16 @@ def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], ) -> list[int]: - # HF processor omits bos_token_id by setting add_special_tokens=False tokenizer = self.info.get_tokenizer() assert prompt_tokens[0] == tokenizer.bos_token_id - return prompt_tokens[1:] + # temporary fix: when running with api_server, the bos_token_id is not + # added to the prompt_tokens, but when using llm.generate(), it is. + # This is a hack to make the output consistent between the two cases. + # TODO: find the root cause and fix it. + if prompt_tokens[0] == prompt_tokens[1]: + return prompt_tokens[1:] + return prompt_tokens def _get_mm_fields_config( self, From 3b0e237d284f7a0f938d4d51ac7c2a6d914c907a Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 16:08:57 -0800 Subject: [PATCH 10/25] batching for whisper to avoid oom Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 30 +++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 11a4b6c6c3612..92e0b7f18aa25 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -43,7 +43,7 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 -_MAX_AUDIO_CHUNKS = 1 +_MAX_ENCODER_BATCH_SIZE = 16 class UltravoxAudioFeatureInputs(TypedDict): @@ -117,7 +117,7 @@ def get_mm_max_tokens_per_item( max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) - return {"audio": max_audio_tokens * _MAX_AUDIO_CHUNKS} + return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] @@ -132,7 +132,7 @@ def get_dummy_processor_inputs( sampling_rate = feature_extractor.sampling_rate audio_len = (feature_extractor.chunk_length * sampling_rate * - _MAX_AUDIO_CHUNKS) + _MAX_ENCODER_BATCH_SIZE) num_audios = mm_counts.get("audio", 0) mm_data = { @@ -482,11 +482,25 @@ def get_mm_mapping(self) -> MultiModelKeys: def _audio_features_to_embeddings( self, input_features: torch.Tensor, - audio_lens: Optional[torch.Tensor]) -> torch.Tensor: - audio_input = input_features.to(self.audio_tower.dtype) - audio_features = self.audio_tower(audio_input, audio_lens) - audio_features = audio_features.to(self.audio_tower.dtype) - audio_embeddings = self.multi_modal_projector(audio_features) + audio_lens: torch.Tensor) -> torch.Tensor: + audio_features = input_features.to(self.audio_tower.dtype) + batch_size = audio_features.size(0) + audio_embeddings = [] + + # Process audio features in batches to keep memory usage predictable + for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE): + end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size) + # Process through audio tower + batch_features = self.audio_tower(audio_features[start:end], + audio_lens[start:end]) + batch_features = batch_features.to(self.audio_tower.dtype) + + # Process through projector + batch_embeddings = self.multi_modal_projector(batch_features) + audio_embeddings.append(batch_embeddings) + + # Concatenate results + audio_embeddings = torch.cat(audio_embeddings, dim=0) return audio_embeddings def _parse_and_validate_audio_input( From 97f6f5ba6d3b45d5c0dbb5ac14b2f203d6b7a085 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 16:12:40 -0800 Subject: [PATCH 11/25] add comment Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 92e0b7f18aa25..94fd38288a2bf 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -89,6 +89,7 @@ def get_hf_processor( # token. hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE vocab = hf_processor.tokenizer.get_vocab() + # Updating both variables for compatibility with older versions hf_processor.audio_replacement_token_id = vocab[ _AUDIO_PLACEHOLDER_OVERRIDE] return hf_processor From bea5a31193b2fb7d3581f131e48e2d4797d89e10 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 16:46:16 -0800 Subject: [PATCH 12/25] use flat_from_sizes for ultravox mm_fields_config Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 94fd38288a2bf..b981210443376 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -208,13 +208,17 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0)) return dict( # to handle longer than 30s audio, each audio might be split # into multiple chunks as such, their batch dimension can be # higher than the number of audio samples - audio_features=MultiModalFieldConfig.batched("audio_chunked"), - audio_token_len=MultiModalFieldConfig.batched("audio_chunked"), - audio_lens=MultiModalFieldConfig.batched("audio_chunked"), + audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + audio_token_len=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), + audio_lens=MultiModalFieldConfig.flat_from_sizes( + "audio", num_chunks), # num_chunks can convert audio_chunked to audio batch dimension audio_num_chunks=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), @@ -428,9 +432,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - # Due to the batching of audio chunks, the preprocessor cache cannot - # do the right thing so disable it. - vllm_config.model_config.disable_mm_preprocessor_cache = True multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config From 28f16ce101f8f7f1e155de02e2423ea4337fe275 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 16:50:05 -0800 Subject: [PATCH 13/25] revert ultravox test model id Signed-off-by: Farzad Abdolhosseini --- tests/models/decoder_only/audio_language/test_ultravox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 50d48838b317f..d1f643a8fdb73 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_3-llama-3_2-1b" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" AudioTuple = Tuple[np.ndarray, int] From 48c359b125688389389fd079b8f350c3b4e9d6a8 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 25 Feb 2025 16:52:46 -0800 Subject: [PATCH 14/25] improve documentation for double bos_id case Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b981210443376..9bc7bb37cd3f8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -195,11 +195,9 @@ def _apply_hf_processor_tokens_only( tokenizer = self.info.get_tokenizer() assert prompt_tokens[0] == tokenizer.bos_token_id - # temporary fix: when running with api_server, the bos_token_id is not - # added to the prompt_tokens, but when using llm.generate(), it is. - # This is a hack to make the output consistent between the two cases. - # TODO: find the root cause and fix it. - if prompt_tokens[0] == prompt_tokens[1]: + # If the prompt was generated with the bos_token_id, we might end up + # with two bos_token_ids in the prompt, so we remove the first one. + if len(prompt_tokens) > 1 and prompt_tokens[0] == prompt_tokens[1]: return prompt_tokens[1:] return prompt_tokens From 4a54ea1585ce54f4553036565255174e31f65231 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 26 Feb 2025 11:02:14 -0800 Subject: [PATCH 15/25] do not use vocab in get_hf_processor Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 07e8220a5918b..bc8908ac27fa3 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -88,10 +88,9 @@ def get_hf_processor( # token, thus we override placeholder with a reserved special # token. hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE - vocab = hf_processor.tokenizer.get_vocab() # Updating both variables for compatibility with older versions - hf_processor.audio_replacement_token_id = vocab[ - _AUDIO_PLACEHOLDER_OVERRIDE] + + hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN return hf_processor def get_feature_extractor( From 347ada89db6bba2d2b42db075f90b701e1fada98 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 26 Feb 2025 15:39:31 -0800 Subject: [PATCH 16/25] revert tests to use v0_5 Signed-off-by: Farzad Abdolhosseini --- tests/models/multimodal/processing/test_common.py | 2 +- tests/models/registry.py | 3 +-- vllm/model_executor/models/ultravox.py | 2 -- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a84999cfbf4fd..e005bb1401568 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -172,7 +172,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_4", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 8614baf18f3b7..767b0810e1443 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -284,8 +284,7 @@ def check_available_online( "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_4", - extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501 + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index bc8908ac27fa3..c305340f684a6 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -88,8 +88,6 @@ def get_hf_processor( # token, thus we override placeholder with a reserved special # token. hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE - # Updating both variables for compatibility with older versions - hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN return hf_processor From b04878e015a1a5ef459a8bb12ee28c157a261e44 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 28 Feb 2025 22:45:31 -0800 Subject: [PATCH 17/25] revert tests to use v0_5 Signed-off-by: Farzad Abdolhosseini --- tests/models/decoder_only/audio_language/test_ultravox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 0ea17247028f5..8a6564291f4b6 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_4" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" AudioTuple = Tuple[np.ndarray, int] From 631487f707e94100740a6fd5133b72ece90811fb Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 28 Feb 2025 22:48:56 -0800 Subject: [PATCH 18/25] adding tests for both ultravox v0.4 and v0.5 Signed-off-by: Farzad Abdolhosseini --- tests/models/multimodal/processing/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index c7a4a4332bcf6..c6f78be29f29d 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -173,6 +173,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", + "fixie-ai/ultravox-v0_4", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", ]) From a9828eae2e2489f23b63035e01ba54435083e1a9 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Sun, 2 Mar 2025 23:24:07 -0800 Subject: [PATCH 19/25] handle audio_num_chunks when no audio is passed Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index cb5890433db49..208d4a08754aa 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -221,11 +221,12 @@ def _get_prompt_updates( # Each audio can be split into multiple chunks. # chunks_start_idx[i] indicates the start index of the chunks # belonging to the i-th audio. - chunks_start_idx: torch.Tensor = torch.cumsum( - out_mm_kwargs["audio_num_chunks"], dim=0, dtype=torch.int32) + num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) + chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, + dim=0, + dtype=torch.int32) chunks_start_idx = torch.cat( [torch.tensor([0], dtype=torch.int32), chunks_start_idx]) - out_mm_kwargs.pop("audio_num_chunks", None) def get_replacement_ultravox(item_idx: int): start = chunks_start_idx[item_idx] From 33a9cf0d2c6c23ca2ae213a2709ea535a982d034 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Sun, 2 Mar 2025 23:24:55 -0800 Subject: [PATCH 20/25] drop test for ultravox v0_4 Signed-off-by: Farzad Abdolhosseini --- tests/models/multimodal/processing/test_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index c6f78be29f29d..c7a4a4332bcf6 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -173,7 +173,6 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_4", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", ]) From 7ca61cf440ec4dde24f7865d3565bcacf94b8629 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 4 Mar 2025 15:34:29 -0800 Subject: [PATCH 21/25] drop matching Ultravox audio_features with cache Signed-off-by: Farzad Abdolhosseini --- .../multimodal/processing/test_common.py | 55 +++++++++++++++++-- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index c7a4a4332bcf6..18ed1741fdc7e 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 +import copy +from collections.abc import Mapping from functools import partial import numpy as np import pytest +import torch from PIL import Image from vllm.config import ModelConfig @@ -21,6 +24,7 @@ def _test_processing_correctness( hit_rate: float, num_batches: int, simplify_rate: float, + ignore_keys: list[str], ): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -123,8 +127,9 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert baseline_result == cached_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_keys(baseline_result, ignore_keys) == _drop_keys( + cached_result, + ignore_keys), (f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") baseline_tokenized_result = baseline_processor.apply( tokenizer.encode(prompt, **tokenizer_encode_kwargs), @@ -132,8 +137,9 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert baseline_result == baseline_tokenized_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_keys(baseline_result, ignore_keys) == _drop_keys( + baseline_tokenized_result, + ignore_keys), (f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") cached_tokenized_result = cached_processor.apply( tokenizer.encode(prompt, **tokenizer_encode_kwargs), @@ -141,8 +147,9 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert cached_result == cached_tokenized_result, ( - f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_keys(cached_result, ignore_keys) == _drop_keys( + cached_tokenized_result, + ignore_keys), (f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") # yapf: disable @@ -186,11 +193,15 @@ def test_processing_correctness( num_batches: int, simplify_rate: float, ): + ignore_keys = [] + if 'ultravox' in model_id: + ignore_keys = ['mm_kwargs.audio_features'] _test_processing_correctness( model_id, hit_rate=hit_rate, num_batches=num_batches, simplify_rate=simplify_rate, + ignore_keys=ignore_keys, ) @@ -218,4 +229,36 @@ def test_processing_correctness_phi3v( hit_rate=hit_rate, num_batches=num_batches, simplify_rate=simplify_rate, + ignore_keys=[], ) + + +def _drop_keys(result: dict, ignore_keys: list[str]) -> bool: + """Drop specified nested keys from the result and convert tensors to lists + for easier comparison. + + Args: + result: Result to drop keys from + ignore_keys: List of tuples containing nested key paths to ignore + e.g. ['mm_kwargs.audio_features'] + """ + result = copy.deepcopy(result) + + for key_path in ignore_keys: + keys = key_path.split('.') + curr = result + for key in keys[:-1]: + curr = curr.get(key, {}) + curr.pop(keys[-1], None) + + return _convert_tensors_to_list(result) + + +def _convert_tensors_to_list(obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().numpy().tolist() + elif isinstance(obj, Mapping): + return {k: _convert_tensors_to_list(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [_convert_tensors_to_list(v) for v in obj] + return obj From 48f7da3dd4943d63eb8cfc4bd62393ce81ccdb53 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 5 Mar 2025 10:50:15 -0800 Subject: [PATCH 22/25] ignore exact match for audio_features in _items_by_modality Signed-off-by: Farzad Abdolhosseini --- .../multimodal/processing/test_common.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 18ed1741fdc7e..2e67fc3431c33 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from collections.abc import Mapping from functools import partial +from typing import Optional import numpy as np import pytest -import torch from PIL import Image from vllm.config import ModelConfig @@ -24,7 +23,7 @@ def _test_processing_correctness( hit_rate: float, num_batches: int, simplify_rate: float, - ignore_keys: list[str], + ignore_mm_keys: Optional[list[str]] = None, ): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -127,9 +126,10 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert _drop_keys(baseline_result, ignore_keys) == _drop_keys( - cached_result, - ignore_keys), (f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_mm_kwargs_keys( + baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( + cached_result, ignore_mm_keys), ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") baseline_tokenized_result = baseline_processor.apply( tokenizer.encode(prompt, **tokenizer_encode_kwargs), @@ -137,9 +137,10 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert _drop_keys(baseline_result, ignore_keys) == _drop_keys( - baseline_tokenized_result, - ignore_keys), (f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_mm_kwargs_keys( + baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( + baseline_tokenized_result, ignore_mm_keys), ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") cached_tokenized_result = cached_processor.apply( tokenizer.encode(prompt, **tokenizer_encode_kwargs), @@ -147,9 +148,10 @@ def _test_processing_correctness( hf_processor_mm_kwargs={}, ) - assert _drop_keys(cached_result, ignore_keys) == _drop_keys( - cached_tokenized_result, - ignore_keys), (f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") + assert _drop_mm_kwargs_keys( + cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys( + cached_tokenized_result, ignore_mm_keys), ( + f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") # yapf: disable @@ -193,15 +195,19 @@ def test_processing_correctness( num_batches: int, simplify_rate: float, ): - ignore_keys = [] + ignore_mm_keys = None if 'ultravox' in model_id: - ignore_keys = ['mm_kwargs.audio_features'] + # In Ultravox, the audio_features can be different depending on padding + # The slight difference should not be a problem though, since + # attention_mask lets us ignore the difference. + ignore_mm_keys = ['audio_features'] + _test_processing_correctness( model_id, hit_rate=hit_rate, num_batches=num_batches, simplify_rate=simplify_rate, - ignore_keys=ignore_keys, + ignore_mm_keys=ignore_mm_keys, ) @@ -229,36 +235,30 @@ def test_processing_correctness_phi3v( hit_rate=hit_rate, num_batches=num_batches, simplify_rate=simplify_rate, - ignore_keys=[], ) -def _drop_keys(result: dict, ignore_keys: list[str]) -> bool: - """Drop specified nested keys from the result and convert tensors to lists - for easier comparison. +def _drop_mm_kwargs_keys(result: dict, + ignore_mm_keys: Optional[list[str]] = None) -> bool: + """Drop specified keys from result['mm_kwargs']. + + This is mainly to avoid doing exact match of audio_features in ultravox. Args: result: Result to drop keys from - ignore_keys: List of tuples containing nested key paths to ignore - e.g. ['mm_kwargs.audio_features'] + ignore_mm_keys: List of keys to ignore, e.g. ['audio_features'] """ - result = copy.deepcopy(result) - - for key_path in ignore_keys: - keys = key_path.split('.') - curr = result - for key in keys[:-1]: - curr = curr.get(key, {}) - curr.pop(keys[-1], None) - - return _convert_tensors_to_list(result) - - -def _convert_tensors_to_list(obj): - if isinstance(obj, torch.Tensor): - return obj.detach().cpu().numpy().tolist() - elif isinstance(obj, Mapping): - return {k: _convert_tensors_to_list(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - return [_convert_tensors_to_list(v) for v in obj] - return obj + if not ignore_mm_keys: + return result + + if 'mm_kwargs' in result: + result = copy.deepcopy(result) + mm_kwargs = result['mm_kwargs'] + for key in ignore_mm_keys: + mm_kwargs.pop(key, None) + for items in mm_kwargs._items_by_modality.values(): + for item in items: + for key in ignore_mm_keys: + item.pop(key, None) + + return result From 2813a47ac5b1690b7f25c11ebd6e56c72a6029bb Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 5 Mar 2025 10:58:42 -0800 Subject: [PATCH 23/25] fix type hint Signed-off-by: Farzad Abdolhosseini --- tests/models/multimodal/processing/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 2e67fc3431c33..219c964618b9a 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -239,7 +239,7 @@ def test_processing_correctness_phi3v( def _drop_mm_kwargs_keys(result: dict, - ignore_mm_keys: Optional[list[str]] = None) -> bool: + ignore_mm_keys: Optional[list[str]] = None) -> dict: """Drop specified keys from result['mm_kwargs']. This is mainly to avoid doing exact match of audio_features in ultravox. From 11ff27f69f547bc36aae495315c2f38a0f1eb41d Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 10 Mar 2025 10:41:44 -0700 Subject: [PATCH 24/25] debug logs for ci Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 2a2508266c044..b7086e331de8f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -527,6 +527,16 @@ def _process_audio_input( audio_features = audio_input["data"] if isinstance(audio_features, list): + if len(audio_features) > 1 and isinstance(audio_features[0], list): + print("This should not happen. This is likely due to a bug.") + print(audio_input['lens']) + print(audio_input['token_len']) + print(audio_features) + audio_features = [torch.stack(x) for x in audio_features] + print("After stack:", [x.shape for x in audio_features]) + audio_features = [ + x.reshape(-1, *x.shape[-2:]) for x in audio_features + ] max_len = max(x.shape[-1] for x in audio_features) # Pad and concatenate: # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] From 2776a31d35244a70d9a6597221837f27fc6bf8b0 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 10 Mar 2025 14:00:05 -0700 Subject: [PATCH 25/25] if all else fails just stack? Signed-off-by: Farzad Abdolhosseini --- vllm/model_executor/models/ultravox.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b7086e331de8f..edfb2aaba0583 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -527,16 +527,13 @@ def _process_audio_input( audio_features = audio_input["data"] if isinstance(audio_features, list): - if len(audio_features) > 1 and isinstance(audio_features[0], list): - print("This should not happen. This is likely due to a bug.") - print(audio_input['lens']) - print(audio_input['token_len']) - print(audio_features) - audio_features = [torch.stack(x) for x in audio_features] - print("After stack:", [x.shape for x in audio_features]) - audio_features = [ - x.reshape(-1, *x.shape[-2:]) for x in audio_features - ] + audio_features = [ + torch.stack(x) if isinstance(x, list) else x + for x in audio_features + ] + audio_features = [ + x.reshape(-1, *x.shape[-2:]) for x in audio_features + ] max_len = max(x.shape[-1] for x in audio_features) # Pad and concatenate: # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]