Skip to content

Commit

Permalink
[v1] EngineArgs for better config handling for v1 (vllm-project#10382)
Browse files Browse the repository at this point in the history
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx authored and weilong.yu committed Dec 13, 2024
1 parent 1f348a4 commit 2eecda0
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ steps:
- vllm/
- tests/v1
commands:
- pytest -v -s v1
- VLLM_USE_V1=1 pytest -v -s v1

- label: Examples Test # 15min
working_dir: "/vllm-workspace/examples"
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str,

@pytest.mark.asyncio
async def test_load(monkeypatch):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

Expand Down
42 changes: 42 additions & 0 deletions tests/v1/engine/test_engine_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from vllm import envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext

if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)


def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")

# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"


def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
UsageContext.LLM_CLASS)

assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 8192

engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048


def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")

vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching
3 changes: 2 additions & 1 deletion tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_engine_core(monkeypatch):
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
Expand Down
6 changes: 4 additions & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
Expand Down Expand Up @@ -153,7 +154,8 @@ async def test_engine_core_client_asyncio(monkeypatch):
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
Expand Down
53 changes: 50 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean

if TYPE_CHECKING:
Expand Down Expand Up @@ -113,7 +114,7 @@ class EngineArgs:
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
block_size: int = 16 if not current_platform.is_hpu() else 128
enable_prefix_caching: bool = False
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
Expand Down Expand Up @@ -197,6 +198,11 @@ def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model

# Override the default value of enable_prefix_caching if it's not set
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)

# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# CompilationConfig object
Expand Down Expand Up @@ -953,7 +959,12 @@ def create_load_config(self) -> LoadConfig:
ignore_patterns=self.ignore_patterns,
)

def create_engine_config(self) -> VllmConfig:
def create_engine_config(self,
usage_context: Optional[UsageContext] = None
) -> VllmConfig:
if envs.VLLM_USE_V1:
self._override_v1_engine_args(usage_context)

# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
Expand Down Expand Up @@ -1170,7 +1181,7 @@ def create_engine_config(self) -> VllmConfig:
or "all" in detailed_trace_modules,
)

return VllmConfig(
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
Expand All @@ -1185,6 +1196,42 @@ def create_engine_config(self) -> VllmConfig:
compilation_config=self.compilation_config,
)

if envs.VLLM_USE_V1:
self._override_v1_engine_config(config)
return config

def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
"""
Override the EngineArgs's args based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"

if self.max_num_batched_tokens is None:
# When no user override, set the default values based on the
# usage context.
if usage_context == UsageContext.LLM_CLASS:
logger.warning("Setting max_num_batched_tokens to 8192 "
"for LLM_CLASS usage context.")
self.max_num_seqs = 1024
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
logger.warning("Setting max_num_batched_tokens to 2048 "
"for OPENAI_API_SERVER usage context.")
self.max_num_seqs = 1024
self.max_num_batched_tokens = 2048

def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
"""
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model:
logger.warning(
"Prefix caching is currently not supported for multimodal "
"models and has been disabled.")
engine_config.cache_config.enable_prefix_caching = False


@dataclass
class AsyncEngineArgs(EngineArgs):
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def from_engine_args(
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
if engine_config is None:
engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)

executor_class = cls._get_executor_cls(engine_config)

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def from_engine_args(
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,
from vllm.plugins import load_general_plugins
load_general_plugins()

engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(usage_context)
executor_class = LLMEngine._get_executor_cls(engine_config)

use_async_sockets = engine_config.model_config.use_async_output_proc
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ async def build_async_engine_client_from_engine_args(
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):

engine_config = engine_args.create_engine_config()
engine_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def from_engine_args(

# Create the engine configs.
if engine_config is None:
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(usage_context)
else:
vllm_config = engine_config

Expand Down
13 changes: 0 additions & 13 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,6 @@ def __init__(
executor_class: Type[GPUExecutor],
usage_context: UsageContext,
):
# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
vllm_config.scheduler_config.max_num_seqs = 1024
vllm_config.scheduler_config.max_num_batched_tokens = 2048

# TODO (ywang96): Enable APC by default when VLM supports it.
if not vllm_config.model_config.is_multimodal_model:
vllm_config.cache_config.enable_prefix_caching = True

assert vllm_config.model_config.task != "embedding"

logger.info("Initializing an LLM engine (v%s) with config: %s",
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_engine_args(
"""Creates an LLM engine from the engine arguments."""

# Create the engine configs.
vllm_config = engine_args.create_engine_config()
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config)

if VLLM_ENABLE_V1_MULTIPROCESSING:
Expand Down

0 comments on commit 2eecda0

Please sign in to comment.