Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KV cache without dependency on chunked prefill #19

Open
wants to merge 1 commit into
base: feat/distributed-kv-cache
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
chunked_prefill_enabled: bool, prefix_cache_hit: bool,
vineyard_llm_cache_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
Expand Down Expand Up @@ -417,8 +418,8 @@ def _add_seq_group(
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
elif ((chunked_prefill_enabled or vineyard_llm_cache_enabled
or not is_prompt) and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
Expand Down Expand Up @@ -453,7 +454,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
prefix_cache_hit,
self.input_builder.vineyard_llm_cache_enabled)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
Expand Down
5 changes: 0 additions & 5 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_USE_VINEYARD_CACHE: Optional[str] = None
VLLM_USE_FLASH_ATTN_DECODING: Optional[str] = None
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
Expand Down Expand Up @@ -103,10 +102,6 @@ def get_default_config_root():
# Enable vineyard kv cache for vLLM.
"VLLM_USE_VINEYARD_CACHE":
lambda: os.getenv("VLLM_USE_VINEYARD_CACHE", None),

# Enable vineyard kv cache for vLLM.
"VLLM_USE_FLASH_ATTN_DECODING":
lambda: os.getenv("VLLM_USE_FLASH_ATTN_DECODING", None),

# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
Expand Down
21 changes: 10 additions & 11 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ def __init__(self,
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size

self.vineyard_llm_cache_enabled = self.runner.vineyard_llm_cache is not None

def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
Expand Down Expand Up @@ -989,30 +991,27 @@ def __init__(
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None

# Delay the initialization of vineyard cache after model loading
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))

# Delay the initialization of vineyard cache after model loading
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

def _init_vineyard_cache(self, metrics: CacheServiceMetrics = None):
if envs.VLLM_USE_VINEYARD_CACHE:
if not self.scheduler_config.chunked_prefill_enabled:
raise Exception("Vineyard LLM cache is not enabled, requires chunked prefill")
if not envs.VLLM_USE_FLASH_ATTN_DECODING:
raise Exception("Vineyard LLM cache is not enabled, requires flash attention decoding")

from vllm.attention.backends.flash_attn import FlashAttentionBackend
if not issubclass(self.attn_backend, FlashAttentionBackend):
raise Exception(
f"Vineyard LLM cache does not support {self.attn_backend}, "
"requires flash attention backend")

from vllm.worker.vineyard_llm_cache import VineyardLLMCache
self.vineyard_llm_cache: VineyardLLMCache = VineyardLLMCache.from_envs(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
kv_cache_dtype=self.kv_cache_dtype,
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
torch_dtype=get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype),
metrics = metrics,
Expand Down
Loading