From ebc53af53a96e41160e72f7a018f630f9dfbd0bb Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 19:01:49 -0700 Subject: [PATCH 1/6] [llama] Update kv cache to have read/write functions Made the interfaces of both caches line up. This allows us to interface with the caches via their utility functions instead of modifying the model behavior. Some roughness still exists in their parameters but the irrelevant details are ignored for each implementation. Need to still add some slicing / ignoring on the page ids to make more flexible. --- sharktank/sharktank/layers/kv_cache.py | 82 ++++++ .../layers/paged_llama_attention_block.py | 156 ++++------- sharktank/tests/layers/kv_cache_test.py | 248 ++++++++++++++++++ 3 files changed, 383 insertions(+), 103 deletions(-) create mode 100644 sharktank/tests/layers/kv_cache_test.py diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index bed0b451d..ab4c84c85 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -122,6 +122,85 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: for _ in range(2 * self.transformer_block_count) ] + def read( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + transformer_block_index: int, + seq_len: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Reads cache partitions from the page table for the given page_ids. + + Args: + state: State struct as returned from allocate(). + read_into_partitions: List of cache partitions to read into in-place. + transformer_block_index: The index of the transformer block accessing + the cache. + page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids + to access. + + Returns a tuple of cache partitions (i.e. k and v caches for the transformer + block), linearized. Note that this reference approach to reading by + materializing linearly may not be terribly efficient unless if the + compiler can fuse the gather. + """ + read_count = len(read_into_partitions) + reads = [] + for i in range(read_count): + reads.append(state[transformer_block_index * read_count + i][:, :seq_len]) + + return tuple(reads) + + def write_timestep( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # List of [bs, 1, attn_head_count, attn_head_dim] + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + # [bs] + seq_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, max_seqlen // block_pos_stride] + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes a single batched timestep across all cache partitions. + + Note that this internally loops over the batch size, which cannot be + dynamic. + """ + bs, _, _, _ = cache_partitions[0].shape + update_count = len(cache_partitions) + + for b in range(bs): + row_index = torch.tensor(b, dtype=torch.int64) + row_start_pos = seq_positions[row_index] + + for i, update in enumerate(cache_partitions): + cache = state[transformer_block_index * update_count + i] + cache.index_put_((row_index, row_start_pos), update[row_index, 0]) + + def write( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes cache partitions from a linear layout to the page table. + + This is the inverse of the linear read. The same caveat applies if the + in-place scatter cannot be fused. + """ + update_count = len(cache_partitions) + + for idx, update_src in enumerate(cache_partitions): + cache_dest = state[transformer_block_index * update_count + idx] + _, batch_seq_len, _, _ = update_src.shape + cache_dest[:, :batch_seq_len] = update_src + class PagedKVCache(BaseKVCache): """Implementation of a KV cache on top of a 'page table'. @@ -263,6 +342,7 @@ def read( *, read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, + seq_len: int, page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Reads cache partitions from the page table for the given page_ids. @@ -331,6 +411,8 @@ def read_cache_partition( for index, read_into_partition in enumerate(read_into_partitions): read_cache_partition(index, read_into_partition) + return tuple([p[:, :seq_len, :] for p in read_into_partitions]) + def write_timestep( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 59ed7b43a..958dc954e 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -113,27 +113,16 @@ def forward( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - if self.cache.is_paged: - xk, xv = self.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - elif self.cache.is_direct: - xk, xv = self.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + xk, xv = self.transact_cache( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -202,58 +191,20 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: h = h + attn_output return h - def transact_cache_direct( - self, - *, - cache_state: list[torch.Tensor], - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - ): - bs, batch_seq_len, _, _ = xk_cache_update.shape - cache_k = cache_state[self.block_index * 2] - cache_v = cache_state[self.block_index * 2 + 1] - - if start_positions is None: - # Prefill. Write the entire cache. - cache_k[:, :batch_seq_len] = xk_cache_update - cache_v[:, :batch_seq_len] = xv_cache_update - return xk_cache_update, xv_cache_update - else: - # Decode. Write a single timestep. - # TODO: This needs to be reworked with index ops. - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - for b in range(bs): - # Make a tensor because indices must be all tensors, so we can avoid - # doing start_positions[row_index].item(), which generates a lot of SymInts. - row_index = torch.tensor( - b, dtype=torch.int64, device=xk_cache_update.device - ) - row_start_pos = start_positions[row_index] - cache_k.index_put_( - (row_index, row_start_pos), xk_cache_update[row_index, 0] - ) - cache_v.index_put_( - (row_index, row_start_pos), xv_cache_update[row_index, 0] - ) - return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] - - def transact_cache_paged( + def transact_cache( self, *, xk_cache_update: torch.Tensor, xv_cache_update: torch.Tensor, cache_state: list[torch.Tensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, + seq_block_ids: Optional[torch.Tensor], kv_seq_len: int, start_positions: Optional[torch.Tensor] = None, xk_temp: Optional[torch.Tensor] = None, xv_temp: Optional[torch.Tensor] = None, ): - cache = self.cache.paged + cache = self.cache # Manage the cache. if start_positions is None: # Prefill: Write the entire cache. @@ -264,46 +215,45 @@ def transact_cache_paged( page_ids=seq_block_ids, ) return xk_cache_update, xv_cache_update - else: - # Decode at ragged start positions. - # We need to initialize/read the K/V from the cache for the whole - # sequence. Note that at this point, it is possible to fork and - # use a memory efficient attention kernel that can do indirect - # reads, skipping this materialization. This path is taken for - # a decode step. - assert xk_temp is not None and xv_temp is not None - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride - - # Write our one updated cache row into the cache. - cache.write_timestep( - cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], - transformer_block_index=self.block_index, - seq_positions=start_positions, - page_ids=seq_block_ids, - ) - # Restore from the cache. - cache.read( - cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) + # Decode at ragged start positions. + # We need to initialize/read the K/V from the cache for the whole + # sequence. Note that at this point, it is possible to fork and + # use a memory efficient attention kernel that can do indirect + # reads, skipping this materialization. This path is taken for + # a decode step. + assert xk_temp is not None and xv_temp is not None + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride + + # Write our one updated cache row into the cache. + cache.write_timestep( + cache_state, + cache_partitions=[ + xk_cache_update, + xv_cache_update, + ], + transformer_block_index=self.block_index, + seq_positions=start_positions, + page_ids=seq_block_ids, + ) + + # Restore from the cache. + xk, xv = cache.read( + cache_state, + read_into_partitions=[ + xk_temp[:, 0:kv_seq_len, ...], + xv_temp[:, 0:kv_seq_len, ...], + ], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + seq_len=kv_seq_len, + ) - # For computation, we create a subview of the xk/xv tensors to have - # a sequence length covering the blocked size. This must include - # the newly added row (the caller is responsible for ensuring that - # every block has at least one row left). We'll compute on this - # ragged view and use an appropriate mask. - xk = xk_temp[:, 0:kv_seq_len, ...] - xv = xv_temp[:, 0:kv_seq_len, ...] - return xk, xv + # For computation, we create a subview of the xk/xv tensors to have + # a sequence length covering the blocked size. This must include + # the newly added row (the caller is responsible for ensuring that + # every block has at least one row left). We'll compute on this + # ragged view and use an appropriate mask. + return xk, xv diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py new file mode 100644 index 000000000..4dd2f2292 --- /dev/null +++ b/sharktank/tests/layers/kv_cache_test.py @@ -0,0 +1,248 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest + +import torch + +from sharktank.layers import * +from sharktank.types import * + + +def test_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) From 08a752f87f6546351ea10c5953bbee7fa995b515 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 21:14:03 -0700 Subject: [PATCH 2/6] fix test --- sharktank/tests/layers/sharded_paged_kv_cache_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d58874f25..2775b3ed0 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -123,6 +123,7 @@ def testRead(self): read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, + seq_len=self.block_seq_len * self.block_seq_stride ) sharded_read_into_partitions = deepcopy( [ @@ -136,6 +137,7 @@ def testRead(self): read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, + seq_len=self.block_seq_len * self.block_seq_stride ) for unsharded, sharded in zip( read_into_partitions, sharded_read_into_partitions From 8d28ac8026e20a712d030c82cc003d228c828353 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 21:18:12 -0700 Subject: [PATCH 3/6] formatting --- sharktank/tests/layers/sharded_paged_kv_cache_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index 2775b3ed0..d7b6a0b33 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -123,7 +123,7 @@ def testRead(self): read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, - seq_len=self.block_seq_len * self.block_seq_stride + seq_len=self.block_seq_len * self.block_seq_stride, ) sharded_read_into_partitions = deepcopy( [ @@ -137,7 +137,7 @@ def testRead(self): read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, - seq_len=self.block_seq_len * self.block_seq_stride + seq_len=self.block_seq_len * self.block_seq_stride, ) for unsharded, sharded in zip( read_into_partitions, sharded_read_into_partitions From 368661afe2e591732e1a21e395b03425079b77f5 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 22:01:10 -0700 Subject: [PATCH 4/6] add support for sharded direct cache --- sharktank/sharktank/layers/kv_cache.py | 48 ++++++----- sharktank/sharktank/types/tensors.py | 2 +- sharktank/tests/layers/kv_cache_test.py | 106 ++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 24 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index ab4c84c85..991834d09 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -92,6 +92,7 @@ def __init__( attn_head_count: int, attn_head_dim: int, seq_length: int, + shard_count: int = 1, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): @@ -100,6 +101,7 @@ def __init__( self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.seq_length = seq_length + self.shard_count = shard_count self.device = device self.dtype = dtype @@ -113,15 +115,20 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - return [ - torch.empty( - [bs, self.seq_length, self.attn_head_count, self.attn_head_dim], + shards = [[torch.empty( + [bs, self.seq_length, self.attn_head_count // self.shard_count, self.attn_head_dim], dtype=self.dtype, device=self.device, - ) + ) for i in range(self.shard_count)] for _ in range(2 * self.transformer_block_count) ] + if self.shard_count == 1: + return [shard[0] for shard in shards] + + return [SplitPrimitiveTensor(ts=shrds, shard_dim=2) for shrds in shards] + + def read( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], @@ -149,7 +156,7 @@ def read( read_count = len(read_into_partitions) reads = [] for i in range(read_count): - reads.append(state[transformer_block_index * read_count + i][:, :seq_len]) + reads.append(state[transformer_block_index * read_count + i][:, :seq_len, :, :]) return tuple(reads) @@ -199,7 +206,7 @@ def write( for idx, update_src in enumerate(cache_partitions): cache_dest = state[transformer_block_index * update_count + idx] _, batch_seq_len, _, _ = update_src.shape - cache_dest[:, :batch_seq_len] = update_src + cache_dest[:, :batch_seq_len, :, :] = update_src class PagedKVCache(BaseKVCache): @@ -317,24 +324,19 @@ def allocate( """Allocates tensor state for a page table for the given capacity in pages. """ + shards = [ + torch.empty( + [page_count, self.page_slab_flat_dim], + dtype=self.dtype, + device=self.device, + ) + for _ in range(self.shard_count) + ] + if self.shard_count == 1: - return [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - ] - else: - shards = [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - for _ in range(self.shard_count) - ] - return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] + return shards + + return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] def read( self, diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 226ffd777..7b3d2e04b 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -990,7 +990,7 @@ def _is_slicing_split_dim(self, key): else: # Any other collection is a indexing only dimension 0. return self.shard_dim == 0 - if len(key) < self.shard_dim: + if len(key) <= self.shard_dim: return False if not isinstance(key[self.shard_dim], slice): return True diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 4dd2f2292..9d981e5b4 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -8,6 +8,7 @@ import torch +from sharktank.ops import replicate, unshard from sharktank.layers import * from sharktank.types import * @@ -124,6 +125,111 @@ def test_direct(): torch.testing.assert_close(check_concat_1, read_back[1]) +def test_sharded_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + shard_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count // shard_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count // shard_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ) + + write_ones = SplitPrimitiveTensor(ts=[write_ones] * shard_count, shard_dim=2) + write_twos = SplitPrimitiveTensor(ts=[write_twos] * shard_count, shard_dim=2) + + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count // shard_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count // shard_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + + write_threes = SplitPrimitiveTensor(ts=[write_threes] * shard_count, shard_dim=2) + write_fours = SplitPrimitiveTensor(ts=[write_fours] * shard_count, shard_dim=2) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) + + def test_paged(): bs = 4 seq_length = 24 From 5076f328554dd31c8ed4989901dcbb60d84ed243 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 22:27:33 -0700 Subject: [PATCH 5/6] add sharded tests --- sharktank/tests/layers/kv_cache_test.py | 186 +++++++++++++++++++++--- 1 file changed, 167 insertions(+), 19 deletions(-) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 9d981e5b4..65b42c986 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -8,7 +8,7 @@ import torch -from sharktank.ops import replicate, unshard +from sharktank.ops import replicate, reshard_split, unshard from sharktank.layers import * from sharktank.types import * @@ -149,19 +149,25 @@ def test_sharded_direct(): write_seq_length = seq_length - 5 # Write a prefill in: - write_ones = torch.full( - (bs, write_seq_length, attn_head_count // shard_count, attn_head_dim), - 1.0, - dtype=torch.float32, - ) - write_twos = torch.full( - (bs, write_seq_length, attn_head_count // shard_count, attn_head_dim), - 2.0, - dtype=torch.float32, + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, ) - write_ones = SplitPrimitiveTensor(ts=[write_ones] * shard_count, shard_dim=2) - write_twos = SplitPrimitiveTensor(ts=[write_twos] * shard_count, shard_dim=2) + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) cache.write( allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 @@ -186,16 +192,17 @@ def test_sharded_direct(): torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) # Write timestep - write_threes = torch.full( - (bs, 1, attn_head_count // shard_count, attn_head_dim), 3.0, dtype=torch.float32 + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, ) - write_fours = torch.full( - (bs, 1, attn_head_count // shard_count, attn_head_dim), 4.0, dtype=torch.float32 + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, ) - write_threes = SplitPrimitiveTensor(ts=[write_threes] * shard_count, shard_dim=2) - write_fours = SplitPrimitiveTensor(ts=[write_fours] * shard_count, shard_dim=2) - write_pos = replicate( torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count ) @@ -352,3 +359,144 @@ def test_paged(): torch.testing.assert_close(check_concat_0, read_back[0]) torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + shard_count = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + page_ids = replicate(page_ids, shard_count) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + empty_k = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + read_empty = [empty_k, empty_v] + + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + empty_k = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + read_back = cache.read( + allocation, + read_into_partitions=[empty_k, empty_v], + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) From 9b9b0885ea7f343ef298add9fd5c3e0c0ee34846 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 22:41:18 -0700 Subject: [PATCH 6/6] cleanup --- sharktank/sharktank/layers/kv_cache.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 991834d09..048bc364c 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -115,19 +115,27 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - shards = [[torch.empty( - [bs, self.seq_length, self.attn_head_count // self.shard_count, self.attn_head_dim], + allocations = [ + torch.empty( + [ + bs, + self.seq_length, + self.attn_head_count, + self.attn_head_dim, + ], dtype=self.dtype, device=self.device, - ) for i in range(self.shard_count)] + ) for _ in range(2 * self.transformer_block_count) ] if self.shard_count == 1: - return [shard[0] for shard in shards] - - return [SplitPrimitiveTensor(ts=shrds, shard_dim=2) for shrds in shards] + return allocations + return [ + ops.reshard_split(allocation, dim=2, count=self.shard_count) + for allocation in allocations + ] def read( self, @@ -156,7 +164,9 @@ def read( read_count = len(read_into_partitions) reads = [] for i in range(read_count): - reads.append(state[transformer_block_index * read_count + i][:, :seq_len, :, :]) + reads.append( + state[transformer_block_index * read_count + i][:, :seq_len, :, :] + ) return tuple(reads)