Skip to content

Commit

Permalink
[tuner] move methods out of configuration as free functions
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 2, 2024
1 parent 4aa640c commit 2f1a638
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 77 deletions.
36 changes: 18 additions & 18 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def apply_configuration(
workgroup_sizes: list[int],
reduction_sizes: list[int],
) -> str:
intrinsic = configuration.intrinsic()
subgroup_m_count = configuration.subgroup_m_count()
subgroup_n_count = configuration.subgroup_n_count()
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
Expand Down Expand Up @@ -125,9 +125,9 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
intrinsic = configuration.intrinsic()
subgroup_m_count = configuration.subgroup_m_count()
subgroup_n_count = configuration.subgroup_n_count()
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -199,9 +199,9 @@ def get_transform_function_conv(
reduction_sizes = ", ".join(
map(str, self.get_conv_reduction_sizes(configuration))
)
intrinsic = configuration.intrinsic()
subgroup_m_count = configuration.subgroup_m_count()
subgroup_n_count = configuration.subgroup_n_count()
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -269,9 +269,9 @@ def get_transform_function_broadcast_rhs_mmt(
reduction_sizes = ", ".join(
map(str, get_batch_mmt_reduction_sizes(configuration))
)
intrinsic = configuration.intrinsic()
subgroup_m_count = configuration.subgroup_m_count()
subgroup_n_count = configuration.subgroup_n_count()
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -359,9 +359,9 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
intrinsic = configuration.intrinsic()
subgroup_m_count = configuration.subgroup_m_count()
subgroup_n_count = configuration.subgroup_n_count()
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -428,9 +428,9 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

intrinsic = configuration.intrinsic()
subgroup_m_count = configuration.subgroup_m_count()
subgroup_n_count = configuration.subgroup_n_count()
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down
96 changes: 48 additions & 48 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,34 +118,39 @@ class Configuration:
gpu_pipeline_options: iree_gpu.PipelineOptionsAttr
waves_per_eu: int

def intrinsic(self) -> Optional[iree_gpu.MMAAttr]:
if "mma_kind" in self.lowering_config.attributes:
return self.lowering_config.attributes["mma_kind"]
return None

def tilesize_workgroup(self) -> list[int]:
if "workgroup" in self.lowering_config.attributes:
workgroup_attrs = self.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []

def tilesize_reduction(self) -> list[int]:
if "reduction" in self.lowering_config.attributes:
reduction_attrs = self.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
return []

def subgroup_m_count(self) -> Optional[int]:
if "subgroup_m_count" in self.lowering_config.attributes:
attr = self.lowering_config.attributes["subgroup_m_count"]
return attr.value
return None

def subgroup_n_count(self) -> Optional[int]:
if "subgroup_n_count" in self.lowering_config.attributes:
attr = self.lowering_config.attributes["subgroup_n_count"]
return attr.value
return None

def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]:
if "mma_kind" in config.lowering_config.attributes:
return config.lowering_config.attributes["mma_kind"]
return None


def get_tilesize_workgroup(config: Configuration) -> list[int]:
if "workgroup" in config.lowering_config.attributes:
workgroup_attrs = config.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []


def get_tilesize_reduction(config: Configuration) -> list[int]:
if "reduction" in config.lowering_config.attributes:
reduction_attrs = config.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
return []


def get_subgroup_m_count(config: Configuration) -> Optional[int]:
if "subgroup_m_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_m_count"]
return attr.value
return None


def get_subgroup_n_count(config: Configuration) -> Optional[int]:
if "subgroup_n_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_n_count"]
return attr.value
return None


def get_lowering_config(
Expand All @@ -154,36 +159,31 @@ def get_lowering_config(
) -> iree_gpu.LoweringConfigAttr:
lowering_config_dict: dict[str, Any] = {}
for key, value in kwargs.items():
# A local variable to hold the transformed value.
promoted_value = value
match key:
case "workgroup" | "reduction":
assert isinstance(
value, (list, ir.ArrayAttr)
), f"Unsupported type for key '{key}': {type(value).__name__}"
if isinstance(value, list):
lowering_config_dict[key] = ir.ArrayAttr.get(
promoted_value = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
)
elif isinstance(value, ir.ArrayAttr):
lowering_config_dict[key] = value
else:
raise TypeError(
f"Unsupported type for key '{key}': {type(value).__name__}"
)
case "subgroup_m_count" | "subgroup_n_count":
assert isinstance(
value, (int, tuner_ctx.type.i64)
), f"Unsupported type for key '{key}': {type(value).__name__}"
if isinstance(value, int):
lowering_config_dict[key] = tuner_ctx.type.getI64(value)
elif isinstance(value, tuner_ctx.type.i64):
lowering_config_dict[key] = value
else:
raise TypeError(
f"Unsupported type for key '{key}': {type(value).__name__}"
)
promoted_value = tuner_ctx.type.getI64(value)
case "mma_kind":
if isinstance(value, iree_gpu.MMAAttr):
lowering_config_dict[key] = value
else:
raise TypeError(
f"Unsupported type for key '{key}': {type(value).__name__}"
)
assert isinstance(
value, iree_gpu.MMAAttr
), f"Unsupported type for key '{key}': {type(value).__name__}"
case _:
raise KeyError(f"Unhandled key in lowering configuration: {key}")
# Single assignment after the match.
lowering_config_dict[key] = promoted_value
lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs)

Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,6 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None:
waves_per_eu=2,
)

assert config.intrinsic() is None
assert config.subgroup_m_count() == 1
assert config.subgroup_n_count() == 1
assert common.get_intrinsic(config) is None
assert common.get_subgroup_m_count(config) == 1
assert common.get_subgroup_n_count(config) == 1
16 changes: 8 additions & 8 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def parse_tensor_type(tensor_type: str) -> ShapedType:


def get_mmt_workgroup_sizes(configuration: Configuration):
return configuration.tilesize_workgroup()
return get_tilesize_workgroup(configuration)


def get_mmt_reduction_sizes(configuration: Configuration):
return configuration.tilesize_reduction()
return get_tilesize_reduction(configuration)


def get_contract_workgroup_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
m, n, _ = configuration.tilesize_workgroup()
m, n, _k = get_tilesize_workgroup(configuration)

workgroup_size = [1] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
Expand All @@ -48,7 +48,7 @@ def get_contract_workgroup_sizes(
def get_contract_reduction_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
_, _, k = configuration.tilesize_reduction()
_m, _n, k = get_tilesize_reduction(configuration)
reduction_size = [0] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
if dim == "k":
Expand All @@ -58,11 +58,11 @@ def get_contract_reduction_sizes(


def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]:
return [1] + configuration.tilesize_workgroup()
return [1] + get_tilesize_workgroup(configuration)


def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]:
return [0] + configuration.tilesize_reduction()
return [0] + get_tilesize_reduction(configuration)


class MlirRegex(Enum):
Expand Down Expand Up @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:

oh = 1

ow, oc, _ic = configuration.tilesize_workgroup()
ow, oc, _ic = get_tilesize_workgroup(configuration)

return [batch, oh, ow, oc, fh, fw, 0]

def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
_ow, _oc, ic = configuration.tilesize_reduction()
_ow, _oc, ic = get_tilesize_reduction(configuration)

return [0, 0, 0, 0, 0, 0, ic]

Expand Down

0 comments on commit 2f1a638

Please sign in to comment.