Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a flag to select between nondecomposed and decomposed attention #296

Merged
merged 18 commits into from
Oct 23, 2024
Merged
1 change: 1 addition & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def main():
llama_config.use_hf = False
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
llama_config.attention_kernel = args.attention_kernel

if llama_config.hp.expert_count:
if llama_config.hp.model_arch == "grok":
Expand Down
11 changes: 9 additions & 2 deletions sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ def main():
parser.add_argument(
"--output-mlir",
help="Output file path for exported MLIR file",
default="/home/aramalin/sharktank/artifacts/paged_llama.mlir",
default="/tmp/sharktank/artifacts/paged_llama.mlir",
)
parser.add_argument(
"--output-config",
help="Output file path for exported config file",
default="/home/aramalin/sharktank/artifacts/paged_llama.json",
default="/tmp/sharktank/artifacts/paged_llama.json",
)
parser.add_argument(
"--bs",
Expand All @@ -192,6 +192,12 @@ def main():
help="Enable Causal attention",
action="store_true",
)
# TODO: move this to CLI to enable re-use with eager
parser.add_argument(
KyleHerndon marked this conversation as resolved.
Show resolved Hide resolved
"--attention_kernel",
help="decomposed/torch",
default="decomposed",
)

args = cli.parse(parser)

Expand Down Expand Up @@ -235,6 +241,7 @@ def main():
head_dim=llama_config.hp.attn_head_dim,
head_count_kv=llama_config.hp.attention_head_count_kv,
rms_epsilon=llama_config.hp.attention_layer_norm_rms_epsilon,
attention_kernel=args.attention_kernel,
)

def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
Expand Down
3 changes: 3 additions & 0 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class LlamaModelConfig:
# arguments.
tensor_parallelism_size: int = 1

# Which attention kernel to use.
attention_kernel: str = "decomposed"

# Indicates if running with HuggingFace implementation and ensures
# numerical equivalency to HuggingFace's LLaMa if true (by modifying
# rotary embedding).
Expand Down
53 changes: 36 additions & 17 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
attention_kernel: str = "decomposed",
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
):
Expand All @@ -47,6 +48,7 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.attention_kernel = attention_kernel
self.attention_scale = attention_scale
self.softcap = softcap

Expand Down Expand Up @@ -154,27 +156,44 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)

attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale
if self.attention_kernel == "decomposed":
attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale

# Flash attention.
if self.softcap is not None:
attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)

# Flash attention.
if self.softcap is not None:
attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)
self.assert_not_nan(attn_weights)

self.assert_not_nan(attn_weights)
# Apply attention mask.
self.trace_tensor("attn_weights", attn_weights, values=False)
if attention_mask is not None:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

# Apply attention mask.
self.trace_tensor("attn_weights", attn_weights, values=False)
if attention_mask is not None:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask
attn_weights = ops.softmax(
ops.to(attn_weights, dtype=torch.float32), dim=-1
)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(
attn_weights, values
) # (bs, heads, slen, head_dim)
else:
is_causal = attention_mask is None and batch_seq_len == 1
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=xq, # [bs, ..., sl, dim]
key=keys, # [bs, ..., sl, dim]
value=values, # [bs, ..., sl, dim]
attn_mask=attention_mask, # [bs, ..., sl, sl]
dropout_p=0.0,
is_causal=is_causal, # assumes causal masking when true
scale=None, # defaults to 1/sqrt(dim)
)

attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
Expand Down
4 changes: 4 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.cache = create_kv_cache(self.config)
self.activation_dtype = config.activation_dtype
self.use_hf = config.use_hf
self.attention_kernel = config.attention_kernel

