forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1]
AsyncLLM
Implementation (vllm-project#9826)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: [email protected] <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
- Loading branch information
1 parent
658520f
commit 3ac6aa6
Showing
29 changed files
with
2,412 additions
and
727 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
This file test accuracy of the vLLM server via LMEval. | ||
It uses local-completions, which interacts with vLLM | ||
through the OAI API with N concurrent connections. | ||
This simulates real work usage of the API and makes | ||
sure that the zmq frontend mp RPC message passing and | ||
AsyncLLMEngine are working correctly. | ||
""" | ||
|
||
import lm_eval | ||
import pytest | ||
|
||
from vllm.platforms import current_platform | ||
|
||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" | ||
NUM_CONCURRENT = 500 | ||
TASK = "gsm8k" | ||
FILTER = "exact_match,strict-match" | ||
RTOL = 0.03 | ||
EXPECTED_VALUE = 0.58 | ||
|
||
|
||
def run_test(): | ||
"""Run the end to end accuracy test.""" | ||
|
||
model_args = f"pretrained={MODEL_NAME},max_model_len=2048" | ||
|
||
results = lm_eval.simple_evaluate( | ||
model="vllm", | ||
model_args=model_args, | ||
tasks="gsm8k", | ||
batch_size="auto", | ||
) | ||
|
||
measured_value = results["results"][TASK][FILTER] | ||
assert (measured_value - RTOL < EXPECTED_VALUE | ||
and measured_value + RTOL > EXPECTED_VALUE | ||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" | ||
|
||
|
||
@pytest.mark.skipif(not current_platform.is_cuda(), | ||
reason="V1 is currently only supported on CUDA.") | ||
def test_lm_eval_accuracy_v1_engine(monkeypatch): | ||
"""Run with the V1 Engine.""" | ||
|
||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
run_test() | ||
|
||
|
||
def test_lm_eval_accuracy_v0_engine(monkeypatch): | ||
"""Run with the V0 Engine.""" | ||
|
||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "0") | ||
run_test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import asyncio | ||
from typing import Tuple | ||
|
||
import pytest | ||
|
||
from vllm import SamplingParams | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.platforms import current_platform | ||
from vllm.v1.engine.async_llm import AsyncLLM | ||
|
||
if not current_platform.is_cuda(): | ||
pytest.skip(reason="V1 currently only supported on CUDA.", | ||
allow_module_level=True) | ||
|
||
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B", | ||
disable_log_requests=True) | ||
|
||
|
||
async def generate(engine: AsyncLLM, request_id: str, | ||
max_tokens: int) -> Tuple[int, str]: | ||
count = 0 | ||
async for _ in engine.generate(request_id=request_id, | ||
prompt="Hello my name is Robert and", | ||
sampling_params=SamplingParams( | ||
max_tokens=max_tokens, temperature=0)): | ||
|
||
count += 1 | ||
await asyncio.sleep(0.) | ||
|
||
return count, request_id | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_load(monkeypatch): | ||
with monkeypatch.context() as m: | ||
m.setenv("VLLM_USE_V1", "1") | ||
|
||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS) | ||
|
||
NUM_REQUESTS = 10000 | ||
NUM_EXPECTED_TOKENS = 10 | ||
|
||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] | ||
|
||
# Create concurrent requests. | ||
tasks = [] | ||
for request_id in request_ids: | ||
tasks.append( | ||
asyncio.create_task( | ||
generate(engine, request_id, NUM_EXPECTED_TOKENS))) | ||
|
||
# Confirm that we got all the EXPECTED tokens from the requests. | ||
failed_request_id = None | ||
tokens = None | ||
for task in tasks: | ||
num_generated_tokens, request_id = await task | ||
if (num_generated_tokens != NUM_EXPECTED_TOKENS | ||
and failed_request_id is None): | ||
failed_request_id = request_id | ||
tokens = num_generated_tokens | ||
|
||
assert failed_request_id is None, ( | ||
f"{failed_request_id} generated {tokens} but " | ||
f"expected {NUM_EXPECTED_TOKENS}") | ||
|
||
engine.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
from typing import List | ||
|
||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from vllm.sampling_params import RequestOutputKind | ||
from vllm.v1.engine import EngineCoreOutput | ||
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest | ||
|
||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" | ||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | ||
|
||
FULL_STRINGS = [ | ||
"My name is Robert from Neural Magic and I love working on vLLM so much!", | ||
"Red Hat is the best open source company by far across Linux, K8s, and AI.", | ||
"Nick is the name of my brother in addition to my colleague from Red Hat.", | ||
] | ||
|
||
STOP_STRINGS = ["I love working on", "company by far", "brother in"] | ||
|
||
FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS] | ||
PROMPT_LEN = 5 | ||
PROMPT_TOKENS = [ | ||
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS | ||
] | ||
GENERATION_TOKENS = [ | ||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS | ||
] | ||
PROMPT_STRINGS = [ | ||
tokenizer.decode(prompt_tokens, skip_special_tokens=True) | ||
for prompt_tokens in PROMPT_TOKENS | ||
] | ||
PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS] | ||
GENERATION_STRINGS = [ | ||
text[prompt_len:] | ||
for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN) | ||
] | ||
|
||
|
||
class MockEngineCore: | ||
"""Mock outputs form premade tokens lists.""" | ||
|
||
def __init__(self, tokens_list: List[List[int]]): | ||
self.tokens_list = tokens_list | ||
self.current_idx = 0 | ||
|
||
def get_outputs(self) -> List[EngineCoreOutput]: | ||
token_idx = self.current_idx | ||
self.current_idx += 1 | ||
|
||
outputs = [] | ||
for req_idx, token_ids in enumerate(self.tokens_list): | ||
if len(token_ids) > token_idx: | ||
output = EngineCoreOutput(request_id=f"request-{req_idx}", | ||
new_token_ids=[token_ids[token_idx]], | ||
finished=False) | ||
if token_idx == len(token_ids) - 1: | ||
output.finished = True | ||
output.finish_reason = "stopped" | ||
outputs.append(output) | ||
|
||
return outputs | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"request_output_kind", | ||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) | ||
def test_incremental_detokenization(request_output_kind: RequestOutputKind): | ||
detokenizer = Detokenizer(TOKENIZER_NAME) | ||
engine_core = MockEngineCore(GENERATION_TOKENS) | ||
|
||
# Make N requests. | ||
requests = [ | ||
DetokenizerRequest( | ||
request_id=f"request-{idx}", | ||
prompt=prompt, | ||
prompt_token_ids=prompt_tokens, | ||
skip_special_tokens=False, | ||
spaces_between_special_tokens=False, | ||
output_kind=request_output_kind, | ||
stop=[], | ||
include_stop_str_in_output=False, | ||
) for idx, ( | ||
prompt, | ||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) | ||
] | ||
|
||
# Add requests to the detokenizer. | ||
for request in requests: | ||
detokenizer.add_request(request) | ||
|
||
gen_strings = {} | ||
gen_tokens = {} | ||
while True: | ||
# Mock output from the EngineCore. | ||
outputs = engine_core.get_outputs() | ||
if len(outputs) == 0: | ||
break | ||
|
||
# Step the Detokenizer. | ||
request_outputs, requests_to_abort = detokenizer.step(outputs) | ||
assert len(requests_to_abort) == 0 | ||
|
||
# Update tracking. | ||
for request_output in request_outputs: | ||
request_id = request_output.request_id | ||
new_text = request_output.outputs[0].text | ||
new_tokens = request_output.outputs[0].token_ids | ||
if request_id not in gen_strings: | ||
gen_strings[request_id] = new_text | ||
gen_tokens[request_id] = new_tokens | ||
else: | ||
gen_strings[request_id] += new_text | ||
gen_tokens[request_id].extend(new_tokens) | ||
|
||
# Confirmed tracked values matches what we expected. | ||
for idx, (ref_gen_str, ref_gen_toks) in enumerate( | ||
zip(GENERATION_STRINGS, GENERATION_TOKENS)): | ||
gen_str = gen_strings[f"request-{idx}"] | ||
gen_toks = gen_tokens[f"request-{idx}"] | ||
|
||
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}" | ||
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}" | ||
|
||
assert detokenizer.get_num_unfinished_requests() == 0 | ||
assert not detokenizer.has_unfinished_requests() | ||
|
||
|
||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False]) | ||
def test_stop_string(include_stop_str_in_output: bool): | ||
detokenizer = Detokenizer(TOKENIZER_NAME) | ||
engine_core = MockEngineCore(GENERATION_TOKENS) | ||
|
||
# Make N requests. | ||
requests = [ | ||
DetokenizerRequest( | ||
request_id=f"request-{idx}", | ||
prompt=prompt, | ||
prompt_token_ids=prompt_tokens, | ||
skip_special_tokens=False, | ||
spaces_between_special_tokens=False, | ||
output_kind=RequestOutputKind.DELTA, | ||
stop=STOP_STRINGS, | ||
include_stop_str_in_output=include_stop_str_in_output, | ||
) for idx, ( | ||
prompt, | ||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) | ||
] | ||
|
||
# Add requests to the detokenizer. | ||
for request in requests: | ||
detokenizer.add_request(request) | ||
|
||
gen_strings = {} | ||
aborted = [] | ||
while True: | ||
# Mock output from the EngineCore. | ||
outputs = engine_core.get_outputs() | ||
if len(outputs) == 0: | ||
break | ||
|
||
# Step the Detokenizer. | ||
request_outputs, requests_to_abort = detokenizer.step(outputs) | ||
for request_output in request_outputs: | ||
# If aborted, we should not get a request output. | ||
assert request_output.request_id not in aborted | ||
aborted.extend(requests_to_abort) | ||
|
||
# Update tracking. | ||
for request_output in request_outputs: | ||
if request_output.finished: | ||
assert request_output.outputs[0].finish_reason == "stop" | ||
|
||
request_id = request_output.request_id | ||
new_text = request_output.outputs[0].text | ||
if request_id not in gen_strings: | ||
gen_strings[request_id] = new_text | ||
else: | ||
gen_strings[request_id] += new_text | ||
|
||
# Confirmed tracked values matches what we expected. | ||
for idx, (ref_gen_str, | ||
stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)): | ||
|
||
# Request should be aborted. | ||
request_id = f"request-{idx}" | ||
assert request_id in aborted | ||
|
||
# Collected values that were generated. | ||
gen_str = gen_strings[request_id] | ||
|
||
# Construct reference strings. | ||
stop_str_idx = ref_gen_str.find(stop_str) | ||
ref_str_exc_stop = ref_gen_str[:stop_str_idx] | ||
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str | ||
|
||
if include_stop_str_in_output: | ||
assert gen_str == ref_str_inc_stop, ( | ||
f"{gen_str=}, {ref_str_inc_stop=}") | ||
else: | ||
assert gen_str == ref_str_exc_stop, ( | ||
f"{gen_str=}, {ref_str_exc_stop=}") | ||
|
||
assert detokenizer.get_num_unfinished_requests() == 0 | ||
assert not detokenizer.has_unfinished_requests() |
Oops, something went wrong.