Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Aug 22, 2024
1 parent 7fc2836 commit e0c7ab5
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 31 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ BAZEL_CMD = bazel
init:
@git submodule update --init --recursive

# Install the package, but do not include pygloo, which should be installed on demand.
.PHONY: install
install: init
pip install -e .

.PHONY: lint
lint:
pylint --rcfile=.pylintrc ./llumnix
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def destory_backend(self) -> None:
raise NotImplementedError

@abstractmethod
def warmup(self) -> None:
def warmup(self) -> bool:
raise NotImplementedError

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
self._run_workers("reserver_memory_for_migration_backend",
self._run_workers("reserve_memory_for_migration",
migration_config=self.migration_config,
model_config=self.model_config,
cache_config=self.cache_config,
Expand Down
41 changes: 20 additions & 21 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
self.rpc_dtype = self.cache_engine.dtype
else:
self.rpc_dtype = torch.float32
logger.warning("Detecting numpy unsupported dtype: {}. Using torch.float32.".format(self.cache_engine.dtype))
logger.warning("Detect numpy unsupported dtype: {}. Using torch.float32.".format(self.cache_engine.dtype))

self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
Expand All @@ -81,8 +81,9 @@ def init_backend(self, group_name, world_size, rank) -> bool:
def destory_backend(self) -> None:
logger.info("destory rpc migrate backend successfully.")

def warmup(self) -> None:
def warmup(self) -> bool:
self.actor.exec_method.remote(self.is_driver_worker, "do_send", [0])
return True

def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
Expand Down Expand Up @@ -158,10 +159,8 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
if self.backend == 'gloo':
try_import_gloo()
self.cache_device = "cpu"
elif self.backend == 'nccl':
self.cache_device = torch.device(f"cuda:{self.local_rank}")
else:
raise ValueError("backend must be 'gloo' or 'nccl'")
self.cache_device = torch.device(f"cuda:{self.local_rank}")

pin_memory = (self.backend == 'gloo')
self.dummy_cache = torch.empty(
Expand All @@ -181,29 +180,15 @@ def init_group(world_size, rank, backend, group_name):
try:
init_group(world_size, rank, self.backend, group_name)
except FunctionTimedOut:
logger.info("create ray collective group fail (group_name:{}, world_size: {}, rank: {}, backbend: {})."
logger.info("create ray collective group fail (group_name: {}, world_size: {}, rank: {}, backbend: {})."
.format(group_name, world_size, rank, self.backend))
return False

self.group_name = group_name
self.global_world_size = world_size
self.global_rank = rank

logger.info("create ray collective group successfully (group_name:{}, world_size: {}, rank: {}, backbend: {})."
.format(self.group_name, self.global_world_size, self.global_rank, self.backend))
return True

def warmup(self):
if self.global_world_size > 1 and self.group_name is not None:
try:
col.allreduce(self.dummy_cache[0], self.group_name)
# pylint: disable=W0703
except Exception as e:
logger.info("warmup collective group failed (group_name:{}, world_size: {}, rank: {}, backbend: {}), err: {}."
.format(self.group_name, self.global_world_size, self.global_rank, self.backend, e))
return False

logger.info("ray collective group warmup successfully (group_name:{}, world_size: {}, rank: {}, backbend: {})."
logger.info("create ray collective group successfully (group_name: {}, world_size: {}, rank: {}, backbend: {})."
.format(self.group_name, self.global_world_size, self.global_rank, self.backend))
return True

Expand All @@ -221,6 +206,20 @@ def destory_backend(self) -> None:
.format(self.group_name, self.backend, err_info))
self.group_name = None

def warmup(self) -> bool:
if self.global_world_size > 1 and self.group_name is not None:
try:
col.allreduce(self.dummy_cache[0], self.group_name)
# pylint: disable=W0703
except Exception as e:
logger.info("warmup collective group failed (group_name:{}, world_size: {}, rank: {}, backbend: {}), err: {}."
.format(self.group_name, self.global_world_size, self.global_rank, self.backend, e))
return False

logger.info("ray collective group warmup successfully (group_name:{}, world_size: {}, rank: {}, backbend: {})."
.format(self.group_name, self.global_world_size, self.global_rank, self.backend))
return True

def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_global_rank"))
Expand Down
14 changes: 7 additions & 7 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def load_model(self):
def get_global_rank(self):
return self.global_rank

def reserver_memory_for_migration_backend(self, migration_config: MigrationConfig, model_config: ModelConfig,
cache_config: CacheConfig, parallel_config: ParallelConfig):
def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_config: ModelConfig,
cache_config: CacheConfig, parallel_config: ParallelConfig):
# TODO(s5u13b): move this to arguments checker
if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl':
logger.warning("NCCL backend is not supported for PP or TP enabled model, using gloo instead.")
logger.warning("nccl backend is not supported for PP or TP enabled model, use gloo instead.")
migration_config.migration_backend = 'gloo'

# for nccl migration backend, reserve gpu memory for dummy cache in migration backend
Expand All @@ -68,7 +68,7 @@ def reserver_memory_for_migration_backend(self, migration_config: MigrationConfi

if cache_config.gpu_memory_utilization <= 0:
raise ValueError("nccl migration backend take {:.4f} gpu memory, which is greater than gpu_memory_utilization {:.4f}. "
"try to increase gpu_memory_utilization or reduce migration-cache-blocks."
"try to increase gpu-memory-utilization or reduce migration-cache-blocks."
.format(migrate_ratio, cache_config.gpu_memory_utilization))

logger.info("nccl migration backend take {:.4f} gpu memory, left gpu_memory_utilization {:.4f} for kv cache."
Expand Down Expand Up @@ -106,7 +106,7 @@ def init_migration(self, instance_id: str, migration_config: MigrationConfig, sr
worker_rank=self.rank,
local_rank=self.local_rank)

def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_blocks: List[int]):
def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_blocks: List[int]) -> None:
src_worker_handle = src_worker_handle_list[self.rank]
try:
self.migration_backend.migrate_cache(src_worker_handle, src_blocks, dst_blocks)
Expand All @@ -119,7 +119,7 @@ def do_recv(self, *args, **kwargs):
def do_send(self, *args, **kwargs):
return self.migration_backend.do_send(*args, **kwargs)

def rebuild_migration_backend(self, instance_rank: Dict[str, int], group_name: str):
def rebuild_migration_backend(self, instance_rank: Dict[str, int], group_name: str) -> bool:
self.migration_backend.destory_backend()

ret = True
Expand All @@ -131,7 +131,7 @@ def rebuild_migration_backend(self, instance_rank: Dict[str, int], group_name: s

return ret

def warmup(self):
def warmup(self) -> bool:
return self.migration_backend.warmup()

def shutdown(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import time
import asyncio
import json
import os
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
Expand Down

0 comments on commit e0c7ab5

Please sign in to comment.