Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefill/decode from seq lens in BaseCausalLMModel #383

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 18 additions & 49 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def generate_batch_prefill(bs: int):
"tokens": {1: sl_dim},
"seq_lens": {},
"seq_block_ids": {1: block_dim},
"cs": cache_dynamic_shapes,
"cache_state": cache_dynamic_shapes,
}

print(f"Exporting prefill_bs{bs}")
Expand All @@ -186,40 +186,22 @@ def generate_batch_prefill(bs: int):
strict=args.strict,
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
def _(model, tokens, seq_lens, seq_block_ids, cache_state):
if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_tensors = torch.unbind(cs)
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

cache_tensors = repack_cache(cs, cache_shard_dim)
cache_state = torch.unbind(cache_state)
if model.config.tensor_parallelism_size != 1:
cache_state = repack_cache(cache_state, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=attention_mask,
return model.prefill_from_seq_lens(
tokens=tokens,
seq_lens=seq_lens,
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
cache_state=cache_state,
)

if llama_config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

def generate_batch_decode(bs: int):
tokens = torch.ones(bs, 1, dtype=torch.int64)
seq_lens = torch.ones(bs, dtype=torch.int64)
Expand Down Expand Up @@ -274,34 +256,21 @@ def _(
seq_block_ids,
cache_state,
):
input_mask = model.input_mask(
seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride
)
attention_mask = model.decode_attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
start_positions = ops.replicate(start_positions, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_state = torch.unbind(cache_state)
if model.config.tensor_parallelism_size != 1:
cache_state = repack_cache(cache_state, cache_shard_dim)

logits = model.decode(
tokens,
attention_mask=attention_mask,
return model.decode_from_seq_lens(
tokens=tokens,
seq_lens=seq_lens,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)

if llama_config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
Expand Down
29 changes: 10 additions & 19 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TorchGenerator:

def __init__(
self,
model: PagedLlamaModelV1,
model: CausalLMModelABC,
tokenizer: InferenceTokenizer,
page_cache_size: int = 128,
# Need to look at the model more for this.
Expand Down Expand Up @@ -162,17 +162,14 @@ def compute_prefill_logits(

def prefill(self):
model = self.parent.model
attention_mask = model.attention_mask(
model.input_mask(self.seq_lens, self.token_ids.shape[1])
)
seq_block_ids_tensor = self.pad_block_ids()
print(f":: Invoke prefill:")
trace_tensor("prefill.token_ids", self.token_ids)
trace_tensor("prefill.seq_lens", self.seq_lens)
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
trace_tensor("prefill.attention_mask", attention_mask)
logits = model.prefill(
self.logits = model.prefill_from_seq_lens(
self.token_ids,
attention_mask=attention_mask,
seq_lens=self.seq_lens,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
Expand All @@ -181,7 +178,7 @@ def prefill(self):
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, self.seq_lens)
model.extract_tokens_from_logits(self.logits, self.seq_lens)
).unsqueeze(1)
print(f":: Prefill results:\n{tokens.tolist()}")
self.add_result_token(tokens)
Expand All @@ -194,28 +191,22 @@ def decode(self):
self.allocate_seq_block_ids()
# TODO: Allocate more blocks on overflow.
seq_block_ids_tensor = self.pad_block_ids()
decode_attention_mask = model.decode_attention_mask(
model.input_mask(
self.seq_lens,
seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride,
)
)
trace_tensor("decode.token_ids", self.next_tokens)
trace_tensor("decode.seq_lens", self.seq_lens)
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)
logits = model.decode(
self.logits = model.decode_from_seq_lens(
self.next_tokens,
attention_mask=decode_attention_mask,
seq_lens=self.seq_lens,
start_positions=start_positions,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
trace_tensor("decode.logits", logits)
trace_tensor("decode.logits", self.logits)
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, [1] * self.bs),
model.extract_tokens_from_logits(self.logits, [1] * self.bs),
device=self.parent.model.device,
).unsqueeze(1)
self.add_result_token(tokens)
Expand Down
5 changes: 4 additions & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from .base import BaseLayer, ThetaLayer
from .conv import Conv2DLayer
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .causal_llm import BaseCausalLMModel
from .causal_llm import (
CausalLMModelABC,
BaseCausalLMModel,
)
from .linear import LinearLayer
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
Expand Down
Loading
Loading