Skip to content
2 changes: 1 addition & 1 deletion tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_models_distributed(
# Import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
if VLLM_USE_V1 and distributed_executor_backend != "mp":
pytest.skip(f"Skip {distributed_executor_backend} for V1")
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why disable this? I thought this should be able to work with the Ray executor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not work with ray executor. But I think it makes sense this option is for MP only as Ray is a different executor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should have anything to do with the executor that's used. #9826 is the PR that introduced it and has a nice diagram of what's going on.

What happens if we try to run the Ray executor with the AsyncLLM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using VLLM_ENABLE_V1_MULTIPROCESSING, it actually caused ray worker initialization to hang, because of an uninitialized ray environment in the new process.

Copy link
Member

@tlrmchlsmth tlrmchlsmth Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix this, since AsyncLLM allows detokenization to be overlapped with the forward pass, and is the default in V1 now.

We probably want to fix this by initializing the ray environment in the entry point for process P1 that owns the executor:

vllm/vllm/v1/engine/core.py

Lines 239 to 276 in 48edab8

@staticmethod
def run_engine_core(*args, **kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core = None
try:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore interrupted.")
except BaseException as e:
logger.exception(e)
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
engine_core = None

I don't think this should be a blocker for this PR but could you look into fixing this in a follow-up soon?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, sounds good. I will follow up soon.


dtype = "half"
max_tokens = 5
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import initialize_ray_cluster

logger = init_logger(__name__)

Expand Down Expand Up @@ -110,7 +111,11 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
if distributed_executor_backend == "ray":
initialize_ray_cluster(vllm_config.parallel_config)
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
Expand Down
339 changes: 339 additions & 0 deletions vllm/v1/executor/ray_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import RayWorkerWrapper, ray
from vllm.v1.outputs import ModelRunnerOutput

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


class RayExecutor(Executor):

def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.forward_dag: Optional[ray.dag.CompiledDAG] = None

# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

placement_group = self.parallel_config.placement_group
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# A list of workers to run a model.
self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)

# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
# Skip bundles that don't have GPUs,
# as each worker needs one GPU.
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
self.workers.append(worker)

logger.debug("workers: %s", self.workers)
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1

worker_to_ip = dict(zip(self.workers, worker_ips))

def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first. This is simply a tiebreaker to make
sure the workers are sorted in a deterministic way.
"""
ip = worker_to_ip[worker]
return (ip != driver_ip, ip_counts[ip], ip)

# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)

# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")

node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids

for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)

for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

all_ips = set(worker_ips)
n_ips = len(all_ips)
n_nodes = len(node_workers)

if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
"VLLM_USE_V1":
str(int(envs.VLLM_USE_V1)),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
}, ) for (node_id, _) in worker_node_and_gpu_ids]

self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)

self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())

if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())

# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("initialize")
self._run_workers("load_model")

def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})

return ray_remote_kwargs

def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers

def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""
Return worker init args for a given rank.
"""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks.

This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.

Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks")

# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)

return num_gpu_blocks, num_cpu_blocks

def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV cache in all workers.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self._run_workers("initialize_cache", num_gpu_blocks)
self._run_workers("compile_or_warm_up_model")

def _run_workers(
self,
method: str,
*args,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> Any:
"""
Runs the given method on all workers. Can be used in the following
ways:

Args:
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 0, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 0, None)

ray_worker_refs = [
worker.execute_method.remote( # type: ignore[attr-defined]
method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
return ray.get(ray_worker_refs)

def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag()
# Only the first worker (with rank 0) returns the execution result.
# Others return None.
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
return output

def profile(self, is_start=True):
raise NotImplementedError

def shutdown(self):
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None

def check_health(self) -> None:
logger.debug("Called check_health.")

def _check_ray_compiled_graph_installation(self):
import pkg_resources
from packaging import version

required_version = version.parse("2.39")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")

import importlib.util
raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
if raycg is None:
raise ValueError("Ray Compiled Graph is not installed. "
"Run `pip install ray[adag]` to install it.")

cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation.")

def _compiled_ray_dag(self):
assert self.parallel_config.use_ray
self._check_ray_compiled_graph_installation()
from ray.dag import InputNode, MultiOutputNode

with InputNode() as input_batches:
outputs = [
worker.execute_model.bind( # type: ignore[attr-defined]
input_batches) for worker in self.workers
]
forward_dag = MultiOutputNode(outputs)

return forward_dag.experimental_compile()

def __del__(self):
self.shutdown()
Loading
Loading