From 43f04c1bb2972df9568a95ce0d47c5dddaee2cbd Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 17 Dec 2024 13:01:58 -0800 Subject: [PATCH 1/6] add llm ver --- .../sharktank/examples/export_paged_llm_v1.py | 2 +- sharktank/sharktank/kernels/attention.py | 38 +++++++++++-------- sharktank/sharktank/ops/attention_impls.py | 2 +- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 5a5f6d40c..30e9ad9a7 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -73,7 +73,7 @@ def main(): if "tensor_parallelism_size" in dataset.properties else args.tensor_parallelism_size ) - + attention_dtype = getattr(torch, args.attention_dtype) llama_config = LlamaModelConfig( hp, tensor_parallelism_size=tensor_parallelism_size, diff --git a/sharktank/sharktank/kernels/attention.py b/sharktank/sharktank/kernels/attention.py index 3e2ef4a57..2dfed0e17 100644 --- a/sharktank/sharktank/kernels/attention.py +++ b/sharktank/sharktank/kernels/attention.py @@ -16,15 +16,14 @@ @CustomOp.register(library=LIBRARY) class flash_attention(CustomOp): - signature = ( - "flash_attention(Tensor q, Tensor k, Tensor v, Tensor scale) -> (Tensor)" - ) + signature = "flash_attention(Tensor q, Tensor k, Tensor v, Tensor? a, Tensor scale) -> (Tensor)" def select(self, ksel: KernelSelection): q_desc = ksel.arg_tensor(0) # Shape b, l, d k_desc = ksel.arg_tensor(1) # Shape b, s, d v_desc = ksel.arg_tensor(2) # Shape b, s, e - s_desc = ksel.arg_tensor(3) + a_desc = ksel.arg_tensor(3) # Shape b, l, s + s_desc = ksel.arg_tensor(4) q_bs = q_desc.t.shape[:-2] k_bs = k_desc.t.shape[:-2] @@ -32,7 +31,8 @@ def select(self, ksel: KernelSelection): bs = len(q_bs) - torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported") + print(q_bs) + # torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported") q_l, q_d = q_desc.t.shape[-2:] k_s, k_d = k_desc.t.shape[-2:] @@ -56,32 +56,40 @@ def select(self, ksel: KernelSelection): torch._check(k_s == v_s, lambda: f"expected matching kv length: {q_d}, {k_d}") - q_desc.specialize_dims(bs, bs + 1) - k_desc.specialize_dims(bs, bs + 1) - v_desc.specialize_dims(bs, bs + 1) + q_desc.specialize_dims(0, -1) + k_desc.specialize_dims(0, -1) + v_desc.specialize_dims(0, -1) + a_desc.specialize_dims(0) + print(q_desc) + print(a_desc) # Result 0: Shape batch..., m, n ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims( - 1, 2 + 0, 2 ) def generate(self, ksel: KernelSelection, kb: KernelBuilder): q = kb.arg_value(0) k = kb.arg_value(1) v = kb.arg_value(2) - scale = kb.arg_value(3) + a = kb.arg_value(3) + scale = kb.arg_value(4) q_tensor_type = RankedTensorType(q.type) scale_tensor_type = RankedTensorType(scale.type) v_tensor_type = RankedTensorType(v.type) + a_tensor_type = RankedTensorType(a.type) - _, _, l, d = q_tensor_type.shape - _, _, s, e = v_tensor_type.shape + b, l, d = q_tensor_type.shape + b, s, e = v_tensor_type.shape + l = "?" + s = "?" i_type_str = str(q_tensor_type.element_type) scale_type_str = str(scale_tensor_type.element_type) o_type_str = "f16" kwargs = { + "b": b, "l": l, "d": d, "s": s, @@ -90,13 +98,13 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): "scale_type": scale_type_str, "o_type": o_type_str, } - template_file = "flash_attention.mlir" - target_function_name = f"sharktank_flash_attention_{l}_{s}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" + template_file = "flash_attention_llm.mlir" + target_function_name = f"sharktank_llm_flash_attention_{b}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" target_function = inline_template_function( kb, template_file, target_function_name, **kwargs, ) - kb.yield_results(*call_function(target_function, q, k, v, scale)) + kb.yield_results(*call_function(target_function, q, k, v, scale, a)) pass diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index d1353daaa..cb7f6d78a 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -66,7 +66,7 @@ def flash_attention(q, k, v, a, is_causal, scale): if v.dtype == torch.float32: v = v.to(torch.float16) - atten = kernels.flash_attention(q, k, v, scale) + atten = kernels.flash_attention(q, k, v, a, scale) atten = atten * vscale if vscale is not None else atten return atten From 521135ddb9e09a8142199fa606c7a4265bf1b9f5 Mon Sep 17 00:00:00 2001 From: dan Date: Mon, 13 Jan 2025 12:38:58 -0600 Subject: [PATCH 2/6] fix global issues --- sharktank/sharktank/ops/attention_impls.py | 24 ++- sharktank/sharktank/utils/debugging.py | 3 + .../tests/kernels/attention_template_test.py | 138 ++++++++++++++++++ 3 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 sharktank/tests/kernels/attention_template_test.py diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index cb7f6d78a..d4d757671 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -47,7 +47,25 @@ def _extract_linear_scale(t): return unbox_tensor(t), None -def flash_attention(q, k, v, a, is_causal, scale): +def register_attention_override_by_name(name: str): + """Provides a way to override available attention kernels + based on something other than a global flag""" + if name == "flash_attention": + scaled_dot_product_attention.override( + PlanarQuantizedTensor, + PlanarQuantizedTensor, + PlanarQuantizedTensor, + NoneType, + )(flash_attention) + elif name == "masked_flash_attention": + scaled_dot_product_attention.override( + AnyTensor, AnyTensor, AnyTensor, AnyTensor + )(masked_flash_attention) + else: + assert False, f"{name} not a registerable override" + + +def prepare_args(q, k, v, scale): scale = torch.scalar_tensor(1.0 / math.sqrt(q.shape[-1]), dtype=torch.float32) q, qscale = _extract_linear_scale(q) @@ -76,3 +94,7 @@ def flash_attention(q, k, v, a, is_causal, scale): scaled_dot_product_attention.override( PlanarQuantizedTensor, PlanarQuantizedTensor, PlanarQuantizedTensor, NoneType )(flash_attention) +if debugging.flags.use_custom_generic_attention: + scaled_dot_product_attention.override(AnyTensor, AnyTensor, AnyTensor, AnyTensor)( + masked_flash_attention + ) diff --git a/sharktank/sharktank/utils/debugging.py b/sharktank/sharktank/utils/debugging.py index dbd9237c6..60a9ddca7 100644 --- a/sharktank/sharktank/utils/debugging.py +++ b/sharktank/sharktank/utils/debugging.py @@ -37,6 +37,7 @@ class DebugFlags: # certain eager use cases are still having problems with these custom # kernels, so keeping it to unblock progress. use_custom_iree_kernels: bool = True + use_custom_generic_attention: bool = False def set(self, part: str): m = re.match(SETTING_PART_PATTERN, part) @@ -55,6 +56,8 @@ def set(self, part: str): self.trace_path = Path(value) elif name == "use_custom_iree_kernels": self.use_custom_iree_kernels = logical_sense + elif name == "use_custom_generic_attention": + self.use_custom_generic_attention = logical_sense else: logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name) diff --git a/sharktank/tests/kernels/attention_template_test.py b/sharktank/tests/kernels/attention_template_test.py new file mode 100644 index 000000000..268ffcf49 --- /dev/null +++ b/sharktank/tests/kernels/attention_template_test.py @@ -0,0 +1,138 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest +from parameterized import parameterized + +import torch + +from iree.turbine import aot +from sharktank import kernels +from sharktank.types import layout_utils +from sharktank.utils import debugging +from sharktank import ops + + +class custom_attention(unittest.TestCase): + def setUp(self): + torch.manual_seed(420) + + @parameterized.expand( + [ + (torch.float32, 5e-3, 1e-3, True), + (torch.float16, 5e-3, 1e-3, True), + # Currently failing on unmasked error + # (torch.float32, 5e-3, 1e-3, False), + # (torch.float16, 5e-3, 1e-3, False), + ] + ) + def test_compare_torch_spda(self, dtype, atol, rtol, use_mask): + H = 4 # Head dim + N = 1 # Batch Size + L = 7 # Target Seq Len + S = 6 # Source Seq Len + Eqk = Ev = 64 # embedding dimensions with subscript identifiers + + q = torch.rand([N, H, L, Eqk], dtype=dtype) + k = torch.rand([N, H, S, Eqk], dtype=dtype) + v = torch.rand([N, H, S, Ev], dtype=dtype) + # mask is same type as inputs, therefore its added to score + mask = None + scale = torch.tensor(1.0, dtype=dtype) + if use_mask: + mask = torch.rand([N, H, L, S], dtype=dtype) + + res2 = kernels.masked_flash_attention(q, k, v, mask, scale=scale) + + else: + res2 = kernels.flash_attention(q, k, v, scale) + + ref = torch.nn.functional.scaled_dot_product_attention( + q, k, v, mask, scale=scale + ) + + torch.testing.assert_close(res2.to(dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + # Todo: fixed unmasked. + # (torch.float32, False, False), + (torch.float32, False, True), + (torch.float16, True, True), + (torch.float8_e4m3fnuz, False, True), + ] + ) + def test_export_dynamic(self, dtype, static, use_mask): + ops.attention_impls.register_attention_override_by_name( + "masked_flash_attention" + ) + cast = False + if dtype == torch.float8_e4m3fnuz: + dtype = torch.float32 + cast = True + H = 4 # Head dim + N = 1 # Batch Size + L = 19 # Target Seq Len + S = 19 # Source Seq Len + Eqk = Ev = 64 # embedding dimensions with subscript identifiers + + q = torch.rand([N, H, L, Eqk], dtype=dtype) + k = torch.rand([N, H, S, Eqk], dtype=dtype) + v = torch.rand([N, H, S, Ev], dtype=dtype) + if use_mask: + # mask is same type as inputs, therefore its added to score + mask = torch.rand([N, H, L, S], dtype=dtype) + if cast: + q = q.to(torch.float8_e4m3fnuz) + k = q.to(torch.float8_e4m3fnuz) + v = v.to(torch.float8_e4m3fnuz) + if use_mask: + mask = mask.to(torch.float8_e4m3fnuz) + scale = torch.tensor(1.0, dtype=dtype) + dynamic_shapes = None + if not static: + L_dim = torch.export.Dim("L") + S_dim = torch.export.Dim("S") + dynamic_shapes = { + "q": {2: L_dim}, + "k": {2: S_dim}, + "v": {2: S_dim}, + "mask": {}, + "scale": {}, + } + if use_mask: + dynamic_shapes["mask"] = {2: L_dim, 3: S_dim} + + class MyModule(torch.nn.Module): + def forward(self, q, k, v, mask, scale): + return ops.scaled_dot_product_attention( + q, k, v, a=mask, is_causal=False, scale=scale + ) + + mod = MyModule() + dtype = torch.dtype + if use_mask: + ep = torch.export.export( + mod, + args=(q, k, v, mask, scale), + dynamic_shapes=dynamic_shapes, + ) + else: + ep = torch.export.export( + mod, + args=(q, k, v, None, scale), + dynamic_shapes=dynamic_shapes, + ) + output = aot.export(ep) + output.verify() + + +if __name__ == "__main__": + unittest.main() From fd5cdcb728a17b657f3aa144a585af1c48b2a3af Mon Sep 17 00:00:00 2001 From: dan Date: Mon, 3 Feb 2025 17:57:42 -0800 Subject: [PATCH 3/6] wip hell --- .../sharktank/examples/export_paged_llm_v1.py | 1 - sharktank/sharktank/kernels/attention.py | 132 +++++++++++++++--- .../templates/masked_flash_attention.mlir | 63 +++++++++ sharktank/sharktank/layers/linear.py | 4 + .../layers/paged_llama_attention_block.py | 12 ++ .../llama/tools/import_quark_dataset.py | 39 ++++-- sharktank/sharktank/ops/attention_impls.py | 1 + sharktank/sharktank/utils/cli.py | 2 +- sharktank/sharktank/utils/debugging.py | 2 +- 9 files changed, 222 insertions(+), 34 deletions(-) create mode 100644 sharktank/sharktank/kernels/templates/masked_flash_attention.mlir diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 30e9ad9a7..60617c9aa 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -73,7 +73,6 @@ def main(): if "tensor_parallelism_size" in dataset.properties else args.tensor_parallelism_size ) - attention_dtype = getattr(torch, args.attention_dtype) llama_config = LlamaModelConfig( hp, tensor_parallelism_size=tensor_parallelism_size, diff --git a/sharktank/sharktank/kernels/attention.py b/sharktank/sharktank/kernels/attention.py index 2dfed0e17..6bc97120e 100644 --- a/sharktank/sharktank/kernels/attention.py +++ b/sharktank/sharktank/kernels/attention.py @@ -10,13 +10,14 @@ __all__ = [ "flash_attention", + "masked_flash_attention", ] @CustomOp.register(library=LIBRARY) -class flash_attention(CustomOp): +class masked_flash_attention(CustomOp): - signature = "flash_attention(Tensor q, Tensor k, Tensor v, Tensor? a, Tensor scale) -> (Tensor)" + signature = "masked_flash_attention(Tensor q, Tensor k, Tensor v, Tensor? a, Tensor scale) -> (Tensor)" def select(self, ksel: KernelSelection): q_desc = ksel.arg_tensor(0) # Shape b, l, d @@ -28,11 +29,12 @@ def select(self, ksel: KernelSelection): q_bs = q_desc.t.shape[:-2] k_bs = k_desc.t.shape[:-2] v_bs = v_desc.t.shape[:-2] + a_bs = a_desc.t.shape[:-2] bs = len(q_bs) - print(q_bs) - # torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported") + # Note: kernel does collapse dims to get to a single batch/head dim + torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported") q_l, q_d = q_desc.t.shape[-2:] k_s, k_d = k_desc.t.shape[-2:] @@ -56,16 +58,14 @@ def select(self, ksel: KernelSelection): torch._check(k_s == v_s, lambda: f"expected matching kv length: {q_d}, {k_d}") - q_desc.specialize_dims(0, -1) - k_desc.specialize_dims(0, -1) - v_desc.specialize_dims(0, -1) - a_desc.specialize_dims(0) - print(q_desc) - print(a_desc) + q_desc.specialize_dims(0, 1, -1) + k_desc.specialize_dims(0, 1, -1) + v_desc.specialize_dims(0, 1, -1) + a_desc.specialize_dims(0, 1) # Result 0: Shape batch..., m, n ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims( - 0, 2 + 0, 1, -1 ) def generate(self, ksel: KernelSelection, kb: KernelBuilder): @@ -74,22 +74,116 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): v = kb.arg_value(2) a = kb.arg_value(3) scale = kb.arg_value(4) + q_tensor_type = RankedTensorType(q.type) scale_tensor_type = RankedTensorType(scale.type) v_tensor_type = RankedTensorType(v.type) - a_tensor_type = RankedTensorType(a.type) - b, l, d = q_tensor_type.shape - b, s, e = v_tensor_type.shape - l = "?" - s = "?" + b1, b2, l, d = q_tensor_type.shape + _, _, s, e = v_tensor_type.shape + # Unspecialized dims will be negative + l = l if l >= 0 else "?" + s = s if s >= 0 else "?" + b = str(int(b1) * int(b2)) i_type_str = str(q_tensor_type.element_type) scale_type_str = str(scale_tensor_type.element_type) o_type_str = "f16" + target_function_name = f"sharktank_masked_flash_attention_{b1}_{b2}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" kwargs = { "b": b, + "b1": b1, + "b2": b2, + "l": l, + "d": d, + "s": s, + "e": e, + "i_dtype": i_type_str, + "scale_dtype": scale_type_str, + "o_dtype": o_type_str, + "func_name": target_function_name, + } + template_file = "masked_flash_attention.mlir" + target_function = inline_template_function( + kb, + template_file, + target_function_name, + **kwargs, + ) + kb.yield_results(*call_function(target_function, q, k, v, scale, a)) + pass + + +@CustomOp.register(library=LIBRARY) +class flash_attention(CustomOp): + + signature = ( + "flash_attention(Tensor q, Tensor k, Tensor v, Tensor scale) -> (Tensor)" + ) + + def select(self, ksel: KernelSelection): + q_desc = ksel.arg_tensor(0) # Shape b, l, d + k_desc = ksel.arg_tensor(1) # Shape b, s, d + v_desc = ksel.arg_tensor(2) # Shape b, s, e + s_desc = ksel.arg_tensor(3) + + q_bs = q_desc.t.shape[:-2] + k_bs = k_desc.t.shape[:-2] + v_bs = v_desc.t.shape[:-2] + + bs = len(q_bs) + + torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported") + + q_l, q_d = q_desc.t.shape[-2:] + k_s, k_d = k_desc.t.shape[-2:] + v_s, v_e = v_desc.t.shape[-2:] + + torch._check( + q_desc.t.dtype.is_floating_point + and k_desc.t.dtype.is_floating_point + and v_desc.t.dtype.is_floating_point + and s_desc.t.dtype.is_floating_point, + lambda: f"flash_attention: Expected floating point", + ) + + for q_b, k_b, v_b in zip(q_bs, k_bs, v_bs): + torch._check( + q_b == k_b and q_b == v_b, + lambda: f"expected matching batch dims: {q_b}, {k_b}, {v_b}", + ) + + torch._check(q_d == k_d, lambda: f"expected matching qk features: {q_d}, {k_d}") + + torch._check(k_s == v_s, lambda: f"expected matching kv length: {q_d}, {k_d}") + + q_desc.specialize_dims(bs, bs + 1) + k_desc.specialize_dims(bs, bs + 1) + v_desc.specialize_dims(bs, bs + 1) + + # Result 0: Shape batch..., m, n + ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims( + 1, 2 + ) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + q = kb.arg_value(0) + k = kb.arg_value(1) + v = kb.arg_value(2) + scale = kb.arg_value(3) + q_tensor_type = RankedTensorType(q.type) + scale_tensor_type = RankedTensorType(scale.type) + v_tensor_type = RankedTensorType(v.type) + + _, _, l, d = q_tensor_type.shape + _, _, s, e = v_tensor_type.shape + + i_type_str = str(q_tensor_type.element_type) + scale_type_str = str(scale_tensor_type.element_type) + o_type_str = "f16" + + kwargs = { "l": l, "d": d, "s": s, @@ -98,13 +192,13 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): "scale_type": scale_type_str, "o_type": o_type_str, } - template_file = "flash_attention_llm.mlir" - target_function_name = f"sharktank_llm_flash_attention_{b}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" + template_file = "flash_attention.mlir" + target_function_name = f"sharktank_flash_attention_{l}_{s}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" target_function = inline_template_function( kb, template_file, target_function_name, **kwargs, ) - kb.yield_results(*call_function(target_function, q, k, v, scale, a)) + kb.yield_results(*call_function(target_function, q, k, v, scale)) pass diff --git a/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir b/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir new file mode 100644 index 000000000..663a9ea84 --- /dev/null +++ b/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +!q_type = tensor<{{b1}}x{{b2}}x{{l}}x{{d}}x{{i_dtype}}> +!k_type = tensor<{{b1}}x{{b2}}x{{s}}x{{d}}x{{i_dtype}}> +!v_type = tensor<{{b1}}x{{b2}}x{{s}}x{{e}}x{{i_dtype}}> +!a_type = tensor<{{1}}x{{1}}x{{l}}x{{s}}x{{i_dtype}}> +!trans_v_type = tensor<{{b1}}x{{b2}}x{{e}}x{{s}}x{{i_dtype}}> +!o_type = tensor<{{b1}}x{{b2}}x{{l}}x{{e}}x{{o_dtype}}> +!o_dyn_type = tensor +!o_collapsed_type = tensor<{{b}}x{{l}}x{{e}}x{{o_dtype}}> +!q_collapsed_type = tensor<{{b}}x{{l}}x{{d}}x{{i_dtype}}> +!k_collapsed_type = tensor<{{b}}x{{s}}x{{d}}x{{i_dtype}}> +!v_collapsed_type = tensor<{{b}}x{{s}}x{{e}}x{{i_dtype}}> +!a_collapsed_type = tensor<{{1}}x{{l}}x{{s}}x{{i_dtype}}> +!s_type = tensor<{{scale_dtype}}> + +module { + +util.func private @{{func_name}}( + %q: !q_type, + %k: !k_type, + %v: !v_type, + %s: !s_type, + %a: !a_type) -> !o_type { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %b0 = arith.constant {{b}} : index + + + %l = tensor.dim %q, %c2 : !q_type + %e = tensor.dim %v, %c3 : !v_type + + %scale = tensor.extract %s[] : !s_type + %empty_dyn = tensor.empty(%b0, %l, %e) : !o_dyn_type + %empty = tensor.cast %empty_dyn : !o_dyn_type to !o_collapsed_type + + %collapsed_q = tensor.collapse_shape %q [[0, 1], [2], [3]] : !q_type into !q_collapsed_type + %collapsed_k = tensor.collapse_shape %k [[0, 1], [2], [3]] : !k_type into !k_collapsed_type + %collapsed_v = tensor.collapse_shape %v [[0, 1], [2], [3]] : !v_type into !v_collapsed_type + %collapsed_a = tensor.collapse_shape %a [[0, 1], [2], [3]] : !a_type into !a_collapsed_type + + %atten = iree_linalg_ext.attention {indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} + ins(%collapsed_q, %collapsed_k, %collapsed_v, %scale, %collapsed_a : !q_collapsed_type, !k_collapsed_type, !v_collapsed_type, {{scale_dtype}}, !a_collapsed_type) outs(%empty : !o_collapsed_type) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + } -> !o_collapsed_type + %expanded_o = tensor.expand_shape %atten [[0,1], [2], [3]] output_shape [{{b1}}, {{b2}}, %l, {{e}}] : !o_collapsed_type into !o_type + util.return %expanded_o : !o_type + } +} diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..5c0c3584a 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -59,6 +59,7 @@ def __init__( if self.q_input is not None and self.qdq_input is not None: raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input") self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output") + self.q_output: Optional[QuantizerTensor] = theta.optional_tensor("q_output") def forward(self, x): weight = self.weight @@ -79,6 +80,9 @@ def forward(self, x): y = ops.linear(x, weight, bias) # Unconditionally dequantize. + if self.q_output is not None: + y = self.q_output.quantize(y) + return y.unpack().qs if isinstance(y, QuantizedTensor): y = y.unpack().dequant() # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 69b011cc4..60af70ba0 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -17,6 +17,7 @@ from .rotary_embedding import RotaryEmbeddingLayer from .kv_cache import PagedKVCache from .. import ops +from .. import kernels __all__ = [ "PagedLlamaAttentionBlock", @@ -74,6 +75,9 @@ def __init__( self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor( "kv_cache.quantizer" ) + self.attention_scale = None + if "attn_scale" in theta.keys: + self.attention_scale = theta("attn_scale").as_torch() if theta.optional_tensor("attn_output_norm") is None: self.add_module( @@ -197,6 +201,14 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_output = ops.matmul( attn_weights, values ) # (bs, heads, slen, head_dim) + elif self.attention_kernel == "sharktank": + assert self.attention_scale is not None + if attention_mask is not None: + attn_output = kernels.masked_flash_attention( + xq, keys, values, attention_mask, self.attention_scale + ) + else: + attn_output = kernels.flash_attention(xq, keys, values) else: attn_output = ops.scaled_dot_product_attention( q=xq, # [bs, ..., sl, dim] diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py index 6ca19d5bd..5520a6201 100644 --- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -210,6 +210,7 @@ def quantize_weight( weight_quant_zero_point, ) # we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences + # scales are multipled by two to account for the difference between fnuz and fn if input_quant_scale is not None: updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer( name=new_layer_name + ".q_input", @@ -218,10 +219,10 @@ def quantize_weight( dtype=torch.float8_e4m3fnuz, ) if output_quant_scale is not None: - updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer( - name=new_layer_name + ".qdq_output", - scale=1.0 / output_quant_scale, - reciprocal_scale=output_quant_scale, + updated_tensors[new_layer_name + ".q_output"] = StaticScaledQuantizer( + name=new_layer_name + ".q_output", + scale=1.0 / (output_quant_scale * 2.0), + reciprocal_scale=output_quant_scale * 2.0, dtype=torch.float8_e4m3fnuz, ) @@ -258,15 +259,29 @@ def update_norm_layer( sub_name = layer_name + "." + sub new_name = hf_to_gguf(sub_name) + ".weight" single_replace(quant_theta, sub_name, new_name, updated_tensors) - kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch() layer_idx = layer_name.split(".")[-1] - new_name = f"blk.{layer_idx}.kv_cache" - updated_tensors[new_name] = StaticScaledQuantizer( - name=new_name + ".quantizer", - scale=1.0 / (kv_cache_scale * 2.0), - reciprocal_scale=kv_cache_scale * 2.0, - dtype=torch.float8_e4m3fnuz, - ) + if "kv_cache_scale" in quant_theta(layer_name, "self_attn").keys: + kv_cache_scale = ( + quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch() + ) + new_name = f"blk.{layer_idx}.kv_cache" + updated_tensors[new_name] = StaticScaledQuantizer( + name=new_name + ".quantizer", + scale=1.0 / (kv_cache_scale * 2.0), + reciprocal_scale=kv_cache_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + if "prob_output_scale" in quant_theta(layer_name, "self_attn").keys: + prob_output_scale = ( + quant_theta(layer_name, "self_attn").tensor("prob_output_scale").as_torch() + * 2.0 + ) + new_name = f"blk.{layer_idx}.attn_scale" + updated_tensors[new_name] = DefaultPrimitiveTensor( + name=new_name, data=prob_output_scale + ) + print("added attn_scale", new_name) + print(prob_output_scale) def single_replace( diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index d4d757671..29d7c6869 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -24,6 +24,7 @@ AnyTensor, PlanarQuantizedTensor, ) +from ..kernels import flash_attention, masked_flash_attention from ..types.layouts import TensorScaledLayout diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index e3dba31fa..2a5d80956 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -73,7 +73,7 @@ def add_model_options(parser: argparse.ArgumentParser): "--attention-kernel", type=str, default="torch", - choices=["decomposed", "torch"], + choices=["decomposed", "torch", "sharktank"], ) parser.add_argument( "--skip-prefill", diff --git a/sharktank/sharktank/utils/debugging.py b/sharktank/sharktank/utils/debugging.py index 60a9ddca7..7769d0676 100644 --- a/sharktank/sharktank/utils/debugging.py +++ b/sharktank/sharktank/utils/debugging.py @@ -37,7 +37,7 @@ class DebugFlags: # certain eager use cases are still having problems with these custom # kernels, so keeping it to unblock progress. use_custom_iree_kernels: bool = True - use_custom_generic_attention: bool = False + use_custom_generic_attention: bool = True def set(self, part: str): m = re.match(SETTING_PART_PATTERN, part) From f224fbfe4cedcb3e3ae6694803453a848df439eb Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 1 Feb 2025 18:33:28 -0800 Subject: [PATCH 4/6] not mergeable as-is --- sharktank/sharktank/kernels/batch_matmul_transpose_b.py | 9 ++++----- sharktank/sharktank/layers/linear.py | 4 ++-- sharktank/sharktank/ops/qlinear_impls.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 21f9e9ed4..a55d6654b 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -8,7 +8,7 @@ import torch -from iree.compiler.ir import IntegerType +from iree.compiler.ir import IntegerType, FloatType __all__ = [ "batch_matmul_transpose_b", @@ -59,9 +59,7 @@ def select(self, ksel: KernelSelection): lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})", ) # Shape batch, m, n - c_desc = ksel.return_new_tensor( - [lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype - ) + c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32) specialize_all_known_dims(lhs_desc) specialize_all_known_dims(rhs_desc) specialize_all_known_dims(c_desc) @@ -77,8 +75,9 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): result_desc = ksel.result_descs[0] # Generate specialization signature and types. - a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type) + a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type) b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type) + accum_type = FloatType.parse("f32") spec_sig = f"L{a_ident}_R{b_ident}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index 5c0c3584a..9f522f7d5 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -89,8 +89,8 @@ def forward(self, x): # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not isinstance(y, QuantizedTensor): - if y.dtype == torch.float8_e4m3fnuz: + if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor): + if x.unpack().qs.dtype == torch.float8_e4m3fnuz: y = ops.to(y, torch.bfloat16) return y if qdq_output is not None: diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f88684273..bef455412 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -93,6 +93,7 @@ def qlinear_tensor_scaled( # Fall back to automatic fusion based on integer, high precision matmul. y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) + return y_qs # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -187,9 +188,8 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) rhs_rank = len(rhs.shape) - y_qs = kernels.batch_matmul_transpose_b( - lhs.to(accum_dtype), rhs.to(accum_dtype) - ) + y_qs = kernels.batch_matmul_transpose_b(lhs, rhs) + return y_qs # Squeeze the batch dimension to maintain shape parity with other # layers. if len(y_qs.shape) > 2: From f3c854507486bbf306127182fea09a4baeab11fc Mon Sep 17 00:00:00 2001 From: dan Date: Mon, 3 Feb 2025 18:35:46 -0800 Subject: [PATCH 5/6] more hell --- sharktank/sharktank/kernels/attention.py | 3 +-- .../kernels/templates/masked_flash_attention.mlir | 7 +++---- sharktank/sharktank/layers/causal_llm.py | 2 ++ sharktank/sharktank/layers/paged_llama_attention_block.py | 6 +++++- sharktank/sharktank/ops/qlinear_impls.py | 8 ++++---- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sharktank/sharktank/kernels/attention.py b/sharktank/sharktank/kernels/attention.py index 6bc97120e..4296ab266 100644 --- a/sharktank/sharktank/kernels/attention.py +++ b/sharktank/sharktank/kernels/attention.py @@ -61,7 +61,6 @@ def select(self, ksel: KernelSelection): q_desc.specialize_dims(0, 1, -1) k_desc.specialize_dims(0, 1, -1) v_desc.specialize_dims(0, 1, -1) - a_desc.specialize_dims(0, 1) # Result 0: Shape batch..., m, n ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims( @@ -88,7 +87,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): b = str(int(b1) * int(b2)) i_type_str = str(q_tensor_type.element_type) scale_type_str = str(scale_tensor_type.element_type) - o_type_str = "f16" + o_type_str = "f32" target_function_name = f"sharktank_masked_flash_attention_{b1}_{b2}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}" kwargs = { diff --git a/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir b/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir index 663a9ea84..a11fb4787 100644 --- a/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir +++ b/sharktank/sharktank/kernels/templates/masked_flash_attention.mlir @@ -7,7 +7,7 @@ !q_type = tensor<{{b1}}x{{b2}}x{{l}}x{{d}}x{{i_dtype}}> !k_type = tensor<{{b1}}x{{b2}}x{{s}}x{{d}}x{{i_dtype}}> !v_type = tensor<{{b1}}x{{b2}}x{{s}}x{{e}}x{{i_dtype}}> -!a_type = tensor<{{1}}x{{1}}x{{l}}x{{s}}x{{i_dtype}}> +!a_type = tensor<{{l}}x{{s}}x{{i_dtype}}> !trans_v_type = tensor<{{b1}}x{{b2}}x{{e}}x{{s}}x{{i_dtype}}> !o_type = tensor<{{b1}}x{{b2}}x{{l}}x{{e}}x{{o_dtype}}> !o_dyn_type = tensor @@ -15,7 +15,7 @@ !q_collapsed_type = tensor<{{b}}x{{l}}x{{d}}x{{i_dtype}}> !k_collapsed_type = tensor<{{b}}x{{s}}x{{d}}x{{i_dtype}}> !v_collapsed_type = tensor<{{b}}x{{s}}x{{e}}x{{i_dtype}}> -!a_collapsed_type = tensor<{{1}}x{{l}}x{{s}}x{{i_dtype}}> +!a_collapsed_type = tensor<{{l}}x{{s}}x{{i_dtype}}> !s_type = tensor<{{scale_dtype}}> module { @@ -44,7 +44,6 @@ util.func private @{{func_name}}( %collapsed_q = tensor.collapse_shape %q [[0, 1], [2], [3]] : !q_type into !q_collapsed_type %collapsed_k = tensor.collapse_shape %k [[0, 1], [2], [3]] : !k_type into !k_collapsed_type %collapsed_v = tensor.collapse_shape %v [[0, 1], [2], [3]] : !v_type into !v_collapsed_type - %collapsed_a = tensor.collapse_shape %a [[0, 1], [2], [3]] : !a_type into !a_collapsed_type %atten = iree_linalg_ext.attention {indexing_maps = [ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, @@ -53,7 +52,7 @@ util.func private @{{func_name}}( affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} - ins(%collapsed_q, %collapsed_k, %collapsed_v, %scale, %collapsed_a : !q_collapsed_type, !k_collapsed_type, !v_collapsed_type, {{scale_dtype}}, !a_collapsed_type) outs(%empty : !o_collapsed_type) { + ins(%collapsed_q, %collapsed_k, %collapsed_v, %scale, %a : !q_collapsed_type, !k_collapsed_type, !v_collapsed_type, {{scale_dtype}}, !a_collapsed_type) outs(%empty : !o_collapsed_type) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> !o_collapsed_type diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 8ace77981..6cd6eda13 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -126,6 +126,8 @@ def attention_mask( # Combine the causal context mask and input mask. dtype = self.attention_dtype + print("attention dtype") + print(self.attention_dtype) _, batch_seq_len = input_mask.shape causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 60af70ba0..7b8cf9b8b 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -205,7 +205,11 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: assert self.attention_scale is not None if attention_mask is not None: attn_output = kernels.masked_flash_attention( - xq, keys, values, attention_mask, self.attention_scale + xq, + keys, + values, + attention_mask.squeeze(0).squeeze(0), + self.attention_scale, ) else: attn_output = kernels.flash_attention(xq, keys, values) diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index bef455412..df6d74b15 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -50,10 +50,10 @@ def qlinear_tensor_scaled( # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - if x_layout.qs.dtype == torch.float8_e4m3fnuz: - # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) - else: + if ( + x_layout.qs.dtype != torch.float8_e4m3fnuz + or weight_layout.qs.dtype != torch.float8_e4m3fnuz + ): return NotImplemented # Bias. From 96a19f17c4150cd9ffa68fe83f9b1dac8ea392f1 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 19 Feb 2025 10:44:59 -0800 Subject: [PATCH 6/6] fixes --- sharktank/sharktank/kernels/attention.py | 2 +- .../sharktank/models/llama/tools/import_quark_dataset.py | 6 +++--- sharktank/sharktank/ops/signatures.py | 2 +- sharktank/tests/kernels/attention_template_test.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sharktank/sharktank/kernels/attention.py b/sharktank/sharktank/kernels/attention.py index 4296ab266..9dae002df 100644 --- a/sharktank/sharktank/kernels/attention.py +++ b/sharktank/sharktank/kernels/attention.py @@ -63,7 +63,7 @@ def select(self, ksel: KernelSelection): v_desc.specialize_dims(0, 1, -1) # Result 0: Shape batch..., m, n - ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims( + ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float32).specialize_dims( 0, 1, -1 ) diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py index 5520a6201..ac2e6a3a5 100644 --- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -313,7 +313,7 @@ def main(argv): type=str, default="7b", help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.", - choices=["7b", "70b"], + choices=["7b", "70b", "405b"], ) args = cli.parse(parser, args=argv) @@ -321,8 +321,8 @@ def main(argv): params_path: Path = args.params # TODO: find a way to get this programatically so we don't have to flag for it split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024] - num_layers = 32 if args.model_base == "7b" else 80 - + layers_per_base = {"7b": 32, "70b": 40, "405b": 125} + num_layers = layers_per_base[args.model_base] # Construct the pre-transform dataset. dataset_props = _get_dataset_props(_load_json(config_json_path)) with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index a698ccb06..dce1646cc 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -859,7 +859,7 @@ def _scaled_dot_product_attention( ): tensors = (q, k, v, a) for override in d.find_overrides(tensors): - result = override(q, k, v, a, is_causal=is_causal, scale=scale) + result = override(q, k, v, a, scale=scale) if result is not NotImplemented: return override, result else: diff --git a/sharktank/tests/kernels/attention_template_test.py b/sharktank/tests/kernels/attention_template_test.py index 268ffcf49..8b68e2b91 100644 --- a/sharktank/tests/kernels/attention_template_test.py +++ b/sharktank/tests/kernels/attention_template_test.py @@ -47,7 +47,7 @@ def test_compare_torch_spda(self, dtype, atol, rtol, use_mask): mask = None scale = torch.tensor(1.0, dtype=dtype) if use_mask: - mask = torch.rand([N, H, L, S], dtype=dtype) + mask = torch.rand([L, S], dtype=dtype) res2 = kernels.masked_flash_attention(q, k, v, mask, scale=scale) @@ -88,7 +88,7 @@ def test_export_dynamic(self, dtype, static, use_mask): v = torch.rand([N, H, S, Ev], dtype=dtype) if use_mask: # mask is same type as inputs, therefore its added to score - mask = torch.rand([N, H, L, S], dtype=dtype) + mask = torch.rand([L, S], dtype=dtype) if cast: q = q.to(torch.float8_e4m3fnuz) k = q.to(torch.float8_e4m3fnuz) @@ -108,7 +108,7 @@ def test_export_dynamic(self, dtype, static, use_mask): "scale": {}, } if use_mask: - dynamic_shapes["mask"] = {2: L_dim, 3: S_dim} + dynamic_shapes["mask"] = {0: L_dim, 1: S_dim} class MyModule(torch.nn.Module): def forward(self, q, k, v, mask, scale):