Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Whisper model implementation #11280

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
cfbd164
add model def
sfc-gh-aqiao Dec 11, 2024
ced0141
WIP
sfc-gh-aqiao Dec 12, 2024
248bafb
WIP, passes profile run
sfc-gh-aqiao Dec 12, 2024
6c9ee61
Merge remote-tracking branch 'upstream/main' into whisper
aurickq Dec 12, 2024
7329b2d
WIP
sfc-gh-aqiao Dec 12, 2024
77ad7ed
kinda working with encoder decoder
sfc-gh-aqiao Dec 12, 2024
755086b
add whisper example
sfc-gh-aqiao Dec 12, 2024
b38f5b7
update
sfc-gh-aqiao Dec 12, 2024
ff70bce
cleanup a bit
sfc-gh-aqiao Dec 13, 2024
3fbd067
batching
sfc-gh-aqiao Dec 13, 2024
9032aa1
flash_attn
sfc-gh-aqiao Dec 13, 2024
ce3a87c
WIP (broken)
sfc-gh-aqiao Dec 13, 2024
04a0ef4
WIP
sfc-gh-aqiao Dec 16, 2024
fd4ed14
13rps
sfc-gh-aqiao Dec 16, 2024
26cfede
fuse qkv
sfc-gh-aqiao Dec 16, 2024
34c5830
clean
sfc-gh-aqiao Dec 16, 2024
bf111b2
20 RPS
sfc-gh-aqiao Dec 16, 2024
a21470b
26rps
sfc-gh-aqiao Dec 17, 2024
b457c01
41 rps
sfc-gh-aqiao Dec 17, 2024
d81d217
fix tokenizer
sfc-gh-aqiao Dec 17, 2024
17712a4
fix tp
sfc-gh-aqiao Dec 17, 2024
b573fa9
clean
sfc-gh-aqiao Dec 17, 2024
6d6cbd9
clean
sfc-gh-aqiao Dec 17, 2024
94a867b
udpate
sfc-gh-aqiao Dec 17, 2024
787708a
add test
sfc-gh-aqiao Dec 18, 2024
e943905
some cleanup
sfc-gh-aqiao Dec 18, 2024
606642e
formatting
sfc-gh-aqiao Dec 19, 2024
fe8e245
format
sfc-gh-aqiao Dec 19, 2024
b59fddb
mypy
sfc-gh-aqiao Dec 19, 2024
d66cd42
mypy
sfc-gh-aqiao Dec 19, 2024
6ba1afc
format
sfc-gh-aqiao Dec 19, 2024
26fd92a
fix tests
sfc-gh-aqiao Dec 19, 2024
4566b10
librosa
sfc-gh-aqiao Dec 19, 2024
a21334c
Merge remote-tracking branch 'vllm-project/main' into whisper
sfc-gh-aqiao Dec 19, 2024
1fe41fc
small
sfc-gh-aqiao Dec 19, 2024
1c16ad2
updates
sfc-gh-aqiao Dec 20, 2024
7282280
lint
sfc-gh-aqiao Dec 20, 2024
3442852
add todos
sfc-gh-aqiao Dec 20, 2024
e0cc63e
bugfix
sfc-gh-aqiao Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/offline_inference_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import time

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

dtype = "float"

# Create a Whisper encoder/decoder model instance
llm = LLM(
model="openai/whisper-large-v3",
max_model_len=448,
max_num_seqs=400,
limit_mm_per_prompt={"audio": 1},
kv_cache_dtype="fp8",
)

prompts = [
{
"prompt": "<|startoftranscript|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt": "<|startoftranscript|>",
}
] * 1024

# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)

start = time.time()

# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

duration = time.time() - start

print("Duration:", duration)
print("RPS:", len(prompts) / duration)
Empty file.
107 changes: 107 additions & 0 deletions tests/models/encoder_decoder/audio/test_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling.

Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`.
"""
from typing import Optional

import pytest

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

from ....utils import fork_new_process_for_each_test, multi_gpu_test

PROMPTS = [
{
"prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
}
]

EXPECTED = {
"openai/whisper-medium": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its fleece was white as"
" snow, and everywhere that Mary went the lamb would shun it all.",
" And the old one pitch on the way to Edgar Martinez swung on the line"
" down the left field line for Obeysmith. Here comes Joy. Here is"
" Jorgen at third base. They're gonna wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh"
" my."
],
"openai/whisper-large-v3": [
" The first words I spoke in the original phonograph. A little piece"
" of practical poetry. Mary had a little lamb, its fleece was white as"
" snow, and everywhere that Mary went, the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line,"
" down the left field line for a base hit. Here comes Joy. Here is"
" Junior to third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh,"
" my."
]
}


def run_test(
model: str,
*,
enforce_eager: bool,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
prompt_list = PROMPTS * 10
expected_list = EXPECTED[model] * 10

llm = LLM(
model=model,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
)

sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)

outputs = llm.generate(prompt_list, sampling_params)

for output, expected in zip(outputs, expected_list):
print(output.outputs[0].text)
assert output.outputs[0].text == expected


@fork_new_process_for_each_test
@pytest.mark.parametrize("model",
["openai/whisper-medium", "openai/whisper-large-v3"])
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_models(model, enforce_eager) -> None:
run_test(model, enforce_eager=enforce_eager, tensor_parallel_size=1)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", ["openai/whisper-large-v3"])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
def test_models_distributed(model, enforce_eager,
distributed_executor_backend) -> None:
run_test(model,
enforce_eager=enforce_eager,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend)
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class _HfExamplesInfo:
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,8 @@ def _get_and_verify_max_len(
"seq_length",
# Command-R
"model_max_length",
# Whisper
"max_target_positions",
# Others
"max_sequence_length",
"max_seq_length",
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ def _preempt_by_recompute(
seq.status = SequenceStatus.WAITING
self.free_seq(seq)
seq.reset_state_for_recompute()
self._free_seq_group_cross_attn_blocks(seq_group)

def _preempt_by_swap(
self,
Expand Down
36 changes: 28 additions & 8 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,16 @@ def _tokenize_prompt(
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()

add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
lora_request=lora_request,
add_special_tokens=add_special_tokens)

async def _tokenize_prompt_async(
self,
Expand All @@ -197,10 +203,17 @@ async def _tokenize_prompt_async(
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()

return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)

def _can_process_multimodal(self) -> bool:
model_config = self.model_config
Expand Down Expand Up @@ -439,8 +452,15 @@ def _build_enc_dec_llm_inputs(
assert_never(encoder_inputs)

if decoder_inputs is None:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder.
# If no explicit encoder/decoder inputs, then copy the prompt
# from the encoder to the decoder. The encoder tokens are later
# overridden by the audio features.
dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
else:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
Comment on lines +455 to +463
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to determine this without model type information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about generalizing this from a single example. In the long term it may be better to allow the model definition to specify exactly the mapping between input fields and where they go (e.g. encoder/decoder)

decoder_inputs = token_inputs(dec_token_ids)
elif (decoder_inputs["type"] == "token"
or decoder_inputs["type"] == "multimodal"):
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
aurickq marked this conversation as resolved.
Show resolved Hide resolved
}

_SPECULATIVE_DECODING_MODELS = {
Expand Down
Loading
Loading