diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 5a5f6d40c..60617c9aa 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -73,7 +73,6 @@ def main(): if "tensor_parallelism_size" in dataset.properties else args.tensor_parallelism_size ) - llama_config = LlamaModelConfig( hp, tensor_parallelism_size=tensor_parallelism_size, diff --git a/sharktank/sharktank/kernels/attention.py b/sharktank/sharktank/kernels/attention.py index 3e2ef4a57..9dae002df 100644 --- a/sharktank/sharktank/kernels/attention.py +++ b/sharktank/sharktank/kernels/attention.py @@ -10,9 +10,110 @@ __all__ = [ "flash_attention", + "masked_flash_attention", ] +@CustomOp.register(library=LIBRARY) +class masked_flash_attention(CustomOp): + + signature = "masked_flash_attention(Tensor q, Tensor k, Tensor v, Tensor? a, Tensor scale) -> (Tensor)" + + def select(self, ksel: KernelSelection): + q_desc = ksel.arg_tensor(0) # Shape b, l, d + k_desc = ksel.arg_tensor(1) # Shape b, s, d + v_desc = ksel.arg_tensor(2) # Shape b, s, e + a_desc = ksel.arg_tensor(3) # Shape b, l, s + s_desc = ksel.arg_tensor(4) + + q_bs = q_desc.t.shape[:-2] + k_bs = k_desc.t.shape[:-2] + v_bs = v_desc.t.shape[:-2] + a_bs = a_desc.t.shape[:-2] + + bs = len(q_bs) + + # Note: kernel does collapse dims to get to a single batch/head dim + torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported") + + q_l, q_d = q_desc.t.shape[-2:] + k_s, k_d = k_desc.t.shape[-2:] + v_s, v_e = v_desc.t.shape[-2:] + + torch._check( + q_desc.t.dtype.is_floating_point + and k_desc.t.dtype.is_floating_point + and v_desc.t.dtype.is_floating_point + and s_desc.t.dtype.is_floating_point, + lambda: f"flash_attention: Expected floating point", + ) + + for q_b, k_b, v_b in zip(q_bs, k_bs, v_bs): + torch._check( + q_b == k_b and q_b == v_b, + lambda: f"expected matching batch dims: {q_b}, {k_b}, {v_b}", + ) + + torch._check(q_d == k_d, lambda: f"expected matching qk features: {q_d}, {k_d}") + + torch._check(k_s == v_s, lambda: f"expected matching kv length: {q_d}, {k_d}") + + q_desc.specialize_dims(0, 1, -1) + k_desc.specialize_dims(0, 1, -1) + v_desc.specialize_dims(0, 1, -1) + + # Result 0: Shape batch..., m, n + ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float32).specialize_dims( + 0, 1, -1 + ) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + q = kb.arg_value(0) + k = kb.arg_value(1) + v = kb.arg_value(2) + a = kb.arg_value(3) + scale = kb.arg_value(4) + + q_tensor_type = RankedTensorType(q.type) + scale_tensor_type = RankedTensorType(scale.type) + v_tensor_type = RankedTensorType(v.type) + + b1, b2, l, d = q_tensor_type.shape + _, _, s, e = v_tensor_type.shape + + # Unspecialized dims will be negative + l = l if l >= 0 else "?" + s = s if s >= 0 else "?" + b = str(int(b1) * int(b2)) + i_type_str = str(q_tensor_type.element_type) + scale_type_str = str(scale_tensor_type.element_type) + o_type_str = "f32" + + target_function_name = f"sharktank_masked_flash_attention_{b1}_{b2}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" + kwargs = { + "b": b, + "b1": b1, + "b2": b2, + "l": l, + "d": d, + "s": s, + "e": e, + "i_dtype": i_type_str, + "scale_dtype": scale_type_str, + "o_dtype": o_type_str, + "func_name": target_function_name, + } + template_file = "masked_flash_attention.mlir" + target_function = inline_template_function( + kb, + template_file, + target_function_name, + **kwargs, + ) + kb.yield_results(*call_function(target_function, q, k, v, scale, a)) + pass + + @CustomOp.register(library=LIBRARY) class flash_attention(CustomOp): diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 21f9e9ed4..a55d6654b 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -8,7 +8,7 @@ import torch -from iree.compiler.ir import IntegerType +from iree.compiler.ir import IntegerType, FloatType __all__ = [ "batch_matmul_transpose_b", @@ -59,9 +59,7 @@ def select(self, ksel: KernelSelection): lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})", ) # Shape batch, m, n - c_desc = ksel.return_new_tensor( - [lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype - ) + c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) specialize_all_known_dims(c_desc) @@ -77,8 +75,9 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): result_desc = ksel.result_descs[0] # Generate specialization signature and types. - a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type) + a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) + accum_type = FloatType.parse("f32") spec_sig = f"L{a_ident}_R{b_ident}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" diff --git a/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir b/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir new file mode 100644 index 000000000..a11fb4787 --- /dev/null +++ b/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir @@ -0,0 +1,62 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +!q_type = tensor<{{b1}}x{{b2}}x{{l}}x{{d}}x{{i_dtype}}> +!k_type = tensor<{{b1}}x{{b2}}x{{s}}x{{d}}x{{i_dtype}}> +!v_type = tensor<{{b1}}x{{b2}}x{{s}}x{{e}}x{{i_dtype}}> +!a_type = tensor<{{l}}x{{s}}x{{i_dtype}}> +!trans_v_type = tensor<{{b1}}x{{b2}}x{{e}}x{{s}}x{{i_dtype}}> +!o_type = tensor<{{b1}}x{{b2}}x{{l}}x{{e}}x{{o_dtype}}> +!o_dyn_type = tensor +!o_collapsed_type = tensor<{{b}}x{{l}}x{{e}}x{{o_dtype}}> +!q_collapsed_type = tensor<{{b}}x{{l}}x{{d}}x{{i_dtype}}> +!k_collapsed_type = tensor<{{b}}x{{s}}x{{d}}x{{i_dtype}}> +!v_collapsed_type = tensor<{{b}}x{{s}}x{{e}}x{{i_dtype}}> +!a_collapsed_type = tensor<{{l}}x{{s}}x{{i_dtype}}> +!s_type = tensor<{{scale_dtype}}> + +module { + +util.func private @{{func_name}}( + %q: !q_type, + %k: !k_type, + %v: !v_type, + %s: !s_type, + %a: !a_type) -> !o_type { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %b0 = arith.constant {{b}} : index + + + %l = tensor.dim %q, %c2 : !q_type + %e = tensor.dim %v, %c3 : !v_type + + %scale = tensor.extract %s[] : !s_type + %empty_dyn = tensor.empty(%b0, %l, %e) : !o_dyn_type + %empty = tensor.cast %empty_dyn : !o_dyn_type to !o_collapsed_type + + %collapsed_q = tensor.collapse_shape %q [[0, 1], [2], [3]] : !q_type into !q_collapsed_type + %collapsed_k = tensor.collapse_shape %k [[0, 1], [2], [3]] : !k_type into !k_collapsed_type + %collapsed_v = tensor.collapse_shape %v [[0, 1], [2], [3]] : !v_type into !v_collapsed_type + + %atten = iree_linalg_ext.attention {indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} + ins(%collapsed_q, %collapsed_k, %collapsed_v, %scale, %a : !q_collapsed_type, !k_collapsed_type, !v_collapsed_type, {{scale_dtype}}, !a_collapsed_type) outs(%empty : !o_collapsed_type) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + } -> !o_collapsed_type + %expanded_o = tensor.expand_shape %atten [[0,1], [2], [3]] output_shape [{{b1}}, {{b2}}, %l, {{e}}] : !o_collapsed_type into !o_type + util.return %expanded_o : !o_type + } +} diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 8ace77981..6cd6eda13 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -126,6 +126,8 @@ def attention_mask( # Combine the causal context mask and input mask. dtype = self.attention_dtype + print("attention dtype") + print(self.attention_dtype) _, batch_seq_len = input_mask.shape causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..9f522f7d5 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -59,6 +59,7 @@ def __init__( if self.q_input is not None and self.qdq_input is not None: raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input") self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output") + self.q_output: Optional[QuantizerTensor] = theta.optional_tensor("q_output") def forward(self, x): weight = self.weight @@ -79,14 +80,17 @@ def forward(self, x): y = ops.linear(x, weight, bias) # Unconditionally dequantize. + if self.q_output is not None: + y = self.q_output.quantize(y) + return y.unpack().qs if isinstance(y, QuantizedTensor): y = y.unpack().dequant() # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor): - if y.dtype == torch.float8_e4m3fnuz: + if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor): + if x.unpack().qs.dtype == torch.float8_e4m3fnuz: y = ops.to(y, torch.bfloat16) return y if qdq_output is not None: diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 69b011cc4..7b8cf9b8b 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -17,6 +17,7 @@ from .rotary_embedding import RotaryEmbeddingLayer from .kv_cache import PagedKVCache from .. import ops +from .. import kernels __all__ = [ "PagedLlamaAttentionBlock", @@ -74,6 +75,9 @@ def __init__( self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor( "kv_cache.quantizer" ) + self.attention_scale = None + if "attn_scale" in theta.keys: + self.attention_scale = theta("attn_scale").as_torch() if theta.optional_tensor("attn_output_norm") is None: self.add_module( @@ -197,6 +201,18 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_output = ops.matmul( attn_weights, values ) # (bs, heads, slen, head_dim) + elif self.attention_kernel == "sharktank": + assert self.attention_scale is not None + if attention_mask is not None: + attn_output = kernels.masked_flash_attention( + xq, + keys, + values, + attention_mask.squeeze(0).squeeze(0), + self.attention_scale, + ) + else: + attn_output = kernels.flash_attention(xq, keys, values) else: attn_output = ops.scaled_dot_product_attention( q=xq, # [bs, ..., sl, dim] diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py index 6ca19d5bd..ac2e6a3a5 100644 --- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -210,6 +210,7 @@ def quantize_weight( weight_quant_zero_point, ) # we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences + # scales are multipled by two to account for the difference between fnuz and fn if input_quant_scale is not None: updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer( name=new_layer_name + ".q_input", @@ -218,10 +219,10 @@ def quantize_weight( dtype=torch.float8_e4m3fnuz, ) if output_quant_scale is not None: - updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer( - name=new_layer_name + ".qdq_output", - scale=1.0 / output_quant_scale, - reciprocal_scale=output_quant_scale, + updated_tensors[new_layer_name + ".q_output"] = StaticScaledQuantizer( + name=new_layer_name + ".q_output", + scale=1.0 / (output_quant_scale * 2.0), + reciprocal_scale=output_quant_scale * 2.0, dtype=torch.float8_e4m3fnuz, ) @@ -258,15 +259,29 @@ def update_norm_layer( sub_name = layer_name + "." + sub new_name = hf_to_gguf(sub_name) + ".weight" single_replace(quant_theta, sub_name, new_name, updated_tensors) - kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch() layer_idx = layer_name.split(".")[-1] - new_name = f"blk.{layer_idx}.kv_cache" - updated_tensors[new_name] = StaticScaledQuantizer( - name=new_name + ".quantizer", - scale=1.0 / (kv_cache_scale * 2.0), - reciprocal_scale=kv_cache_scale * 2.0, - dtype=torch.float8_e4m3fnuz, - ) + if "kv_cache_scale" in quant_theta(layer_name, "self_attn").keys: + kv_cache_scale = ( + quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch() + ) + new_name = f"blk.{layer_idx}.kv_cache" + updated_tensors[new_name] = StaticScaledQuantizer( + name=new_name + ".quantizer", + scale=1.0 / (kv_cache_scale * 2.0), + reciprocal_scale=kv_cache_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + if "prob_output_scale" in quant_theta(layer_name, "self_attn").keys: + prob_output_scale = ( + quant_theta(layer_name, "self_attn").tensor("prob_output_scale").as_torch() + * 2.0 + ) + new_name = f"blk.{layer_idx}.attn_scale" + updated_tensors[new_name] = DefaultPrimitiveTensor( + name=new_name, data=prob_output_scale + ) + print("added attn_scale", new_name) + print(prob_output_scale) def single_replace( @@ -298,7 +313,7 @@ def main(argv): type=str, default="7b", help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.", - choices=["7b", "70b"], + choices=["7b", "70b", "405b"], ) args = cli.parse(parser, args=argv) @@ -306,8 +321,8 @@ def main(argv): params_path: Path = args.params # TODO: find a way to get this programatically so we don't have to flag for it split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024] - num_layers = 32 if args.model_base == "7b" else 80 - + layers_per_base = {"7b": 32, "70b": 40, "405b": 125} + num_layers = layers_per_base[args.model_base] # Construct the pre-transform dataset. dataset_props = _get_dataset_props(_load_json(config_json_path)) with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index d1353daaa..29d7c6869 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -24,6 +24,7 @@ AnyTensor, PlanarQuantizedTensor, ) +from ..kernels import flash_attention, masked_flash_attention from ..types.layouts import TensorScaledLayout @@ -47,7 +48,25 @@ def _extract_linear_scale(t): return unbox_tensor(t), None -def flash_attention(q, k, v, a, is_causal, scale): +def register_attention_override_by_name(name: str): + """Provides a way to override available attention kernels + based on something other than a global flag""" + if name == "flash_attention": + scaled_dot_product_attention.override( + PlanarQuantizedTensor, + PlanarQuantizedTensor, + PlanarQuantizedTensor, + NoneType, + )(flash_attention) + elif name == "masked_flash_attention": + scaled_dot_product_attention.override( + AnyTensor, AnyTensor, AnyTensor, AnyTensor + )(masked_flash_attention) + else: + assert False, f"{name} not a registerable override" + + +def prepare_args(q, k, v, scale): scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32) q, qscale = _extract_linear_scale(q) @@ -66,7 +85,7 @@ def flash_attention(q, k, v, a, is_causal, scale): if v.dtype == torch.float32: v = v.to(torch.float16) - atten = kernels.flash_attention(q, k, v, scale) + atten = kernels.flash_attention(q, k, v, a, scale) atten = atten * vscale if vscale is not None else atten return atten @@ -76,3 +95,7 @@ def flash_attention(q, k, v, a, is_causal, scale): scaled_dot_product_attention.override( PlanarQuantizedTensor, PlanarQuantizedTensor, PlanarQuantizedTensor, NoneType )(flash_attention) +if debugging.flags.use_custom_generic_attention: + scaled_dot_product_attention.override(AnyTensor, AnyTensor, AnyTensor, AnyTensor)( + masked_flash_attention + ) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f88684273..df6d74b15 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -50,10 +50,10 @@ def qlinear_tensor_scaled( # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - if x_layout.qs.dtype == torch.float8_e4m3fnuz: - # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) - else: + if ( + x_layout.qs.dtype != torch.float8_e4m3fnuz + or weight_layout.qs.dtype != torch.float8_e4m3fnuz + ): return NotImplemented # Bias. @@ -93,6 +93,7 @@ def qlinear_tensor_scaled( # Fall back to automatic fusion based on integer, high precision matmul. y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) + return y_qs # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -187,9 +188,8 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b( - lhs.to(accum_dtype), rhs.to(accum_dtype) - ) + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs) + return y_qs # Squeeze the batch dimension to maintain shape parity with other # layers. if len(y_qs.shape) > 2: diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index a698ccb06..dce1646cc 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -859,7 +859,7 @@ def _scaled_dot_product_attention( ): tensors = (q, k, v, a) for override in d.find_overrides(tensors): - result = override(q, k, v, a, is_causal=is_causal, scale=scale) + result = override(q, k, v, a, scale=scale) if result is not NotImplemented: return override, result else: diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index e3dba31fa..2a5d80956 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -73,7 +73,7 @@ def add_model_options(parser: argparse.ArgumentParser): "--attention-kernel", type=str, default="torch", - choices=["decomposed", "torch"], + choices=["decomposed", "torch", "sharktank"], ) parser.add_argument( "--skip-prefill", diff --git a/sharktank/sharktank/utils/debugging.py b/sharktank/sharktank/utils/debugging.py index dbd9237c6..7769d0676 100644 --- a/sharktank/sharktank/utils/debugging.py +++ b/sharktank/sharktank/utils/debugging.py @@ -37,6 +37,7 @@ class DebugFlags: # certain eager use cases are still having problems with these custom # kernels, so keeping it to unblock progress. use_custom_iree_kernels: bool = True + use_custom_generic_attention: bool = True def set(self, part: str): m = re.match(SETTING_PART_PATTERN, part) @@ -55,6 +56,8 @@ def set(self, part: str): self.trace_path = Path(value) elif name == "use_custom_iree_kernels": self.use_custom_iree_kernels = logical_sense + elif name == "use_custom_generic_attention": + self.use_custom_generic_attention = logical_sense else: logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name) diff --git a/sharktank/tests/kernels/attention_template_test.py b/sharktank/tests/kernels/attention_template_test.py new file mode 100644 index 000000000..8b68e2b91 --- /dev/null +++ b/sharktank/tests/kernels/attention_template_test.py @@ -0,0 +1,138 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest +from parameterized import parameterized + +import torch + +from iree.turbine import aot +from sharktank import kernels +from sharktank.types import layout_utils +from sharktank.utils import debugging +from sharktank import ops + + +class custom_attention(unittest.TestCase): + def setUp(self): + torch.manual_seed(420) + + @parameterized.expand( + [ + (torch.float32, 5e-3, 1e-3, True), + (torch.float16, 5e-3, 1e-3, True), + # Currently failing on unmasked error + # (torch.float32, 5e-3, 1e-3, False), + # (torch.float16, 5e-3, 1e-3, False), + ] + ) + def test_compare_torch_spda(self, dtype, atol, rtol, use_mask): + H = 4 # Head dim + N = 1 # Batch Size + L = 7 # Target Seq Len + S = 6 # Source Seq Len + Eqk = Ev = 64 # embedding dimensions with subscript identifiers + + q = torch.rand([N, H, L, Eqk], dtype=dtype) + k = torch.rand([N, H, S, Eqk], dtype=dtype) + v = torch.rand([N, H, S, Ev], dtype=dtype) + # mask is same type as inputs, therefore its added to score + mask = None + scale = torch.tensor(1.0, dtype=dtype) + if use_mask: + mask = torch.rand([L, S], dtype=dtype) + + res2 = kernels.masked_flash_attention(q, k, v, mask, scale=scale) + + else: + res2 = kernels.flash_attention(q, k, v, scale) + + ref = torch.nn.functional.scaled_dot_product_attention( + q, k, v, mask, scale=scale + ) + + torch.testing.assert_close(res2.to(dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + # Todo: fixed unmasked. + # (torch.float32, False, False), + (torch.float32, False, True), + (torch.float16, True, True), + (torch.float8_e4m3fnuz, False, True), + ] + ) + def test_export_dynamic(self, dtype, static, use_mask): + ops.attention_impls.register_attention_override_by_name( + "masked_flash_attention" + ) + cast = False + if dtype == torch.float8_e4m3fnuz: + dtype = torch.float32 + cast = True + H = 4 # Head dim + N = 1 # Batch Size + L = 19 # Target Seq Len + S = 19 # Source Seq Len + Eqk = Ev = 64 # embedding dimensions with subscript identifiers + + q = torch.rand([N, H, L, Eqk], dtype=dtype) + k = torch.rand([N, H, S, Eqk], dtype=dtype) + v = torch.rand([N, H, S, Ev], dtype=dtype) + if use_mask: + # mask is same type as inputs, therefore its added to score + mask = torch.rand([L, S], dtype=dtype) + if cast: + q = q.to(torch.float8_e4m3fnuz) + k = q.to(torch.float8_e4m3fnuz) + v = v.to(torch.float8_e4m3fnuz) + if use_mask: + mask = mask.to(torch.float8_e4m3fnuz) + scale = torch.tensor(1.0, dtype=dtype) + dynamic_shapes = None + if not static: + L_dim = torch.export.Dim("L") + S_dim = torch.export.Dim("S") + dynamic_shapes = { + "q": {2: L_dim}, + "k": {2: S_dim}, + "v": {2: S_dim}, + "mask": {}, + "scale": {}, + } + if use_mask: + dynamic_shapes["mask"] = {0: L_dim, 1: S_dim} + + class MyModule(torch.nn.Module): + def forward(self, q, k, v, mask, scale): + return ops.scaled_dot_product_attention( + q, k, v, a=mask, is_causal=False, scale=scale + ) + + mod = MyModule() + dtype = torch.dtype + if use_mask: + ep = torch.export.export( + mod, + args=(q, k, v, mask, scale), + dynamic_shapes=dynamic_shapes, + ) + else: + ep = torch.export.export( + mod, + args=(q, k, v, None, scale), + dynamic_shapes=dynamic_shapes, + ) + output = aot.export(ep) + output.verify() + + +if __name__ == "__main__": + unittest.main()