Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
rickyyx committed Nov 7, 2024
1 parent 97b8475 commit 666e0fc
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 167 deletions.
9 changes: 9 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,15 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
self._swap_mapping.clear()
return list(mapping.items())

def find_cached_blocks_prefix(
self,
block_hashes: List[int],
is_allocated: bool,
device: Device = Device.GPU,
) -> List[int]:
return self._allocators[device].find_cached_blocks_prefix(
block_hashes, is_allocated)


class NullBlock(Block):
"""
Expand Down
17 changes: 17 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ def get_prefix_cache_hit_rate(self) -> float:
class NoFreeBlocksError(ValueError):
pass

@abstractmethod
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
is_allocated: bool,
) -> List[int]:
pass


class DeviceAwareBlockAllocator(ABC):

Expand Down Expand Up @@ -284,3 +292,12 @@ def allocate_or_get_null_block(self) -> Block:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
is_allocated: bool,
device: Device = Device.GPU,
) -> List[int]:
pass
5 changes: 5 additions & 0 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ def swap_in(self, blocks: List[Block]) -> None:
def get_prefix_cache_hit_rate(self) -> float:
return -1

def find_cached_blocks_prefix(self, block_hashes: List[int],
is_allocated: bool) -> List[int]:
# Not applicable for naive block allocator.
return []


class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
Expand Down
210 changes: 134 additions & 76 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.sequence import Sequence

PrefixHash = int

Expand Down Expand Up @@ -634,6 +636,45 @@ def swap_in(self, blocks: List[Block]) -> None:

block.block_id = block_id # Assign block_id

def find_cached_blocks_prefix(self, block_hashes: List[int],
is_allocated: bool) -> List[int]:
"""
Return the prefix of the block hashes that are already computed and
cached.
When `is_allocated` is True, only return the blocks that are allocated.
"""

def _block_is_cached(block_hash: PrefixHash) -> bool:
if block_hash not in self._cached_blocks:
return False

cached_block_id = self._cached_blocks[block_hash]
if is_allocated and cached_block_id in self.evictor:
# When we require the block to be allocated even if it's cached,
# it must not be in evictor.
return False

# We only consider the blocks that are marked as computed.
return self.block_is_computed(cached_block_id)

def _bisect_left(a, x, key) -> int:
import sys
from bisect import bisect_left

# python <= 3.10 don't have the key argument
if sys.version_info < (3, 10):
a = [_block_is_cached(x) for x in a]
return bisect_left(a, x)
else:
return bisect_left(a, x, key=key)

# Look for the first block that's not cached, and returns the prefix
# , i.e. blocks that are cached.
idx = _bisect_left(block_hashes,
True,
key=lambda x: not _block_is_cached(x))
return block_hashes[:idx]


class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.
Expand Down Expand Up @@ -843,86 +884,103 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],


class ComputedBlocksTracker:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
""" """

def __init__(self, allocator):
def __init__(
self,
allocator: DeviceAwareBlockAllocator,
block_size: int,
enable_caching: bool,
):
self._allocator = allocator
self._cached_computed_seq_blocks: Dict[int, Tuple[List[int],
bool]] = {}

def add_seq(self, seq_id: int) -> None:
"""Start tracking seq_id
"""
assert seq_id not in self._cached_computed_seq_blocks
self._cached_computed_seq_blocks[seq_id] = ([], False)
self._block_size = block_size
self._enable_caching = enable_caching

# A map from seq_id to the list of block hashes for the
# sequence. This is so that we don't have to recompute the block hashes
# for the sequence when we need to check if the sequence is cached.
self._full_blocks_hashes: Dict[int, List[int]] = {}

# A map from (seq_id, and is allocated status) to the number of tokens
# that are cached for the sequence.
self._num_tokens_computed: Dict[Tuple[int, bool], int] = {}

def _update_seq_hashes(self, seq: Sequence) -> None:
assert self._enable_caching

block_hashes_recorded = self._full_blocks_hashes.get(seq.seq_id, [])
cur_num_blocks_recorded = len(block_hashes_recorded)
token_ids = seq.get_token_ids()
assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, (
f"The sequence has {len(token_ids)} tokens, but"
f" already recorded {cur_num_blocks_recorded} blocks. "
"This should not happen since we assume blocks are "
"only appended other than recomputation. When the sequence is "
"recomputed, we should have removed the info of the old blocks.")
# Update the computed block hashes for the sequence
num_total_blocks = len(token_ids) // self._block_size

# We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across
# sequences of prefixes.
prev_block_hash = (None if cur_num_blocks_recorded == 0 else
block_hashes_recorded[-1])
# Only update the computed block hashes for the new blocks
for i in range(cur_num_blocks_recorded, num_total_blocks):
assert len(token_ids) >= (i + 1) * self._block_size
block_token_ids = token_ids[i * self._block_size:(i + 1) *
self._block_size]
block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None,
prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids,
)
block_hashes_recorded.append(block_hash)
prev_block_hash = block_hash

# Update the cached sequence hashes
self._full_blocks_hashes[seq.seq_id] = block_hashes_recorded

def get_num_cached_tokens(self, seq: Sequence, is_allocated: bool) -> int:
if not self._enable_caching:
return 0

# We always tries to update the sequence hashes for the sequence.
# This is to ensure that we don't miss any cached tokens for the
# sequence during decodes.
# This routine should only update hash for any new blocks too.
self._update_seq_hashes(seq)

num_computed_tokens_prev = self._num_tokens_computed.get(
(seq.seq_id, is_allocated), None)

if num_computed_tokens_prev is not None and seq.is_prefill():
# For a sequence that is still in prefill, we don't have to
# recompute the number of cached tokens.
# This also handles correctly chunked prefill since currently
# we mark blocks as computed even if the sequence is still partially
# prefilled. So a continuously prefilled sequence should not
# see its cached token count change while running.
return num_computed_tokens_prev

block_hashes = self._full_blocks_hashes[seq.seq_id]

# This is O(logN), where N is the number of blocks.
num_cached_blocks = len(
self._allocator.find_cached_blocks_prefix(block_hashes,
is_allocated))
num_cached_tokens = num_cached_blocks * self._block_size
self._num_tokens_computed[(seq.seq_id,
is_allocated)] = (num_cached_tokens)
return num_cached_tokens

def remove_seq(self, seq_id: int) -> None:
"""Stop tracking seq_id
"""
assert seq_id in self._cached_computed_seq_blocks
del self._cached_computed_seq_blocks[seq_id]

def get_cached_computed_blocks_and_update(
self, seq_id: int, block_ids: List[int]) -> List[int]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert seq_id in self._cached_computed_seq_blocks

# Get cached data (may be empty on the first time)
prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[
seq_id]

if has_gap:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return prev_computed_block_ids

# We do not consider the last block id for caching purposes.
num_cur_blocks = len(block_ids) - 1
assert num_cur_blocks >= 0

if len(prev_computed_block_ids) >= num_cur_blocks:
# Cache HIT
assert len(prev_computed_block_ids) == num_cur_blocks
return prev_computed_block_ids

# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.

# Incremental init for seq_id => Look only at the new blocks
computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501
prev_computed_block_ids,
block_ids,
skip_last_block_id=
True, # We skip last block id to avoid caching of full seq
)

# Detect if there is a "gap"
has_gap = len(computed_block_ids) < num_cur_blocks

# Record
self._cached_computed_seq_blocks[seq_id] = (computed_block_ids,
has_gap)
"""Stop tracking seq_id"""
if not self._enable_caching:
return

return computed_block_ids
assert seq_id in self._full_blocks_hashes
del self._full_blocks_hashes[seq_id]


class LastAccessBlocksTracker:
Expand Down
22 changes: 13 additions & 9 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}

self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator)
self.block_allocator, self.block_size, self.enable_caching)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)

Expand Down Expand Up @@ -170,15 +170,13 @@ def allocate(self, seq_group: SequenceGroup) -> None:
self.block_tables[seq.seq_id] = block_table

# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)

# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()

# Track seq
self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)

# Allocate cross-attention block table for encoder sequence
Expand Down Expand Up @@ -314,11 +312,14 @@ def get_common_computed_block_ids(
"""
computed_seq_block_ids = []
for seq in seqs:
computed_seq_block_ids.append(
self._computed_blocks_tracker.
get_cached_computed_blocks_and_update(
seq.seq_id,
self.block_tables[seq.seq_id].physical_block_ids))
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
num_cached_tokens = (
self._computed_blocks_tracker.get_num_cached_tokens(
seq, is_allocated=True))
assert num_cached_tokens % self.block_size == 0
num_cached_blocks = num_cached_tokens // self.block_size
computed_block_ids = all_blocks[:num_cached_blocks]
computed_seq_block_ids.append(computed_block_ids)

# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
Expand All @@ -332,7 +333,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_tables[child_seq.seq_id] = src_block_table.fork()

# Track child seq
self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)

def can_swap_in(self, seq_group: SequenceGroup,
Expand Down Expand Up @@ -503,3 +503,7 @@ def _can_swap(self,
return AllocStatus.OK
else:
return AllocStatus.LATER

def get_num_cached_tokens(self, seq: Sequence, is_allocated: bool) -> int:
return self._computed_blocks_tracker.get_num_cached_tokens(
seq, is_allocated)
4 changes: 4 additions & 0 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def get_num_cached_tokens(self, seq: Sequence, is_allocated: bool) -> int:
pass
Loading

0 comments on commit 666e0fc

Please sign in to comment.