From 86c187a3f15e669a9e6dae38e96d3ab8b284daa0 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Fri, 29 Nov 2024 23:59:37 -0600 Subject: [PATCH] [tuner]: fixed the format of lowering_config Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 32 ++++----------- tuner/tuner/candidate_gen_test.py | 64 ++++++++++++++--------------- tuner/tuner/common.py | 47 ++++++++------------- tuner/tuner/dispatch_constraints.py | 11 +++-- tuner/tuner/dispatch_parser.py | 1 - 5 files changed, 64 insertions(+), 91 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index a560a9c70..6f90891e8 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -55,14 +55,14 @@ def apply_configuration( expr1 = re.compile( r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," ) - expr2 = re.compile(r"workgroup = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"reduction = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr2 = re.compile(r"workgroup = \[([0-9]+)(, ([0-9]+))+\]") + expr3 = re.compile(r"reduction = \[([0-9]+)(, ([0-9]+))+\]") expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'workgroup = [[{", ".join(map(str, workgroup_sizes))}]]' - repl3 = f'reduction = [[{", ".join(map(str, reduction_sizes))}]]' + repl2 = f'workgroup = [{", ".join(map(str, workgroup_sizes))}]' + repl3 = f'reduction = [{", ".join(map(str, reduction_sizes))}]' repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' @@ -125,8 +125,6 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - workgroup_sizes = ", ".join(map(str, get_mmt_workgroup_sizes(configuration))) - reduction_sizes = ", ".join(map(str, get_mmt_reduction_sizes(configuration))) intrinsic = configuration.intrinsic() subgroup_m_count = configuration.subgroup_m_count() subgroup_n_count = configuration.subgroup_n_count() @@ -141,7 +139,7 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info str: - workgroup_sizes = ", ".join( - map(str, get_batch_mmt_workgroup_sizes(configuration)) - ) - reduction_sizes = ", ".join( - map(str, get_batch_mmt_reduction_sizes(configuration)) - ) intrinsic = configuration.intrinsic() subgroup_m_count = configuration.subgroup_m_count() subgroup_n_count = configuration.subgroup_n_count() @@ -382,7 +374,7 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info" output = f"tensor<{problem_size.res_type}>" - workgroup_sizes = ", ".join( - map(str, get_contract_workgroup_sizes(configuration, tile_dims)) - ) - reduction_sizes = ", ".join( - map(str, get_contract_reduction_sizes(configuration, tile_dims)) - ) intrinsic = configuration.intrinsic() subgroup_m_count = configuration.subgroup_m_count() subgroup_n_count = configuration.subgroup_n_count() @@ -459,7 +445,7 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", "gpu_pipeline_options = #iree_gpu.pipeline_options", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -50,7 +50,7 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[8, 8, 0], reduction=[0, 0, 8], subgroup_m_count=16, @@ -89,8 +89,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" in modified ) - assert "workgroup = [[8, 8, 0]]" in modified - assert "reduction = [[0, 0, 8]]" in modified + assert "workgroup = [8, 8, 0]" in modified + assert "reduction = [0, 0, 8]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options" in modified @@ -102,7 +102,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -112,7 +112,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[464, 320, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -155,8 +155,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 1, 464, 320, 1, 1, 0]]" in modified - assert "reduction = [[0, 0, 0, 0, 0, 0, 16]]" in modified + assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified + assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" in modified @@ -168,7 +168,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -185,7 +185,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[480, 384, 0], reduction=[0, 0, 32], subgroup_m_count=1, @@ -214,8 +214,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in new_mlir ) - assert "workgroup = [[1, 480, 384, 0]]" in new_mlir - assert "reduction = [[0, 0, 0, 32]]" in new_mlir + assert "workgroup = [1, 480, 384, 0]" in new_mlir + assert "reduction = [0, 0, 0, 32]" in new_mlir assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir @@ -223,7 +223,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -240,7 +240,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[416, 320, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -273,8 +273,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 416, 320, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 416, 320, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -282,7 +282,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -298,7 +298,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 64, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -329,8 +329,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 128, 64, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -338,7 +338,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -354,7 +354,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 64, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -387,8 +387,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 128, 64, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [[1, 128, 64, 0]]" in embeddable - assert "reduction = [[0, 0, 0, 128]]" in embeddable + assert "workgroup = [128, 64, 0]" in embeddable + assert "reduction = [0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -418,7 +418,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -434,7 +434,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 64, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -470,8 +470,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 128, 64, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [[1, 128, 64, 0]]" in embeddable - assert "reduction = [[0, 0, 0, 128]]" in embeddable + assert "workgroup = [128, 64, 0]" in embeddable + assert "reduction = [0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a6fd4dd42..27dfe67c2 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -9,6 +9,7 @@ from dataclasses import astuple, dataclass from enum import Enum from typing import Optional +from typing import Any from iree.compiler import ir # type: ignore @@ -22,6 +23,7 @@ def __init__(self, ctx: ir.Context): self.i8 = ir.IntegerType.get_signless(8, ctx) self.i16 = ir.IntegerType.get_signless(16, ctx) self.i32 = ir.IntegerType.get_signless(32, ctx) + self.i64 = ir.IntegerType.get_signless(64, ctx) self.f8E4M3FNUZ = ir.Float8E4M3FNUZType.get(ctx) self.f8E5M2FNUZ = ir.Float8E5M2FNUZType.get(ctx) @@ -30,8 +32,8 @@ def __init__(self, ctx: ir.Context): self.bf16 = ir.BF16Type.get(ctx) - def getI32(self, value: int) -> ir.IntegerAttr: - return ir.IntegerAttr.get(self.i32, value) + def getI64(self, value: int) -> ir.IntegerAttr: + return ir.IntegerAttr.get(self.i64, value) class TunerContext: @@ -148,37 +150,20 @@ def subgroup_n_count(self) -> Optional[int]: def get_lowering_config( tuner_ctx: TunerContext, - mma_attr: Optional[iree_gpu.MMAAttr] = None, - workgroup: Optional[list[int]] = None, - reduction: Optional[list[int]] = None, - subgroup_m_count: Optional[int] = None, - subgroup_n_count: Optional[int] = None, + **kwargs: Any, ) -> iree_gpu.LoweringConfigAttr: lowering_config_dict = {} - if workgroup is not None: - lowering_config_dict["workgroup"] = ir.ArrayAttr.get( - [tuner_ctx.type.getI32(x) for x in workgroup] - ) - if reduction is not None: - lowering_config_dict["reduction"] = ir.ArrayAttr.get( - [tuner_ctx.type.getI32(x) for x in reduction] - ) - if subgroup_m_count is not None: - lowering_config_dict["subgroup_m_count"] = tuner_ctx.type.getI32( - subgroup_m_count - ) - if subgroup_n_count is not None: - lowering_config_dict["subgroup_n_count"] = tuner_ctx.type.getI32( - subgroup_n_count - ) - # lowering_config_dict = { - # "workgroup": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in workgroup]), - # "reduction": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in reduction]), - # "subgroup_m_count": tuner_ctx.type.getI32(subgroup_m_count), - # "subgroup_n_count": tuner_ctx.type.getI32(subgroup_n_count), - # } - if mma_attr is not None: - lowering_config_dict["mma_kind"] = mma_attr + for key, value in kwargs.items(): + if isinstance(value, list): + lowering_config_dict[key] = ir.ArrayAttr.get( + [tuner_ctx.type.getI64(x) for x in value] + ) + elif isinstance(value, int): + lowering_config_dict[key] = tuner_ctx.type.getI64(value) + elif isinstance(value, iree_gpu.MMAAttr): + lowering_config_dict[key] = value + else: + raise TypeError(f"Unsupported type for key '{key}': {type(value).__name__}") lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 5a775bdcd..39ccec523 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -220,7 +220,7 @@ def generate_solutions( solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") - int_type = ir.IntegerType.get_signless(32) + int_type = ir.IntegerType.get_signless(64) i = 0 while solver.check() == z3.sat: @@ -239,11 +239,15 @@ def generate_solutions( [ ir.IntegerAttr.get(int_type, lookup(m)), ir.IntegerAttr.get(int_type, lookup(n)), - ir.IntegerAttr.get(int_type, lookup(k)), + ir.IntegerAttr.get(int_type, 0), ] ), "reduction": ir.ArrayAttr.get( - [] + [ + ir.IntegerAttr.get(int_type, 0), + ir.IntegerAttr.get(int_type, 0), + ir.IntegerAttr.get(int_type, lookup(k)), + ] ), # placeholder now to be consistent with iree "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), @@ -251,7 +255,6 @@ def generate_solutions( lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) - config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index dd36f5075..bc7788f44 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -171,7 +171,6 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - # oc = configuration.tilesize_workgroup()[1] ow, oc, _ = configuration.tilesize_workgroup() return [batch, oh, ow, oc, fh, fw, 0]