Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner]: use lowering config binding #629

Merged
merged 9 commits into from
Dec 3, 2024
133 changes: 85 additions & 48 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,48 @@

tune_logger = logging.getLogger("tune")


# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.
def apply_configuration(
template: list[str], configuration: Configuration, tile_sizes: list[int]
template: list[str],
configuration: Configuration,
workgroup_sizes: list[int],
reduction_sizes: list[int],
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
)
expr1 = re.compile(
r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+),"
)
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
expr2 = re.compile(r"workgroup = \[([0-9]+)(, ([0-9]+))+\]")
expr3 = re.compile(r"reduction = \[([0-9]+)(, ([0-9]+))+\]")
expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"'
repl2 = f"workgroup = {workgroup_sizes}"
repl3 = f"reduction = {reduction_sizes}"
repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"'

new_mlir = ""
for line in template:
if "intrinsic =" in line:
line = re.sub(expr0, repl0, line)
if "LLVMGPUVectorDistribute " in line:
line = re.sub(expr1, repl1, line)
if "tile_sizes" in line:
if "workgroup" in line:
line = re.sub(expr2, repl2, line)
if "gpu_pipeline_options =" in line:
if "reduction" in line:
line = re.sub(expr3, repl3, line)
if "amdgpu-waves-per-eu" in line:
if "gpu_pipeline_options =" in line:
line = re.sub(expr4, repl4, line)
if "amdgpu-waves-per-eu" in line:
line = re.sub(expr5, repl5, line)
new_mlir += line

return new_mlir
Expand Down Expand Up @@ -115,7 +125,9 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration)))
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -127,12 +139,12 @@ def get_transform_function_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
Expand All @@ -153,7 +165,10 @@ def apply_params(
"// ",
)
modified += apply_configuration(
template, configuration, get_mmt_tile_sizes(configuration)
template,
configuration,
get_mmt_workgroup_sizes(configuration),
get_mmt_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
Expand All @@ -163,13 +178,6 @@ def apply_params(


class ConvTuner(DispatchTuner, ConvParser):
# int64_t n = outputShape[0];
# int64_t oh = outputShape[1];
# int64_t ow = outputShape[2];
# int64_t oc = outputShape[3];
# int64_t fh = filterShape[0];
# int64_t fw = filterShape[1];
# int64_t ic = filterShape[2];
def get_transform_function_conv(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
Expand All @@ -185,7 +193,15 @@ def get_transform_function_conv(
filter = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{dynamic_batch_output_ty}>"

tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration)))
workgroup_sizes = ", ".join(
map(str, self.get_conv_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, self.get_conv_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -200,12 +216,12 @@ def get_transform_function_conv(
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %conv, %config : !transform.any_op, !transform.any_param
Expand All @@ -228,7 +244,10 @@ def apply_params(
"// ",
)
modified += apply_configuration(
template, configuration, self.get_conv_tile_sizes(configuration)
template,
configuration,
self.get_conv_workgroup_sizes(configuration),
self.get_conv_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_conv(problem_size, f"match_op", configuration),
Expand All @@ -244,7 +263,15 @@ def get_transform_function_broadcast_rhs_mmt(
functionName: str,
configuration: Configuration,
) -> str:
tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration)))
workgroup_sizes = ", ".join(
map(str, get_batch_mmt_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, get_batch_mmt_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -261,12 +288,12 @@ def get_transform_function_broadcast_rhs_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
Expand All @@ -287,7 +314,10 @@ def apply_params_broadcast_rhs_mmt(
"// ",
)
modified += apply_configuration(
template, configuration, get_batch_mmt_tile_sizes(configuration)
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand Down Expand Up @@ -315,7 +345,8 @@ def apply_params(
apply_configuration(
template,
configuration,
get_contract_tile_sizes(configuration, self.tile_dims),
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
),
"",
)
Expand All @@ -328,7 +359,9 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration)))
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -341,12 +374,12 @@ def get_transform_function_batch_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
Expand All @@ -368,7 +401,10 @@ def apply_params(
"// ",
)
modified += apply_configuration(
template, configuration, get_batch_mmt_tile_sizes(configuration)
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand All @@ -392,9 +428,9 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

tile_sizes = ", ".join(
map(str, get_contract_tile_sizes(configuration, tile_dims))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -409,12 +445,12 @@ def get_transform_function_batch_matmul(
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
Expand All @@ -440,7 +476,8 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_contract_tile_sizes(configuration, self.tile_dims),
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
)

embeddable = indent(
Expand Down Expand Up @@ -548,7 +585,7 @@ def tune(
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups, mma_list)
generate_solutions(tuner_context, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
Expand Down
Loading
Loading