From 21d95781e34fc0c58e0cb35ed3403fd2758440dc Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Thu, 28 Nov 2024 16:57:05 -0600 Subject: [PATCH 1/9] [tuner]: use lowering config binding Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 40 +++++--- tuner/tuner/candidate_gen_test.py | 140 ++++++++++++++++++++++------ tuner/tuner/common.py | 34 ++++--- tuner/tuner/common_test.py | 21 ++++- tuner/tuner/dispatch_constraints.py | 36 +++++-- tuner/tuner/dispatch_parser.py | 8 +- tuner/tuner/dispatch_parser_test.py | 57 ++++++++--- 7 files changed, 255 insertions(+), 81 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index f09e08888..01a6ed2aa 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -42,6 +42,9 @@ def apply_configuration( template: list[str], configuration: Configuration, tile_sizes: list[int] ) -> str: + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -52,7 +55,7 @@ def apply_configuration( expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f"" + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" @@ -116,6 +119,9 @@ def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -131,8 +137,8 @@ def get_transform_function_mmt( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param @@ -186,6 +192,9 @@ def get_transform_function_conv( output = f"tensor<{dynamic_batch_output_ty}>" tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -204,8 +213,8 @@ def get_transform_function_conv( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param @@ -245,6 +254,9 @@ def get_transform_function_broadcast_rhs_mmt( configuration: Configuration, ) -> str: tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -265,8 +277,8 @@ def get_transform_function_broadcast_rhs_mmt( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param @@ -329,6 +341,9 @@ def get_transform_function_batch_mmt( configuration: Configuration, ) -> str: tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -345,8 +360,8 @@ def get_transform_function_batch_mmt( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param @@ -395,6 +410,9 @@ def get_transform_function_batch_matmul( tile_sizes = ", ".join( map(str, get_contract_tile_sizes(configuration, tile_dims)) ) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -413,8 +431,8 @@ def get_transform_function_batch_matmul( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 19b6e1fe7..98055eb3d 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -48,13 +48,25 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( prefetch_shared_memory=True ), @@ -104,13 +116,24 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 464), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( iree_gpu.ReorderWorkgroupsStrategy.Transpose @@ -171,13 +194,25 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 480), + ir.IntegerAttr.get(tuner_ctx.type.i32, 384), + ir.IntegerAttr.get(tuner_ctx.type.i32, 32), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -220,13 +255,26 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 416), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -272,13 +320,25 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 64), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -322,13 +382,25 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 64), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -395,13 +467,25 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 64), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 80c755aa7..ba79f3be6 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -105,26 +105,34 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: return list(filter(is_comptible, mma_intrinsics)) -class ReorderWorkgroupsStrategy(Enum): - NONE = 0 - SWIZZLE = 1 - TRANSPOSE = 2 - - def __str__(self) -> str: - return self.name.title() - - @dataclass class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: iree_gpu.MMAAttr - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int + lowering_config: iree_gpu.LoweringConfigAttr gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int + @property + def intrinsic(self) -> iree_gpu.MMAAttr: + return self.lowering_config.attributes["mma_kind"] + + @property + def tilesize_workgroup(self) -> list[int]: + return [attr.value for attr in self.lowering_config.attributes["workgroup"]] + + @property + def tilesize_reduction(self) -> list[int]: + return [attr.value for attr in self.lowering_config.attributes["reduction"]] + + @property + def subgroup_m_count(self) -> int: + return self.lowering_config.attributes["subgroup_m_count"].value + + @property + def subgroup_n_count(self) -> int: + return self.lowering_config.attributes["subgroup_n_count"].value + def get_pipeline_config(configuration: Configuration) -> str: extra_config = "" diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 73d3f04e3..bbb241980 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -73,16 +73,27 @@ def test_gpu_pipeline_options(tuner_ctx: common.TunerContext) -> None: ) -def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: +def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index cdfb1bd50..5a775bdcd 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -10,8 +10,10 @@ import z3 # type: ignore from typing import Iterator +from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from .common import * @@ -217,15 +219,15 @@ def generate_solutions( ) solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") + + int_type = ir.IntegerType.get_signless(32) + i = 0 while solver.check() == z3.sat: model = solver.model() lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - getMMAAttr( + lowering_config_dict = { + "mma_kind": getMMAAttr( problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), @@ -233,9 +235,27 @@ def generate_solutions( problem_size.lhs_type.element_type, problem_size.rhs_type.element_type, ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(int_type, lookup(m)), + ir.IntegerAttr.get(int_type, lookup(n)), + ir.IntegerAttr.get(int_type, lookup(k)), + ] + ), + "reduction": ir.ArrayAttr.get( + [] + ), # placeholder now to be consistent with iree + "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), + "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + lowering_config, iree_gpu.PipelineOptionsAttr.get(), lookup(waves_per_eu), ) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index c4b4b9ad5..421e2c7ef 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,11 +21,11 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes + return configuration.tilesize_workgroup def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes + m, n, k = configuration.tilesize_workgroup tile_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "m": @@ -38,7 +38,7 @@ def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> lis def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes + return [1] + configuration.tilesize_workgroup class MlirRegex(Enum): @@ -141,7 +141,7 @@ def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes + m, n, k = configuration.tilesize_workgroup batch = 1 fh = 1 fw = 1 diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 529559f83..8318ca9c1 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -42,13 +42,24 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 32), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=mma_attr, - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) @@ -58,13 +69,24 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 464), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) @@ -82,13 +104,24 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) From 005278a973e7f98f23a6c6ec302e1fbe684bb183 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Fri, 29 Nov 2024 02:47:47 -0600 Subject: [PATCH 2/9] [tuner]: address comments Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 137 ++++++++++------- tuner/tuner/candidate_gen_test.py | 218 ++++++++++++---------------- tuner/tuner/common.py | 79 ++++++++-- tuner/tuner/common_test.py | 45 ++++-- tuner/tuner/dispatch_parser.py | 69 ++++++--- tuner/tuner/dispatch_parser_test.py | 102 ++++++------- 6 files changed, 364 insertions(+), 286 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 01a6ed2aa..a560a9c70 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -40,11 +40,14 @@ def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] + template: list[str], + configuration: Configuration, + workgroup_sizes: list[int], + reduction_sizes: list[int], ) -> str: - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -52,14 +55,16 @@ def apply_configuration( expr1 = re.compile( r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") - expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + expr2 = re.compile(r"workgroup = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile(r"reduction = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") + expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" - repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + repl2 = f'workgroup = [[{", ".join(map(str, workgroup_sizes))}]]' + repl3 = f'reduction = [[{", ".join(map(str, reduction_sizes))}]]' + repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" + repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' new_mlir = "" for line in template: @@ -67,12 +72,14 @@ def apply_configuration( line = re.sub(expr0, repl0, line) if "LLVMGPUVectorDistribute " in line: line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: + if "workgroup" in line: line = re.sub(expr2, repl2, line) - if "gpu_pipeline_options =" in line: + if "reduction" in line: line = re.sub(expr3, repl3, line) - if "amdgpu-waves-per-eu" in line: + if "gpu_pipeline_options =" in line: line = re.sub(expr4, repl4, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr5, repl5, line) new_mlir += line return new_mlir @@ -118,10 +125,11 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join(map(str, get_mmt_workgroup_sizes(configuration))) + reduction_sizes = ", ".join(map(str, get_mmt_reduction_sizes(configuration))) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -133,7 +141,7 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info str: @@ -191,10 +195,15 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join( + map(str, self.get_conv_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, self.get_conv_reduction_sizes(configuration)) + ) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -209,7 +218,7 @@ def get_transform_function_conv( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join( + map(str, get_batch_mmt_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, get_batch_mmt_reduction_sizes(configuration)) + ) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -273,7 +290,7 @@ def get_transform_function_broadcast_rhs_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join( + map(str, get_batch_mmt_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, get_batch_mmt_reduction_sizes(configuration)) + ) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -356,7 +382,7 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info" output = f"tensor<{problem_size.res_type}>" - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) + workgroup_sizes = ", ".join( + map(str, get_contract_workgroup_sizes(configuration, tile_dims)) + ) + reduction_sizes = ", ".join( + map(str, get_contract_reduction_sizes(configuration, tile_dims)) ) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -427,7 +459,7 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", "gpu_pipeline_options = #iree_gpu.pipeline_options", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -48,25 +48,18 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[8, 8, 0], + reduction=[0, 0, 8], + subgroup_m_count=16, + subgroup_n_count=16, + ) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( prefetch_shared_memory=True ), @@ -96,7 +89,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" in modified ) - assert "tile_sizes = [[8, 8, 8]]" in modified + assert "workgroup = [[8, 8, 0]]" in modified + assert "reduction = [[0, 0, 8]]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options" in modified @@ -108,7 +102,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -116,24 +110,18 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 464), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[464, 320, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( iree_gpu.ReorderWorkgroupsStrategy.Transpose @@ -167,7 +155,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert "workgroup = [[1, 1, 464, 320, 1, 1, 0]]" in modified + assert "reduction = [[0, 0, 0, 0, 0, 0, 16]]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" in modified @@ -179,7 +168,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -194,25 +183,18 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 480), - ir.IntegerAttr.get(tuner_ctx.type.i32, 384), - ir.IntegerAttr.get(tuner_ctx.type.i32, 32), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[480, 384, 0], + reduction=[0, 0, 32], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -232,7 +214,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in new_mlir ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert "workgroup = [[1, 480, 384, 0]]" in new_mlir + assert "reduction = [[0, 0, 0, 32]]" in new_mlir assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir @@ -240,7 +223,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -255,26 +238,18 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 416), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) - + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[416, 320, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -298,7 +273,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert "workgroup = [[1, 416, 320, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -306,7 +282,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -320,25 +296,18 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 64), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -360,7 +329,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [[1, 128, 64, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -368,7 +338,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -382,25 +352,18 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 64), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -424,7 +387,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [[1, 128, 64, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -444,7 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert "workgroup = [[1, 128, 64, 0]]" in embeddable + assert "reduction = [[0, 0, 0, 128]]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -453,7 +418,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -467,25 +432,18 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 64), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -512,7 +470,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [[1, 128, 64, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -533,7 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert "workgroup = [[1, 128, 64, 0]]" in embeddable + assert "reduction = [[0, 0, 0, 128]]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -541,7 +501,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', ] assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index ba79f3be6..a6fd4dd42 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -30,6 +30,9 @@ def __init__(self, ctx: ir.Context): self.bf16 = ir.BF16Type.get(ctx) + def getI32(self, value: int) -> ir.IntegerAttr: + return ir.IntegerAttr.get(self.i32, value) + class TunerContext: def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): @@ -113,25 +116,71 @@ class Configuration: gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int - @property - def intrinsic(self) -> iree_gpu.MMAAttr: - return self.lowering_config.attributes["mma_kind"] + def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: + if self.lowering_config.attributes.__contains__("mma_kind"): + return self.lowering_config.attributes.__getitem__("mma_kind") + return None - @property def tilesize_workgroup(self) -> list[int]: - return [attr.value for attr in self.lowering_config.attributes["workgroup"]] + if self.lowering_config.attributes.__contains__("workgroup"): + workgroup_attrs = self.lowering_config.attributes.__getitem__("workgroup") + return [attr.value for attr in workgroup_attrs] + return [] - @property def tilesize_reduction(self) -> list[int]: - return [attr.value for attr in self.lowering_config.attributes["reduction"]] - - @property - def subgroup_m_count(self) -> int: - return self.lowering_config.attributes["subgroup_m_count"].value - - @property - def subgroup_n_count(self) -> int: - return self.lowering_config.attributes["subgroup_n_count"].value + if self.lowering_config.attributes.__contains__("reduction"): + reduction_attrs = self.lowering_config.attributes.__getitem__("reduction") + return [attr.value for attr in reduction_attrs] + return [] + + def subgroup_m_count(self) -> Optional[int]: + if self.lowering_config.attributes.__contains__("subgroup_m_count"): + attr = self.lowering_config.attributes.__getitem__("subgroup_m_count") + return attr.value + return None + + def subgroup_n_count(self) -> Optional[int]: + if self.lowering_config.attributes.__contains__("subgroup_n_count"): + attr = self.lowering_config.attributes.__getitem__("subgroup_n_count") + return attr.value + return None + + +def get_lowering_config( + tuner_ctx: TunerContext, + mma_attr: Optional[iree_gpu.MMAAttr] = None, + workgroup: Optional[list[int]] = None, + reduction: Optional[list[int]] = None, + subgroup_m_count: Optional[int] = None, + subgroup_n_count: Optional[int] = None, +) -> iree_gpu.LoweringConfigAttr: + lowering_config_dict = {} + if workgroup is not None: + lowering_config_dict["workgroup"] = ir.ArrayAttr.get( + [tuner_ctx.type.getI32(x) for x in workgroup] + ) + if reduction is not None: + lowering_config_dict["reduction"] = ir.ArrayAttr.get( + [tuner_ctx.type.getI32(x) for x in reduction] + ) + if subgroup_m_count is not None: + lowering_config_dict["subgroup_m_count"] = tuner_ctx.type.getI32( + subgroup_m_count + ) + if subgroup_n_count is not None: + lowering_config_dict["subgroup_n_count"] = tuner_ctx.type.getI32( + subgroup_n_count + ) + # lowering_config_dict = { + # "workgroup": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in workgroup]), + # "reduction": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in reduction]), + # "subgroup_m_count": tuner_ctx.type.getI32(subgroup_m_count), + # "subgroup_n_count": tuner_ctx.type.getI32(subgroup_n_count), + # } + if mma_attr is not None: + lowering_config_dict["mma_kind"] = mma_attr + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) def get_pipeline_config(configuration: Configuration) -> str: diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index bbb241980..056224458 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -76,24 +76,18 @@ def test_gpu_pipeline_options(tuner_ctx: common.TunerContext) -> None: def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -197,3 +191,24 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: ) == [] ) + + +def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) + config = common.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + lowering_config=lowering_config, + gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), + waves_per_eu=2, + ) + + assert config.intrinsic() is None + assert config.subgroup_m_count() == 1 + assert config.subgroup_n_count() == 1 diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 421e2c7ef..dd36f5075 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -20,25 +20,49 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(shaped_ty.shape, shaped_ty.element_type) -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tilesize_workgroup +def get_mmt_workgroup_sizes(configuration: Configuration): + return configuration.tilesize_workgroup() -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tilesize_workgroup - tile_size = [1] * len(tile_dims) +def get_mmt_reduction_sizes(configuration: Configuration): + return configuration.tilesize_reduction() + + +def get_contract_workgroup_sizes( + configuration: Configuration, tile_dims: str +) -> list[int]: + m, n, _ = configuration.tilesize_workgroup() + + workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "m": - tile_size[idx] = m + workgroup_size[idx] = m if dim == "n": - tile_size[idx] = n + workgroup_size[idx] = n + if dim == "k": + workgroup_size[idx] = 0 + + return workgroup_size + + +def get_contract_reduction_sizes( + configuration: Configuration, tile_dims: str +) -> list[int]: + _, _, k = configuration.tilesize_reduction() + reduction_size = [0] * len(tile_dims) + for idx, dim in enumerate(tile_dims): if dim == "k": - tile_size[idx] = k - return tile_size + reduction_size[idx] = k + + return reduction_size + + +def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tilesize_workgroup() -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tilesize_workgroup +def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: + return [0] + configuration.tilesize_reduction() class MlirRegex(Enum): @@ -140,18 +164,22 @@ class ConvParser(DispatchParser): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tilesize_workgroup + def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: batch = 1 fh = 1 fw = 1 oh = 1 - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] + # oc = configuration.tilesize_workgroup()[1] + ow, oc, _ = configuration.tilesize_workgroup() + + return [batch, oh, ow, oc, fh, fw, 0] + + def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: + _, _, ic = configuration.tilesize_reduction() + + return [0, 0, 0, 0, 0, 0, ic] def get_shapes(self, template: list[str]) -> ProblemSize: for line in template: @@ -178,13 +206,6 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_shaped_type = parse_tensor_type(res_tensor_type) assert res_shaped_type.rank() == 4 - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) return ProblemSize( MatmulSize( diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 8318ca9c1..8e99188d0 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -42,61 +42,59 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 32), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 320, 0], + reduction=[0, 0, 32], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) - assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + assert dispatch_parser.get_mmt_workgroup_sizes(config) == [128, 320, 0] + assert dispatch_parser.get_mmt_reduction_sizes(config) == [0, 0, 32] def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 464), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[464, 320, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) - assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + assert dispatch_parser.ConvParser().get_conv_workgroup_sizes(config) == [ 1, 1, 464, 320, 1, 1, + 0, + ] + assert dispatch_parser.ConvParser().get_conv_reduction_sizes(config) == [ + 0, + 0, + 0, + 0, + 0, + 0, 16, ] @@ -104,31 +102,33 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) - assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] - assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] - assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] - assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ + assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16] + assert dispatch_parser.get_contract_workgroup_sizes(config, "nmk") == [8, 4, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] + assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] + assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] + assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [ + 0, + 0, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [ 16, 16, 16, From 16734d6cfa8ed8c93726af80a58bcc52919bd26b Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Fri, 29 Nov 2024 23:59:37 -0600 Subject: [PATCH 3/9] [tuner]: fixed the format of lowering_config Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 32 ++++----------- tuner/tuner/candidate_gen_test.py | 64 ++++++++++++++--------------- tuner/tuner/common.py | 47 ++++++++------------- tuner/tuner/dispatch_constraints.py | 11 +++-- tuner/tuner/dispatch_parser.py | 1 - 5 files changed, 64 insertions(+), 91 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index a560a9c70..6f90891e8 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -55,14 +55,14 @@ def apply_configuration( expr1 = re.compile( r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," ) - expr2 = re.compile(r"workgroup = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"reduction = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr2 = re.compile(r"workgroup = \[([0-9]+)(, ([0-9]+))+\]") + expr3 = re.compile(r"reduction = \[([0-9]+)(, ([0-9]+))+\]") expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'workgroup = [[{", ".join(map(str, workgroup_sizes))}]]' - repl3 = f'reduction = [[{", ".join(map(str, reduction_sizes))}]]' + repl2 = f'workgroup = [{", ".join(map(str, workgroup_sizes))}]' + repl3 = f'reduction = [{", ".join(map(str, reduction_sizes))}]' repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' @@ -125,8 +125,6 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - workgroup_sizes = ", ".join(map(str, get_mmt_workgroup_sizes(configuration))) - reduction_sizes = ", ".join(map(str, get_mmt_reduction_sizes(configuration))) intrinsic = configuration.intrinsic() subgroup_m_count = configuration.subgroup_m_count() subgroup_n_count = configuration.subgroup_n_count() @@ -141,7 +139,7 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info str: - workgroup_sizes = ", ".join( - map(str, get_batch_mmt_workgroup_sizes(configuration)) - ) - reduction_sizes = ", ".join( - map(str, get_batch_mmt_reduction_sizes(configuration)) - ) intrinsic = configuration.intrinsic() subgroup_m_count = configuration.subgroup_m_count() subgroup_n_count = configuration.subgroup_n_count() @@ -382,7 +374,7 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info" output = f"tensor<{problem_size.res_type}>" - workgroup_sizes = ", ".join( - map(str, get_contract_workgroup_sizes(configuration, tile_dims)) - ) - reduction_sizes = ", ".join( - map(str, get_contract_reduction_sizes(configuration, tile_dims)) - ) intrinsic = configuration.intrinsic() subgroup_m_count = configuration.subgroup_m_count() subgroup_n_count = configuration.subgroup_n_count() @@ -459,7 +445,7 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = {configuration.lowering_config}>, translation_info = #iree_codegen.translation_info None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", "gpu_pipeline_options = #iree_gpu.pipeline_options", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -50,7 +50,7 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[8, 8, 0], reduction=[0, 0, 8], subgroup_m_count=16, @@ -89,8 +89,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" in modified ) - assert "workgroup = [[8, 8, 0]]" in modified - assert "reduction = [[0, 0, 8]]" in modified + assert "workgroup = [8, 8, 0]" in modified + assert "reduction = [0, 0, 8]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options" in modified @@ -102,7 +102,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -112,7 +112,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[464, 320, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -155,8 +155,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 1, 464, 320, 1, 1, 0]]" in modified - assert "reduction = [[0, 0, 0, 0, 0, 0, 16]]" in modified + assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified + assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" in modified @@ -168,7 +168,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -185,7 +185,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[480, 384, 0], reduction=[0, 0, 32], subgroup_m_count=1, @@ -214,8 +214,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in new_mlir ) - assert "workgroup = [[1, 480, 384, 0]]" in new_mlir - assert "reduction = [[0, 0, 0, 32]]" in new_mlir + assert "workgroup = [1, 480, 384, 0]" in new_mlir + assert "reduction = [0, 0, 0, 32]" in new_mlir assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir @@ -223,7 +223,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -240,7 +240,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[416, 320, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -273,8 +273,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 416, 320, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 416, 320, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -282,7 +282,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -298,7 +298,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 64, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -329,8 +329,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 128, 64, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -338,7 +338,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -354,7 +354,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 64, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -387,8 +387,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 128, 64, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [[1, 128, 64, 0]]" in embeddable - assert "reduction = [[0, 0, 0, 128]]" in embeddable + assert "workgroup = [128, 64, 0]" in embeddable + assert "reduction = [0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -418,7 +418,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -434,7 +434,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 64, 0], reduction=[0, 0, 128], subgroup_m_count=2, @@ -470,8 +470,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "workgroup = [[1, 128, 64, 0]]" in modified - assert "reduction = [[0, 0, 0, 128]]" in modified + assert "workgroup = [1, 128, 64, 0]" in modified + assert "reduction = [0, 0, 0, 128]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [[1, 128, 64, 0]]" in embeddable - assert "reduction = [[0, 0, 0, 128]]" in embeddable + assert "workgroup = [128, 64, 0]" in embeddable + assert "reduction = [0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a6fd4dd42..27dfe67c2 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -9,6 +9,7 @@ from dataclasses import astuple, dataclass from enum import Enum from typing import Optional +from typing import Any from iree.compiler import ir # type: ignore @@ -22,6 +23,7 @@ def __init__(self, ctx: ir.Context): self.i8 = ir.IntegerType.get_signless(8, ctx) self.i16 = ir.IntegerType.get_signless(16, ctx) self.i32 = ir.IntegerType.get_signless(32, ctx) + self.i64 = ir.IntegerType.get_signless(64, ctx) self.f8E4M3FNUZ = ir.Float8E4M3FNUZType.get(ctx) self.f8E5M2FNUZ = ir.Float8E5M2FNUZType.get(ctx) @@ -30,8 +32,8 @@ def __init__(self, ctx: ir.Context): self.bf16 = ir.BF16Type.get(ctx) - def getI32(self, value: int) -> ir.IntegerAttr: - return ir.IntegerAttr.get(self.i32, value) + def getI64(self, value: int) -> ir.IntegerAttr: + return ir.IntegerAttr.get(self.i64, value) class TunerContext: @@ -148,37 +150,20 @@ def subgroup_n_count(self) -> Optional[int]: def get_lowering_config( tuner_ctx: TunerContext, - mma_attr: Optional[iree_gpu.MMAAttr] = None, - workgroup: Optional[list[int]] = None, - reduction: Optional[list[int]] = None, - subgroup_m_count: Optional[int] = None, - subgroup_n_count: Optional[int] = None, + **kwargs: Any, ) -> iree_gpu.LoweringConfigAttr: lowering_config_dict = {} - if workgroup is not None: - lowering_config_dict["workgroup"] = ir.ArrayAttr.get( - [tuner_ctx.type.getI32(x) for x in workgroup] - ) - if reduction is not None: - lowering_config_dict["reduction"] = ir.ArrayAttr.get( - [tuner_ctx.type.getI32(x) for x in reduction] - ) - if subgroup_m_count is not None: - lowering_config_dict["subgroup_m_count"] = tuner_ctx.type.getI32( - subgroup_m_count - ) - if subgroup_n_count is not None: - lowering_config_dict["subgroup_n_count"] = tuner_ctx.type.getI32( - subgroup_n_count - ) - # lowering_config_dict = { - # "workgroup": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in workgroup]), - # "reduction": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in reduction]), - # "subgroup_m_count": tuner_ctx.type.getI32(subgroup_m_count), - # "subgroup_n_count": tuner_ctx.type.getI32(subgroup_n_count), - # } - if mma_attr is not None: - lowering_config_dict["mma_kind"] = mma_attr + for key, value in kwargs.items(): + if isinstance(value, list): + lowering_config_dict[key] = ir.ArrayAttr.get( + [tuner_ctx.type.getI64(x) for x in value] + ) + elif isinstance(value, int): + lowering_config_dict[key] = tuner_ctx.type.getI64(value) + elif isinstance(value, iree_gpu.MMAAttr): + lowering_config_dict[key] = value + else: + raise TypeError(f"Unsupported type for key '{key}': {type(value).__name__}") lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 5a775bdcd..39ccec523 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -220,7 +220,7 @@ def generate_solutions( solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") - int_type = ir.IntegerType.get_signless(32) + int_type = ir.IntegerType.get_signless(64) i = 0 while solver.check() == z3.sat: @@ -239,11 +239,15 @@ def generate_solutions( [ ir.IntegerAttr.get(int_type, lookup(m)), ir.IntegerAttr.get(int_type, lookup(n)), - ir.IntegerAttr.get(int_type, lookup(k)), + ir.IntegerAttr.get(int_type, 0), ] ), "reduction": ir.ArrayAttr.get( - [] + [ + ir.IntegerAttr.get(int_type, 0), + ir.IntegerAttr.get(int_type, 0), + ir.IntegerAttr.get(int_type, lookup(k)), + ] ), # placeholder now to be consistent with iree "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), @@ -251,7 +255,6 @@ def generate_solutions( lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) - config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index dd36f5075..bc7788f44 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -171,7 +171,6 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - # oc = configuration.tilesize_workgroup()[1] ow, oc, _ = configuration.tilesize_workgroup() return [batch, oh, ow, oc, fh, fw, 0] From 1cab36c1537f624608a9d10999fb3f39eeb4dccf Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 2 Dec 2024 01:03:02 -0600 Subject: [PATCH 4/9] [tuner]: format the code Signed-off-by: Bangtian Liu --- tuner/tuner/common.py | 61 +++++++++++++++++++---------- tuner/tuner/common_test.py | 8 +++- tuner/tuner/dispatch_constraints.py | 2 +- tuner/tuner/dispatch_parser.py | 4 +- tuner/tuner/dispatch_parser_test.py | 18 +++------ 5 files changed, 55 insertions(+), 38 deletions(-) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 27dfe67c2..b33d5845e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -119,31 +119,30 @@ class Configuration: waves_per_eu: int def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: - if self.lowering_config.attributes.__contains__("mma_kind"): - return self.lowering_config.attributes.__getitem__("mma_kind") - return None + if "mma_kind" in self.lowering_config.attributes: + return self.lowering_config.attributes["mma_kind"] def tilesize_workgroup(self) -> list[int]: - if self.lowering_config.attributes.__contains__("workgroup"): - workgroup_attrs = self.lowering_config.attributes.__getitem__("workgroup") + if "workgroup" in self.lowering_config.attributes: + workgroup_attrs = self.lowering_config.attributes["workgroup"] return [attr.value for attr in workgroup_attrs] return [] def tilesize_reduction(self) -> list[int]: - if self.lowering_config.attributes.__contains__("reduction"): - reduction_attrs = self.lowering_config.attributes.__getitem__("reduction") + if "reduction" in self.lowering_config.attributes: + reduction_attrs = self.lowering_config.attributes["reduction"] return [attr.value for attr in reduction_attrs] return [] def subgroup_m_count(self) -> Optional[int]: - if self.lowering_config.attributes.__contains__("subgroup_m_count"): - attr = self.lowering_config.attributes.__getitem__("subgroup_m_count") + if "subgroup_m_count" in self.lowering_config.attributes: + attr = self.lowering_config.attributes["subgroup_m_count"] return attr.value return None def subgroup_n_count(self) -> Optional[int]: - if self.lowering_config.attributes.__contains__("subgroup_n_count"): - attr = self.lowering_config.attributes.__getitem__("subgroup_n_count") + if "subgroup_n_count" in self.lowering_config.attributes: + attr = self.lowering_config.attributes["subgroup_n_count"] return attr.value return None @@ -154,16 +153,36 @@ def get_lowering_config( ) -> iree_gpu.LoweringConfigAttr: lowering_config_dict = {} for key, value in kwargs.items(): - if isinstance(value, list): - lowering_config_dict[key] = ir.ArrayAttr.get( - [tuner_ctx.type.getI64(x) for x in value] - ) - elif isinstance(value, int): - lowering_config_dict[key] = tuner_ctx.type.getI64(value) - elif isinstance(value, iree_gpu.MMAAttr): - lowering_config_dict[key] = value - else: - raise TypeError(f"Unsupported type for key '{key}': {type(value).__name__}") + match key: + case "workgroup" | "reduction": + if isinstance(value, list): + lowering_config_dict[key] = ir.ArrayAttr.get( + [tuner_ctx.type.getI64(x) for x in value] + ) + elif isinstance(value, ir.ArrayAttr): + lowering_config_dict[key] = value + else: + raise TypeError( + f"Unsupported type for key '{key}': {type(value).__name__}" + ) + case "subgroup_m_count" | "subgroup_n_count": + if isinstance(value, int): + lowering_config_dict[key] = tuner_ctx.type.getI64(value) + elif isinstance(value, tuner_ctx.type.i64): + lowering_config_dict[key] = value + else: + raise TypeError( + f"Unsupported type for key '{key}': {type(value).__name__}" + ) + case "mma_kind": + if isinstance(value, iree_gpu.MMAAttr): + lowering_config_dict[key] = value + else: + raise TypeError( + f"Unsupported type for key '{key}': {type(value).__name__}" + ) + case _: + raise KeyError(f"Unhandled key in lowering configuration: {key}") lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 056224458..1dfb6ff7b 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -78,7 +78,7 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[4, 8, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -201,6 +201,12 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=1, ) + + assert ( + str(lowering_config) + == "#iree_gpu.lowering_config<{reduction = [0, 0, 16], subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64, workgroup = [4, 8, 0]}>" + ) + config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 39ccec523..915e84711 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -248,7 +248,7 @@ def generate_solutions( ir.IntegerAttr.get(int_type, 0), ir.IntegerAttr.get(int_type, lookup(k)), ] - ), # placeholder now to be consistent with iree + ), # Placeholder now to be consistent with iree. "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), } diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index bc7788f44..0c5209ccd 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - ow, oc, _ = configuration.tilesize_workgroup() + ow, oc, _ic = configuration.tilesize_workgroup() return [batch, oh, ow, oc, fh, fw, 0] def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _, _, ic = configuration.tilesize_reduction() + _ow, _oc, ic = configuration.tilesize_reduction() return [0, 0, 0, 0, 0, 0, ic] diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 8e99188d0..650540c63 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -44,7 +44,7 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 320, 0], reduction=[0, 0, 32], subgroup_m_count=1, @@ -66,7 +66,7 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[464, 320, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -104,7 +104,7 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[4, 8, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -123,16 +123,8 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] - assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [ - 0, - 0, - 0, - ] - assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [ - 16, - 16, - 16, - ] + assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [0, 0, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [16, 16, 16] def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: From a2094ac8187bfde94869cb311a0bbf31d8168410 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 2 Dec 2024 01:12:12 -0600 Subject: [PATCH 5/9] [tuner]: fix the ci error Signed-off-by: Bangtian Liu --- tuner/tuner/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index b33d5845e..3253dd077 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -121,6 +121,7 @@ class Configuration: def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: if "mma_kind" in self.lowering_config.attributes: return self.lowering_config.attributes["mma_kind"] + return None def tilesize_workgroup(self) -> list[int]: if "workgroup" in self.lowering_config.attributes: @@ -151,7 +152,7 @@ def get_lowering_config( tuner_ctx: TunerContext, **kwargs: Any, ) -> iree_gpu.LoweringConfigAttr: - lowering_config_dict = {} + lowering_config_dict: dict[str, Any] = {} for key, value in kwargs.items(): match key: case "workgroup" | "reduction": From ac9f941b2759decd21edfd73e02988c4da930b03 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 2 Dec 2024 14:03:24 -0600 Subject: [PATCH 6/9] [tuner] move methods out of configuration as free functions Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 36 ++++++------- tuner/tuner/common.py | 96 +++++++++++++++++----------------- tuner/tuner/common_test.py | 6 +-- tuner/tuner/dispatch_parser.py | 16 +++--- 4 files changed, 77 insertions(+), 77 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 6f90891e8..1c6ef5c8d 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -45,9 +45,9 @@ def apply_configuration( workgroup_sizes: list[int], reduction_sizes: list[int], ) -> str: - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -125,9 +125,9 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -199,9 +199,9 @@ def get_transform_function_conv( reduction_sizes = ", ".join( map(str, self.get_conv_reduction_sizes(configuration)) ) - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -269,9 +269,9 @@ def get_transform_function_broadcast_rhs_mmt( reduction_sizes = ", ".join( map(str, get_batch_mmt_reduction_sizes(configuration)) ) - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -359,9 +359,9 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -428,9 +428,9 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 3253dd077..a839ad4c4 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -118,34 +118,39 @@ class Configuration: gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int - def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: - if "mma_kind" in self.lowering_config.attributes: - return self.lowering_config.attributes["mma_kind"] - return None - - def tilesize_workgroup(self) -> list[int]: - if "workgroup" in self.lowering_config.attributes: - workgroup_attrs = self.lowering_config.attributes["workgroup"] - return [attr.value for attr in workgroup_attrs] - return [] - - def tilesize_reduction(self) -> list[int]: - if "reduction" in self.lowering_config.attributes: - reduction_attrs = self.lowering_config.attributes["reduction"] - return [attr.value for attr in reduction_attrs] - return [] - - def subgroup_m_count(self) -> Optional[int]: - if "subgroup_m_count" in self.lowering_config.attributes: - attr = self.lowering_config.attributes["subgroup_m_count"] - return attr.value - return None - - def subgroup_n_count(self) -> Optional[int]: - if "subgroup_n_count" in self.lowering_config.attributes: - attr = self.lowering_config.attributes["subgroup_n_count"] - return attr.value - return None + +def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: + if "mma_kind" in config.lowering_config.attributes: + return config.lowering_config.attributes["mma_kind"] + return None + + +def get_tilesize_workgroup(config: Configuration) -> list[int]: + if "workgroup" in config.lowering_config.attributes: + workgroup_attrs = config.lowering_config.attributes["workgroup"] + return [attr.value for attr in workgroup_attrs] + return [] + + +def get_tilesize_reduction(config: Configuration) -> list[int]: + if "reduction" in config.lowering_config.attributes: + reduction_attrs = config.lowering_config.attributes["reduction"] + return [attr.value for attr in reduction_attrs] + return [] + + +def get_subgroup_m_count(config: Configuration) -> Optional[int]: + if "subgroup_m_count" in config.lowering_config.attributes: + attr = config.lowering_config.attributes["subgroup_m_count"] + return attr.value + return None + + +def get_subgroup_n_count(config: Configuration) -> Optional[int]: + if "subgroup_n_count" in config.lowering_config.attributes: + attr = config.lowering_config.attributes["subgroup_n_count"] + return attr.value + return None def get_lowering_config( @@ -154,36 +159,31 @@ def get_lowering_config( ) -> iree_gpu.LoweringConfigAttr: lowering_config_dict: dict[str, Any] = {} for key, value in kwargs.items(): + # A local variable to hold the transformed value. + promoted_value = value match key: case "workgroup" | "reduction": + assert isinstance( + value, (list, ir.ArrayAttr) + ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, list): - lowering_config_dict[key] = ir.ArrayAttr.get( + promoted_value = ir.ArrayAttr.get( [tuner_ctx.type.getI64(x) for x in value] ) - elif isinstance(value, ir.ArrayAttr): - lowering_config_dict[key] = value - else: - raise TypeError( - f"Unsupported type for key '{key}': {type(value).__name__}" - ) case "subgroup_m_count" | "subgroup_n_count": + assert isinstance( + value, (int, tuner_ctx.type.i64) + ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, int): - lowering_config_dict[key] = tuner_ctx.type.getI64(value) - elif isinstance(value, tuner_ctx.type.i64): - lowering_config_dict[key] = value - else: - raise TypeError( - f"Unsupported type for key '{key}': {type(value).__name__}" - ) + promoted_value = tuner_ctx.type.getI64(value) case "mma_kind": - if isinstance(value, iree_gpu.MMAAttr): - lowering_config_dict[key] = value - else: - raise TypeError( - f"Unsupported type for key '{key}': {type(value).__name__}" - ) + assert isinstance( + value, iree_gpu.MMAAttr + ), f"Unsupported type for key '{key}': {type(value).__name__}" case _: raise KeyError(f"Unhandled key in lowering configuration: {key}") + # Single assignment after the match. + lowering_config_dict[key] = promoted_value lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 1dfb6ff7b..f13aed3d7 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -215,6 +215,6 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: waves_per_eu=2, ) - assert config.intrinsic() is None - assert config.subgroup_m_count() == 1 - assert config.subgroup_n_count() == 1 + assert common.get_intrinsic(config) is None + assert common.get_subgroup_m_count(config) == 1 + assert common.get_subgroup_n_count(config) == 1 diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 0c5209ccd..503ece345 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,17 +21,17 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_mmt_workgroup_sizes(configuration: Configuration): - return configuration.tilesize_workgroup() + return get_tilesize_workgroup(configuration) def get_mmt_reduction_sizes(configuration: Configuration): - return configuration.tilesize_reduction() + return get_tilesize_reduction(configuration) def get_contract_workgroup_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - m, n, _ = configuration.tilesize_workgroup() + m, n, _k = get_tilesize_workgroup(configuration) workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): @@ -48,7 +48,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - _, _, k = configuration.tilesize_reduction() + _m, _n, k = get_tilesize_reduction(configuration) reduction_size = [0] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "k": @@ -58,11 +58,11 @@ def get_contract_reduction_sizes( def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tilesize_workgroup() + return [1] + get_tilesize_workgroup(configuration) def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: - return [0] + configuration.tilesize_reduction() + return [0] + get_tilesize_reduction(configuration) class MlirRegex(Enum): @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - ow, oc, _ic = configuration.tilesize_workgroup() + ow, oc, _ic = get_tilesize_workgroup(configuration) return [batch, oh, ow, oc, fh, fw, 0] def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _ow, _oc, ic = configuration.tilesize_reduction() + _ow, _oc, ic = get_tilesize_reduction(configuration) return [0, 0, 0, 0, 0, 0, ic] From 06c1b51687e13f2dcac92b1f3e2735f066c6b8be Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 2 Dec 2024 15:31:15 -0600 Subject: [PATCH 7/9] [tuner]: rename functions and assert False Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 4 ++-- tuner/tuner/common.py | 29 ++++++++++++++++------------- tuner/tuner/dispatch_parser.py | 16 ++++++++-------- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 1c6ef5c8d..2a544ef55 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -61,8 +61,8 @@ def apply_configuration( expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'workgroup = [{", ".join(map(str, workgroup_sizes))}]' - repl3 = f'reduction = [{", ".join(map(str, reduction_sizes))}]' + repl2 = f"workgroup = {workgroup_sizes}" + repl3 = f"reduction = {reduction_sizes}" repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a839ad4c4..702008f5e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -125,14 +125,14 @@ def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: return None -def get_tilesize_workgroup(config: Configuration) -> list[int]: +def get_workgroup_tile_sizes(config: Configuration) -> list[int]: if "workgroup" in config.lowering_config.attributes: workgroup_attrs = config.lowering_config.attributes["workgroup"] return [attr.value for attr in workgroup_attrs] return [] -def get_tilesize_reduction(config: Configuration) -> list[int]: +def get_reduction_tile_sizes(config: Configuration) -> list[int]: if "reduction" in config.lowering_config.attributes: reduction_attrs = config.lowering_config.attributes["reduction"] return [attr.value for attr in reduction_attrs] @@ -163,26 +163,29 @@ def get_lowering_config( promoted_value = value match key: case "workgroup" | "reduction": - assert isinstance( - value, (list, ir.ArrayAttr) - ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, list): promoted_value = ir.ArrayAttr.get( [tuner_ctx.type.getI64(x) for x in value] ) + elif not isinstance(value, ir.ArrayAttr): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" case "subgroup_m_count" | "subgroup_n_count": - assert isinstance( - value, (int, tuner_ctx.type.i64) - ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, int): promoted_value = tuner_ctx.type.getI64(value) + elif not isinstance(value, tuner_ctx.type.i64): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" case "mma_kind": - assert isinstance( - value, iree_gpu.MMAAttr - ), f"Unsupported type for key '{key}': {type(value).__name__}" + if not isinstance(value, iree_gpu.MMAAttr): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" case _: - raise KeyError(f"Unhandled key in lowering configuration: {key}") - # Single assignment after the match. + assert False, f"Unhandled key in lowering configuration: {key}" + lowering_config_dict[key] = promoted_value lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 503ece345..ad63ba815 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,17 +21,17 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_mmt_workgroup_sizes(configuration: Configuration): - return get_tilesize_workgroup(configuration) + return get_workgroup_tile_sizes(configuration) def get_mmt_reduction_sizes(configuration: Configuration): - return get_tilesize_reduction(configuration) + return get_reduction_tile_sizes(configuration) def get_contract_workgroup_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - m, n, _k = get_tilesize_workgroup(configuration) + m, n, _k = get_workgroup_tile_sizes(configuration) workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): @@ -48,7 +48,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - _m, _n, k = get_tilesize_reduction(configuration) + _m, _n, k = get_reduction_tile_sizes(configuration) reduction_size = [0] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "k": @@ -58,11 +58,11 @@ def get_contract_reduction_sizes( def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: - return [1] + get_tilesize_workgroup(configuration) + return [1] + get_workgroup_tile_sizes(configuration) def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: - return [0] + get_tilesize_reduction(configuration) + return [0] + get_reduction_tile_sizes(configuration) class MlirRegex(Enum): @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - ow, oc, _ic = get_tilesize_workgroup(configuration) + ow, oc, _ic = get_workgroup_tile_sizes(configuration) return [batch, oh, ow, oc, fh, fw, 0] def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _ow, _oc, ic = get_tilesize_reduction(configuration) + _ow, _oc, ic = get_reduction_tile_sizes(configuration) return [0, 0, 0, 0, 0, 0, ic] From 98b4add4eddcc9395948b9f8e4a74f125c97dc02 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 2 Dec 2024 18:42:13 -0600 Subject: [PATCH 8/9] [tuner]: use helper function Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 2 +- tuner/tuner/dispatch_constraints.py | 51 +++++++++++++---------------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 2a544ef55..ca5d035e9 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -38,7 +38,7 @@ tune_logger = logging.getLogger("tune") - +# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'. def apply_configuration( template: list[str], configuration: Configuration, diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 915e84711..ea5c66f76 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -222,39 +222,32 @@ def generate_solutions( int_type = ir.IntegerType.get_signless(64) + tuner_ctx = TunerContext(ir.Context(), logger) + i = 0 while solver.check() == z3.sat: model = solver.model() lookup = lambda var: model[var].as_long() - lowering_config_dict = { - "mma_kind": getMMAAttr( - problem_size.res_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.lhs_type.element_type, - problem_size.rhs_type.element_type, - ), - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(int_type, lookup(m)), - ir.IntegerAttr.get(int_type, lookup(n)), - ir.IntegerAttr.get(int_type, 0), - ] - ), - "reduction": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(int_type, 0), - ir.IntegerAttr.get(int_type, 0), - ir.IntegerAttr.get(int_type, lookup(k)), - ] - ), # Placeholder now to be consistent with iree. - "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), - "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) - lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) + mma_attr = getMMAAttr( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, + ) + lowering_config = get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[lookup(m), lookup(n), 0], + reduction=[ + 0, + 0, + lookup(k), + ], + subgroup_m_count=lookup(sg_m_cnt), + subgroup_n_count=lookup(sg_n_cnt), + ) config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], From 07919e457492f05c4393a6cb7f823b6a157eae1e Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 2 Dec 2024 19:23:05 -0600 Subject: [PATCH 9/9] [tuner]: use tuner_context Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 2 +- tuner/tuner/dispatch_constraints.py | 10 +++------- tuner/tuner/dispatch_constraints_test.py | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index ca5d035e9..c903ec85f 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -585,7 +585,7 @@ def tune( tune_logger.debug(str(problem_size)) configs = [] for i, config in enumerate( - generate_solutions(tune_logger, problem_size, num_subgroups, mma_list) + generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) ): if i >= limit: break diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index ea5c66f76..f86523389 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -174,13 +174,13 @@ def getMMAAttr( def generate_solutions( - logger: logging.Logger, + tuner_ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int, mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> Iterator[Configuration]: M, N, K = problem_size.MNK - logger.info(f"{M},{N},{K}") + tuner_ctx.logger.info(f"{M},{N},{K}") m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -218,11 +218,7 @@ def generate_solutions( mma_intrinsics, ) solver.add(z3.simplify(z3.And(constraints))) - logger.debug(f"Initial constraints: {solver}") - - int_type = ir.IntegerType.get_signless(64) - - tuner_ctx = TunerContext(ir.Context(), logger) + tuner_ctx.logger.debug(f"Initial constraints: {solver}") i = 0 while solver.check() == z3.sat: diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 9de4beeee..842ea8509 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -39,7 +39,7 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) configs = dispatch_constraints.generate_solutions( - tuner_ctx.logger, + tuner_ctx, problem_size, 4, [