diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 40cfea94f..d5aad1e7f 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -174,7 +174,7 @@ def generate_batch_prefill(bs: int): "tokens": {1: sl_dim}, "seq_lens": {}, "seq_block_ids": {1: block_dim}, - "cs": cache_dynamic_shapes, + "cache_state": cache_dynamic_shapes, } print(f"Exporting prefill_bs{bs}") @@ -186,40 +186,22 @@ def generate_batch_prefill(bs: int): strict=args.strict, arg_device=arg_affinities, ) - def _(model, tokens, seq_lens, seq_block_ids, cs): + def _(model, tokens, seq_lens, seq_block_ids, cache_state): if ( model.config.tensor_parallelism_size == 1 and model.config.kv_cache_type == "direct" ): - cache_tensors = torch.unbind(cs) - else: - cache_tensors = cs - - sl = tokens.shape[1] - input_mask = model.input_mask(seq_lens, sl) - attention_mask = model.attention_mask(input_mask) - - if llama_config.tensor_parallelism_size != 1: - shard_count = llama_config.tensor_parallelism_size - - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) - - cache_tensors = repack_cache(cs, cache_shard_dim) + cache_state = torch.unbind(cache_state) + if model.config.tensor_parallelism_size != 1: + cache_state = repack_cache(cache_state, cache_shard_dim) - logits = model.prefill( - tokens, - attention_mask=attention_mask, + return model.prefill_from_seq_lens( + tokens=tokens, + seq_lens=seq_lens, seq_block_ids=seq_block_ids, - cache_state=cache_tensors, + cache_state=cache_state, ) - if llama_config.tensor_parallelism_size != 1: - logits = ops.unshard(logits) - - return logits - def generate_batch_decode(bs: int): tokens = torch.ones(bs, 1, dtype=torch.int64) seq_lens = torch.ones(bs, dtype=torch.int64) @@ -274,34 +256,21 @@ def _( seq_block_ids, cache_state, ): - input_mask = model.input_mask( - seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride - ) - attention_mask = model.decode_attention_mask(input_mask) - - if llama_config.tensor_parallelism_size != 1: - shard_count = llama_config.tensor_parallelism_size - - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - start_positions = ops.replicate(start_positions, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) - + if ( + model.config.tensor_parallelism_size == 1 + and model.config.kv_cache_type == "direct" + ): + cache_state = torch.unbind(cache_state) + if model.config.tensor_parallelism_size != 1: cache_state = repack_cache(cache_state, cache_shard_dim) - - logits = model.decode( - tokens, - attention_mask=attention_mask, + return model.decode_from_seq_lens( + tokens=tokens, + seq_lens=seq_lens, start_positions=start_positions, seq_block_ids=seq_block_ids, cache_state=cache_state, ) - if llama_config.tensor_parallelism_size != 1: - logits = ops.unshard(logits) - - return logits - bsizes = [] for bs in args.bs: generate_batch_prefill(bs) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 2ff664b19..98dc872de 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -32,7 +32,7 @@ class TorchGenerator: def __init__( self, - model: PagedLlamaModelV1, + model: BaseCausalLMModel, tokenizer: InferenceTokenizer, page_cache_size: int = 128, # Need to look at the model more for this. @@ -162,17 +162,14 @@ def compute_prefill_logits( def prefill(self): model = self.parent.model - attention_mask = model.attention_mask( - model.input_mask(self.seq_lens, self.token_ids.shape[1]) - ) seq_block_ids_tensor = self.pad_block_ids() print(f":: Invoke prefill:") trace_tensor("prefill.token_ids", self.token_ids) + trace_tensor("prefill.seq_lens", self.seq_lens) trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) - trace_tensor("prefill.attention_mask", attention_mask) - logits = model.prefill( + self.logits = model.prefill_from_seq_lens( self.token_ids, - attention_mask=attention_mask, + seq_lens=self.seq_lens, seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) @@ -181,7 +178,7 @@ def prefill(self): # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. tokens = torch.tensor( - model.extract_tokens_from_logits(logits, self.seq_lens) + model.extract_tokens_from_logits(self.logits, self.seq_lens) ).unsqueeze(1) print(f":: Prefill results:\n{tokens.tolist()}") self.add_result_token(tokens) @@ -194,28 +191,22 @@ def decode(self): self.allocate_seq_block_ids() # TODO: Allocate more blocks on overflow. seq_block_ids_tensor = self.pad_block_ids() - decode_attention_mask = model.decode_attention_mask( - model.input_mask( - self.seq_lens, - seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, - ) - ) trace_tensor("decode.token_ids", self.next_tokens) + trace_tensor("decode.seq_lens", self.seq_lens) trace_tensor("decode.start_positions", start_positions) trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) - trace_tensor("decode.attention_mask", decode_attention_mask) - logits = model.decode( + self.logits = model.decode_from_seq_lens( self.next_tokens, - attention_mask=decode_attention_mask, + seq_lens=self.seq_lens, start_positions=start_positions, seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) - trace_tensor("decode.logits", logits) + trace_tensor("decode.logits", self.logits) # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. tokens = torch.tensor( - model.extract_tokens_from_logits(logits, [1] * self.bs), + model.extract_tokens_from_logits(self.logits, [1] * self.bs), device=self.parent.model.device, ).unsqueeze(1) self.add_result_token(tokens) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index fd56ec872..e0c9ea0ed 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -7,7 +7,9 @@ from .base import BaseLayer, ThetaLayer from .conv import Conv2DLayer from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache -from .causal_llm import BaseCausalLMModel +from .causal_llm import ( + BaseCausalLMModel, +) from .linear import LinearLayer from .norm import RMSNormLayer from .rotary_embedding import RotaryEmbeddingLayer diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 7a09995a8..8eb8b0d2d 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -4,17 +4,19 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional +from typing import Optional, Union +from abc import ABC, abstractmethod import torch -from ..types import Theta +from ..types import SplitPrimitiveTensor, ReplicatedTensor +from .. import ops from .base import ( - ThetaLayer, + BaseLayer, ) -class BaseCausalLMModel(ThetaLayer): +class BaseCausalLMModel(BaseLayer): """Base class for causal LM models. This provides some utilities and common API surface related to masking @@ -25,16 +27,14 @@ class BaseCausalLMModel(ThetaLayer): def __init__( self, - theta: Theta, *, context_length: int, static_tables: bool = True, - static_context_mask: bool = False, device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, ): - super().__init__(theta) + super().__init__() self.device = device self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype @@ -149,3 +149,97 @@ def extract_tokens_from_logits( step_logits = logits[batch, seq_len - 1] results.append(torch.argmax(step_logits)) return results + + def prefill( + self, + # [bs, batch_seq_len] + tokens: Union[torch.Tensor, ReplicatedTensor], + *, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: Union[torch.Tensor, ReplicatedTensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + raise NotImplementedError() + + def prefill_from_seq_lens( + self, + tokens: torch.Tensor, + *, + seq_lens: torch.Tensor, + seq_block_ids: torch.Tensor, + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + batch_seq_len = tokens.shape[1] + input_mask = self.input_mask(seq_lens, batch_seq_len) + attention_mask = self.attention_mask(input_mask) + + if self.config.tensor_parallelism_size != 1: + shard_count = self.config.tensor_parallelism_size + + tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) + seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + + logits = self.prefill( + tokens, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + + if self.config.tensor_parallelism_size != 1: + logits = ops.unshard(logits) + + return logits + + def decode( + self, + # [bs, 1] + tokens: Union[torch.Tensor, ReplicatedTensor], + *, + # [bs, 1, 1, batch_seq_len] + attention_mask: Union[torch.Tensor, ReplicatedTensor], + # [bs] of starting positions + start_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + raise NotImplementedError() + + def decode_from_seq_lens( + self, + tokens: torch.Tensor, + *, + seq_lens: torch.Tensor, + start_positions: torch.Tensor, + seq_block_ids: torch.Tensor, + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + input_mask = self.input_mask( + seq_lens, seq_block_ids.shape[1] * self.cache.block_seq_stride + ) + attention_mask = self.decode_attention_mask(input_mask) + + if self.config.tensor_parallelism_size != 1: + shard_count = self.config.tensor_parallelism_size + + tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) + start_positions = ops.replicate(start_positions, count=shard_count) + seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + + logits = self.decode( + tokens, + attention_mask=attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + + if self.config.tensor_parallelism_size != 1: + logits = ops.unshard(logits) + + return logits diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 077e4e064..a93b1145b 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -51,7 +51,6 @@ class PagedGrokModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp super().__init__( - theta, context_length=config.hp.context_length, device=config.device, activation_dtype=config.activation_dtype, diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 656b4432b..3ebf81cbd 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -65,7 +65,6 @@ class PagedLlamaModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp super().__init__( - theta, context_length=config.hp.context_length, static_tables=config.static_tables, device=config.device, diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index e2995dfde..910984671 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -53,7 +53,6 @@ class PagedMixtralModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp super().__init__( - theta, context_length=config.hp.context_length, device=config.device, activation_dtype=config.activation_dtype, diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index a80575951..fce0b7cc6 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -107,9 +107,7 @@ def setUp(self): self.embedding_batch_mask = self.attention_embedding.compute_batch_mask( self.start_positions, batch_seq_len=1 ) - self.model = causal_llm.BaseCausalLMModel( - self.attention_block_theta, context_length=self.max_seq_len - ) + self.model = causal_llm.BaseCausalLMModel(context_length=self.max_seq_len) self.prefill_attention_mask = self.model.attention_mask( self.model.input_mask(self.start_positions, self.seq_len) )