Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeldaHuang committed Dec 16, 2024
1 parent 6044d19 commit a1bd218
Show file tree
Hide file tree
Showing 19 changed files with 353 additions and 203 deletions.
5 changes: 3 additions & 2 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def pre_alloc(self,
request_id: str,
request_status: RequestStatus,
request_arrival_time: float,
block_num: int) -> List[int]:
block_num: int,
token_ids: List[int]) -> List[int]:
"""Pre-allocates cache blocks for a migrating request.
This method selects a specified number of free cache blocks to be reserved for an incoming
Expand All @@ -211,7 +212,7 @@ def pre_alloc(self,
request_status: The status (waiting/running) of the request.
request_arrival_time: The arrival time of the request.
block_num: The number of cache blocks that need to be pre-allocated for the request.
token_ids: The token IDs of the request.
Returns:
A list of integers where each integer represents the block table reserved for the migration request.
"""
Expand Down
176 changes: 141 additions & 35 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
import asyncio

from collections import defaultdict
from typing import List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Type
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
# pylint: disable=unused-import
from ray.util.placement_group import PlacementGroup

from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.sampler import SamplerOutput, CompletionSequenceGroupOutput
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper,\
get_distributed_init_method, get_ip, get_vllm_instance_id, get_open_port
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper, envs, \
get_ip, get_vllm_instance_id, get_distributed_init_method, get_open_port
from vllm.worker.worker_base import WorkerBase

from vllm import envs
from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, ExecuteModelRequest
from vllm.sequence import Logprob, SequenceOutput, ExecuteModelRequest
from vllm.utils import GiB_bytes

from llumnix.internal_config import MigrationConfig
Expand All @@ -40,11 +40,12 @@
class LlumnixRayGPUExecutor(RayGPUExecutorAsync):
node_id: str = None
migration_config: MigrationConfig = None
last_inference_latency:int = 0

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
self.last_inference_latency = 0
if self.parallel_config.tensor_parallel_size == 1:
if (self.parallel_config.tensor_parallel_size == 1
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
Expand All @@ -57,13 +58,21 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []

# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)

logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)

# Create the workers.
driver_ip = get_ip()
node_id = self.node_id
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for rank in range(self.parallel_config.world_size):
if placement_group:
bundle = placement_group.bundle_specs[rank+1]
Expand All @@ -78,51 +87,94 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
node_id=node_id,
soft=False,
)

worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
max_concurrency=2,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="llumnix.backends.vllm.worker",
worker_class_name="MigrationWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="llumnix.backends.vllm.worker",
worker_class_name="MigrationWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)

if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)

if self.driver_dummy_worker is None:
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")

worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1

def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)

# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)

# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids

for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)

# pylint: disable=invalid-name
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)

if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")

VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
Expand All @@ -133,10 +185,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
}, ) for (node_id, _) in worker_node_and_gpu_ids]

self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)

self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
all_args=self._get_env_vars_to_be_updated())

if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())

Expand All @@ -154,12 +223,49 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
self._run_workers("reserve_memory_for_migration",
migration_config=self.migration_config,
model_config=self.model_config,
cache_config=self.cache_config,
parallel_config=self.parallel_config)

if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])

# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []

# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)

