diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 900c1a9ae..43242e0c0 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -74,7 +74,7 @@ def main(): tensor_parallelism_size=tensor_parallelism_size, use_hf=False, static_tables=False, # Rely on the compiler for hoisting tables. - kv_cache_type="direct" if args.bs == [1] else "paged", + kv_cache_type="paged" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, ) llama_config.fake_quant = args.fake_quant diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 8a443e6ca..6cf79402e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -144,7 +144,7 @@ class LlamaModelConfig: # Block sequence stride for a paged KV cache. This must divide evenly # into the context length. - block_seq_stride: int = 16 + block_seq_stride: int = 32 # Either "paged" or "direct". kv_cache_type: str = "paged" @@ -167,7 +167,7 @@ class LlamaModelConfig: tensor_parallelism_size: int = 1 # Which attention kernel to use. - attention_kernel: str = "decomposed" + attention_kernel: str = "torch" # Indicates if running with HuggingFace implementation and ensures # numerical equivalency to HuggingFace's LLaMa if true (by modifying diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 22647bf49..6bd33c93f 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -216,14 +216,12 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_weights, values ) # (bs, heads, slen, head_dim) else: - is_causal = True - attention_mask = None attn_output = ops.scaled_dot_product_attention( q=xq, # [bs, ..., sl, dim] k=keys, # [bs, ..., sl, dim] v=values, # [bs, ..., sl, dim] a=attention_mask, # [bs, ..., sl, sl] - is_causal=is_causal, # assumes causal masking when true + is_causal=False, # assumes causal masking when true scale=None, # defaults to 1/sqrt(dim) ) diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 99917c2d3..9fefeb66f 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -66,7 +66,7 @@ def add_model_options(parser: argparse.ArgumentParser): parser.add_argument( "--attention-kernel", type=str, - default="decomposed", + default="torch", choices=["decomposed", "torch"], ) parser.add_argument(