diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index bed0b451d..048bc364c 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -92,6 +92,7 @@ def __init__( attn_head_count: int, attn_head_dim: int, seq_length: int, + shard_count: int = 1, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): @@ -100,6 +101,7 @@ def __init__( self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.seq_length = seq_length + self.shard_count = shard_count self.device = device self.dtype = dtype @@ -113,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - return [ + allocations = [ torch.empty( - [bs, self.seq_length, self.attn_head_count, self.attn_head_dim], + [ + bs, + self.seq_length, + self.attn_head_count, + self.attn_head_dim, + ], dtype=self.dtype, device=self.device, ) for _ in range(2 * self.transformer_block_count) ] + if self.shard_count == 1: + return allocations + + return [ + ops.reshard_split(allocation, dim=2, count=self.shard_count) + for allocation in allocations + ] + + def read( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + transformer_block_index: int, + seq_len: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Reads cache partitions from the page table for the given page_ids. + + Args: + state: State struct as returned from allocate(). + read_into_partitions: List of cache partitions to read into in-place. + transformer_block_index: The index of the transformer block accessing + the cache. + page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids + to access. + + Returns a tuple of cache partitions (i.e. k and v caches for the transformer + block), linearized. Note that this reference approach to reading by + materializing linearly may not be terribly efficient unless if the + compiler can fuse the gather. + """ + read_count = len(read_into_partitions) + reads = [] + for i in range(read_count): + reads.append( + state[transformer_block_index * read_count + i][:, :seq_len, :, :] + ) + + return tuple(reads) + + def write_timestep( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # List of [bs, 1, attn_head_count, attn_head_dim] + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + # [bs] + seq_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, max_seqlen // block_pos_stride] + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes a single batched timestep across all cache partitions. + + Note that this internally loops over the batch size, which cannot be + dynamic. + """ + bs, _, _, _ = cache_partitions[0].shape + update_count = len(cache_partitions) + + for b in range(bs): + row_index = torch.tensor(b, dtype=torch.int64) + row_start_pos = seq_positions[row_index] + + for i, update in enumerate(cache_partitions): + cache = state[transformer_block_index * update_count + i] + cache.index_put_((row_index, row_start_pos), update[row_index, 0]) + + def write( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes cache partitions from a linear layout to the page table. + + This is the inverse of the linear read. The same caveat applies if the + in-place scatter cannot be fused. + """ + update_count = len(cache_partitions) + + for idx, update_src in enumerate(cache_partitions): + cache_dest = state[transformer_block_index * update_count + idx] + _, batch_seq_len, _, _ = update_src.shape + cache_dest[:, :batch_seq_len, :, :] = update_src + class PagedKVCache(BaseKVCache): """Implementation of a KV cache on top of a 'page table'. @@ -238,24 +334,19 @@ def allocate( """Allocates tensor state for a page table for the given capacity in pages. """ + shards = [ + torch.empty( + [page_count, self.page_slab_flat_dim], + dtype=self.dtype, + device=self.device, + ) + for _ in range(self.shard_count) + ] + if self.shard_count == 1: - return [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - ] - else: - shards = [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - for _ in range(self.shard_count) - ] - return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] + return shards + + return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] def read( self, @@ -263,6 +354,7 @@ def read( *, read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, + seq_len: int, page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Reads cache partitions from the page table for the given page_ids. @@ -331,6 +423,8 @@ def read_cache_partition( for index, read_into_partition in enumerate(read_into_partitions): read_cache_partition(index, read_into_partition) + return tuple([p[:, :seq_len, :] for p in read_into_partitions]) + def write_timestep( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 59ed7b43a..958dc954e 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -113,27 +113,16 @@ def forward( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - if self.cache.is_paged: - xk, xv = self.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - elif self.cache.is_direct: - xk, xv = self.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + xk, xv = self.transact_cache( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -202,58 +191,20 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: h = h + attn_output return h - def transact_cache_direct( - self, - *, - cache_state: list[torch.Tensor], - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - ): - bs, batch_seq_len, _, _ = xk_cache_update.shape - cache_k = cache_state[self.block_index * 2] - cache_v = cache_state[self.block_index * 2 + 1] - - if start_positions is None: - # Prefill. Write the entire cache. - cache_k[:, :batch_seq_len] = xk_cache_update - cache_v[:, :batch_seq_len] = xv_cache_update - return xk_cache_update, xv_cache_update - else: - # Decode. Write a single timestep. - # TODO: This needs to be reworked with index ops. - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - for b in range(bs): - # Make a tensor because indices must be all tensors, so we can avoid - # doing start_positions[row_index].item(), which generates a lot of SymInts. - row_index = torch.tensor( - b, dtype=torch.int64, device=xk_cache_update.device - ) - row_start_pos = start_positions[row_index] - cache_k.index_put_( - (row_index, row_start_pos), xk_cache_update[row_index, 0] - ) - cache_v.index_put_( - (row_index, row_start_pos), xv_cache_update[row_index, 0] - ) - return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] - - def transact_cache_paged( + def transact_cache( self, *, xk_cache_update: torch.Tensor, xv_cache_update: torch.Tensor, cache_state: list[torch.Tensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, + seq_block_ids: Optional[torch.Tensor], kv_seq_len: int, start_positions: Optional[torch.Tensor] = None, xk_temp: Optional[torch.Tensor] = None, xv_temp: Optional[torch.Tensor] = None, ): - cache = self.cache.paged + cache = self.cache # Manage the cache. if start_positions is None: # Prefill: Write the entire cache. @@ -264,46 +215,45 @@ def transact_cache_paged( page_ids=seq_block_ids, ) return xk_cache_update, xv_cache_update - else: - # Decode at ragged start positions. - # We need to initialize/read the K/V from the cache for the whole - # sequence. Note that at this point, it is possible to fork and - # use a memory efficient attention kernel that can do indirect - # reads, skipping this materialization. This path is taken for - # a decode step. - assert xk_temp is not None and xv_temp is not None - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride - - # Write our one updated cache row into the cache. - cache.write_timestep( - cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], - transformer_block_index=self.block_index, - seq_positions=start_positions, - page_ids=seq_block_ids, - ) - # Restore from the cache. - cache.read( - cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) + # Decode at ragged start positions. + # We need to initialize/read the K/V from the cache for the whole + # sequence. Note that at this point, it is possible to fork and + # use a memory efficient attention kernel that can do indirect + # reads, skipping this materialization. This path is taken for + # a decode step. + assert xk_temp is not None and xv_temp is not None + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride + + # Write our one updated cache row into the cache. + cache.write_timestep( + cache_state, + cache_partitions=[ + xk_cache_update, + xv_cache_update, + ], + transformer_block_index=self.block_index, + seq_positions=start_positions, + page_ids=seq_block_ids, + ) + + # Restore from the cache. + xk, xv = cache.read( + cache_state, + read_into_partitions=[ + xk_temp[:, 0:kv_seq_len, ...], + xv_temp[:, 0:kv_seq_len, ...], + ], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + seq_len=kv_seq_len, + ) - # For computation, we create a subview of the xk/xv tensors to have - # a sequence length covering the blocked size. This must include - # the newly added row (the caller is responsible for ensuring that - # every block has at least one row left). We'll compute on this - # ragged view and use an appropriate mask. - xk = xk_temp[:, 0:kv_seq_len, ...] - xv = xv_temp[:, 0:kv_seq_len, ...] - return xk, xv + # For computation, we create a subview of the xk/xv tensors to have + # a sequence length covering the blocked size. This must include + # the newly added row (the caller is responsible for ensuring that + # every block has at least one row left). We'll compute on this + # ragged view and use an appropriate mask. + return xk, xv diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 226ffd777..7b3d2e04b 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -990,7 +990,7 @@ def _is_slicing_split_dim(self, key): else: # Any other collection is a indexing only dimension 0. return self.shard_dim == 0 - if len(key) < self.shard_dim: + if len(key) <= self.shard_dim: return False if not isinstance(key[self.shard_dim], slice): return True diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py new file mode 100644 index 000000000..65b42c986 --- /dev/null +++ b/sharktank/tests/layers/kv_cache_test.py @@ -0,0 +1,502 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest + +import torch + +from sharktank.ops import replicate, reshard_split, unshard +from sharktank.layers import * +from sharktank.types import * + + +def test_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + shard_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) + + +def test_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + shard_count = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + page_ids = replicate(page_ids, shard_count) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + empty_k = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + read_empty = [empty_k, empty_v] + + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + empty_k = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + read_back = cache.read( + allocation, + read_into_partitions=[empty_k, empty_v], + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d58874f25..d7b6a0b33 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -123,6 +123,7 @@ def testRead(self): read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, ) sharded_read_into_partitions = deepcopy( [ @@ -136,6 +137,7 @@ def testRead(self): read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, ) for unsharded, sharded in zip( read_into_partitions, sharded_read_into_partitions