Skip to content

Commit

Permalink
[tuner]: rename functions and assert False
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 3, 2024
1 parent 33e5eac commit 0a9ad21
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
4 changes: 2 additions & 2 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def apply_configuration(
expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'workgroup = [{", ".join(map(str, workgroup_sizes))}]'
repl3 = f'reduction = [{", ".join(map(str, reduction_sizes))}]'
repl2 = f"workgroup = {workgroup_sizes}"
repl3 = f"reduction = {reduction_sizes}"
repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"'

Expand Down
29 changes: 16 additions & 13 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]:
return None


def get_tilesize_workgroup(config: Configuration) -> list[int]:
def get_workgroup_tile_sizes(config: Configuration) -> list[int]:
if "workgroup" in config.lowering_config.attributes:
workgroup_attrs = config.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []


def get_tilesize_reduction(config: Configuration) -> list[int]:
def get_reduction_tile_sizes(config: Configuration) -> list[int]:
if "reduction" in config.lowering_config.attributes:
reduction_attrs = config.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
Expand Down Expand Up @@ -163,26 +163,29 @@ def get_lowering_config(
promoted_value = value
match key:
case "workgroup" | "reduction":
assert isinstance(
value, (list, ir.ArrayAttr)
), f"Unsupported type for key '{key}': {type(value).__name__}"
if isinstance(value, list):
promoted_value = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
)
elif not isinstance(value, ir.ArrayAttr):
assert (
False
), f"Unsupported type for key '{key}': {type(value).__name__}"
case "subgroup_m_count" | "subgroup_n_count":
assert isinstance(
value, (int, tuner_ctx.type.i64)
), f"Unsupported type for key '{key}': {type(value).__name__}"
if isinstance(value, int):
promoted_value = tuner_ctx.type.getI64(value)
elif not isinstance(value, tuner_ctx.type.i64):
assert (
False
), f"Unsupported type for key '{key}': {type(value).__name__}"
case "mma_kind":
assert isinstance(
value, iree_gpu.MMAAttr
), f"Unsupported type for key '{key}': {type(value).__name__}"
if not isinstance(value, iree_gpu.MMAAttr):
assert (
False
), f"Unsupported type for key '{key}': {type(value).__name__}"
case _:
raise KeyError(f"Unhandled key in lowering configuration: {key}")
# Single assignment after the match.
assert False, f"Unhandled key in lowering configuration: {key}"

lowering_config_dict[key] = promoted_value
lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs)
Expand Down
16 changes: 8 additions & 8 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def parse_tensor_type(tensor_type: str) -> ShapedType:


def get_mmt_workgroup_sizes(configuration: Configuration):
return get_tilesize_workgroup(configuration)
return get_workgroup_tile_sizes(configuration)


def get_mmt_reduction_sizes(configuration: Configuration):
return get_tilesize_reduction(configuration)
return get_reduction_tile_sizes(configuration)


def get_contract_workgroup_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
m, n, _k = get_tilesize_workgroup(configuration)
m, n, _k = get_workgroup_tile_sizes(configuration)

workgroup_size = [1] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
Expand All @@ -48,7 +48,7 @@ def get_contract_workgroup_sizes(
def get_contract_reduction_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
_m, _n, k = get_tilesize_reduction(configuration)
_m, _n, k = get_reduction_tile_sizes(configuration)
reduction_size = [0] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
if dim == "k":
Expand All @@ -58,11 +58,11 @@ def get_contract_reduction_sizes(


def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]:
return [1] + get_tilesize_workgroup(configuration)
return [1] + get_workgroup_tile_sizes(configuration)


def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]:
return [0] + get_tilesize_reduction(configuration)
return [0] + get_reduction_tile_sizes(configuration)


class MlirRegex(Enum):
Expand Down Expand Up @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:

oh = 1

ow, oc, _ic = get_tilesize_workgroup(configuration)
ow, oc, _ic = get_workgroup_tile_sizes(configuration)

return [batch, oh, ow, oc, fh, fw, 0]

def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
_ow, _oc, ic = get_tilesize_reduction(configuration)
_ow, _oc, ic = get_reduction_tile_sizes(configuration)

return [0, 0, 0, 0, 0, 0, ic]

Expand Down

0 comments on commit 0a9ad21

Please sign in to comment.