Skip to content

Commit

Permalink
[llama] Update kv cache to have read/write functions (#280)
Browse files Browse the repository at this point in the history
Made the interfaces of both caches line up. This allows us to interface
with the caches via their utility functions instead of modifying the
model behavior. Some roughness still exists in their parameters but the
irrelevant details are ignored for each implementation.

Need to still add some slicing / ignoring on the page ids to make more
flexible.
  • Loading branch information
rsuderman authored Oct 29, 2024
1 parent 187c45e commit 89e26c0
Show file tree
Hide file tree
Showing 5 changed files with 671 additions and 123 deletions.
132 changes: 113 additions & 19 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
shard_count: int = 1,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
Expand All @@ -100,6 +101,7 @@ def __init__(
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length
self.shard_count = shard_count
self.device = device
self.dtype = dtype

Expand All @@ -113,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]:
Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
"""
return [
allocations = [
torch.empty(
[bs, self.seq_length, self.attn_head_count, self.attn_head_dim],
[
bs,
self.seq_length,
self.attn_head_count,
self.attn_head_dim,
],
dtype=self.dtype,
device=self.device,
)
for _ in range(2 * self.transformer_block_count)
]

if self.shard_count == 1:
return allocations

return [
ops.reshard_split(allocation, dim=2, count=self.shard_count)
for allocation in allocations
]

def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Reads cache partitions from the page table for the given page_ids.
Args:
state: State struct as returned from allocate().
read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
to access.
Returns a tuple of cache partitions (i.e. k and v caches for the transformer
block), linearized. Note that this reference approach to reading by
materializing linearly may not be terribly efficient unless if the
compiler can fuse the gather.
"""
read_count = len(read_into_partitions)
reads = []
for i in range(read_count):
reads.append(
state[transformer_block_index * read_count + i][:, :seq_len, :, :]
)

return tuple(reads)

def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Writes a single batched timestep across all cache partitions.
Note that this internally loops over the batch size, which cannot be
dynamic.
"""
bs, _, _, _ = cache_partitions[0].shape
update_count = len(cache_partitions)

for b in range(bs):
row_index = torch.tensor(b, dtype=torch.int64)
row_start_pos = seq_positions[row_index]

for i, update in enumerate(cache_partitions):
cache = state[transformer_block_index * update_count + i]
cache.index_put_((row_index, row_start_pos), update[row_index, 0])

def write(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Writes cache partitions from a linear layout to the page table.
This is the inverse of the linear read. The same caveat applies if the
in-place scatter cannot be fused.
"""
update_count = len(cache_partitions)

for idx, update_src in enumerate(cache_partitions):
cache_dest = state[transformer_block_index * update_count + idx]
_, batch_seq_len, _, _ = update_src.shape
cache_dest[:, :batch_seq_len, :, :] = update_src


class PagedKVCache(BaseKVCache):
"""Implementation of a KV cache on top of a 'page table'.
Expand Down Expand Up @@ -238,31 +334,27 @@ def allocate(
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]

if self.shard_count == 1:
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
else:
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]
return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]
return shards

return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]

def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
Expand Down Expand Up @@ -331,6 +423,8 @@ def read_cache_partition(
for index, read_into_partition in enumerate(read_into_partitions):
read_cache_partition(index, read_into_partition)

return tuple([p[:, :seq_len, :] for p in read_into_partitions])

def write_timestep(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
Expand Down
156 changes: 53 additions & 103 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,16 @@ def forward(
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride

if self.cache.is_paged:
xk, xv = self.transact_cache_paged(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
elif self.cache.is_direct:
xk, xv = self.transact_cache_direct(
xk_cache_update=xk,
xv_cache_update=xv,
start_positions=start_positions,
kv_seq_len=kv_seq_len,
cache_state=cache_state,
)
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}")
xk, xv = self.transact_cache(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Expand kv heads for GQA.
gqa_n_rep = self.head_count // self.head_count_kv
Expand Down Expand Up @@ -202,58 +191,20 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
h = h + attn_output
return h

def transact_cache_direct(
self,
*,
cache_state: list[torch.Tensor],
xk_cache_update: torch.Tensor,
xv_cache_update: torch.Tensor,
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
):
bs, batch_seq_len, _, _ = xk_cache_update.shape
cache_k = cache_state[self.block_index * 2]
cache_v = cache_state[self.block_index * 2 + 1]

if start_positions is None:
# Prefill. Write the entire cache.
cache_k[:, :batch_seq_len] = xk_cache_update
cache_v[:, :batch_seq_len] = xv_cache_update
return xk_cache_update, xv_cache_update
else:
# Decode. Write a single timestep.
# TODO: This needs to be reworked with index ops.
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
for b in range(bs):
# Make a tensor because indices must be all tensors, so we can avoid
# doing start_positions[row_index].item(), which generates a lot of SymInts.
row_index = torch.tensor(
b, dtype=torch.int64, device=xk_cache_update.device
)
row_start_pos = start_positions[row_index]
cache_k.index_put_(
(row_index, row_start_pos), xk_cache_update[row_index, 0]
)
cache_v.index_put_(
(row_index, row_start_pos), xv_cache_update[row_index, 0]
)
return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len]

def transact_cache_paged(
def transact_cache(
self,
*,
xk_cache_update: torch.Tensor,
xv_cache_update: torch.Tensor,
cache_state: list[torch.Tensor],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: torch.Tensor,
seq_block_ids: Optional[torch.Tensor],
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
cache = self.cache.paged
cache = self.cache
# Manage the cache.
if start_positions is None:
# Prefill: Write the entire cache.
Expand All @@ -264,46 +215,45 @@ def transact_cache_paged(
page_ids=seq_block_ids,
)
return xk_cache_update, xv_cache_update
else:
# Decode at ragged start positions.
# We need to initialize/read the K/V from the cache for the whole
# sequence. Note that at this point, it is possible to fork and
# use a memory efficient attention kernel that can do indirect
# reads, skipping this materialization. This path is taken for
# a decode step.
assert xk_temp is not None and xv_temp is not None
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride

# Write our one updated cache row into the cache.
cache.write_timestep(
cache_state,
cache_partitions=[
xk_cache_update,
xv_cache_update,
],
transformer_block_index=self.block_index,
seq_positions=start_positions,
page_ids=seq_block_ids,
)

# Restore from the cache.
cache.read(
cache_state,
read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
)
# Decode at ragged start positions.
# We need to initialize/read the K/V from the cache for the whole
# sequence. Note that at this point, it is possible to fork and
# use a memory efficient attention kernel that can do indirect
# reads, skipping this materialization. This path is taken for
# a decode step.
assert xk_temp is not None and xv_temp is not None
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride

# Write our one updated cache row into the cache.
cache.write_timestep(
cache_state,
cache_partitions=[
xk_cache_update,
xv_cache_update,
],
transformer_block_index=self.block_index,
seq_positions=start_positions,
page_ids=seq_block_ids,
)

# Restore from the cache.
xk, xv = cache.read(
cache_state,
read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
seq_len=kv_seq_len,
)

# For computation, we create a subview of the xk/xv tensors to have
# a sequence length covering the blocked size. This must include
# the newly added row (the caller is responsible for ensuring that
# every block has at least one row left). We'll compute on this
# ragged view and use an appropriate mask.
xk = xk_temp[:, 0:kv_seq_len, ...]
xv = xv_temp[:, 0:kv_seq_len, ...]
return xk, xv
# For computation, we create a subview of the xk/xv tensors to have
# a sequence length covering the blocked size. This must include
# the newly added row (the caller is responsible for ensuring that
# every block has at least one row left). We'll compute on this
# ragged view and use an appropriate mask.
return xk, xv
2 changes: 1 addition & 1 deletion sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def _is_slicing_split_dim(self, key):
else:
# Any other collection is a indexing only dimension 0.
return self.shard_dim == 0
if len(key) < self.shard_dim:
if len(key) <= self.shard_dim:
return False
if not isinstance(key[self.shard_dim], slice):
return True
Expand Down
Loading

0 comments on commit 89e26c0

Please sign in to comment.