self.add_module(
"token_embedding",
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
attention_kernel=self.attention_kernel,
)
for n in range(hp.block_count)
]
Expand Down Expand Up @@ -327,6 +329,7 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
attention_kernel: str = "decomposed",
):
super().__init__(theta)
self.add_module(
Expand All @@ -339,6 +342,7 @@ def __init__(
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
attention_kernel=attention_kernel,
),
)
self.add_module(
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def rms_norm_default(x, weight, *, epsilon: float) -> Tensor:
variance = x.pow(2).mean(-1, keepdim=True)
output = x * elementwise(torch.rsqrt, variance + epsilon)
# The cast here is to match the hf implementation, affects numerics
output = weight * to(output, weight.dtype)
output = elementwise(torch.mul, weight, to(output, weight.dtype))
return output


Expand Down
196 changes: 196 additions & 0 deletions sharktank/tests/layers/paged_llama_attention_block_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging

logging.basicConfig(level=logging.DEBUG)

import unittest

import torch

from iree.turbine import aot
from sharktank.layers import (
PagedLlamaAttentionBlock,
PagedKVCache,
RotaryEmbeddingLayer,
)
from sharktank.layers.testing import make_llama_attention_block_theta
from sharktank.types.tensors import DefaultPrimitiveTensor


class PagedLlamaAttentionBlockTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(12345)
self.transformer_block_count = 13
self.block_index = 1
self.shard_count = 3
self.head_count_kv = 2 * self.shard_count
self.attention_head_count = 5 * self.head_count_kv
self.attention_head_dim = 11 * 2
self.rms_epsilon = 0.01
self.block_seq_stride = 17
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
self.block_seqlen = 7
self.max_seqlen = self.block_seq_stride * self.block_seqlen
self.rope_freq_base = None
self.batch_size = 3
self.start_index = 0

def testExportDecomposed(self):
dtype = torch.float32

cache = PagedKVCache(
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
)

cache_state = cache.paged.allocate(self.page_count)
cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)

theta = make_llama_attention_block_theta(
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
head_dim=self.attention_head_dim,
embedding_length=self.embedding_length,
)
attn = PagedLlamaAttentionBlock(
theta=theta,
block_index=self.block_index,
cache=cache,
head_count=self.attention_head_count,
head_dim=self.attention_head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
attention_kernel="decomposed",
)

seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)

embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

class MyModule(torch.nn.Module):
def forward(self, h, seq_block_ids, cache_state):
return attn.forward(
h,
seq_block_ids=seq_block_ids,
embedding=embedding_module,
start_index=0,
cache_state=cache_state,
)

mod = MyModule()
h = torch.rand(
[
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
]
)
mod.forward(h, seq_block_ids, cache_state)
ep = torch.export.export(
mod,
args=(
h,
seq_block_ids,
cache_state,
),
)
output = aot.export(ep)
output.verify()
asm = str(output.mlir_module)
self.assertNotIn("scaled_dot_product_attention", asm)

def testExportNondecomposed(self):
dtype = torch.float32

cache = PagedKVCache(
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
)

cache_state = cache.paged.allocate(self.page_count)
cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)

theta = make_llama_attention_block_theta(
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
head_dim=self.attention_head_dim,
embedding_length=self.embedding_length,
)
attn = PagedLlamaAttentionBlock(
theta=theta,
block_index=self.block_index,
cache=cache,
head_count=self.attention_head_count,
head_dim=self.attention_head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
attention_kernel="torch",
)

seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)

embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

class MyModule(torch.nn.Module):
def forward(self, h, seq_block_ids, cache_state):
return attn.forward(
h,
seq_block_ids=seq_block_ids,
embedding=embedding_module,
start_index=0,
cache_state=cache_state,
)

mod = MyModule()
h = torch.rand(
[
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
]
)
mod.forward(h, seq_block_ids, cache_state)
ep = torch.export.export(
mod,
args=(
h,
seq_block_ids,
cache_state,
),
)
output = aot.export(ep)
output.verify()
asm = str(output.mlir_module)
self.assertIn("torch.aten._scaled_dot_product_flash_attention_for_cpu", asm)


if __name__ == "__main__":
unittest.main()
6 changes: 5 additions & 1 deletion sharktank/tests/models/llama/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class AttentionBlockTest(unittest.TestCase):
def test(self):
torch.manual_seed(123456)
torch.set_default_dtype(torch.float32)
block_index = 0
seq_len = 13
Expand Down Expand Up @@ -58,6 +59,7 @@ def test(self):
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
attention_kernel="torch",
)
attention_embedding = RotaryEmbeddingLayer(
rope_dimension_count=rope_dimension_count,
Expand Down Expand Up @@ -147,7 +149,9 @@ def test(self):
)[0]

assert sharktank_output.shape == huggingface_output.shape
torch.testing.assert_close(sharktank_output, huggingface_output)
torch.testing.assert_close(
sharktank_output, huggingface_output, atol=1e-5, rtol=5e-2
)


if __name__ == "__main__":
Expand Down
Loading