Skip to content

Commit dfd9301

Browse files
committed
fix: format
1 parent d83915b commit dfd9301

File tree

10 files changed

+52
-46
lines changed

10 files changed

+52
-46
lines changed

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
destroy_model_parallel,
2828
init_distributed_environment,
2929
initialize_model_parallel)
30-
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, EmbedsPrompt,
31-
to_enc_dec_tuple_list, zip_enc_dec_prompts)
30+
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
31+
EmbedsPrompt, to_enc_dec_tuple_list,
32+
zip_enc_dec_prompts)
3233
from vllm.logger import init_logger
3334
from vllm.outputs import RequestOutput
3435
from vllm.sequence import SampleLogprobs

tests/worker/test_model_runner.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
3434
return model_runner
3535

3636

37-
@pytest.mark.parametrize("batch_size, prompt_embeds_ratio",
38-
list(itertools.product(range(1, 257),
39-
(0.0, 0.5, 1.0))))
37+
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
38+
@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0))
4039
def test_prepare_prompt(batch_size, prompt_embeds_ratio):
4140
model_runner = _create_model_runner(
4241
"facebook/opt-125m",
@@ -54,11 +53,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio):
5453
seq_len = i % (model_runner.block_size - 1) + 1
5554
seq_lens.append(seq_len)
5655
if random.random() < prompt_embeds_ratio:
57-
seq_data = SequenceData([], prompt_embeds=torch.rand(seq_len, 10))
56+
seq_data = SequenceData(
57+
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len)),
58+
torch.rand(seq_len, 10))
5859
input_embeds_len += seq_len
59-
else
60-
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
61-
range(seq_len)))
60+
else:
61+
seq_data = SequenceData(
62+
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len)))
6263
seq_group_metadata = SequenceGroupMetadata(
6364
request_id=f"test_{i}",
6465
is_prompt=True,
@@ -163,7 +164,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio):
163164
torch.testing.assert_close(actual, expected)
164165

165166

166-
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
167+
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
167168
@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0))
168169
def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio):
169170
model_runner = _create_model_runner(
@@ -185,8 +186,8 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio):
185186
context_len = i % (model_runner.block_size - 1) + 1
186187
context_lens.append(context_len)
187188
if random.random() < prompt_embeds_ratio:
188-
seq_data = SequenceData([],
189-
prompt_embeds=torch.rand(context_len, 10))
189+
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)),
190+
torch.rand(context_len, 10))
190191
input_embeds_len += context_len
191192
else:
192193
seq_data = SequenceData(
@@ -337,7 +338,7 @@ def distributed_init():
337338
ensure_model_parallel_initialized(1, 1)
338339

339340

340-
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
341+
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
341342
@pytest.mark.parametrize("enforce_eager", [True, False])
342343
@pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0])
343344
def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
@@ -366,11 +367,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
366367
seq_len = i % (model_runner.block_size - 1) + 1
367368
seq_lens.append(seq_len)
368369
if random.random() < prompt_embeds_ratio:
369-
seq_data = SequenceData([], prompt_embeds=torch.rand(seq_len, 10))
370+
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)),
371+
torch.rand(seq_len, 10))
370372
input_embeds_len += seq_len
371373
else:
372-
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
373-
range(seq_len)))
374+
seq_data = SequenceData(
375+
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len)))
374376
seq_group_metadata = SequenceGroupMetadata(
375377
request_id=f"test_{i}",
376378
is_prompt=True,
@@ -387,8 +389,8 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
387389
# make sure all tokens fit into one block
388390
context_len = i % (model_runner.block_size - 1) + 1
389391
if random.random() < prompt_embeds_ratio:
390-
seq_data = SequenceData([],
391-
prompt_embeds=torch.rand(context_len, 10))
392+
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)),
393+
torch.rand(context_len, 10))
392394
else:
393395
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
394396
seq_data = SequenceData(prompt_toks)

