Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Ray Backend V1 #10006

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from ..utils import multi_gpu_test

MODELS = [
"google/gemma-2-2b-it",
"meta-llama/Llama-3.2-1B",
"facebook/opt-125m",
# "google/gemma-2-2b-it",
# "meta-llama/Llama-3.2-1B",
]

TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
Expand All @@ -37,10 +38,11 @@ def test_vllm_gc_ed():


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
# @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("backend", ["FLASH_ATTN_VLLM_V1"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("enforce_eager", [True])
def test_models(
hf_runner,
model: str,
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def step(self) -> List[RequestOutput]:

return request_outputs

# TODO(rob): Can we get rid of these?

def get_model_config(self):
pass

Expand Down
89 changes: 89 additions & 0 deletions vllm/v1/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio

Check failure on line 1 in vllm/v1/executor/distributed_gpu_executor.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/v1/executor/distributed_gpu_executor.py:1:8: F401 `asyncio` imported but unused
from abc import abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union

Check failure on line 3 in vllm/v1/executor/distributed_gpu_executor.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/v1/executor/distributed_gpu_executor.py:3:25: F401 `typing.Awaitable` imported but unused

Check failure on line 3 in vllm/v1/executor/distributed_gpu_executor.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/v1/executor/distributed_gpu_executor.py:3:36: F401 `typing.Dict` imported but unused

Check failure on line 3 in vllm/v1/executor/distributed_gpu_executor.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/v1/executor/distributed_gpu_executor.py:3:42: F401 `typing.List` imported but unused

Check failure on line 3 in vllm/v1/executor/distributed_gpu_executor.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/v1/executor/distributed_gpu_executor.py:3:58: F401 `typing.Set` imported but unused

Check failure on line 3 in vllm/v1/executor/distributed_gpu_executor.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/v1/executor/distributed_gpu_executor.py:3:70: F401 `typing.Union` imported but unused

from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.core.scheduler import SchedulerOutput

logger = init_logger(__name__)


class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.

This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.

Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )

# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
return num_gpu_blocks, 0

def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self._run_workers("initialize_cache", num_gpu_blocks)
self._run_workers("compile_or_warm_up_model")

@abstractmethod
def execute_model(
self,
scheduler_output: SchedulerOutput,
) -> ModelRunnerOutput:
raise NotImplementedError

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)

@abstractmethod
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.

Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
raise NotImplementedError

@abstractmethod
def check_health(self) -> None:
# SANG-TODO
raise NotImplementedError

25 changes: 24 additions & 1 deletion vllm/v1/executor/gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable, Type, Dict, Any

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
from vllm.worker.worker_base import WorkerBase

logger = init_logger(__name__)

Expand Down Expand Up @@ -75,3 +76,25 @@ def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return

def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_module_name = "vllm.v1.worker.gpu_worker"
worker_class_name = "Worker"
return worker_module_name, worker_class_name

def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
Loading
Loading