From 22e90d8dea5673f8141b91e7fd2b530b056cac93 Mon Sep 17 00:00:00 2001 From: Haiyang Shi Date: Fri, 15 Nov 2024 13:45:30 -0800 Subject: [PATCH] Async update for vineyard llm cache (#12) Vineyard is designed to drop updates whose prefix chunks are not already present in the cache, which imposes an ordering requirement on updates: we must perform updates in the issued order for each sequence. For simplicity, we use a single thread to process all updates sequentially. Signed-off-by: Haiyang Shi --- vllm/entrypoints/llm.py | 12 ++ vllm/envs.py | 27 +++ vllm/utils.py | 76 ++++++++ vllm/worker/vineyard_llm_cache.py | 276 ++++++++++++++++++++++++------ 4 files changed, 343 insertions(+), 48 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c9d3f1b4da01d..73f6582b933c0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -764,6 +764,18 @@ def _run_engine( np.mean(normalized_times_to_first_token), np.std(normalized_times_to_first_token), np.median(normalized_times_to_first_token), ) + logger.info( + "Cache service errors: err_async_update_task_queue_full %d", + self.llm_engine.cache_service_metrics.err_async_update_task_queue_full, + ) + + with self.llm_engine.cache_service_metrics.lock: + logger.info( + "Cache service async update: time queueing avg %.4f std %.4f median %.4f, time execution avg %.4f std %.4f median %.4f, counter updated avg %.2f std %.2f median %.2f", + np.mean(self.llm_engine.cache_service_metrics.time_async_update_queue), np.std(self.llm_engine.cache_service_metrics.time_async_update_queue), np.median(self.llm_engine.cache_service_metrics.time_async_update_queue), + np.mean(self.llm_engine.cache_service_metrics.time_async_update_exec), np.std(self.llm_engine.cache_service_metrics.time_async_update_exec), np.median(self.llm_engine.cache_service_metrics.time_async_update_exec), + np.mean(self.llm_engine.cache_service_metrics.counter_async_update_updated), np.std(self.llm_engine.cache_service_metrics.counter_async_update_updated), np.median(self.llm_engine.cache_service_metrics.counter_async_update_updated), + ) # Restore original behavior diff --git a/vllm/envs.py b/vllm/envs.py index 59fe9244a92ea..4aba1fb63c877 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,6 +64,11 @@ VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VINEYARD_CACHE_CPU_MEM_LIMIT_GB: float = 10 + VINEYARD_CACHE_ENABLE_ASYNC_UPDATE: bool = False + VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL: float = 0.2 + VINEYARD_CACHE_MIN_INFLIGHT_TASKS: int = 1 + VINEYARD_CACHE_MAX_INFLIGHT_TASKS: int = 32 def get_default_cache_root(): @@ -431,6 +436,28 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in ("1", "true")), + + # CPU memory limit for vineyard cache, in GB + "VINEYARD_CACHE_CPU_MEM_LIMIT_GB": + lambda: float(os.getenv("VINEYARD_CACHE_CPU_MEM_LIMIT_GB", "10")), + + # If set, vineyard will use async update + "VINEYARD_CACHE_ENABLE_ASYNC_UPDATE": lambda: ( + os.environ.get("VINEYARD_CACHE_ENABLE_ASYNC_UPDATE", "0").strip().lower() + in ("1", "true") + ), + + # CPU memory utilization for async update, default 20% + "VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL": + lambda: float(os.getenv("VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL", "0.2")), + + # Min number of inflight async tasks for vineyard cache + "VINEYARD_CACHE_MIN_INFLIGHT_TASKS": + lambda: int(os.getenv("VINEYARD_CACHE_MIN_INFLIGHT_TASKS", "1")), + + # Max number of inflight async tasks for vineyard cache + "VINEYARD_CACHE_MAX_INFLIGHT_TASKS": + lambda: int(os.getenv("VINEYARD_CACHE_MAX_INFLIGHT_TASKS", "32")), } # end-env-vars-definition diff --git a/vllm/utils.py b/vllm/utils.py index a22081ebe8df0..071e66334cbee 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -18,6 +18,7 @@ from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) +from queue import Queue, Empty from uuid import uuid4 import numpy as np @@ -1249,3 +1250,78 @@ def dec(self, num=1): @property def value(self): return self._value + + +class ObjectPool: + def __init__( + self, + min_pool_size: int, + max_pool_size: int, + object_creator: Callable[[], T], + ): + """ + Initialize ObjectPool + + Args: + min_pool_size: The min number of objects to maintain + in the pool. + max_pool_size: The max number of objects that can be + in the pool. + object_creator: A function that returns a new object. + """ + self.min_pool_size: int = min_pool_size + self.max_pool_size: int = max_pool_size + self.object_creator: Callable[[], T] = object_creator + self._pool: Queue = Queue(maxsize=max_pool_size) + self._lock: threading.Lock = threading.Lock() + self._size: int = min_pool_size + + # Pre-fill the pool with min_pool_size objects + for _ in range(min_pool_size): + self._pool.put(object_creator()) + + def get( + self, block: bool = True, timeout: Optional[float] = None + ) -> Optional[T]: + """ + Fetch an object from the pool, creating one if none are available + and the max pool size isn't reached. + + Args: + block: If True, block until an object is available. + timeout: Time in seconds to wait for an available object. + + Returns: + The object if available, otherwise None. + """ + with self._lock: + try: + return self._pool.get(block=block, timeout=timeout) + except Empty: + # If the pool is empty but we haven't hit the max size, + # create a new object + if self._size < self.max_pool_size: + self._size += 1 + return self.object_creator() + return None + + def put(self, object: T) -> None: + """ + Return an object to the pool. Only objects returned by get() + can be put back to the pool. Therefore, the pool should never + be full. + + Args: + object: The object to return to the pool. + """ + if object is not None: + with self._lock: + assert self._pool.qsize() < self._size + self._pool.put(object, block=False) + + @property + def size(self) -> int: + """ + Get the current size of the pool. + """ + return self._pool.qsize() diff --git a/vllm/worker/vineyard_llm_cache.py b/vllm/worker/vineyard_llm_cache.py index a3e2efa7371ea..b768e4b6b4dbe 100644 --- a/vllm/worker/vineyard_llm_cache.py +++ b/vllm/worker/vineyard_llm_cache.py @@ -1,6 +1,9 @@ import logging -import time import numpy as np +import time +import threading +from functools import partial +from queue import Queue, Full from typing import Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -12,7 +15,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.sequence import (SequenceData, SequenceGroupMetadata) -from vllm.utils import init_logger +from vllm.utils import init_logger, ObjectPool try: from vineyard.llm import KVCache as VineyardKVCache @@ -41,9 +44,15 @@ class CacheServiceMetrics: normalized_time_reshape: list = [] # Times used reshaping tensors for flash attention KV format normalized by number of tokens. normalized_time_unload: list = [] # Times used move computed KV from device memory normalized by number of tokens. normalized_time_update: list = [] # Times used update computed KV to cache service normalized by number of tokens. - - - + err_async_update_task_queue_full: int = 0 # Number of Full exceptions when enqueuing async update tasks + + lock: threading.Lock = threading.Lock() + # The following metrics need to be protected by `lock` + time_async_update_queue: list = [] # Queuing delays of async update tasks + time_async_update_exec: list = [] # Execution times of async update tasks + counter_async_update_updated: list = [] # Number of udpated tokens + + class VineyardLLMCache: def __init__( self, @@ -54,7 +63,10 @@ def __init__( layer: int = 2, kv_cache_dtype: str = None, torch_dtype: torch.dtype = torch.bfloat16, - metrics: CacheServiceMetrics = None + metrics: CacheServiceMetrics = None, + enable_async_update: bool = False, + min_inflight_tasks: int = 1, + max_inflight_tasks: int = 1, ): self._init_vineyard_logger() @@ -80,23 +92,57 @@ def __init__( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must" \ f"be a multiple of chunk_size ({self.chunk_size})" ) - self.buffer = torch.empty( + self.fetch_buffer, self.fetch_tensors = self._pinned_tensor_creator() + self.cuda_buffer = self.fetch_buffer.cuda() + self.enable_async_update = enable_async_update + + if self.enable_async_update: + # we use an object pool to reuse the pinned tensors and restrict the number of + # inflight tasks. if an update operation cannot get a tensor from the pool, + # meaning we already have max_inflight_tasks tasks issued, it then simply skips + # the update. A completed task will return the used tensor back to the pool. + self.tensor_pool = ObjectPool( + min_inflight_tasks, + max_inflight_tasks, + self._pinned_tensor_creator, + ) + + # `_update_tasks` is a task queue being accessed by both the main thread + # and the background thread. + self._update_tasks = Queue(maxsize=max_inflight_tasks) + # The cache backend is designed to drop updates whose prefix chunks are + # not already present in the cache, which imposes an ordering requirement + # on updates: we must perform updates in the issued order. For simplicity, + # we use a single thread to process all updates sequentially. + self._background_loop = threading.Thread( + target=self._run_background_loop, daemon=True + ) + self._background_loop.start() + + self.metrics = metrics + logger.info(f"VineyardLLMCache init {metrics}") + logger.info(self) + + def _pinned_tensor_creator( + self, + ) -> Tuple[torch.Tensor, List[List[Tuple[VineyardKVTensor, VineyardKVTensor]]]]: + '''Create a pinned tensor and a list of tensors to hold the KV tensors. + ''' + buffer = torch.empty( (2, self.layer, self.max_num_batched_tokens, self.num_kv_heads, self.head_size), - dtype=torch_dtype, device='cpu', + dtype=self.torch_dtype, device='cpu', ).pin_memory() - self.cuda_buffer = self.buffer.cuda() - self.tensors = [] + tensors = [] for i in range(self.max_num_batched_tokens): - self.tensors.append([]) + tensors.append([]) for j in range(self.layer): - k_tensor = self.buffer[0, j, i] - v_tensor = self.buffer[1, j, i] - self.tensors[-1].append(( + k_tensor = buffer[0, j, i] + v_tensor = buffer[1, j, i] + tensors[-1].append(( VineyardKVTensor(k_tensor.data_ptr(), k_tensor.numel() * k_tensor.element_size()), VineyardKVTensor(v_tensor.data_ptr(), v_tensor.numel() * v_tensor.element_size()), )) - self.metrics = metrics - logger.info(f"VineyardLLMCache init {metrics}") + return buffer, tensors def _init_vineyard_logger(self): import vineyard @@ -107,6 +153,22 @@ def _init_vineyard_logger(self): for handler in logger.handlers: vineyard.logger.addHandler(handler) + def _run_background_loop(self): + '''Start a background loop to process the update tasks. + ''' + logger.info("VineyardKVCache background loop is running") + while True: + # Wait until there is a task in the queue + update_fn = self._update_tasks.get() + # Run the task + try: + update_fn() + logger.debug( + f"Completed an update op, current task queue size={self._update_tasks.qsize()}" + ) + except Exception: + pass + @staticmethod def from_envs( model_config: ModelConfig, @@ -124,21 +186,64 @@ def from_envs( logger.warn("VineyardLLMCache requires flash attention decoding") return None + cpu_mem_limit = int(envs.VINEYARD_CACHE_CPU_MEM_LIMIT_GB * 1024**3) head_size = model_config.get_head_size() num_kv_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) + token_nbytes = num_layers * num_kv_heads * head_size * 2 + + # we will use one temp buffer to hold the kv tensors fetched from v6d cache + # , i.e., `fetch_buffer` + num_temp_cpu_buffer = 1 + kwargs = {} + + # if async update is enabled, we will use a portion of cpu memory as temp + # buffers to hold the kv tensors being updated into the v6d cache + if envs.VINEYARD_CACHE_ENABLE_ASYNC_UPDATE: + # get the mem limit + async_update_cpu_mem_util = envs.VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL + async_update_cpu_mem_limit = async_update_cpu_mem_util * cpu_mem_limit + max_inflight_tasks = int( + async_update_cpu_mem_limit // (max_num_batched_tokens * token_nbytes) + ) + max_inflight_tasks = min(max_inflight_tasks, envs.VINEYARD_CACHE_MAX_INFLIGHT_TASKS) + num_temp_cpu_buffer += max_inflight_tasks + kwargs["enable_async_update"] = True + kwargs["min_inflight_tasks"] = min(envs.VINEYARD_CACHE_MIN_INFLIGHT_TASKS, max_inflight_tasks) + kwargs["max_inflight_tasks"] = max_inflight_tasks + logger.info(f"VineyardLLMCache async update: {kwargs}") + + # convert cache capacity to number of tokens + cache_capacity = ( + cpu_mem_limit + - num_temp_cpu_buffer * max_num_batched_tokens * token_nbytes + ) // token_nbytes logger.info(f"VineyardLLMCache from_envs {metrics}") return VineyardLLMCache( head_size=head_size, num_kv_heads=num_kv_heads, max_num_batched_tokens=max_num_batched_tokens, - cache_capacity=2**20, + cache_capacity=cache_capacity, layer=num_layers, kv_cache_dtype=kv_cache_dtype, torch_dtype=torch_dtype, - metrics = metrics + metrics = metrics, + **kwargs, ) + def _update_seq_group_metadata( + self, seq_group_metadata: SequenceGroupMetadata, value: int + ) -> None: + '''Update sequence group's metadata + ''' + if seq_group_metadata is not None: + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + seq_data = seq_group_metadata.seq_data[seq_id] + seq_data.update_num_computed_tokens(value) + seq_group_metadata.token_chunk_size -= value + def prefetch_seq_kv_caches( self, seq_group_metadata: SequenceGroupMetadata, @@ -192,11 +297,14 @@ def prefetch_seq_kv_caches( query_tokens ) = query_args + if query_token_size <= 0: + return seq_id, 0 + start_time = time.perf_counter() matched = self.cache.query( prefix=query_prefix, tokens=query_tokens, - kv_cache_list=self.tensors[:query_token_size], + kv_cache_list=self.fetch_tensors[:query_token_size], ) duration = time.perf_counter() - start_time self.metrics.time_query.append(duration) @@ -247,13 +355,13 @@ def prefetch_seq_kv_caches( # efficient than performing multiple smaller copy operations. This # approach reduces the number of transfers between CPU and GPU, # leading to faster overall performance. - buffer = self.cuda_buffer.copy_(self.buffer)[:, :, :matched] + buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, :matched] copy_end.record() copy_end.synchronize() duration = copy_start.elapsed_time(copy_end) / 1000.0 self.metrics.time_load.append(duration) self.metrics.normalized_time_load.append(0 if matched == 0 else duration/matched) - + reshape_start = torch.cuda.Event(enable_timing=True) reshape_end = torch.cuda.Event(enable_timing=True) reshape_start.record() @@ -275,11 +383,9 @@ def prefetch_seq_kv_caches( duration = reshape_start.elapsed_time(reshape_end) / 1000.0 self.metrics.time_reshape.append(duration) self.metrics.normalized_time_reshape.append(0 if matched == 0 else duration/matched) - + # update the seq_group_metadata's and seq's metadata - if seq_group_metadata is not None: - seq_data.update_num_computed_tokens(matched) - seq_group_metadata.token_chunk_size -= matched + self._update_seq_group_metadata(seq_group_metadata, matched) return seq_id, matched @@ -319,13 +425,63 @@ def prefetch_kv_caches( logger.debug(f"prefetch_kv_caches: matched=%r", matched) return matched + def _update_kv_cache( + self, + prefix: List[int], + tokens: List[int], + buffer_tensors_tuple: Tuple[ + torch.Tensor, List[List[Tuple[VineyardKVTensor, VineyardKVTensor]]] + ], + scheduled_time: float, + ) -> None: + '''Update the KV cache. + + Args: + prefix: Prefix tokens. + tokens: Tokens to be cached. + buffer_tensors_tuple: Within the tuple, the first element is a continugous, + pinned buffer, and the second element is a logical view + of the buffer that is used to let v6d to know about the + actual layout of the KV tensors. + If async update is enabled, the `buffer_tensors_tuple` + is allocated from the object pool, and thus we need to + return it back to the pool after completing the update + operation. + scheduled_time: The timestamp that the task is scheduled. + ''' + try: + start_time = time.perf_counter() + queue_duration = start_time - scheduled_time + update_token_size = len(tokens) + kv_cache_list = buffer_tensors_tuple[1][:update_token_size] + updated = self.cache.update(prefix, tokens, kv_cache_list) + exec_duration = time.perf_counter() - start_time + if self.enable_async_update: + logger.debug( + f"update kv cache: #prefix={len(prefix)}, #tokens={len(tokens)}, updated={updated}, " + f"queue_duration={queue_duration:.4f}, exec_duration={exec_duration:.4f}" + ) + with self.metrics.lock: + self.metrics.time_async_update_queue.append(queue_duration) + self.metrics.time_async_update_exec.append(exec_duration) + self.metrics.counter_async_update_updated.append(updated) + else: + logger.debug( + f"update kv cache: #prefix={len(prefix)}, #tokens={len(tokens)}, updated={updated}" + ) + except Exception: + pass + finally: + if self.enable_async_update: + self.tensor_pool.put(buffer_tensors_tuple) + def update_seq_kv_caches( self, matched: Dict[str, int], seq_group_metadata: SequenceGroupMetadata, kv_caches: List[torch.Tensor], block_size: int, - ) -> Tuple[str, int]: + ) -> None: if seq_group_metadata is not None: seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 @@ -364,10 +520,25 @@ def update_seq_kv_caches( ) = update_args if update_token_size <= 0: # restore the seq_group_metadata's and seq's metadata - if seq_group_metadata is not None: - seq_data.update_num_computed_tokens(-matched[seq_id]) - seq_group_metadata.token_chunk_size += matched[seq_id] - return seq_id, 0 + self._update_seq_group_metadata(seq_group_metadata, -matched[seq_id]) + return + + if self.enable_async_update: + buffer_tensors_tuple = self.tensor_pool.get(block=False) + # buffer_tensors_tuple is None means that we have max number of async + # updates in the flight. Right now, we just skip updates if we have no + # buffer. + if buffer_tensors_tuple is None: + # restore the seq_group_metadata's and seq's metadata + self._update_seq_group_metadata(seq_group_metadata, -matched[seq_id]) + return + else: + # if async update is disabled, its safe to reuse the same buffer and + # tensors used in the fetch operation + buffer_tensors_tuple = self.fetch_buffer, self.fetch_tensors + + update_buffer, _ = buffer_tensors_tuple + if seq_group_metadata is not None: block_table = seq_group_metadata.block_tables[seq_id] slot_mapping = [] @@ -392,30 +563,44 @@ def update_seq_kv_caches( for j in range(self.layer): self.cuda_buffer[:, j, :update_token_size].copy_( kv_caches[j][:, slot_mapping // block_size, slot_mapping % block_size]) - self.buffer.copy_(self.cuda_buffer) + update_buffer.copy_(self.cuda_buffer) end_unload.record() end_unload.synchronize() duration = start_unload.elapsed_time(end_unload) / 1000.0 self.metrics.time_unload.append(duration) self.metrics.normalized_time_unload.append(0 if update_token_size == 0 else duration/update_token_size) - + start_time = time.perf_counter() - # updates into vineyard - updated = self.cache.update( + + update_task = partial(self._update_kv_cache, prefix=update_prefix, tokens=update_tokens, - kv_cache_list=self.tensors[:update_token_size], + buffer_tensors_tuple=buffer_tensors_tuple, + scheduled_time=start_time, ) + if self.enable_async_update: + # async update + try: + logger.debug( + f"submit update task: #prefix={len(update_prefix)}, #tokens={len(update_tokens)}" + ) + self._update_tasks.put_nowait(update_task) + logger.debug( + f"task queue size={self._update_tasks.qsize()}, tensor pool size={self.tensor_pool.size}" + ) + except Full: + logger.warning(f"update_seq_kv_caches: queue is full, skip this update") + self.metrics.err_async_update_task_queue_full += 1 + self.tensor_pool.put(buffer_tensors_tuple) + else: + update_task() + duration = time.perf_counter() - start_time self.metrics.time_update.append(duration) self.metrics.normalized_time_update.append(0 if update_token_size == 0 else duration/update_token_size) # restore the seq_group_metadata's and seq's metadata - if seq_group_metadata is not None: - seq_data.update_num_computed_tokens(-matched[seq_id]) - seq_group_metadata.token_chunk_size += matched[seq_id] - - return seq_id, updated + self._update_seq_group_metadata(seq_group_metadata, -matched[seq_id]) def update_kv_caches( self, @@ -423,9 +608,9 @@ def update_kv_caches( seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], block_size: int, - ) -> Dict[str, int]: + ) -> None: if block_size is None or kv_caches[0] is None: # profile run - return {} + return if seq_group_metadata_list is not None: prefill_requests = [] @@ -442,15 +627,10 @@ def update_kv_caches( prefill_requests = [None] * num_prefill_requests[0] num_prefill_requests = num_prefill_requests[0] - updated = {} for seq_group_meta in prefill_requests: - seq_id, seq_updated = self.update_seq_kv_caches( + self.update_seq_kv_caches( matched, seq_group_meta, kv_caches, block_size, ) - updated[seq_id] = seq_updated - if updated: - logger.debug(f"update_kv_caches: updated=%r", updated) - return updated def __repr__(self): return (