Skip to content

Commit

Permalink
SPMD worker
Browse files Browse the repository at this point in the history
Signed-off-by: Stephanie Wang <[email protected]>
  • Loading branch information
stephanie-wang authored and ruisearch42 committed Jul 16, 2024
1 parent 9f4ccec commit 217b862
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 139 deletions.
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from transformers import PreTrainedTokenizer

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig,
Expand Down Expand Up @@ -48,6 +49,8 @@
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5

USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER


def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
Expand Down Expand Up @@ -413,6 +416,9 @@ def from_engine_args(
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
assert not USE_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
else:
from vllm.executor.gpu_executor import GPUExecutor
Expand All @@ -424,6 +430,7 @@ def from_engine_args(
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
)

return engine

def __reduce__(self):
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
Expand Down Expand Up @@ -261,6 +262,13 @@ def get_default_config_root():
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),

# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
# Run vLLM with VLLM_USE_SPMD_WORKER=1 to enable it.
"VLLM_USE_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_SPMD_WORKER", 0)),

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
Expand Down
8 changes: 5 additions & 3 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,18 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks=num_cpu_blocks)

def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)

# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs

def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
Expand Down
164 changes: 100 additions & 64 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
Expand All @@ -23,12 +22,28 @@

logger = init_logger(__name__)

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_SPMD_WORKER=True.
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the "driver worker"
# vs other workers. Also, the rank 0 worker will be executed in a remote Ray
# worker. Currently this requires USE_RAY_COMPILED_DAG=True.
USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER


class RayGPUExecutor(DistributedGPUExecutor):

def _init_executor(self) -> None:
if USE_RAY_COMPILED_DAG:
assert USE_SPMD_WORKER, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1")
if USE_SPMD_WORKER:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")

assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group

Expand All @@ -40,11 +55,7 @@ def _init_executor(self) -> None:
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)

self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None

def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
Expand Down Expand Up @@ -110,21 +121,24 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
trust_remote_code=self.model_config.trust_remote_code,
)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
if USE_SPMD_WORKER:
self.workers.append(worker)

if self.driver_dummy_worker is None:
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)

if not USE_SPMD_WORKER and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
Expand Down Expand Up @@ -240,9 +254,23 @@ def _driver_execute_model(
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert not USE_SPMD_WORKER, (
"driver_worker does not exist for VLLM_USE_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model",
execute_model_req)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not USE_SPMD_WORKER:
return super().execute_model(execute_model_req)

if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

outputs = ray.get(self.forward_dag.execute(execute_model_req))
return outputs

def _run_workers(
self,
method: str,
Expand All @@ -252,7 +280,6 @@ def _run_workers(
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
Expand Down Expand Up @@ -280,64 +307,57 @@ def _run_workers(
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)

if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]

if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs

driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not USE_SPMD_WORKER:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
else:
assert self.driver_dummy_worker is not None
driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
]

# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
ray_worker_outputs = ray.get(ray_worker_outputs)

return [driver_worker_output] + ray_worker_outputs
return driver_worker_output + ray_worker_outputs

def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)

def _compiled_ray_dag(self):
def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources

# TODO(swang): Upgrade version.
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
Expand All @@ -355,7 +375,7 @@ def _compiled_ray_dag(self):
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile()
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)


class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
Expand All @@ -364,10 +384,24 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)

async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not USE_SPMD_WORKER:
return await super().execute_model_async(execute_model_req)

if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

outputs = await self.forward_dag.execute_async(execute_model_req)
return await outputs

async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
assert not USE_SPMD_WORKER, (
"driver_worker does not exist for VLLM_USE_SPMD_WORKER=1")
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
Expand Down Expand Up @@ -401,6 +435,8 @@ async def _run_task_with_lock(task, lock, *args, **kwargs):
return results[-1]

async def _start_worker_execution_loop(self):
assert not USE_SPMD_WORKER, (
"worker loop is disabled for VLLM_USE_SPMD_WORKER=1")
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers
Expand Down
20 changes: 10 additions & 10 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pickle
from typing import List, Optional, Tuple

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

Expand Down Expand Up @@ -31,16 +31,16 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids

def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
def execute_model(self, execute_model_req: ExecuteModelRequest):
"""Used only when SPMD worker and compiled DAG are both
enabled."""
## TODO(swang): remove?
#import torch
#if not self.compiled_dag_cuda_device_set:
# torch.cuda.set_device(self.worker.device)
# self.compiled_dag_cuda_device_set = True

output = self.worker.execute_model()
output = pickle.dumps(output)
return output
return self.worker.execute_model(execute_model_req)

ray_import_err = None

Expand Down
Loading

0 comments on commit 217b862

Please sign in to comment.