vllm/engine/async_llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ async def _extract_prompt_components_async(
436436
multi_modal_data = None
437437
prompt_embeds = None
438438
elif isinstance(inputs, dict):
439-
prompt = inputs.get("prompt")
440439
prompt_embeds = inputs.get("prompt_embeds")
441440
driver_worker = self.model_executor.driver_worker
442441
if prompt_embeds is not None:
@@ -450,6 +449,7 @@ async def _extract_prompt_components_async(
450449
raise ValueError(
451450
f"Model {self.model_config.model} does not support input "
452451
"embeddings, but prompt_embeds was provided.")
452+
prompt = None
453453
prompt_token_ids = []
454454
elif "prompt_token_ids" in inputs:
455455
prompt = None

vllm/engine/llm_engine.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
7676
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
7777
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
7878

79-
PromptComponents = Tuple[Optional[str], List[int],
79+
PromptComponents = Tuple[Optional[str], List[int], Optional[torch.Tensor],
8080
Optional[MultiModalDataDict]]
8181
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
8282
Optional[MultiModalDataDict]]
@@ -808,7 +808,6 @@ def _extract_prompt_components(
808808
multi_modal_data = None
809809
prompt_embeds = None
810810
elif isinstance(inputs, dict):
811-
prompt = inputs.get("prompt")
812811
prompt_embeds = inputs.get("prompt_embeds")
813812
driver_worker = self.model_executor.driver_worker
814813
if prompt_embeds is not None:
@@ -822,6 +821,7 @@ def _extract_prompt_components(
822821
raise ValueError(
823822
f"Model {self.model_config.model} does not support input "
824823
"embeddings, but prompt_embeds was provided.")
824+
prompt = None
825825
prompt_token_ids = []
826826
elif "prompt_token_ids" in inputs:
827827
prompt = None
@@ -894,7 +894,7 @@ def _build_enc_dec_llm_inputs(
894894
encoder_comps: PromptComponents,
895895
decoder_comps: DecoderPromptComponents,
896896
) -> EncoderDecoderLLMInputs:
897-
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
897+
encoder_prompt, encoder_prompt_ids, _, encoder_mm_data = encoder_comps
898898
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
899899

900900
if encoder_mm_data is not None or decoder_mm_data is not None:
@@ -961,10 +961,11 @@ def _process_encoder_decoder_prompt(
961961
if (decoder_input := inputs["decoder_prompt"]) is None:
962962
decoder_comps = None, None, None
963963
else:
964-
decoder_comps = self._extract_prompt_components(
964+
prompt, prompt_token_ids, _, multi_modal_data = self._extract_prompt_components(
965965
decoder_input,
966966
request_id=request_id,
967967
)
968+
decoder_comps = prompt, prompt_token_ids, multi_modal_data
968969
else:
969970
encoder_comps = self._extract_prompt_components(
970971
inputs,
@@ -2015,7 +2016,8 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs,
20152016
prompt_ids = inputs.get("prompt_token_ids")
20162017
prompt_embeds = inputs.get("prompt_embeds")
20172018

2018-
if (prompt_ids is None or len(prompt_ids) == 0) and prompt_embeds is None:
2019+
if (prompt_ids is None
2020+
or len(prompt_ids) == 0) and prompt_embeds is None:
20192021
raise ValueError("Prompt cannot be empty")
20202022

20212023
if self.model_config.is_multimodal_model:

vllm/inputs/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from .data import (EmbedsPrompt, EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
2-
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
3-
TokensPrompt, build_explicit_enc_dec_prompt,
4-
to_enc_dec_tuple_list, zip_enc_dec_prompts)
1+
from .data import (EmbedsPrompt, EncoderDecoderLLMInputs,
2+
ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs,
3+
SingletonPromptInputs, TextPrompt, TokensPrompt,
4+
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
5+
zip_enc_dec_prompts)
56
from .registry import InputContext, InputRegistry
67

78
INPUT_REGISTRY = InputRegistry()

vllm/model_executor/models/fuyu.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ def forward(
284284
):
285285
image_input = self._parse_and_validate_image_input(**kwargs)
286286
inputs_embeds = get_inputs_embeds(
287-
input_ids, self.language_model.model.embed_tokens,
288-
inputs_embeds, inputs_embeds_masks)
287+
input_ids, self.language_model.model.embed_tokens, inputs_embeds,
288+
inputs_embeds_masks)
289289
if image_input is not None:
290290
vision_embeddings = self._process_image_input(image_input)
291291
inputs_embeds = merge_multimodal_embeddings(
@@ -298,8 +298,7 @@ def forward(
298298
kv_caches=kv_caches,
299299
attn_metadata=attn_metadata,
300300
inputs_embeds=inputs_embeds,
301-
inputs_embeds_masks=inputs_embeds_masks
302-
)
301+
inputs_embeds_masks=inputs_embeds_masks)
303302
return hidden_states
304303

305304
def compute_logits(

vllm/model_executor/models/jamba.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,11 @@ def forward(self,
637637
# CUDA graph capturing runs
638638
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
639639

640-
hidden_states = self.model(input_ids, positions, kv_caches,
641-
attn_metadata, mamba_cache[0],
640+
hidden_states = self.model(input_ids,
641+
positions,
642+
kv_caches,
643+
attn_metadata,
644+
mamba_cache[0],
642645
mamba_cache[1],
643646
inputs_embeds=inputs_embeds,
644647
inputs_embeds_masks=inputs_embeds_masks)

vllm/model_executor/models/persimmon.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,12 @@ def forward(
278278
inputs_embeds: Optional[torch.Tensor] = None,
279279
inputs_embeds_masks: Optional[torch.Tensor] = None,
280280
):
281-
hidden_states = self.model(
282-
input_ids=input_ids,
283-
positions=positions,
284-
kv_caches=kv_caches,
285-
attn_metadata=attn_metadata,
286-
inputs_embeds=inputs_embeds,
287-
inputs_embeds_masks=inputs_embeds_masks
288-
)
281+
hidden_states = self.model(input_ids=input_ids,
282+
positions=positions,
283+
kv_caches=kv_caches,
284+
attn_metadata=attn_metadata,
285+
inputs_embeds=inputs_embeds,
286+
inputs_embeds_masks=inputs_embeds_masks)
289287
return hidden_states
290288

291289
def compute_logits(

vllm/model_executor/models/phi3v.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,9 +620,8 @@ def forward(self,
620620
**kwargs: object):
621621
image_input = self._parse_and_validate_image_input(**kwargs)
622622
inputs_embeds = get_inputs_embeds(input_ids,
623-
self.model.get_input_embeddings,
624-
inputs_embeds,
625-
inputs_embeds_masks)
623+
self.model.get_input_embeddings,
624+
inputs_embeds, inputs_embeds_masks)
626625
if image_input is not None:
627626
vision_embeddings = self._process_image_input(image_input)
628627
inputs_embeds = merge_multimodal_embeddings(

vllm/sequence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ def __init__(
402402
"encoder input prompt fields?")
403403

404404
self.data = SequenceData(
405-
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), self.prompt_embeds)
405+
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids),
406+
self.prompt_embeds)
406407
self.output_logprobs: SampleLogprobs = []
407408
self.output_text = ""
408409

0 commit comments

Comments
 (0)