diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 27dfe67c2..b33d5845e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -119,31 +119,30 @@ class Configuration: waves_per_eu: int def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: - if self.lowering_config.attributes.__contains__("mma_kind"): - return self.lowering_config.attributes.__getitem__("mma_kind") - return None + if "mma_kind" in self.lowering_config.attributes: + return self.lowering_config.attributes["mma_kind"] def tilesize_workgroup(self) -> list[int]: - if self.lowering_config.attributes.__contains__("workgroup"): - workgroup_attrs = self.lowering_config.attributes.__getitem__("workgroup") + if "workgroup" in self.lowering_config.attributes: + workgroup_attrs = self.lowering_config.attributes["workgroup"] return [attr.value for attr in workgroup_attrs] return [] def tilesize_reduction(self) -> list[int]: - if self.lowering_config.attributes.__contains__("reduction"): - reduction_attrs = self.lowering_config.attributes.__getitem__("reduction") + if "reduction" in self.lowering_config.attributes: + reduction_attrs = self.lowering_config.attributes["reduction"] return [attr.value for attr in reduction_attrs] return [] def subgroup_m_count(self) -> Optional[int]: - if self.lowering_config.attributes.__contains__("subgroup_m_count"): - attr = self.lowering_config.attributes.__getitem__("subgroup_m_count") + if "subgroup_m_count" in self.lowering_config.attributes: + attr = self.lowering_config.attributes["subgroup_m_count"] return attr.value return None def subgroup_n_count(self) -> Optional[int]: - if self.lowering_config.attributes.__contains__("subgroup_n_count"): - attr = self.lowering_config.attributes.__getitem__("subgroup_n_count") + if "subgroup_n_count" in self.lowering_config.attributes: + attr = self.lowering_config.attributes["subgroup_n_count"] return attr.value return None @@ -154,16 +153,36 @@ def get_lowering_config( ) -> iree_gpu.LoweringConfigAttr: lowering_config_dict = {} for key, value in kwargs.items(): - if isinstance(value, list): - lowering_config_dict[key] = ir.ArrayAttr.get( - [tuner_ctx.type.getI64(x) for x in value] - ) - elif isinstance(value, int): - lowering_config_dict[key] = tuner_ctx.type.getI64(value) - elif isinstance(value, iree_gpu.MMAAttr): - lowering_config_dict[key] = value - else: - raise TypeError(f"Unsupported type for key '{key}': {type(value).__name__}") + match key: + case "workgroup" | "reduction": + if isinstance(value, list): + lowering_config_dict[key] = ir.ArrayAttr.get( + [tuner_ctx.type.getI64(x) for x in value] + ) + elif isinstance(value, ir.ArrayAttr): + lowering_config_dict[key] = value + else: + raise TypeError( + f"Unsupported type for key '{key}': {type(value).__name__}" + ) + case "subgroup_m_count" | "subgroup_n_count": + if isinstance(value, int): + lowering_config_dict[key] = tuner_ctx.type.getI64(value) + elif isinstance(value, tuner_ctx.type.i64): + lowering_config_dict[key] = value + else: + raise TypeError( + f"Unsupported type for key '{key}': {type(value).__name__}" + ) + case "mma_kind": + if isinstance(value, iree_gpu.MMAAttr): + lowering_config_dict[key] = value + else: + raise TypeError( + f"Unsupported type for key '{key}': {type(value).__name__}" + ) + case _: + raise KeyError(f"Unhandled key in lowering configuration: {key}") lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 056224458..1dfb6ff7b 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -78,7 +78,7 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[4, 8, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -201,6 +201,12 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=1, ) + + assert ( + str(lowering_config) + == "#iree_gpu.lowering_config<{reduction = [0, 0, 16], subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64, workgroup = [4, 8, 0]}>" + ) + config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 39ccec523..915e84711 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -248,7 +248,7 @@ def generate_solutions( ir.IntegerAttr.get(int_type, 0), ir.IntegerAttr.get(int_type, lookup(k)), ] - ), # placeholder now to be consistent with iree + ), # Placeholder now to be consistent with iree. "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), } diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index bc7788f44..0c5209ccd 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - ow, oc, _ = configuration.tilesize_workgroup() + ow, oc, _ic = configuration.tilesize_workgroup() return [batch, oh, ow, oc, fh, fw, 0] def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _, _, ic = configuration.tilesize_reduction() + _ow, _oc, ic = configuration.tilesize_reduction() return [0, 0, 0, 0, 0, 0, ic] diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 8e99188d0..650540c63 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -44,7 +44,7 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[128, 320, 0], reduction=[0, 0, 32], subgroup_m_count=1, @@ -66,7 +66,7 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[464, 320, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -104,7 +104,7 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) lowering_config = common.get_lowering_config( tuner_ctx=tuner_ctx, - mma_attr=mma_attr, + mma_kind=mma_attr, workgroup=[4, 8, 0], reduction=[0, 0, 16], subgroup_m_count=1, @@ -123,16 +123,8 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] - assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [ - 0, - 0, - 0, - ] - assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [ - 16, - 16, - 16, - ] + assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [0, 0, 0] + assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [16, 16, 16] def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: