@@ -34,9 +34,8 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
34
34
return model_runner
35
35
36
36
37
- @pytest .mark .parametrize ("batch_size, prompt_embeds_ratio" ,
38
- list (itertools .product (range (1 , 257 ),
39
- (0.0 , 0.5 , 1.0 ))))
37
+ @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 , 3 )))
38
+ @pytest .mark .parametrize ("prompt_embeds_ratio" , (0.0 , 0.5 , 1.0 ))
40
39
def test_prepare_prompt (batch_size , prompt_embeds_ratio ):
41
40
model_runner = _create_model_runner (
42
41
"facebook/opt-125m" ,
@@ -54,11 +53,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio):
54
53
seq_len = i % (model_runner .block_size - 1 ) + 1
55
54
seq_lens .append (seq_len )
56
55
if random .random () < prompt_embeds_ratio :
57
- seq_data = SequenceData ([], prompt_embeds = torch .rand (seq_len , 10 ))
56
+ seq_data = SequenceData (
57
+ array (VLLM_TOKEN_ID_ARRAY_TYPE , range (seq_len )),
58
+ torch .rand (seq_len , 10 ))
58
59
input_embeds_len += seq_len
59
- else
60
- seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE ,
61
- range (seq_len )))
60
+ else :
61
+ seq_data = SequenceData (
62
+ array ( VLLM_TOKEN_ID_ARRAY_TYPE , range (seq_len )))
62
63
seq_group_metadata = SequenceGroupMetadata (
63
64
request_id = f"test_{ i } " ,
64
65
is_prompt = True ,
@@ -163,7 +164,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio):
163
164
torch .testing .assert_close (actual , expected )
164
165
165
166
166
- @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 )))
167
+ @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 , 3 )))
167
168
@pytest .mark .parametrize ("prompt_embeds_ratio" , (0.0 , 0.5 , 1.0 ))
168
169
def test_prepare_decode_cuda_graph (batch_size , prompt_embeds_ratio ):
169
170
model_runner = _create_model_runner (
@@ -185,8 +186,8 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio):
185
186
context_len = i % (model_runner .block_size - 1 ) + 1
186
187
context_lens .append (context_len )
187
188
if random .random () < prompt_embeds_ratio :
188
- seq_data = SequenceData ([] ,
189
- prompt_embeds = torch .rand (context_len , 10 ))
189
+ seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE , range ( 0 )) ,
190
+ torch .rand (context_len , 10 ))
190
191
input_embeds_len += context_len
191
192
else :
192
193
seq_data = SequenceData (
@@ -337,7 +338,7 @@ def distributed_init():
337
338
ensure_model_parallel_initialized (1 , 1 )
338
339
339
340
340
- @pytest .mark .parametrize ("batch_size" , list (range (2 , 128 )))
341
+ @pytest .mark .parametrize ("batch_size" , list (range (2 , 128 , 3 )))
341
342
@pytest .mark .parametrize ("enforce_eager" , [True , False ])
342
343
@pytest .mark .parametrize ('prompt_embeds_ratio' , [0.0 , 0.5 , 1.0 ])
343
344
def test_hybrid_batches (batch_size , enforce_eager , prompt_embeds_ratio ,
@@ -366,11 +367,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
366
367
seq_len = i % (model_runner .block_size - 1 ) + 1
367
368
seq_lens .append (seq_len )
368
369
if random .random () < prompt_embeds_ratio :
369
- seq_data = SequenceData ([], prompt_embeds = torch .rand (seq_len , 10 ))
370
+ seq_data = SequenceData (array (VLLM_TOKEN_ID_ARRAY_TYPE , range (0 )),
371
+ torch .rand (seq_len , 10 ))
370
372
input_embeds_len += seq_len
371
373
else :
372
- seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE ,
373
- range (seq_len )))
374
+ seq_data = SequenceData (
375
+ array ( VLLM_TOKEN_ID_ARRAY_TYPE , range (seq_len )))
374
376
seq_group_metadata = SequenceGroupMetadata (
375
377
request_id = f"test_{ i } " ,
376
378
is_prompt = True ,
@@ -387,8 +389,8 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
387
389
# make sure all tokens fit into one block
388
390
context_len = i % (model_runner .block_size - 1 ) + 1
389
391
if random .random () < prompt_embeds_ratio :
390
- seq_data = SequenceData ([] ,
391
- prompt_embeds = torch .rand (context_len , 10 ))
392
+ seq_data = SequenceData (array ( VLLM_TOKEN_ID_ARRAY_TYPE , range ( 0 )) ,
393
+ torch .rand (context_len , 10 ))
392
394
else :
393
395
prompt_toks = array (VLLM_TOKEN_ID_ARRAY_TYPE , range (context_len ))
394
396
seq_data = SequenceData (prompt_toks )
0 commit comments