diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index c903ec85f..bc01bb709 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -38,16 +38,19 @@ tune_logger = logging.getLogger("tune") -# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'. + def apply_configuration( template: list[str], configuration: Configuration, - workgroup_sizes: list[int], - reduction_sizes: list[int], ) -> str: - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn + workgroup_sizes = lowering_config.workgroup_tile_sizes + reduction_sizes = lowering_config.reduction_tile_sizes tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -167,8 +173,6 @@ def apply_params( modified += apply_configuration( template, configuration, - get_mmt_workgroup_sizes(configuration), - get_mmt_reduction_sizes(configuration), ) embeddable = indent( self.get_transform_function_mmt(problem_size, f"match_op", configuration), @@ -193,15 +197,12 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - workgroup_sizes = ", ".join( - map(str, self.get_conv_workgroup_sizes(configuration)) - ) - reduction_sizes = ", ".join( - map(str, self.get_conv_reduction_sizes(configuration)) - ) - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -246,8 +247,6 @@ def apply_params( modified += apply_configuration( template, configuration, - self.get_conv_workgroup_sizes(configuration), - self.get_conv_reduction_sizes(configuration), ) embeddable = indent( self.get_transform_function_conv(problem_size, f"match_op", configuration), @@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt( functionName: str, configuration: Configuration, ) -> str: - workgroup_sizes = ", ".join( - map(str, get_batch_mmt_workgroup_sizes(configuration)) - ) - reduction_sizes = ", ".join( - map(str, get_batch_mmt_reduction_sizes(configuration)) - ) - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt( modified += apply_configuration( template, configuration, - get_batch_mmt_workgroup_sizes(configuration), - get_batch_mmt_reduction_sizes(configuration), ) embeddable = indent( @@ -345,8 +339,6 @@ def apply_params( apply_configuration( template, configuration, - get_contract_workgroup_sizes(configuration, self.tile_dims), - get_contract_reduction_sizes(configuration, self.tile_dims), ), "", ) @@ -359,9 +351,12 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -403,8 +398,6 @@ def apply_params( modified += apply_configuration( template, configuration, - get_batch_mmt_workgroup_sizes(configuration), - get_batch_mmt_reduction_sizes(configuration), ) embeddable = indent( @@ -428,9 +421,12 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - intrinsic = get_intrinsic(configuration) - subgroup_m_count = get_subgroup_m_count(configuration) - subgroup_n_count = get_subgroup_n_count(configuration) + lowering_config = configuration.lowering_config + intrinsic = lowering_config.mma_kind + ( + subgroup_m_count, + subgroup_n_count, + ) = lowering_config.subgroup_count_mn wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -476,8 +472,6 @@ def apply_params( modified += apply_configuration( template, configuration, - get_contract_workgroup_sizes(configuration, self.tile_dims), - get_contract_reduction_sizes(configuration, self.tile_dims), ) embeddable = indent( diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 11de8a900..45da323c5 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -106,15 +106,15 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16 mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[464, 320, 0], - reduction=[0, 0, 16], + workgroup=[n, oh, ow, oc, fh, fw, 0], + reduction=[0, 0, 0, 0, 0, 0, ic], subgroup_m_count=1, subgroup_n_count=4, ) @@ -155,7 +155,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" in modified ) - assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified + assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified assert ( "gpu_pipeline_options = #iree_gpu.pipeline_options>" @@ -186,8 +186,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[480, 384, 0], - reduction=[0, 0, 32], + workgroup=[1, 480, 384, 0], + reduction=[0, 0, 0, 32], subgroup_m_count=1, subgroup_n_count=4, ) @@ -241,8 +241,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[416, 320, 0], - reduction=[0, 0, 128], + workgroup=[1, 416, 320, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -299,8 +299,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[128, 64, 0], - reduction=[0, 0, 128], + workgroup=[1, 128, 64, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -355,8 +355,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[128, 64, 0], - reduction=[0, 0, 128], + workgroup=[1, 128, 64, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [128, 64, 0]" in embeddable - assert "reduction = [0, 0, 128]" in embeddable + assert "workgroup = [1, 128, 64, 0]" in embeddable + assert "reduction = [0, 0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable @@ -435,8 +435,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[128, 64, 0], - reduction=[0, 0, 128], + workgroup=[1, 128, 64, 0], + reduction=[0, 0, 0, 128], subgroup_m_count=2, subgroup_n_count=2, ) @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: "%config = transform.param.constant #iree_codegen.compilation_info<" in embeddable ) - assert "workgroup = [128, 64, 0]" in embeddable - assert "reduction = [0, 0, 128]" in embeddable + assert "workgroup = [1, 128, 64, 0]" in embeddable + assert "reduction = [0, 0, 0, 128]" in embeddable assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 702008f5e..0a2b03fd1 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -119,40 +119,6 @@ class Configuration: waves_per_eu: int -def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: - if "mma_kind" in config.lowering_config.attributes: - return config.lowering_config.attributes["mma_kind"] - return None - - -def get_workgroup_tile_sizes(config: Configuration) -> list[int]: - if "workgroup" in config.lowering_config.attributes: - workgroup_attrs = config.lowering_config.attributes["workgroup"] - return [attr.value for attr in workgroup_attrs] - return [] - - -def get_reduction_tile_sizes(config: Configuration) -> list[int]: - if "reduction" in config.lowering_config.attributes: - reduction_attrs = config.lowering_config.attributes["reduction"] - return [attr.value for attr in reduction_attrs] - return [] - - -def get_subgroup_m_count(config: Configuration) -> Optional[int]: - if "subgroup_m_count" in config.lowering_config.attributes: - attr = config.lowering_config.attributes["subgroup_m_count"] - return attr.value - return None - - -def get_subgroup_n_count(config: Configuration) -> Optional[int]: - if "subgroup_n_count" in config.lowering_config.attributes: - attr = config.lowering_config.attributes["subgroup_n_count"] - return attr.value - return None - - def get_lowering_config( tuner_ctx: TunerContext, **kwargs: Any, diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index f13aed3d7..6d76c216f 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -215,6 +215,5 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: waves_per_eu=2, ) - assert common.get_intrinsic(config) is None - assert common.get_subgroup_m_count(config) == 1 - assert common.get_subgroup_n_count(config) == 1 + assert config.lowering_config.mma_kind is None + assert config.lowering_config.subgroup_count_mn == (1, 1) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index ad63ba815..cc63c89a3 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -20,18 +20,10 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(shaped_ty.shape, shaped_ty.element_type) -def get_mmt_workgroup_sizes(configuration: Configuration): - return get_workgroup_tile_sizes(configuration) - - -def get_mmt_reduction_sizes(configuration: Configuration): - return get_reduction_tile_sizes(configuration) - - def get_contract_workgroup_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - m, n, _k = get_workgroup_tile_sizes(configuration) + m, n, _k = configuration.lowering_config.workgroup_tile_sizes workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): @@ -48,7 +40,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - _m, _n, k = get_reduction_tile_sizes(configuration) + _m, _n, k = configuration.lowering_config.reduction_tile_sizes reduction_size = [0] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "k": @@ -57,14 +49,6 @@ def get_contract_reduction_sizes( return reduction_size -def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: - return [1] + get_workgroup_tile_sizes(configuration) - - -def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: - return [0] + get_reduction_tile_sizes(configuration) - - class MlirRegex(Enum): ssa_value = r"%[a-zA-Z0-9-_]+" tensor_type = r"tensor<([^>]+)>" @@ -164,22 +148,6 @@ class ConvParser(DispatchParser): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name - def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - ow, oc, _ic = get_workgroup_tile_sizes(configuration) - - return [batch, oh, ow, oc, fh, fw, 0] - - def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _ow, _oc, ic = get_reduction_tile_sizes(configuration) - - return [0, 0, 0, 0, 0, 0, ic] - def get_shapes(self, template: list[str]) -> ProblemSize: for line in template: if "linalg.conv_2d_nhwc_hwcf" not in line: diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 650540c63..db8c4a7da 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -57,8 +57,9 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) - assert dispatch_parser.get_mmt_workgroup_sizes(config) == [128, 320, 0] - assert dispatch_parser.get_mmt_reduction_sizes(config) == [0, 0, 32] + lowering_config = config.lowering_config + assert lowering_config.workgroup_tile_sizes == [128, 320, 0] + assert lowering_config.reduction_tile_sizes == [0, 0, 32] def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: @@ -67,8 +68,8 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, mma_kind=mma_attr, - workgroup=[464, 320, 0], - reduction=[0, 0, 16], + workgroup=[1, 1, 464, 320, 1, 1, 0], + reduction=[0, 0, 0, 0, 0, 0, 16], subgroup_m_count=1, subgroup_n_count=4, ) @@ -79,24 +80,8 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) - assert dispatch_parser.ConvParser().get_conv_workgroup_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 0, - ] - assert dispatch_parser.ConvParser().get_conv_reduction_sizes(config) == [ - 0, - 0, - 0, - 0, - 0, - 0, - 16, - ] + assert config.lowering_config.workgroup_tile_sizes == [1, 1, 464, 320, 1, 1, 0] + assert config.lowering_config.reduction_tile_sizes == [0, 0, 0, 0, 0, 0, 16] def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: