diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 01a6ed2aa..a560a9c70 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -40,11 +40,14 @@ def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] + template: list[str], + configuration: Configuration, + workgroup_sizes: list[int], + reduction_sizes: list[int], ) -> str: - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -52,14 +55,16 @@ def apply_configuration( expr1 = re.compile( r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") - expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + expr2 = re.compile(r"workgroup = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile(r"reduction = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") + expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" - repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + repl2 = f'workgroup = [[{", ".join(map(str, workgroup_sizes))}]]' + repl3 = f'reduction = [[{", ".join(map(str, reduction_sizes))}]]' + repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" + repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' new_mlir = "" for line in template: @@ -67,12 +72,14 @@ def apply_configuration( line = re.sub(expr0, repl0, line) if "LLVMGPUVectorDistribute " in line: line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: + if "workgroup" in line: line = re.sub(expr2, repl2, line) - if "gpu_pipeline_options =" in line: + if "reduction" in line: line = re.sub(expr3, repl3, line) - if "amdgpu-waves-per-eu" in line: + if "gpu_pipeline_options =" in line: line = re.sub(expr4, repl4, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr5, repl5, line) new_mlir += line return new_mlir @@ -118,10 +125,11 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join(map(str, get_mmt_workgroup_sizes(configuration))) + reduction_sizes = ", ".join(map(str, get_mmt_reduction_sizes(configuration))) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -133,7 +141,7 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info str: @@ -191,10 +195,15 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join( + map(str, self.get_conv_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, self.get_conv_reduction_sizes(configuration)) + ) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -209,7 +218,7 @@ def get_transform_function_conv( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join( + map(str, get_batch_mmt_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, get_batch_mmt_reduction_sizes(configuration)) + ) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -273,7 +290,7 @@ def get_transform_function_broadcast_rhs_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + workgroup_sizes = ", ".join( + map(str, get_batch_mmt_workgroup_sizes(configuration)) + ) + reduction_sizes = ", ".join( + map(str, get_batch_mmt_reduction_sizes(configuration)) + ) + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -356,7 +382,7 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info" output = f"tensor<{problem_size.res_type}>" - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) + workgroup_sizes = ", ".join( + map(str, get_contract_workgroup_sizes(configuration, tile_dims)) + ) + reduction_sizes = ", ".join( + map(str, get_contract_reduction_sizes(configuration, tile_dims)) ) - intrinsic = configuration.intrinsic - subgroup_m_count = configuration.subgroup_m_count - subgroup_n_count = configuration.subgroup_n_count + intrinsic = configuration.intrinsic() + subgroup_m_count = configuration.subgroup_m_count() + subgroup_n_count = configuration.subgroup_n_count() wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -427,7 +459,7 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, + lowering_config = #iree_codegen.lowering_config, translation_info = #iree_codegen.translation_info None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", "gpu_pipeline_options = #iree_gpu.pipeline_options", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -48,25 +48,18 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[8, 8, 0], + reduction=[0, 0, 8], + subgroup_m_count=16, + subgroup_n_count=16, + ) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( prefetch_shared_memory=True ), @@ -96,7 +89,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" in modified ) - assert "tile_sizes = [[8, 8, 8]]" in modified + assert "workgroup = [[8, 8, 0]]" in modified + assert "reduction = [[0, 0, 8]]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options" in modified @@ -108,7 +102,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "", 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -116,24 +110,18 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 464), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[464, 320, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( iree_gpu.ReorderWorkgroupsStrategy.Transpose @@ -167,7 +155,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert "workgroup = [[1, 1, 464, 320, 1, 1, 0]]" in modified + assert "reduction = [[0, 0, 0, 0, 0, 0, 16]]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" in modified @@ -179,7 +168,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -194,25 +183,18 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 480), - ir.IntegerAttr.get(tuner_ctx.type.i32, 384), - ir.IntegerAttr.get(tuner_ctx.type.i32, 32), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[480, 384, 0], + reduction=[0, 0, 32], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -232,7 +214,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in new_mlir ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert "workgroup = [[1, 480, 384, 0]]" in new_mlir + assert "reduction = [[0, 0, 0, 32]]" in new_mlir assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir @@ -240,7 +223,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -255,26 +238,18 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 416), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) - + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[416, 320, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -298,7 +273,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert "workgroup = [[1, 416, 320, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -306,7 +282,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -320,25 +296,18 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 64), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -360,7 +329,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [[1, 128, 64, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -368,7 +338,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -382,25 +352,18 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 64), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -424,7 +387,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [[1, 128, 64, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -444,7 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert "workgroup = [[1, 128, 64, 0]]" in embeddable + assert "reduction = [[0, 0, 0, 128]]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -453,7 +418,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + "", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] @@ -467,25 +432,18 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 64), - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 64, 0], + reduction=[0, 0, 128], + subgroup_m_count=2, + subgroup_n_count=2, + ) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -512,7 +470,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" in modified ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert "workgroup = [[1, 128, 64, 0]]" in modified + assert "reduction = [[0, 0, 0, 128]]" in modified assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified assert embeddable @@ -533,7 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert "workgroup = [[1, 128, 64, 0]]" in embeddable + assert "reduction = [[0, 0, 0, 128]]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -541,7 +501,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', ] assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index ba79f3be6..a6fd4dd42 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -30,6 +30,9 @@ def __init__(self, ctx: ir.Context): self.bf16 = ir.BF16Type.get(ctx) + def getI32(self, value: int) -> ir.IntegerAttr: + return ir.IntegerAttr.get(self.i32, value) + class TunerContext: def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): @@ -113,25 +116,71 @@ class Configuration: gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int - @property - def intrinsic(self) -> iree_gpu.MMAAttr: - return self.lowering_config.attributes["mma_kind"] + def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: + if self.lowering_config.attributes.__contains__("mma_kind"): + return self.lowering_config.attributes.__getitem__("mma_kind") + return None - @property def tilesize_workgroup(self) -> list[int]: - return [attr.value for attr in self.lowering_config.attributes["workgroup"]] + if self.lowering_config.attributes.__contains__("workgroup"): + workgroup_attrs = self.lowering_config.attributes.__getitem__("workgroup") + return [attr.value for attr in workgroup_attrs] + return [] - @property def tilesize_reduction(self) -> list[int]: - return [attr.value for attr in self.lowering_config.attributes["reduction"]] - - @property - def subgroup_m_count(self) -> int: - return self.lowering_config.attributes["subgroup_m_count"].value - - @property - def subgroup_n_count(self) -> int: - return self.lowering_config.attributes["subgroup_n_count"].value + if self.lowering_config.attributes.__contains__("reduction"): + reduction_attrs = self.lowering_config.attributes.__getitem__("reduction") + return [attr.value for attr in reduction_attrs] + return [] + + def subgroup_m_count(self) -> Optional[int]: + if self.lowering_config.attributes.__contains__("subgroup_m_count"): + attr = self.lowering_config.attributes.__getitem__("subgroup_m_count") + return attr.value + return None + + def subgroup_n_count(self) -> Optional[int]: + if self.lowering_config.attributes.__contains__("subgroup_n_count"): + attr = self.lowering_config.attributes.__getitem__("subgroup_n_count") + return attr.value + return None + + +def get_lowering_config( + tuner_ctx: TunerContext, + mma_attr: Optional[iree_gpu.MMAAttr] = None, + workgroup: Optional[list[int]] = None, + reduction: Optional[list[int]] = None, + subgroup_m_count: Optional[int] = None, + subgroup_n_count: Optional[int] = None, +) -> iree_gpu.LoweringConfigAttr: + lowering_config_dict = {} + if workgroup is not None: + lowering_config_dict["workgroup"] = ir.ArrayAttr.get( + [tuner_ctx.type.getI32(x) for x in workgroup] + ) + if reduction is not None: + lowering_config_dict["reduction"] = ir.ArrayAttr.get( + [tuner_ctx.type.getI32(x) for x in reduction] + ) + if subgroup_m_count is not None: + lowering_config_dict["subgroup_m_count"] = tuner_ctx.type.getI32( + subgroup_m_count + ) + if subgroup_n_count is not None: + lowering_config_dict["subgroup_n_count"] = tuner_ctx.type.getI32( + subgroup_n_count + ) + # lowering_config_dict = { + # "workgroup": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in workgroup]), + # "reduction": ir.ArrayAttr.get([tuner_ctx.type.getI32(x) for x in reduction]), + # "subgroup_m_count": tuner_ctx.type.getI32(subgroup_m_count), + # "subgroup_n_count": tuner_ctx.type.getI32(subgroup_n_count), + # } + if mma_attr is not None: + lowering_config_dict["mma_kind"] = mma_attr + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) def get_pipeline_config(configuration: Configuration) -> str: diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index bbb241980..056224458 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -76,24 +76,18 @@ def test_gpu_pipeline_options(tuner_ctx: common.TunerContext) -> None: def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -197,3 +191,24 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: ) == [] ) + + +def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) + config = common.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + lowering_config=lowering_config, + gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), + waves_per_eu=2, + ) + + assert config.intrinsic() is None + assert config.subgroup_m_count() == 1 + assert config.subgroup_n_count() == 1 diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 421e2c7ef..dd36f5075 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -20,25 +20,49 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(shaped_ty.shape, shaped_ty.element_type) -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tilesize_workgroup +def get_mmt_workgroup_sizes(configuration: Configuration): + return configuration.tilesize_workgroup() -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tilesize_workgroup - tile_size = [1] * len(tile_dims) +def get_mmt_reduction_sizes(configuration: Configuration): + return configuration.tilesize_reduction() + + +def get_contract_workgroup_sizes( + configuration: Configuration, tile_dims: str +) -> list[int]: + m, n, _ = configuration.tilesize_workgroup() + + workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "m": - tile_size[idx] = m + workgroup_size[idx] = m if dim == "n": - tile_size[idx] = n + workgroup_size[idx] = n + if dim == "k": + workgroup_size[idx] = 0 + + return workgroup_size + + +def get_contract_reduction_sizes( + configuration: Configuration, tile_dims: str +) -> list[int]: + _, _, k = configuration.tilesize_reduction() + reduction_size = [0] * len(tile_dims) + for idx, dim in enumerate(tile_dims): if dim == "k": - tile_size[idx] = k - return tile_size + reduction_size[idx] = k + + return reduction_size + + +def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tilesize_workgroup() -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tilesize_workgroup +def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: + return [0] + configuration.tilesize_reduction() class MlirRegex(Enum): @@ -140,18 +164,22 @@ class ConvParser(DispatchParser): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tilesize_workgroup + def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: batch = 1 fh = 1 fw = 1 oh = 1 - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] + # oc = configuration.tilesize_workgroup()[1] + ow, oc, _ = configuration.tilesize_workgroup() + + return [batch, oh, ow, oc, fh, fw, 0] + + def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: + _, _, ic = configuration.tilesize_reduction() + + return [0, 0, 0, 0, 0, 0, ic] def get_shapes(self, template: list[str]) -> ProblemSize: for line in template: @@ -178,13 +206,6 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_shaped_type = parse_tensor_type(res_tensor_type) assert res_shaped_type.rank() == 4 - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) return ProblemSize( MatmulSize( diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 8318ca9c1..8e99188d0 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -42,61 +42,59 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 128), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 32), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[128, 320, 0], + reduction=[0, 0, 32], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) - assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + assert dispatch_parser.get_mmt_workgroup_sizes(config) == [128, 320, 0] + assert dispatch_parser.get_mmt_reduction_sizes(config) == [0, 0, 32] def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 464), - ir.IntegerAttr.get(tuner_ctx.type.i32, 320), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[464, 320, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) - assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + assert dispatch_parser.ConvParser().get_conv_workgroup_sizes(config) == [ 1, 1, 464, 320, 1, 1, + 0, + ] + assert dispatch_parser.ConvParser().get_conv_reduction_sizes(config) == [ + 0, + 0, + 0, + 0, + 0, + 0, 16, ] @@ -104,31 +102,33 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - lowering_config_dict = { - "mma_kind": mma_attr, - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(tuner_ctx.type.i32, 4), - ir.IntegerAttr.get(tuner_ctx.type.i32, 8), - ir.IntegerAttr.get(tuner_ctx.type.i32, 16), - ] - ), - "reduction": ir.ArrayAttr.get([]), - "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), - } - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_attr=mma_attr, + workgroup=[4, 8, 0], + reduction=[0, 0, 16], + subgroup_m_count=1, + subgroup_n_count=1, + ) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), + lowering_config=lowering_config, gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) - assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] - assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] - assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] - assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ + assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16] + assert dispatch_parser.get_contract_workgroup_sizes(config, "nmk") == [8, 4, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] + assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] + assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] + assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [ + 0, + 0, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [ 16, 16, 16,