From 1b0f2f0d05a1028f45e9ae6251aca1a5ef0b07f2 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 30 Oct 2024 15:16:31 -0500 Subject: [PATCH 1/2] Introduce CausalLMModelABC We do not have a clearly defined interface for LMs. Decode and prefill have different signature when exporting to IREE. Here is added a new ABC CausalLMModelABC that makes a distinction between the two variants. The BaseCausalLMModel provides a default implementation for the new prefill_from_seq_lens and decode_from_seq_lens methods. The export script export_paged_llm_v1 does too much in its exported functions. It computes the attention mask then. It shards its arguments and unshards its result. This change lets it be a thiner wrapper around the new interface functions. Make paged_llm_v1.TorchGenerator use the new interface methods. --- .../sharktank/examples/export_paged_llm_v1.py | 67 ++---- sharktank/sharktank/examples/paged_llm_v1.py | 29 +-- sharktank/sharktank/layers/__init__.py | 5 +- sharktank/sharktank/layers/causal_llm.py | 197 +++++++++++++++++- sharktank/sharktank/models/grok/grok.py | 6 +- sharktank/sharktank/models/llama/llama.py | 8 +- sharktank/sharktank/models/mixtral/mixtral.py | 6 +- sharktank/tests/models/llama/kv_cache_test.py | 4 +- 8 files changed, 233 insertions(+), 89 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 40cfea94f..d5aad1e7f 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -174,7 +174,7 @@ def generate_batch_prefill(bs: int): "tokens": {1: sl_dim}, "seq_lens": {}, "seq_block_ids": {1: block_dim}, - "cs": cache_dynamic_shapes, + "cache_state": cache_dynamic_shapes, } print(f"Exporting prefill_bs{bs}") @@ -186,40 +186,22 @@ def generate_batch_prefill(bs: int): strict=args.strict, arg_device=arg_affinities, ) - def _(model, tokens, seq_lens, seq_block_ids, cs): + def _(model, tokens, seq_lens, seq_block_ids, cache_state): if ( model.config.tensor_parallelism_size == 1 and model.config.kv_cache_type == "direct" ): - cache_tensors = torch.unbind(cs) - else: - cache_tensors = cs - - sl = tokens.shape[1] - input_mask = model.input_mask(seq_lens, sl) - attention_mask = model.attention_mask(input_mask) - - if llama_config.tensor_parallelism_size != 1: - shard_count = llama_config.tensor_parallelism_size - - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) - - cache_tensors = repack_cache(cs, cache_shard_dim) + cache_state = torch.unbind(cache_state) + if model.config.tensor_parallelism_size != 1: + cache_state = repack_cache(cache_state, cache_shard_dim) - logits = model.prefill( - tokens, - attention_mask=attention_mask, + return model.prefill_from_seq_lens( + tokens=tokens, + seq_lens=seq_lens, seq_block_ids=seq_block_ids, - cache_state=cache_tensors, + cache_state=cache_state, ) - if llama_config.tensor_parallelism_size != 1: - logits = ops.unshard(logits) - - return logits - def generate_batch_decode(bs: int): tokens = torch.ones(bs, 1, dtype=torch.int64) seq_lens = torch.ones(bs, dtype=torch.int64) @@ -274,34 +256,21 @@ def _( seq_block_ids, cache_state, ): - input_mask = model.input_mask( - seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride - ) - attention_mask = model.decode_attention_mask(input_mask) - - if llama_config.tensor_parallelism_size != 1: - shard_count = llama_config.tensor_parallelism_size - - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - start_positions = ops.replicate(start_positions, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) - + if ( + model.config.tensor_parallelism_size == 1 + and model.config.kv_cache_type == "direct" + ): + cache_state = torch.unbind(cache_state) + if model.config.tensor_parallelism_size != 1: cache_state = repack_cache(cache_state, cache_shard_dim) - - logits = model.decode( - tokens, - attention_mask=attention_mask, + return model.decode_from_seq_lens( + tokens=tokens, + seq_lens=seq_lens, start_positions=start_positions, seq_block_ids=seq_block_ids, cache_state=cache_state, ) - if llama_config.tensor_parallelism_size != 1: - logits = ops.unshard(logits) - - return logits - bsizes = [] for bs in args.bs: generate_batch_prefill(bs) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 2ff664b19..0cb76f09e 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -32,7 +32,7 @@ class TorchGenerator: def __init__( self, - model: PagedLlamaModelV1, + model: CausalLMModelABC, tokenizer: InferenceTokenizer, page_cache_size: int = 128, # Need to look at the model more for this. @@ -162,17 +162,14 @@ def compute_prefill_logits( def prefill(self): model = self.parent.model - attention_mask = model.attention_mask( - model.input_mask(self.seq_lens, self.token_ids.shape[1]) - ) seq_block_ids_tensor = self.pad_block_ids() print(f":: Invoke prefill:") trace_tensor("prefill.token_ids", self.token_ids) + trace_tensor("prefill.seq_lens", self.seq_lens) trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) - trace_tensor("prefill.attention_mask", attention_mask) - logits = model.prefill( + self.logits = model.prefill_from_seq_lens( self.token_ids, - attention_mask=attention_mask, + seq_lens=self.seq_lens, seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) @@ -181,7 +178,7 @@ def prefill(self): # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. tokens = torch.tensor( - model.extract_tokens_from_logits(logits, self.seq_lens) + model.extract_tokens_from_logits(self.logits, self.seq_lens) ).unsqueeze(1) print(f":: Prefill results:\n{tokens.tolist()}") self.add_result_token(tokens) @@ -194,28 +191,22 @@ def decode(self): self.allocate_seq_block_ids() # TODO: Allocate more blocks on overflow. seq_block_ids_tensor = self.pad_block_ids() - decode_attention_mask = model.decode_attention_mask( - model.input_mask( - self.seq_lens, - seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, - ) - ) trace_tensor("decode.token_ids", self.next_tokens) + trace_tensor("decode.seq_lens", self.seq_lens) trace_tensor("decode.start_positions", start_positions) trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) - trace_tensor("decode.attention_mask", decode_attention_mask) - logits = model.decode( + self.logits = model.decode_from_seq_lens( self.next_tokens, - attention_mask=decode_attention_mask, + seq_lens=self.seq_lens, start_positions=start_positions, seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) - trace_tensor("decode.logits", logits) + trace_tensor("decode.logits", self.logits) # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. tokens = torch.tensor( - model.extract_tokens_from_logits(logits, [1] * self.bs), + model.extract_tokens_from_logits(self.logits, [1] * self.bs), device=self.parent.model.device, ).unsqueeze(1) self.add_result_token(tokens) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index fd56ec872..aeb5f591c 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -7,7 +7,10 @@ from .base import BaseLayer, ThetaLayer from .conv import Conv2DLayer from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache -from .causal_llm import BaseCausalLMModel +from .causal_llm import ( + CausalLMModelABC, + BaseCausalLMModel, +) from .linear import LinearLayer from .norm import RMSNormLayer from .rotary_embedding import RotaryEmbeddingLayer diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 7a09995a8..548c06862 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -4,17 +4,136 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional +from typing import Optional, Union +from abc import ABC, abstractmethod import torch -from ..types import Theta +from ..types import SplitPrimitiveTensor, ReplicatedTensor +from .. import ops from .base import ( - ThetaLayer, + BaseLayer, ) -class BaseCausalLMModel(ThetaLayer): +class CausalLMModelABC(ABC): + """Interface for causal LM models.""" + + @abstractmethod + def generate_causal_context_mask(self) -> torch.Tensor: + raise NotImplementedError() + + def input_mask( + self, + # [bs] of integers + seq_lens: torch.Tensor, + batch_seqlen: int, + ): + """Compute a boolean input mask for a batch of sequence lengths. + + The mask will be [bs, batch_seqlen] with True at any position that is + masked. + """ + raise NotImplementedError() + + def decode_attention_mask(self, boolean_input_mask: torch.Tensor): + raise NotImplementedError() + + def attention_mask( + self, + input_mask: torch.Tensor, + *, + causal_context_mask: Optional[torch.Tensor] = None, + ): + """Generates a causal attention mask of [1, 1, sl, sl] of activation dtype. + + All masked positions are -inf and unmasked are 0.0. + + The pre-initialized causal context mask can be passed in. If not, then + it will either be generated or use the initialization time buffer. + Since this is a bool tensor of context_length^2, different deployment + scenarios can benefit from managing this in different ways. + """ + raise NotImplementedError() + + def extract_tokens_from_logits( + self, logits: torch.Tensor, seq_lens: list[int] + ) -> list[int]: + """Extracts tokens from a batch of logits (B, S, D). + + The length of seq_lens must be equal to the batch size. + Note that there are ways to code the indexing as tensor operations + but it just creates a bunch of weirdly shaped little work on the + accelerator. Statically looping like this is more efficient. + """ + raise NotImplementedError() + + @abstractmethod + def prefill( + self, + # [bs, batch_seq_len] + tokens: Union[torch.Tensor, ReplicatedTensor], + *, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: Union[torch.Tensor, ReplicatedTensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + raise NotImplementedError() + + @abstractmethod + def prefill_from_seq_lens( + self, + # [bs, batch_seq_len] + tokens: torch.Tensor, + *, + # [bs] + seq_lens: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + """This prefill variant accepts seq_lens instead of an attention_mask. + It also does not support sharded arguments other than the cache state.""" + raise NotImplementedError() + + @abstractmethod + def decode( + self, + # [bs, 1] + tokens: Union[torch.Tensor, ReplicatedTensor], + *, + # [bs, 1, 1, batch_seq_len] + attention_mask: Union[torch.Tensor, ReplicatedTensor], + # [bs] of starting positions + start_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + raise NotImplementedError() + + @abstractmethod + def decode_from_seq_lens( + self, + # [bs, 1] + tokens: torch.Tensor, + *, + # [bs] + seq_lens: torch.Tensor, + # [bs] of starting positions + start_positions: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + """This decode variant accepts seq_lens instead of an attention_mask. + It also does not support sharded arguments other than the cache state.""" + raise NotImplementedError() + + +class BaseCausalLMModel(BaseLayer): """Base class for causal LM models. This provides some utilities and common API surface related to masking @@ -25,16 +144,14 @@ class BaseCausalLMModel(ThetaLayer): def __init__( self, - theta: Theta, *, context_length: int, static_tables: bool = True, - static_context_mask: bool = False, device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, ): - super().__init__(theta) + super().__init__() self.device = device self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype @@ -149,3 +266,69 @@ def extract_tokens_from_logits( step_logits = logits[batch, seq_len - 1] results.append(torch.argmax(step_logits)) return results + + def prefill_from_seq_lens( + self, + tokens: torch.Tensor, + *, + seq_lens: torch.Tensor, + seq_block_ids: torch.Tensor, + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + batch_seq_len = tokens.shape[1] + input_mask = self.input_mask(seq_lens, batch_seq_len) + attention_mask = self.attention_mask(input_mask) + + if self.config.tensor_parallelism_size != 1: + shard_count = self.config.tensor_parallelism_size + + tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) + seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + + logits = self.prefill( + tokens, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + + if self.config.tensor_parallelism_size != 1: + logits = ops.unshard(logits) + + return logits + + def decode_from_seq_lens( + self, + tokens: torch.Tensor, + *, + seq_lens: torch.Tensor, + start_positions: torch.Tensor, + seq_block_ids: torch.Tensor, + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + input_mask = self.input_mask( + seq_lens, seq_block_ids.shape[1] * self.cache.block_seq_stride + ) + attention_mask = self.decode_attention_mask(input_mask) + + if self.config.tensor_parallelism_size != 1: + shard_count = self.config.tensor_parallelism_size + + tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) + start_positions = ops.replicate(start_positions, count=shard_count) + seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + + logits = self.decode( + tokens, + attention_mask=attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + + if self.config.tensor_parallelism_size != 1: + logits = ops.unshard(logits) + + return logits diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 077e4e064..13b67d2bd 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -26,7 +26,7 @@ ################################################################################ -class PagedGrokModelV1(BaseCausalLMModel): +class PagedGrokModelV1(BaseCausalLMModel, CausalLMModelABC): """Grok model with a paged KV cache and supporting variable sequence length batched inference. @@ -50,8 +50,8 @@ class PagedGrokModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - super().__init__( - theta, + BaseCausalLMModel.__init__( + self, context_length=config.hp.context_length, device=config.device, activation_dtype=config.activation_dtype, diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 656b4432b..ee03dbb76 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -7,7 +7,7 @@ from typing import Optional from dataclasses import dataclass -from typing import Union +from typing import Any, Union import torch import torch.nn as nn @@ -27,7 +27,7 @@ ################################################################################ -class PagedLlamaModelV1(BaseCausalLMModel): +class PagedLlamaModelV1(BaseCausalLMModel, CausalLMModelABC): """LlamaModel with a paged KV cache and supporting variable sequence length batched inference. @@ -64,8 +64,8 @@ class PagedLlamaModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - super().__init__( - theta, + BaseCausalLMModel.__init__( + self, context_length=config.hp.context_length, static_tables=config.static_tables, device=config.device, diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index e2995dfde..e51d8dc74 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -28,7 +28,7 @@ ################################################################################ -class PagedMixtralModelV1(BaseCausalLMModel): +class PagedMixtralModelV1(BaseCausalLMModel, CausalLMModelABC): """MixtralModel with a paged KV cache and supporting variable sequence length batched inference. @@ -52,8 +52,8 @@ class PagedMixtralModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - super().__init__( - theta, + BaseCausalLMModel.__init__( + self, context_length=config.hp.context_length, device=config.device, activation_dtype=config.activation_dtype, diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index a80575951..fce0b7cc6 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -107,9 +107,7 @@ def setUp(self): self.embedding_batch_mask = self.attention_embedding.compute_batch_mask( self.start_positions, batch_seq_len=1 ) - self.model = causal_llm.BaseCausalLMModel( - self.attention_block_theta, context_length=self.max_seq_len - ) + self.model = causal_llm.BaseCausalLMModel(context_length=self.max_seq_len) self.prefill_attention_mask = self.model.attention_mask( self.model.input_mask(self.start_positions, self.seq_len) ) From d0a946666b7d8e4429ae7964053385c47d5d1458 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 7 Nov 2024 07:09:15 -0600 Subject: [PATCH 2/2] Remove CausalLMModelABC --- sharktank/sharktank/examples/paged_llm_v1.py | 2 +- sharktank/sharktank/layers/__init__.py | 1 - sharktank/sharktank/layers/causal_llm.py | 145 ++++-------------- sharktank/sharktank/models/grok/grok.py | 5 +- sharktank/sharktank/models/llama/llama.py | 7 +- sharktank/sharktank/models/mixtral/mixtral.py | 5 +- 6 files changed, 36 insertions(+), 129 deletions(-) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 0cb76f09e..98dc872de 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -32,7 +32,7 @@ class TorchGenerator: def __init__( self, - model: CausalLMModelABC, + model: BaseCausalLMModel, tokenizer: InferenceTokenizer, page_cache_size: int = 128, # Need to look at the model more for this. diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index aeb5f591c..e0c9ea0ed 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -8,7 +8,6 @@ from .conv import Conv2DLayer from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache from .causal_llm import ( - CausalLMModelABC, BaseCausalLMModel, ) from .linear import LinearLayer diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 548c06862..8eb8b0d2d 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -16,123 +16,6 @@ ) -class CausalLMModelABC(ABC): - """Interface for causal LM models.""" - - @abstractmethod - def generate_causal_context_mask(self) -> torch.Tensor: - raise NotImplementedError() - - def input_mask( - self, - # [bs] of integers - seq_lens: torch.Tensor, - batch_seqlen: int, - ): - """Compute a boolean input mask for a batch of sequence lengths. - - The mask will be [bs, batch_seqlen] with True at any position that is - masked. - """ - raise NotImplementedError() - - def decode_attention_mask(self, boolean_input_mask: torch.Tensor): - raise NotImplementedError() - - def attention_mask( - self, - input_mask: torch.Tensor, - *, - causal_context_mask: Optional[torch.Tensor] = None, - ): - """Generates a causal attention mask of [1, 1, sl, sl] of activation dtype. - - All masked positions are -inf and unmasked are 0.0. - - The pre-initialized causal context mask can be passed in. If not, then - it will either be generated or use the initialization time buffer. - Since this is a bool tensor of context_length^2, different deployment - scenarios can benefit from managing this in different ways. - """ - raise NotImplementedError() - - def extract_tokens_from_logits( - self, logits: torch.Tensor, seq_lens: list[int] - ) -> list[int]: - """Extracts tokens from a batch of logits (B, S, D). - - The length of seq_lens must be equal to the batch size. - Note that there are ways to code the indexing as tensor operations - but it just creates a bunch of weirdly shaped little work on the - accelerator. Statically looping like this is more efficient. - """ - raise NotImplementedError() - - @abstractmethod - def prefill( - self, - # [bs, batch_seq_len] - tokens: Union[torch.Tensor, ReplicatedTensor], - *, - # [1, 1, batch_seq_len, batch_seq_len] - attention_mask: Union[torch.Tensor, ReplicatedTensor], - # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: Union[torch.Tensor, ReplicatedTensor], - cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - ): - raise NotImplementedError() - - @abstractmethod - def prefill_from_seq_lens( - self, - # [bs, batch_seq_len] - tokens: torch.Tensor, - *, - # [bs] - seq_lens: torch.Tensor, - # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - ): - """This prefill variant accepts seq_lens instead of an attention_mask. - It also does not support sharded arguments other than the cache state.""" - raise NotImplementedError() - - @abstractmethod - def decode( - self, - # [bs, 1] - tokens: Union[torch.Tensor, ReplicatedTensor], - *, - # [bs, 1, 1, batch_seq_len] - attention_mask: Union[torch.Tensor, ReplicatedTensor], - # [bs] of starting positions - start_positions: Union[torch.Tensor, ReplicatedTensor], - # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: Union[torch.Tensor, ReplicatedTensor], - cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - ): - raise NotImplementedError() - - @abstractmethod - def decode_from_seq_lens( - self, - # [bs, 1] - tokens: torch.Tensor, - *, - # [bs] - seq_lens: torch.Tensor, - # [bs] of starting positions - start_positions: torch.Tensor, - # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], - ): - """This decode variant accepts seq_lens instead of an attention_mask. - It also does not support sharded arguments other than the cache state.""" - raise NotImplementedError() - - class BaseCausalLMModel(BaseLayer): """Base class for causal LM models. @@ -267,6 +150,19 @@ def extract_tokens_from_logits( results.append(torch.argmax(step_logits)) return results + def prefill( + self, + # [bs, batch_seq_len] + tokens: Union[torch.Tensor, ReplicatedTensor], + *, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: Union[torch.Tensor, ReplicatedTensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + raise NotImplementedError() + def prefill_from_seq_lens( self, tokens: torch.Tensor, @@ -298,6 +194,21 @@ def prefill_from_seq_lens( return logits + def decode( + self, + # [bs, 1] + tokens: Union[torch.Tensor, ReplicatedTensor], + *, + # [bs, 1, 1, batch_seq_len] + attention_mask: Union[torch.Tensor, ReplicatedTensor], + # [bs] of starting positions + start_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + ): + raise NotImplementedError() + def decode_from_seq_lens( self, tokens: torch.Tensor, diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 13b67d2bd..a93b1145b 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -26,7 +26,7 @@ ################################################################################ -class PagedGrokModelV1(BaseCausalLMModel, CausalLMModelABC): +class PagedGrokModelV1(BaseCausalLMModel): """Grok model with a paged KV cache and supporting variable sequence length batched inference. @@ -50,8 +50,7 @@ class PagedGrokModelV1(BaseCausalLMModel, CausalLMModelABC): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - BaseCausalLMModel.__init__( - self, + super().__init__( context_length=config.hp.context_length, device=config.device, activation_dtype=config.activation_dtype, diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ee03dbb76..3ebf81cbd 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -7,7 +7,7 @@ from typing import Optional from dataclasses import dataclass -from typing import Any, Union +from typing import Union import torch import torch.nn as nn @@ -27,7 +27,7 @@ ################################################################################ -class PagedLlamaModelV1(BaseCausalLMModel, CausalLMModelABC): +class PagedLlamaModelV1(BaseCausalLMModel): """LlamaModel with a paged KV cache and supporting variable sequence length batched inference. @@ -64,8 +64,7 @@ class PagedLlamaModelV1(BaseCausalLMModel, CausalLMModelABC): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - BaseCausalLMModel.__init__( - self, + super().__init__( context_length=config.hp.context_length, static_tables=config.static_tables, device=config.device, diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index e51d8dc74..910984671 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -28,7 +28,7 @@ ################################################################################ -class PagedMixtralModelV1(BaseCausalLMModel, CausalLMModelABC): +class PagedMixtralModelV1(BaseCausalLMModel): """MixtralModel with a paged KV cache and supporting variable sequence length batched inference. @@ -52,8 +52,7 @@ class PagedMixtralModelV1(BaseCausalLMModel, CausalLMModelABC): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - BaseCausalLMModel.__init__( - self, + super().__init__( context_length=config.hp.context_length, device=config.device, activation_dtype=config.activation_dtype,