def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.scheduler_config.is_multi_step:
raise NotImplementedError
elif self.speculative_config:
raise NotImplementedError
else:
worker_module_name = "llumnix.backends.vllm.worker"
worker_class_name = "MigrationWorker"
return (worker_module_name, worker_class_name, worker_class_fn)
async def execute_model_async(self, *args, **kwargs):
t0 = time.time()
outputs = await super().execute_model_async(*args, **kwargs)
Expand Down
44 changes: 23 additions & 21 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def from_engine_args(
from llumnix.backends.vllm.executor import SimGPUExecutor
executor_class = SimGPUExecutor
executor_class.latency_mem = latency_mem
elif engine_config.parallel_config.worker_use_ray:
elif engine_config.parallel_config.use_ray:
from llumnix.backends.vllm.executor import LlumnixRayGPUExecutor
executor_class = LlumnixRayGPUExecutor
executor_class.migration_config = migration_config
Expand Down Expand Up @@ -275,8 +275,10 @@ def __init__(
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id)
self.engine.scheduler = [SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)]
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.scheduler = [SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
for _ in range(engine_args.pipeline_parallel_size)]
for vid in range(engine_args.pipeline_parallel_size):
self.engine.scheduler[vid].add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
self.worker_handle_list = self.engine.model_executor.workers.copy()
Expand Down Expand Up @@ -337,8 +339,8 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
seq.seq_id = next(self.engine.seq_counter)
logger.info("pop request {} from pre_alloc_cache_dict".format(backend_request.request_id))
pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id)
self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id)
pre_alloc_blocks = self.engine.scheduler[0].pre_alloc_cache_dict.pop(backend_request.request_id)
self.engine.scheduler[0].block_manager.add_block_table(pre_alloc_blocks, seq.seq_id)
backend_request.reset_migration_args_dst()
assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \
"The status of request migrated to dst instance should be \
Expand Down Expand Up @@ -369,47 +371,47 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
request_ids = set(request_id)
return self.engine.abort_request(request_ids)

def get_running_queue(self) -> Deque[SequenceGroupLlumnix]:
return self.engine.scheduler.get_running_queue()
def get_running_queue(self) -> List[SequenceGroupLlumnix]:
return self.engine.scheduler[0].get_running_queue()

def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]:
return self.engine.scheduler.get_waiting_queue()

def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs)
return self.engine.scheduler[0].get_request_incremental_blocks(*args, **kwargs)

def remove_running_request(self, *args, **kwargs) -> bool:
return self.engine.scheduler.remove_running_request(*args, **kwargs)
def remove_running_request(self, *args, **kwargs) -> None:
return self.engine.scheduler[0].remove_running_request(*args, **kwargs)

def remove_waiting_request(self, *args, **kwargs) -> bool:
return self.engine.scheduler.remove_waiting_request(*args, **kwargs)

def add_migrating_out_request_last_stage(self, *args, **kwargs) -> None:
return self.engine.scheduler.add_migrating_out_request_last_stage(*args, **kwargs)
return self.engine.scheduler[0].add_migrating_out_request_last_stage(*args, **kwargs)

def remove_migrating_out_request_last_stage(self, *args, **kwargs) -> None:
return self.engine.scheduler.remove_migrating_out_request_last_stage(*args, **kwargs)
return self.engine.scheduler[0].remove_migrating_out_request_last_stage(*args, **kwargs)

def pop_migrating_out_requests_last_stage(self, *args, **kwargs) -> List[Any]:
return self.engine.scheduler.pop_migrating_out_requests_last_stage(*args, **kwargs)
return self.engine.scheduler[0].pop_migrating_out_requests_last_stage(*args, **kwargs)

def pre_alloc(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.pre_alloc(*args, **kwargs)
return self.engine.scheduler[0].pre_alloc(*args, **kwargs)

def should_abort_migration(self, *args, **kwargs) -> bool:
return self.engine.scheduler.should_abort_migration(*args, **kwargs)
return self.engine.scheduler[0].should_abort_migration(*args, **kwargs)

def add_running_request(self, *args, **kwargs) -> None:
return self.engine.scheduler.add_running_request(*args, **kwargs)
return self.engine.scheduler[0].add_running_request(*args, **kwargs)

def add_waiting_request(self, *args, **kwargs) -> None:
return self.engine.scheduler.add_waiting_request(*args, **kwargs)
def is_request_running(self, *args, **kwargs) -> bool:
return self.engine.scheduler[0].is_request_running(*args, **kwargs)

def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None:
return self.engine.scheduler.free_dst_pre_alloc_cache(*args, **kwargs)
return self.engine.scheduler[0].free_dst_pre_alloc_cache(*args, **kwargs)

def free_src_request(self, backend_request: SequenceGroup) -> None:
return self.engine.scheduler.free_src_request(backend_request)
return self.engine.scheduler[0].free_src_request(backend_request)

def get_all_request_ids(self) -> List[str]:
return self.engine.scheduler.get_all_request_ids()
return self.engine.scheduler[0].get_all_request_ids()
Loading

0 comments on commit a1bd218

Please sign in to comment.