Skip to content

Commit

Permalink
Enable flash attention by default and make block size 32
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Dec 12, 2024
1 parent d279aff commit 0469c25
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main():
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
kv_cache_type="paged" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
)
llama_config.fake_quant = args.fake_quant
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class LlamaModelConfig:

# Block sequence stride for a paged KV cache. This must divide evenly
# into the context length.
block_seq_stride: int = 16
block_seq_stride: int = 32

# Either "paged" or "direct".
kv_cache_type: str = "paged"
Expand All @@ -167,7 +167,7 @@ class LlamaModelConfig:
tensor_parallelism_size: int = 1

# Which attention kernel to use.
attention_kernel: str = "decomposed"
attention_kernel: str = "torch"

# Indicates if running with HuggingFace implementation and ensures
# numerical equivalency to HuggingFace's LLaMa if true (by modifying
Expand Down
4 changes: 1 addition & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,12 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
attn_weights, values
) # (bs, heads, slen, head_dim)
else:
is_causal = True
attention_mask = None
attn_output = ops.scaled_dot_product_attention(
q=xq, # [bs, ..., sl, dim]
k=keys, # [bs, ..., sl, dim]
v=values, # [bs, ..., sl, dim]
a=attention_mask, # [bs, ..., sl, sl]
is_causal=is_causal, # assumes causal masking when true
is_causal=False, # assumes causal masking when true
scale=None, # defaults to 1/sqrt(dim)
)

Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def add_model_options(parser: argparse.ArgumentParser):
parser.add_argument(
"--attention-kernel",
type=str,
default="decomposed",
default="torch",
choices=["decomposed", "torch"],
)
parser.add_argument(
Expand Down

0 comments on commit 0469c25

Please sign in to comment.