Skip to content

Commit

Permalink
Async update for vineyard llm cache (#12)
Browse files Browse the repository at this point in the history
Vineyard is designed to drop updates whose prefix chunks are
not already present in the cache, which imposes an ordering requirement
on updates: we must perform updates in the issued order for each
sequence. For simplicity, we use a single thread to process all updates sequentially.

Signed-off-by: Haiyang Shi <[email protected]>
  • Loading branch information
DwyaneShi authored Nov 15, 2024
1 parent 28ef00e commit 22e90d8
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 48 deletions.
12 changes: 12 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,18 @@ def _run_engine(
np.mean(normalized_times_to_first_token), np.std(normalized_times_to_first_token), np.median(normalized_times_to_first_token),

)
logger.info(
"Cache service errors: err_async_update_task_queue_full %d",
self.llm_engine.cache_service_metrics.err_async_update_task_queue_full,
)

with self.llm_engine.cache_service_metrics.lock:
logger.info(
"Cache service async update: time queueing avg %.4f std %.4f median %.4f, time execution avg %.4f std %.4f median %.4f, counter updated avg %.2f std %.2f median %.2f",
np.mean(self.llm_engine.cache_service_metrics.time_async_update_queue), np.std(self.llm_engine.cache_service_metrics.time_async_update_queue), np.median(self.llm_engine.cache_service_metrics.time_async_update_queue),
np.mean(self.llm_engine.cache_service_metrics.time_async_update_exec), np.std(self.llm_engine.cache_service_metrics.time_async_update_exec), np.median(self.llm_engine.cache_service_metrics.time_async_update_exec),
np.mean(self.llm_engine.cache_service_metrics.counter_async_update_updated), np.std(self.llm_engine.cache_service_metrics.counter_async_update_updated), np.median(self.llm_engine.cache_service_metrics.counter_async_update_updated),
)


# Restore original behavior
Expand Down
27 changes: 27 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VINEYARD_CACHE_CPU_MEM_LIMIT_GB: float = 10
VINEYARD_CACHE_ENABLE_ASYNC_UPDATE: bool = False
VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL: float = 0.2
VINEYARD_CACHE_MIN_INFLIGHT_TASKS: int = 1
VINEYARD_CACHE_MAX_INFLIGHT_TASKS: int = 32


def get_default_cache_root():
Expand Down Expand Up @@ -431,6 +436,28 @@ def get_default_config_root():
lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")),

# CPU memory limit for vineyard cache, in GB
"VINEYARD_CACHE_CPU_MEM_LIMIT_GB":
lambda: float(os.getenv("VINEYARD_CACHE_CPU_MEM_LIMIT_GB", "10")),

# If set, vineyard will use async update
"VINEYARD_CACHE_ENABLE_ASYNC_UPDATE": lambda: (
os.environ.get("VINEYARD_CACHE_ENABLE_ASYNC_UPDATE", "0").strip().lower()
in ("1", "true")
),

# CPU memory utilization for async update, default 20%
"VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL":
lambda: float(os.getenv("VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL", "0.2")),

# Min number of inflight async tasks for vineyard cache
"VINEYARD_CACHE_MIN_INFLIGHT_TASKS":
lambda: int(os.getenv("VINEYARD_CACHE_MIN_INFLIGHT_TASKS", "1")),

# Max number of inflight async tasks for vineyard cache
"VINEYARD_CACHE_MAX_INFLIGHT_TASKS":
lambda: int(os.getenv("VINEYARD_CACHE_MAX_INFLIGHT_TASKS", "32")),
}

# end-env-vars-definition
Expand Down
76 changes: 76 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Type, TypeVar, Union, overload)
from queue import Queue, Empty
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -1249,3 +1250,78 @@ def dec(self, num=1):
@property
def value(self):
return self._value


class ObjectPool:
def __init__(
self,
min_pool_size: int,
max_pool_size: int,
object_creator: Callable[[], T],
):
"""
Initialize ObjectPool
Args:
min_pool_size: The min number of objects to maintain
in the pool.
max_pool_size: The max number of objects that can be
in the pool.
object_creator: A function that returns a new object.
"""
self.min_pool_size: int = min_pool_size
self.max_pool_size: int = max_pool_size
self.object_creator: Callable[[], T] = object_creator
self._pool: Queue = Queue(maxsize=max_pool_size)
self._lock: threading.Lock = threading.Lock()
self._size: int = min_pool_size

# Pre-fill the pool with min_pool_size objects
for _ in range(min_pool_size):
self._pool.put(object_creator())

def get(
self, block: bool = True, timeout: Optional[float] = None
) -> Optional[T]:
"""
Fetch an object from the pool, creating one if none are available
and the max pool size isn't reached.
Args:
block: If True, block until an object is available.
timeout: Time in seconds to wait for an available object.
Returns:
The object if available, otherwise None.
"""
with self._lock:
try:
return self._pool.get(block=block, timeout=timeout)
except Empty:
# If the pool is empty but we haven't hit the max size,
# create a new object
if self._size < self.max_pool_size:
self._size += 1
return self.object_creator()
return None

def put(self, object: T) -> None:
"""
Return an object to the pool. Only objects returned by get()
can be put back to the pool. Therefore, the pool should never
be full.
Args:
object: The object to return to the pool.
"""
if object is not None:
with self._lock:
assert self._pool.qsize() < self._size
self._pool.put(object, block=False)

@property
def size(self) -> int:
"""
Get the current size of the pool.
"""
return self._pool.qsize()
Loading

0 comments on commit 22e90d8

Please sign in to comment.