Skip to content

Commit

Permalink
Merge branch 'main' into v1_logprobs_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
abf149 committed Nov 4, 2024
2 parents 37a76c3 + 9a5664d commit 9139191
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 432 deletions.
11 changes: 6 additions & 5 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ fi
PARALLEL_JOB_COUNT=8
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
if [[ $commands == *"--shard-id="* ]]; then
# assign job count as the number of shards used
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
#replace shard arguments
commands=${commands//"--shard-id= "/"--shard-id=${GPU} "}
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
echo "Shard ${GPU} commands:$commands"
# assign shard-id for each shard
commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
echo "Shard ${GPU} commands:$commands_gpu"
docker run \
--device /dev/kfd --device /dev/dri \
--network host \
Expand All @@ -123,7 +124,7 @@ if [[ $commands == *"--shard-id="* ]]; then
-e HF_HOME=${HF_MOUNT} \
--name ${container_name}_${GPU} \
${image_name} \
/bin/bash -c "${commands}" \
/bin/bash -c "${commands_gpu}" \
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
PIDS+=($!)
done
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
81 changes: 55 additions & 26 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import random
import time
from typing import List, Optional, Tuple
from typing import List, Optional

import torch
import uvloop
Expand All @@ -15,16 +15,35 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
) -> List[SampleRequest]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -41,7 +60,7 @@ def sample_requests(
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
Expand All @@ -60,31 +79,34 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))

return filtered_dataset


def run_vllm(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

use_beam_search = False
Expand All @@ -94,11 +116,11 @@ def run_vllm(
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
Expand All @@ -112,7 +134,7 @@ def run_vllm(


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
Expand All @@ -123,17 +145,17 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

generators = []
Expand All @@ -149,7 +171,7 @@ async def run_vllm_async(


def run_hf(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
Expand Down Expand Up @@ -207,14 +229,14 @@ def run_hf(


def run_mii(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]

start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
Expand Down Expand Up @@ -243,8 +265,12 @@ def main(args: argparse.Namespace):
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
requests = [
SampleRequest(prompt=prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len)
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
Expand All @@ -270,9 +296,10 @@ def main(args: argparse.Namespace):
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand All @@ -299,7 +326,9 @@ def main(args: argparse.Namespace):
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len",
type=int,
default=None,
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ To consume the server, you can use the OpenAI client like in the example below:
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.

.. tip::
Loading from local file paths is also supported on vLLM: You can specify the allowed local media path via ``--allowed-local-media-path`` when launching the API server/engine,
and pass the file path as ``url`` in the API request.

.. tip::
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
In fact, you can place image placeholders in the middle of the text by interleaving text and image content.
Expand Down
17 changes: 15 additions & 2 deletions tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import lm_eval
import pytest

from vllm.platforms import current_platform

from ...utils import RemoteOpenAIServer

MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
Expand All @@ -18,12 +20,21 @@
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked
["--num-scheduler-steps", "8"], # MS
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
]
MAX_WAIT_SECONDS = None

if current_platform.is_tpu():
MORE_ARGS_LIST = [
[], # Default
# ["--num-scheduler-steps", "8"], # Multi-step << currently fails
]
MAX_WAIT_SECONDS = 600


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand All @@ -33,7 +44,9 @@ def test_lm_eval_accuracy(more_args):

print(f"Running with: {args}")

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
with RemoteOpenAIServer(
MODEL_NAME, args,
max_wait_seconds=MAX_WAIT_SECONDS) as remote_server:
url = f"{remote_server.url_for('v1')}/completions"

model_args = (
Expand Down
7 changes: 6 additions & 1 deletion tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List

import pytest

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

Expand Down Expand Up @@ -53,6 +56,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
Expand All @@ -63,7 +69,6 @@ def test_minicpmv_lora(minicpmv_lora_files):
trust_remote_code=True,
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
)

output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
Expand Down
12 changes: 10 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,

# Stream for each individual request.
self.output_queues: Dict[str, asyncio.Queue] = {}
self.output_loop = asyncio.create_task(self.run_output_handler_loop())

# Loop to handle output of the LLMEngine periodically.
# Started after the MQLLMEngine is ready so that we can
# build the Client in an executor to enable clean shutdown.
self.output_loop: Optional[asyncio.Task] = None

# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
Expand Down Expand Up @@ -247,6 +251,9 @@ async def run_output_handler_loop(self):
async def setup(self):
"""Setup the client before it starts sending server requests."""

# Start output_loop
self.output_loop = asyncio.create_task(self.run_output_handler_loop())

with self.get_data_socket() as socket:
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
Expand All @@ -265,7 +272,8 @@ def close(self):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
self.output_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()

def _set_errored(self, e: BaseException):
logger.exception(repr(e))
Expand Down
24 changes: 15 additions & 9 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,22 @@ def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile")


def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
ipc_path: str, engine_alive):
try:
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)

def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated")
signal.signal(signal.SIGTERM, signal_handler)

signal.signal(signal.SIGTERM, signal_handler)
engine.start()

engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e
Loading

0 comments on commit 9139191

Please sign in to comment.