Skip to content

Commit

Permalink
[tuner]: format the code
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 2, 2024
1 parent c8a6542 commit c4f917b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 38 deletions.
61 changes: 40 additions & 21 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,31 +119,30 @@ class Configuration:
waves_per_eu: int

def intrinsic(self) -> Optional[iree_gpu.MMAAttr]:
if self.lowering_config.attributes.__contains__("mma_kind"):
return self.lowering_config.attributes.__getitem__("mma_kind")
return None
if "mma_kind" in self.lowering_config.attributes:
return self.lowering_config.attributes["mma_kind"]

def tilesize_workgroup(self) -> list[int]:
if self.lowering_config.attributes.__contains__("workgroup"):
workgroup_attrs = self.lowering_config.attributes.__getitem__("workgroup")
if "workgroup" in self.lowering_config.attributes:
workgroup_attrs = self.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []

def tilesize_reduction(self) -> list[int]:
if self.lowering_config.attributes.__contains__("reduction"):
reduction_attrs = self.lowering_config.attributes.__getitem__("reduction")
if "reduction" in self.lowering_config.attributes:
reduction_attrs = self.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
return []

def subgroup_m_count(self) -> Optional[int]:
if self.lowering_config.attributes.__contains__("subgroup_m_count"):
attr = self.lowering_config.attributes.__getitem__("subgroup_m_count")
if "subgroup_m_count" in self.lowering_config.attributes:
attr = self.lowering_config.attributes["subgroup_m_count"]
return attr.value
return None

def subgroup_n_count(self) -> Optional[int]:
if self.lowering_config.attributes.__contains__("subgroup_n_count"):
attr = self.lowering_config.attributes.__getitem__("subgroup_n_count")
if "subgroup_n_count" in self.lowering_config.attributes:
attr = self.lowering_config.attributes["subgroup_n_count"]
return attr.value
return None

Expand All @@ -154,16 +153,36 @@ def get_lowering_config(
) -> iree_gpu.LoweringConfigAttr:
lowering_config_dict = {}
for key, value in kwargs.items():
if isinstance(value, list):
lowering_config_dict[key] = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
)
elif isinstance(value, int):
lowering_config_dict[key] = tuner_ctx.type.getI64(value)
elif isinstance(value, iree_gpu.MMAAttr):
lowering_config_dict[key] = value
else:
raise TypeError(f"Unsupported type for key '{key}': {type(value).__name__}")
match key:
case "workgroup" | "reduction":
if isinstance(value, list):
lowering_config_dict[key] = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
)
elif isinstance(value, ir.ArrayAttr):
lowering_config_dict[key] = value
else:
raise TypeError(
f"Unsupported type for key '{key}': {type(value).__name__}"
)
case "subgroup_m_count" | "subgroup_n_count":
if isinstance(value, int):
lowering_config_dict[key] = tuner_ctx.type.getI64(value)
elif isinstance(value, tuner_ctx.type.i64):
lowering_config_dict[key] = value
else:
raise TypeError(
f"Unsupported type for key '{key}': {type(value).__name__}"
)
case "mma_kind":
if isinstance(value, iree_gpu.MMAAttr):
lowering_config_dict[key] = value
else:
raise TypeError(
f"Unsupported type for key '{key}': {type(value).__name__}"
)
case _:
raise KeyError(f"Unhandled key in lowering configuration: {key}")
lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs)

Expand Down
8 changes: 7 additions & 1 deletion tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None:
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_attr=mma_attr,
mma_kind=mma_attr,
workgroup=[4, 8, 0],
reduction=[0, 0, 16],
subgroup_m_count=1,
Expand Down Expand Up @@ -201,6 +201,12 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=1,
subgroup_n_count=1,
)

assert (
str(lowering_config)
== "#iree_gpu.lowering_config<{reduction = [0, 0, 16], subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64, workgroup = [4, 8, 0]}>"
)

config = common.Configuration(
subgroup_size=32,
workgroup_size=[16, 16, 1],
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def generate_solutions(
ir.IntegerAttr.get(int_type, 0),
ir.IntegerAttr.get(int_type, lookup(k)),
]
), # placeholder now to be consistent with iree
), # Placeholder now to be consistent with iree.
"subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)),
"subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)),
}
Expand Down
4 changes: 2 additions & 2 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:

oh = 1

ow, oc, _ = configuration.tilesize_workgroup()
ow, oc, _ic = configuration.tilesize_workgroup()

return [batch, oh, ow, oc, fh, fw, 0]

def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
_, _, ic = configuration.tilesize_reduction()
_ow, _oc, ic = configuration.tilesize_reduction()

return [0, 0, 0, 0, 0, 0, ic]

Expand Down
18 changes: 5 additions & 13 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_attr=mma_attr,
mma_kind=mma_attr,
workgroup=[128, 320, 0],
reduction=[0, 0, 32],
subgroup_m_count=1,
Expand All @@ -66,7 +66,7 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None:
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_attr=mma_attr,
mma_kind=mma_attr,
workgroup=[464, 320, 0],
reduction=[0, 0, 16],
subgroup_m_count=1,
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None:
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_attr=mma_attr,
mma_kind=mma_attr,
workgroup=[4, 8, 0],
reduction=[0, 0, 16],
subgroup_m_count=1,
Expand All @@ -123,16 +123,8 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None:
assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16]
assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4]
assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0]
assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [
0,
0,
0,
]
assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [
16,
16,
16,
]
assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [0, 0, 0]
assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [16, 16, 16]


def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None:
Expand Down

0 comments on commit c4f917b

Please sign in to comment.