Skip to content

Commit

Permalink
[Core] Adopting vllm to external KV store Vineyard (#1)
Browse files Browse the repository at this point in the history
* Enable vineyard llm kv cache in vLLM

Based on another version of vllm: sighingnow@d347dab

Cherry-pick from commit d347dab

Signed-off-by: Tao He <[email protected]>
(cherry picked from commit 1545f6bf7edcd667e305d3fbcadd913066f04747)

resolving vllm update diff

temporarily comment out torch.distributed for single node env

add VineyardCacheConfig with https://github.com/v6d-io/v6d/blob/ebe8f077e3d3780a27d49238c501854b6b8e29df/modules/llm-cache/ds/kv_cache_block.cc#L163 commented out; cache_ops fix

remove CacheConfig from argument (configure through ENV)

v6d: fix integration w/ v1 APIs

Signed-off-by: Haiyang Shi <[email protected]>

Change model_runner to latest version

cherry pick model_runner from d347dab source sighingnow@d347dab

fix reshape_and_cache_flash argument

add cache prefetch/update to work_base

clean up

Fix after rebase to 029c71d

remove tensor copy from cache managed address to pin memory

clean up

* Add fixes to address comments

---------

Co-authored-by: Tao He <[email protected]>
  • Loading branch information
happyandslow and sighingnow authored Oct 1, 2024
1 parent 3fd2b0d commit 7cdac48
Show file tree
Hide file tree
Showing 5 changed files with 477 additions and 13 deletions.
6 changes: 5 additions & 1 deletion benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def main(args):
enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching)
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

Expand Down Expand Up @@ -175,6 +176,9 @@ def main(args):
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--enable-chunked-prefill',
action='store_true',
help='enable chunked prefill')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
Expand Down
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_USE_VINEYARD_CACHE: Optional[str] = None
VLLM_USE_FLASH_ATTN_DECODING: Optional[str] = None
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
Expand Down Expand Up @@ -92,6 +94,14 @@ def get_default_config_root():
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),

# Enable vineyard kv cache for vLLM.
"VLLM_USE_VINEYARD_CACHE":
lambda: os.getenv("VLLM_USE_VINEYARD_CACHE", None),

# Enable vineyard kv cache for vLLM.
"VLLM_USE_FLASH_ATTN_DECODING":
lambda: os.getenv("VLLM_USE_FLASH_ATTN_DECODING", None),

# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS":
Expand Down
49 changes: 45 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY, InputRegistry
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand All @@ -47,7 +48,7 @@
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
get_kv_cache_torch_dtype, supports_dynamo)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -985,8 +986,35 @@ def __init__(
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None

# Delay the initialization of vineyard cache after model loading
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))

# Delay the initialization of vineyard cache after model loading
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

def _init_vineyard_cache(self):
if envs.VLLM_USE_VINEYARD_CACHE:
if not self.scheduler_config.chunked_prefill_enabled:
raise Exception("Vineyard LLM cache is not enabled, requires chunked prefill")
if not envs.VLLM_USE_FLASH_ATTN_DECODING:
raise Exception("Vineyard LLM cache is not enabled, requires flash attention decoding")

from vllm.worker.vineyard_llm_cache import VineyardLLMCache
self.vineyard_llm_cache: VineyardLLMCache = VineyardLLMCache.from_envs(
model_config=self.model_config,
parallel_config=self.parallel_config,
kv_cache_dtype=self.kv_cache_dtype,
torch_dtype=get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype),
)
logger.info("Using Vineyard LLM cache")
else:
logger.info("Vineyard LLM cache is not enabled")

# Used to cache python objects
self.inter_data_cache: Dict[int, PyObjectCache] = {}
Expand Down Expand Up @@ -1068,6 +1096,7 @@ def load_model(self) -> None:
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager")
self._init_vineyard_cache()

def save_sharded_state(
self,
Expand All @@ -1083,6 +1112,9 @@ def save_sharded_state(
max_size=max_size,
)

def set_block_size(self, block_size: int) -> None:
self.block_size = block_size

def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
Expand Down Expand Up @@ -1205,15 +1237,18 @@ def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
model_input, cache_hints = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids, kv_caches = kv_caches)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
if self.vineyard_llm_cache and kv_caches[0] is not None:
self.vineyard_llm_cache.update_kv_caches(
cache_hints, seqs, kv_caches, getattr(self, 'block_size', None))
torch.cuda.synchronize()
return

Expand Down Expand Up @@ -1456,6 +1491,7 @@ def prepare_model_input(
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
kv_caches: List[torch.Tensor] = [],
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
Expand All @@ -1470,6 +1506,10 @@ def prepare_model_input(
If cuda graph is required, this API automatically pads inputs.
"""
cache_hints = None
if self.vineyard_llm_cache and len(kv_caches) > 0 and kv_caches[0] is not None:
cache_hints = self.vineyard_llm_cache.prefetch_kv_caches(
seq_group_metadata_list, kv_caches, getattr(self, 'block_size', None))
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
if get_pp_group().is_last_rank:
Expand All @@ -1486,7 +1526,8 @@ def prepare_model_input(
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt,
virtual_engine=virtual_engine)
virtual_engine=virtual_engine), cache_hints


@torch.inference_mode()
def execute_model(
Expand Down
Loading

0 comments on commit 7cdac48

Please sign in to comment.