From 72c2ae3b69334263b84c4122c478357c47a56a66 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Sat, 2 Nov 2024 00:00:14 +0000 Subject: [PATCH] Update attention template (#30) Ths commit updates the attention template to include promote operands and decomposition_config. --------- Signed-off-by: Manupa Karunaratne --- .github/workflows/run_bench.yml | 2 +- attentionbench/attention_utils.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run_bench.yml b/.github/workflows/run_bench.yml index ebe07f0..b8c19a8 100644 --- a/.github/workflows/run_bench.yml +++ b/.github/workflows/run_bench.yml @@ -16,7 +16,7 @@ concurrency: jobs: benchmark: - runs-on: mi300-kernel + runs-on: mi300-sdxl-kernel steps: - name: "Checkout Repo" diff --git a/attentionbench/attention_utils.py b/attentionbench/attention_utils.py index d8d8d18..7628061 100644 --- a/attentionbench/attention_utils.py +++ b/attentionbench/attention_utils.py @@ -70,7 +70,8 @@ def get_lowering_config(self) -> str: f"#iree_gpu.lowering_config<" + "{ " + f"workgroup = [{', '.join(map(str, self.wg_tiles))}], " - + f"reduction = [{', '.join(map(str, self.reduction_tiles))}]" + + f"reduction = [{', '.join(map(str, self.reduction_tiles))}]," + + f"promote_operands = [0, 1, 2]" + " }" + f">" ) @@ -93,7 +94,7 @@ def get_translation_info(self) -> str: return ( f"#iree_codegen.translation_info<" + f"LLVMGPUVectorDistribute" - + f" workgroup_size = [{self.N_warp * 64}, {self.M_warp}]" + + f" workgroup_size = [{self.N_warp * self.M_warp * 64}]" + f" subgroup_size = 64" + f" ,{{mma_schedule = {self.get_mma_schedule()}" + f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}" @@ -137,6 +138,10 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None): %empty = tensor.empty() : !O %O = iree_linalg_ext.attention {{ indexing_maps = [#Q, #K, #V, #S, #O] + ,decomposition_config = {{ + qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}}, + pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}} + }} {",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""} }} ins(%Q, %K, %V, %scale : !Q, !K, !V, !dtype) outs(%empty : !O) {{