From eeb53a950c670adfbe48628c9084f9e1926b7de0 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Thu, 22 Aug 2024 12:57:05 -0500 Subject: [PATCH 01/23] Add candidate_gen.py --- sharktank/sharktank/tools/candidate_gen.py | 1312 ++++++++++++++++++++ 1 file changed, 1312 insertions(+) create mode 100755 sharktank/sharktank/tools/candidate_gen.py diff --git a/sharktank/sharktank/tools/candidate_gen.py b/sharktank/sharktank/tools/candidate_gen.py new file mode 100755 index 000000000..02f8eb9d9 --- /dev/null +++ b/sharktank/sharktank/tools/candidate_gen.py @@ -0,0 +1,1312 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import argparse +import logging +import math +import pickle +import re +import z3 +from dataclasses import asdict, dataclass +from enum import Enum +from os import mkdir, path, makedirs +from typing import Callable +from textwrap import indent + +import iree.compiler as ireec +from iree.compiler import ir +from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen + +""" +Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk +""" + +tune_logger = logging.getLogger("tune") + + +class DispatchKind(Enum): + conv = 1 + mmt = 2 + contraction = 3 + batch_mmt = 4 + batch_matmul = 5 + broadcast_rhs_mmt = 6 + + +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_kind: DispatchKind | None = None + + +class ElementType(Enum): + i8 = 1 + i32 = 2 + f8 = 3 + f16 = 4 + f32 = 5 + + @property + def bitwidth(self) -> int: + match self: + case ElementType.i8 | ElementType.f8: + return 8 + case ElementType.f16: + return 16 + case ElementType.i32 | ElementType.f32: + return 32 + case _: + assert False, "unhandled case" + + def __str__(self) -> str: + return self.name + + +@dataclass +class ShapedType: + shape: list[int] + element_type: ElementType + + def rank(self) -> int: + return len(self.shape) + + @property + def bitwidth(self) -> int: + return self.element_type.bitwidth + + def __str__(self) -> str: + dim_to_str = lambda dim: str(dim) if dim != -1 else "?" + return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) + + +@dataclass +class MatmulSize: + M: int + N: int + K: int + B: int = 1 + + +@dataclass +class ProblemSize: + matmul_size: MatmulSize + lhs_type: ShapedType + rhs_type: ShapedType + res_type: ShapedType + dispatch_kind: DispatchKind + + @property + def MNK(self) -> tuple[int, int, int]: + return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) + + +@dataclass +class MfmaIntrinsic: + input_type: ElementType + m: int + n: int + k: int + output_type: ElementType + + def __str__(self) -> str: + input = str(self.input_type).upper() + output = str(self.output_type).upper() + return f"MFMA_{input}_{self.m}x{self.n}x{self.k}_{output}" + + @staticmethod + def mfma_f16_16x16x16_f32(): + return MfmaIntrinsic(ElementType.f16, 16, 16, 16, ElementType.f32) + + @staticmethod + def mfma_f16_32x32x8_f32(): + return MfmaIntrinsic(ElementType.f16, 32, 32, 8, ElementType.f32) + + @staticmethod + def mfma_i8_16x16x32_i32(): + return MfmaIntrinsic(ElementType.i8, 16, 16, 32, ElementType.i32) + + @staticmethod + def mfma_i8_32x32x16_i32(): + return MfmaIntrinsic(ElementType.i8, 32, 32, 16, ElementType.i32) + + @staticmethod + def all(): + return [ + MfmaIntrinsic.mfma_f16_16x16x16_f32(), + MfmaIntrinsic.mfma_f16_32x32x8_f32(), + MfmaIntrinsic.mfma_i8_16x16x32_i32(), + MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + +@dataclass +class Configuration: + subgroup_size: int + workgroup_size: list[int] + intrinsic: MfmaIntrinsic + tile_sizes: list[int] + subgroup_m_count: int + subgroup_n_count: int + waves_per_eu: int + + +class MlirRegex(str, Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + +def read_input_mlir(filename: str) -> list[str]: + with open(filename, "r") as f: + return f.readlines() + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +@dataclass +class ConvDimInfo: + n: int + oh: int + ow: int + oc: int + fh: int + fw: int + ic: int + + @staticmethod + def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): + n, oh, ow, oc = res_shaped_type.shape + fh, fw, ic, _ = rhs_shaped_type.shape + return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) + + @staticmethod + def from_problem_size(problem_size: ProblemSize): + return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) + + +def get_conv_tile_sizes(configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +def get_pipeline_config(configuration: Configuration) -> str: + extra_config = ", prefetch_shared_memory" + if configuration.waves_per_eu != 2: + extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' + return extra_config + + +def get_transform_function_mmt( + problem_size: ProblemSize, functionName: str, configuration: Configuration +) -> str: + tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +}} +""" + + +# int64_t n = outputShape[0]; +# int64_t oh = outputShape[1]; +# int64_t ow = outputShape[2]; +# int64_t oc = outputShape[3]; +# int64_t fh = filterShape[0]; +# int64_t fw = filterShape[1]; +# int64_t ic = filterShape[2]; +def get_transform_function_conv( + problem_size: ProblemSize, functionName: str, configuration: Configuration +) -> str: + dynamic_batch_input_ty = problem_size.lhs_type + dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() + dynamic_batch_input_ty.shape[0] = -1 + + dynamic_batch_output_ty = problem_size.res_type + dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() + dynamic_batch_output_ty.shape[0] - 1 + + input = f"tensor<{dynamic_batch_input_ty}>" + filter = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{dynamic_batch_output_ty}>" + + tile_sizes = ", ".join(map(str, get_conv_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ + ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): + %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} + ins(%lhs, %rhs : {input}, {filter}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param +}} +""" + + +def get_transform_function_batch_matmul( + problem_size: ProblemSize, + tile_dims: str, + functionName: str, + configuration: Configuration, +) -> str: + input0 = f"tensor<{problem_size.lhs_type}>" + input1 = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{problem_size.res_type}>" + + tile_sizes = ", ".join(map(str, get_contract_tile_sizes(configuration, tile_dims))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ + ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : {input0}, {input1}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param +}} +""" + + +def get_transform_function_batch_mmt( + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, +) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + +def get_transform_function_broadcast_rhs_mmt( + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, +) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + lhs_dynamic_batch = problem_size.lhs_type + lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() + lhs_dynamic_batch.shape[0] = -1 + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + +def apply_configuration( + template: list[str], configuration: Configuration, tile_sizes: list[int] +) -> str: + tune_logger.info(f"Applying: {configuration}") + expr0 = re.compile( + r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" + ) + expr1 = re.compile( + r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," + ) + expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' + repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' + repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + + new_mlir = "" + for line in template: + if "intrinsic =" in line: + line = re.sub(expr0, repl0, line) + if "LLVMGPUVectorDistribute " in line: + line = re.sub(expr1, repl1, line) + if "tile_sizes" in line: + line = re.sub(expr2, repl2, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr3, repl3, line) + new_mlir += line + + return new_mlir + + +def apply_params_mmt( + problem_size: ProblemSize, template: list[str], configuration: Configuration +) -> tuple[str, str]: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_mmt( + problem_size, f"match_mmt_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_mmt_tile_sizes(configuration) + ) + embeddable = indent( + get_transform_function_mmt(problem_size, f"match_op", configuration), " " + ) + return modified, embeddable + + +def apply_params_conv( + problem_size: ProblemSize, template: list[str], configuration: Configuration +) -> tuple[str, str]: + conv_dims = ConvDimInfo.from_problem_size(problem_size) + modified = indent( + get_transform_function_conv( + problem_size, + f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_conv_tile_sizes(configuration) + ) + embeddable = indent( + get_transform_function_conv(problem_size, f"match_op", configuration), + " ", + ) + return modified, embeddable + + +def apply_params_contract( + problem_size: ProblemSize, + tile_dims: str, + template: list[str], + configuration: Configuration, +) -> tuple[str, str]: + # TODO: Generate transform function. + return ( + apply_configuration( + template, configuration, get_contract_tile_sizes(configuration, tile_dims) + ), + "", + ) + + +def apply_params_batch_matmul( + problem_size: ProblemSize, + tile_dims: str, + template: list[str], + configuration: Configuration, +) -> tuple[str, str]: + tune_logger.info(f"{configuration}") + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_batch_matmul( + problem_size, + tile_dims, + f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_contract_tile_sizes(configuration, tile_dims) + ) + + embeddable = indent( + get_transform_function_batch_matmul( + problem_size, tile_dims, f"match_op", configuration + ), + " ", + ) + return modified, embeddable + + +def apply_params_batch_mmt( + problem_size: ProblemSize, template: list[str], configuration: Configuration +) -> tuple[str, str]: + M, N, K = problem_size.MNK + B = problem_size.matmul_size.B + modified = indent( + get_transform_function_batch_mmt( + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + get_transform_function_batch_mmt(problem_size, f"match_op", configuration), + " ", + ) + return modified, embeddable + + +def apply_params_broadcast_rhs_mmt( + problem_size: ProblemSize, template: list[str], configuration: Configuration +) -> tuple[str, str]: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return modified, embeddable + + +def parse_tensor_type(tensor_type: str) -> ShapedType: + shape_match = re.search(MlirRegex.tensor_type, tensor_type) + assert shape_match + + shape_str = shape_match.group(1) + dims_and_elem = shape_str.split("x") + dims = [int(x) for x in dims_and_elem[:-1]] + elem = dims_and_elem[-1] + str_to_elem_ty = {x.name: x for x in ElementType} + return ShapedType(dims, str_to_elem_ty[elem]) + + +def get_shapes_mmt(template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], rhs_shaped_type.shape[0], lhs_shaped_type.shape[1] + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + + assert False, "Shape not found" + + +def get_shapes_conv(template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + +def get_shapes_contract( + template: list[str], lhs_dims: str, rhs_dims: str +) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + +def get_shapes_batch_matmul( + template: list[str], lhs_dims: str, rhs_dims: str +) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, LHS)) + B0 = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, RHS)) + B1 = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, RES)) + M = math.prod(val if dim == "m" else 1 for dim, val in zip(lhs_dims, LHS)) + N = math.prod(val if dim == "n" else 1 for dim, val in zip(rhs_dims, RHS)) + K0 = math.prod(val if dim == "k" else 1 for dim, val in zip(lhs_dims, LHS)) + K1 = math.prod(val if dim == "k" else 1 for dim, val in zip(rhs_dims, RHS)) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" + + +def get_shapes_batch_mmt(template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + +def is_broadcast_rhs_mmt_op(line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + +def is_broadcast_rhs_mmt(template: list[str]) -> bool: + return any(is_broadcast_rhs_mmt_op(line) for line in template) + + +def get_shapes_broadcast_rhs_mmt(template: list[str]) -> ProblemSize: + for line in template: + if not is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + +def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: + if problem_size.res_type.element_type != intrinsic.output_type: + return False + if problem_size.dispatch_kind != DispatchKind.batch_matmul: + if problem_size.lhs_type.element_type != intrinsic.input_type: + return False + if problem_size.rhs_type.element_type != intrinsic.input_type: + return False + return True + + return list(filter(is_compatible, MfmaIntrinsic.all())) + + +def get_mfma_intrinsic_constraints( + problem_size: ProblemSize, + intrinsic_m: z3.ArithRef, + intrinsic_n: z3.ArithRef, + intrinsic_k: z3.ArithRef, +) -> z3.BoolRef: + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + return z3.Or( + *( + z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) + for mfma in compatible_intrinsics + ) + ) + + +def get_dispatch_constraints( + problem_size: ProblemSize, + tile_m: z3.ArithRef, + tile_n: z3.ArithRef, + tile_k: z3.ArithRef, +) -> list[z3.BoolRef]: + if problem_size.dispatch_kind != DispatchKind.conv: + return [] + + dim_info = ConvDimInfo.from_problem_size(problem_size) + conv_constraints = [] + # WARNING: This sometimes makes the constraints UNSAT for some reason. + conv_constraints += [tile_m <= dim_info.ow] + conv_constraints += [tile_n <= dim_info.oc] + conv_constraints += [tile_k <= dim_info.ic] + return conv_constraints + + +def calculate_shared_memory_usage_in_bytes( + problem_size: ProblemSize, + m: int | z3.ArithRef, + n: int | z3.ArithRef, + k: int | z3.ArithRef, +) -> int | z3.ArithRef: + lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) + rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + return lhs_memory + rhs_memory + + +def generate_constraints( + problem_size: ProblemSize, + tile_sizes, + num_subgroups, + subgroup_size, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, + waves_per_eu, +): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +def generate_solutions(problem_size: ProblemSize, num_subgrups: int): + M, N, K = problem_size.MNK + tune_logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = generate_constraints( + problem_size, + [m, n, k], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + solver.add(z3.simplify(z3.And(constraints))) + tune_logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + MfmaIntrinsic( + problem_size.lhs_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.res_type.element_type, + ), + [lookup(m), lookup(n), lookup(k)], + lookup(sg_m_cnt), + lookup(sg_n_cnt), + lookup(waves_per_eu), + ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) + i += 1 + yield config + + +def get_default_output_dir() -> str: + from datetime import datetime + + return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") + + +def parse_mlir(mlir_text: str) -> ir.Module: + mlir_module = None + with ireec.ir.Context() as context: + try: + mlir_module = ireec.ir.Module.parse(mlir_text) + tune_logger.info("MLIR parsing successful!") + except ireec.ir.MLIRError as e: + tune_logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +def walk_callback_detect_type( + op: ir.Operation, walk_result: OpWalkResult +) -> ir.WalkResult: + if op.name == "linalg.conv_2d_nhwc_hwcf": + walk_result.was_interrupted = True + walk_result.dispatch_kind = DispatchKind.conv + return ir.WalkResult.INTERRUPT + + if op.name == "util.func": + func_name = str(op.opview.sym_name) + if "batch_matmul_transpose_b" in func_name: + walk_result.was_interrupted = True + walk_result.dispatch_kind = DispatchKind.batch_mmt + return ir.WalkResult.INTERRUPT + if "batch_matmul" in func_name: + walk_result.was_interrupted = True + walk_result.dispatch_kind = DispatchKind.batch_matmul + return ir.WalkResult.INTERRUPT + if "matmul_transpose_b" in func_name: + walk_result.was_interrupted = True + walk_result.dispatch_kind = DispatchKind.mmt + return ir.WalkResult.INTERRUPT + if "matmul_like" in func_name: + walk_result.was_interrupted = True + walk_result.dispatch_kind = DispatchKind.contraction + return ir.WalkResult.INTERRUPT + return ir.WalkResult.ADVANCE + + +def walk_mlir_op(mlir_module: ir.Module) -> OpWalkResult: + walk_result = OpWalkResult() + for op in mlir_module.body.operations: + op.walk( + lambda op: walk_callback_detect_type(op, walk_result), + ir.WalkOrder.POST_ORDER, + ) + if walk_result.was_interrupted: + break + + return walk_result + + +def tune( + input: str, + output: str = "", + limit: int = 4096, + num_subgroups: int = 4, + lhs_dims: str = "mk", + rhs_dims: str = "nk", + tile_dims: str = "mnk", +): + input_file = str(input) + + if not output: + output = get_default_output_dir() + + # Create the directory if it does not exist + makedirs(str(output), exist_ok=True) + + tune_logger.debug(f"Output directory {output}") + tune_logger.debug(f"Processing {input_file}") + mlir_template = read_input_mlir(input_file) + mlir_text = "".join(mlir_template) + + mlir_module = parse_mlir(mlir_text) + walk_result = walk_mlir_op(mlir_module) + assert walk_result.dispatch_kind != None + + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + get_shapes_fn: Callable[[list[str]], ProblemSize] | None = None + apply_params_fn: ( + Callable[[ProblemSize, list[str], Configuration], tuple[str, str]] | None + ) = None + if walk_result.dispatch_kind == DispatchKind.conv: + get_shapes_fn = get_shapes_conv + apply_params_fn = apply_params_conv + elif walk_result.dispatch_kind == DispatchKind.mmt: + get_shapes_fn = get_shapes_mmt + apply_params_fn = apply_params_mmt + elif walk_result.dispatch_kind == DispatchKind.contraction: + if is_broadcast_rhs_mmt(mlir_template): + get_shapes_fn = get_shapes_broadcast_rhs_mmt + apply_params_fn = apply_params_broadcast_rhs_mmt + else: + get_shapes_fn = lambda template: get_shapes_contract( + template, lhs_dims, rhs_dims + ) + apply_params_fn = lambda ps, template, config: apply_params_contract( + ps, tile_dims, template, config + ) + elif walk_result.dispatch_kind == DispatchKind.batch_matmul: + get_shapes_fn = lambda template: get_shapes_batch_matmul( + template, lhs_dims, rhs_dims + ) + apply_params_fn = lambda ps, template, config: apply_params_batch_matmul( + ps, tile_dims, template, config + ) + elif walk_result.dispatch_kind == DispatchKind.batch_mmt: + get_shapes_fn = get_shapes_batch_mmt + apply_params_fn = apply_params_batch_mmt + else: + assert False, f"Unhandled dispatch kind: {walk_result.dispatch_kind}" + + problem_size = get_shapes_fn(mlir_template) + tune_logger.debug(str(problem_size)) + + configs = [] + for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + new_mlir, embeddable_tuning = apply_params_fn( + problem_size, mlir_template, config + ) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(new_mlir) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(embeddable_tuning) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input", help="Input mlir file", type=str) + parser.add_argument( + "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() + ) + parser.add_argument( + "-l", + "--limit", + help="Max number of candidates generated", + type=int, + default=4096, + ) + parser.add_argument( + "--num-subgroups", + help="Number of subgroups per workgroup to use. (-1 == unconstrained)", + type=int, + default=-1, + ) + parser.add_argument( + "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" + ) + parser.add_argument( + "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" + ) + parser.add_argument( + "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" + ) + + args = parser.parse_args() + tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) + + # Create printing formatter for logging info + formatter = logging.Formatter("%(message)s") + + # Create a handler to print to console + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + tune_logger.addHandler(console_handler) + + # # Optionally, add a file handler to log to a file + # file_handler = logging.FileHandler("tune.log") + # file_handler.setFormatter(formatter) + # tune_logger.addHandler(file_handler) + + tune( + args.input, + args.output, + args.limit, + args.num_subgroups, + args.lhs_dims, + args.rhs_dims, + args.tile_dims, + ) + + +if __name__ == "__main__": + args = main() From 5b69c41420cf98ef7f74028329d1062fa0bbcb11 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Thu, 22 Aug 2024 13:34:18 -0500 Subject: [PATCH 02/23] Create tuner dir and move tuning file --- sharktank/sharktank/tools/{ => tuner}/candidate_gen.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sharktank/sharktank/tools/{ => tuner}/candidate_gen.py (100%) diff --git a/sharktank/sharktank/tools/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py similarity index 100% rename from sharktank/sharktank/tools/candidate_gen.py rename to sharktank/sharktank/tools/tuner/candidate_gen.py From 996c42be6ef33215b7c4c8788b5584bcebef48ee Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Thu, 22 Aug 2024 13:37:14 -0500 Subject: [PATCH 03/23] Add candidate_gen test --- .../tools/tuner/candidate_gen_test.py | 784 ++++++++++++++++++ 1 file changed, 784 insertions(+) create mode 100644 sharktank/sharktank/tools/tuner/candidate_gen_test.py diff --git a/sharktank/sharktank/tools/tuner/candidate_gen_test.py b/sharktank/sharktank/tools/tuner/candidate_gen_test.py new file mode 100644 index 000000000..ad9b97e0a --- /dev/null +++ b/sharktank/sharktank/tools/tuner/candidate_gen_test.py @@ -0,0 +1,784 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import candidate_gen + +""" +Usage: python -m pytest test_tune.py +""" + + +def test_get_shaped_type_element_bitwidth(): + assert ( + candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 + ) + assert ( + candidate_gen.ShapedType( + [2048, 512, 384], candidate_gen.ElementType.f8 + ).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + ) + + +def test_get_shaped_type_to_str(): + assert ( + str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) + == "1024x2048xi8" + ) + assert ( + str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) + == "1024xf32" + ) + assert ( + str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) + == "1x2x3xf16" + ) + assert ( + str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) + == "?x2x3xf16" + ) + + +def test_parse_tensor_type(): + assert candidate_gen.parse_tensor_type( + "tensor<1x2x3xf32>" + ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) + assert candidate_gen.parse_tensor_type( + "tensor<123xi8>" + ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) + + +def test_get_mmt_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic="", + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + waves_per_eu=0, + ) + assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic="#iree_gpu.mma_layout", + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=1, + ) + assert candidate_gen.get_conv_tile_sizes(config) == [1, 1, 464, 320, 1, 1, 16] + + +def test_get_contract_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + 16, + 16, + 16, + ] + + +def test_get_pipeline_config(): + config1 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + config2 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=4, + ) + assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" + assert ( + candidate_gen.get_pipeline_config(config2) + == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_get_shapes_mmt(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.get_shapes_mmt(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + + +def test_get_shapes_conv(): + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.get_shapes_conv(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 256, 11520), + candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + + +def test_get_shapes_contract(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.get_shapes_contract( + template, "mk", "nk" + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul(): + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.get_shapes_batch_matmul( + template, "bmk", "bkn" + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 32, 1024, 1), + candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt(): + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.get_shapes_batch_mmt(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + +def test_mfma_intrinsic_to_str(): + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) + == "MFMA_F16_16x16x16_F32" + ) + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) + == "MFMA_I8_32x32x16_I32" + ) + + +def test_get_compatible_mfma_intrinsics(): + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), + candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + +def test_generate_solutions(): + matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) + lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + configs = candidate_gen.generate_solutions(problem_size, 4) + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 147456 + ) + + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 81920 + ) + + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) + == 12288 + ) + + +def test_generate_constraints_valid_input(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == candidate_gen.z3.sat + + +def test_generate_constraints_invalid_input(): + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == candidate_gen.z3.unsat + + +def test_apply_params_mmt(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.Configuration( + subgroup_size=16, + workgroup_size=[16, 16, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[8, 8, 8], + subgroup_m_count=16, + subgroup_n_count=16, + waves_per_eu=8, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + modified, embeddable = candidate_gen.apply_params_mmt( + problem_size, mlir_template, config + ) + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" + in modified + ) + assert "tile_sizes = [[8, 8, 8]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified + + +def test_apply_params_conv(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), + candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + modified, embeddable = candidate_gen.apply_params_conv( + problem_size, mlir_template, config + ) + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_contract(): + mlir_template = [ + ", subgroup_m_count = 2, subgroup_n_count = 2>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "*mnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 3840, 1280), + candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[480, 384, 32], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + new_mlir, _embeddable = candidate_gen.apply_params_contract( + problem_size, tile_dims, mlir_template, config + ) + + assert new_mlir + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in new_mlir + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in new_mlir + ) + assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir + + +def test_apply_params_batch_matmul(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "bmnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[416, 320, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + modified, embeddable = candidate_gen.apply_params_batch_matmul( + problem_size, tile_dims, mlir_template, config + ) + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_float(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + modified, embeddable = candidate_gen.apply_params_batch_mmt( + problem_size, mlir_template, config + ) + + assert embeddable + assert modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_int(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + modified, embeddable = candidate_gen.apply_params_batch_mmt( + problem_size, mlir_template, config + ) + + assert modified + assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_apply_params_broadcast_rhs_mmt(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.broadcast_rhs_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + modified, embeddable = candidate_gen.apply_params_broadcast_rhs_mmt( + problem_size, mlir_template, config + ) + + assert modified + assert ( + "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" + in modified + ) + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_detect_broadcast_rhs_mmt(): + mlir_lines = [ + r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + ] + assert candidate_gen.is_broadcast_rhs_mmt(mlir_lines) + + +def test_parse_mlir(): + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + mlir_module = candidate_gen.parse_mlir(mlir_str) + assert mlir_module != None + assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) + assert isinstance( + mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp + ) From d9c859e6b6898eb86e3e1f9d7f5b97ca735f8cd8 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Thu, 22 Aug 2024 14:28:11 -0500 Subject: [PATCH 04/23] Move test --- .../tools/tuner/candidate_gen_test.py | 784 ------------------ tuner-requirements.txt | 3 + 2 files changed, 3 insertions(+), 784 deletions(-) delete mode 100644 sharktank/sharktank/tools/tuner/candidate_gen_test.py create mode 100644 tuner-requirements.txt diff --git a/sharktank/sharktank/tools/tuner/candidate_gen_test.py b/sharktank/sharktank/tools/tuner/candidate_gen_test.py deleted file mode 100644 index ad9b97e0a..000000000 --- a/sharktank/sharktank/tools/tuner/candidate_gen_test.py +++ /dev/null @@ -1,784 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest -import candidate_gen - -""" -Usage: python -m pytest test_tune.py -""" - - -def test_get_shaped_type_element_bitwidth(): - assert ( - candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - ) - assert ( - candidate_gen.ShapedType( - [2048, 512, 384], candidate_gen.ElementType.f8 - ).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 - ) - - -def test_get_shaped_type_to_str(): - assert ( - str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) - == "1024x2048xi8" - ) - assert ( - str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) - == "1024xf32" - ) - assert ( - str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) - == "1x2x3xf16" - ) - assert ( - str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) - == "?x2x3xf16" - ) - - -def test_parse_tensor_type(): - assert candidate_gen.parse_tensor_type( - "tensor<1x2x3xf32>" - ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) - assert candidate_gen.parse_tensor_type( - "tensor<123xi8>" - ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) - - -def test_get_mmt_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic="", - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=1, - ) - assert candidate_gen.get_conv_tile_sizes(config) == [1, 1, 464, 320, 1, 1, 16] - - -def test_get_contract_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ - 16, - 16, - 16, - ] - - -def test_get_pipeline_config(): - config1 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - config2 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=4, - ) - assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" - assert ( - candidate_gen.get_pipeline_config(config2) - == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - ) - - -def test_get_shapes_mmt(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.get_shapes_mmt(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv(): - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.get_shapes_conv(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - - -def test_get_shapes_contract(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.get_shapes_contract( - template, "mk", "nk" - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul(): - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.get_shapes_batch_matmul( - template, "bmk", "bkn" - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt(): - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.get_shapes_batch_mmt(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - -def test_mfma_intrinsic_to_str(): - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) - == "MFMA_F16_16x16x16_F32" - ) - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) - == "MFMA_I8_32x32x16_I32" - ) - - -def test_get_compatible_mfma_intrinsics(): - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), - candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - -def test_generate_solutions(): - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - configs = candidate_gen.generate_solutions(problem_size, 4) - assert configs is not None - - -def test_calculate_shared_memory_usage_in_bytes(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 147456 - ) - - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 81920 - ) - - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) - == 12288 - ) - - -def test_generate_constraints_valid_input(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - # Define input parameters as z3 Ints - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are satisfiable - assert solver.check() == candidate_gen.z3.sat - - -def test_generate_constraints_invalid_input(): - # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are unsatisfiable - assert solver.check() == candidate_gen.z3.unsat - - -def test_apply_params_mmt(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - M, N, K = 2048, 1280, 1280 - - config = candidate_gen.Configuration( - subgroup_size=16, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, - waves_per_eu=8, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - modified, embeddable = candidate_gen.apply_params_mmt( - problem_size, mlir_template, config - ) - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" - in modified - ) - assert "tile_sizes = [[8, 8, 8]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified - - -def test_apply_params_conv(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - modified, embeddable = candidate_gen.apply_params_conv( - problem_size, mlir_template, config - ) - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_contract(): - mlir_template = [ - ", subgroup_m_count = 2, subgroup_n_count = 2>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - new_mlir, _embeddable = candidate_gen.apply_params_contract( - problem_size, tile_dims, mlir_template, config - ) - - assert new_mlir - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in new_mlir - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in new_mlir - ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir - - -def test_apply_params_batch_matmul(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - modified, embeddable = candidate_gen.apply_params_batch_matmul( - problem_size, tile_dims, mlir_template, config - ) - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_float(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - modified, embeddable = candidate_gen.apply_params_batch_mmt( - problem_size, mlir_template, config - ) - - assert embeddable - assert modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_int(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - modified, embeddable = candidate_gen.apply_params_batch_mmt( - problem_size, mlir_template, config - ) - - assert modified - assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_apply_params_broadcast_rhs_mmt(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - modified, embeddable = candidate_gen.apply_params_broadcast_rhs_mmt( - problem_size, mlir_template, config - ) - - assert modified - assert ( - "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" - in modified - ) - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_detect_broadcast_rhs_mmt(): - mlir_lines = [ - r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - ] - assert candidate_gen.is_broadcast_rhs_mmt(mlir_lines) - - -def test_parse_mlir(): - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str) - assert mlir_module != None - assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance( - mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp - ) diff --git a/tuner-requirements.txt b/tuner-requirements.txt new file mode 100644 index 000000000..26ef8f6c7 --- /dev/null +++ b/tuner-requirements.txt @@ -0,0 +1,3 @@ +tqdm==4.66.4 +types-tqdm==4.66.0.20240417 +z3_solver==4.13.0.0 From 16d108708042b23a6a456258a9766608b78c1e02 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 10:25:04 -0500 Subject: [PATCH 05/23] Add missing test file --- sharktank/tests/tuner/candidate_gen_test.py | 784 ++++++++++++++++++++ 1 file changed, 784 insertions(+) create mode 100644 sharktank/tests/tuner/candidate_gen_test.py diff --git a/sharktank/tests/tuner/candidate_gen_test.py b/sharktank/tests/tuner/candidate_gen_test.py new file mode 100644 index 000000000..ad9b97e0a --- /dev/null +++ b/sharktank/tests/tuner/candidate_gen_test.py @@ -0,0 +1,784 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import candidate_gen + +""" +Usage: python -m pytest test_tune.py +""" + + +def test_get_shaped_type_element_bitwidth(): + assert ( + candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 + ) + assert ( + candidate_gen.ShapedType( + [2048, 512, 384], candidate_gen.ElementType.f8 + ).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + ) + + +def test_get_shaped_type_to_str(): + assert ( + str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) + == "1024x2048xi8" + ) + assert ( + str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) + == "1024xf32" + ) + assert ( + str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) + == "1x2x3xf16" + ) + assert ( + str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) + == "?x2x3xf16" + ) + + +def test_parse_tensor_type(): + assert candidate_gen.parse_tensor_type( + "tensor<1x2x3xf32>" + ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) + assert candidate_gen.parse_tensor_type( + "tensor<123xi8>" + ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) + + +def test_get_mmt_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic="", + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + waves_per_eu=0, + ) + assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic="#iree_gpu.mma_layout", + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=1, + ) + assert candidate_gen.get_conv_tile_sizes(config) == [1, 1, 464, 320, 1, 1, 16] + + +def test_get_contract_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + 16, + 16, + 16, + ] + + +def test_get_pipeline_config(): + config1 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + config2 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=4, + ) + assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" + assert ( + candidate_gen.get_pipeline_config(config2) + == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_get_shapes_mmt(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.get_shapes_mmt(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + + +def test_get_shapes_conv(): + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.get_shapes_conv(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 256, 11520), + candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + + +def test_get_shapes_contract(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.get_shapes_contract( + template, "mk", "nk" + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul(): + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.get_shapes_batch_matmul( + template, "bmk", "bkn" + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 32, 1024, 1), + candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt(): + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.get_shapes_batch_mmt(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + +def test_mfma_intrinsic_to_str(): + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) + == "MFMA_F16_16x16x16_F32" + ) + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) + == "MFMA_I8_32x32x16_I32" + ) + + +def test_get_compatible_mfma_intrinsics(): + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), + candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + +def test_generate_solutions(): + matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) + lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + configs = candidate_gen.generate_solutions(problem_size, 4) + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 147456 + ) + + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 81920 + ) + + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) + == 12288 + ) + + +def test_generate_constraints_valid_input(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == candidate_gen.z3.sat + + +def test_generate_constraints_invalid_input(): + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == candidate_gen.z3.unsat + + +def test_apply_params_mmt(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.Configuration( + subgroup_size=16, + workgroup_size=[16, 16, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[8, 8, 8], + subgroup_m_count=16, + subgroup_n_count=16, + waves_per_eu=8, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + modified, embeddable = candidate_gen.apply_params_mmt( + problem_size, mlir_template, config + ) + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" + in modified + ) + assert "tile_sizes = [[8, 8, 8]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified + + +def test_apply_params_conv(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), + candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + modified, embeddable = candidate_gen.apply_params_conv( + problem_size, mlir_template, config + ) + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_contract(): + mlir_template = [ + ", subgroup_m_count = 2, subgroup_n_count = 2>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "*mnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 3840, 1280), + candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[480, 384, 32], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + new_mlir, _embeddable = candidate_gen.apply_params_contract( + problem_size, tile_dims, mlir_template, config + ) + + assert new_mlir + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in new_mlir + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in new_mlir + ) + assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir + + +def test_apply_params_batch_matmul(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "bmnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[416, 320, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + modified, embeddable = candidate_gen.apply_params_batch_matmul( + problem_size, tile_dims, mlir_template, config + ) + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_float(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + modified, embeddable = candidate_gen.apply_params_batch_mmt( + problem_size, mlir_template, config + ) + + assert embeddable + assert modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_int(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + modified, embeddable = candidate_gen.apply_params_batch_mmt( + problem_size, mlir_template, config + ) + + assert modified + assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_apply_params_broadcast_rhs_mmt(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.broadcast_rhs_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + modified, embeddable = candidate_gen.apply_params_broadcast_rhs_mmt( + problem_size, mlir_template, config + ) + + assert modified + assert ( + "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" + in modified + ) + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_detect_broadcast_rhs_mmt(): + mlir_lines = [ + r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + ] + assert candidate_gen.is_broadcast_rhs_mmt(mlir_lines) + + +def test_parse_mlir(): + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + mlir_module = candidate_gen.parse_mlir(mlir_str) + assert mlir_module != None + assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) + assert isinstance( + mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp + ) From 3931a99705d199fc3d47ad5a6f159a66572b6913 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 11:15:29 -0500 Subject: [PATCH 06/23] Create separate ci tuner test --- .github/workflows/ci-tuner.yml | 49 +++++++++++++++++++ .../tools/tuner/requirements-dev.txt | 2 + .../tools/tuner/requirements-tuner.txt | 3 +- 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/ci-tuner.yml create mode 100644 sharktank/sharktank/tools/tuner/requirements-dev.txt rename tuner-requirements.txt => sharktank/sharktank/tools/tuner/requirements-tuner.txt (81%) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml new file mode 100644 index 000000000..9a735965f --- /dev/null +++ b/.github/workflows/ci-tuner.yml @@ -0,0 +1,49 @@ +name: CI - Tuner + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + pre-commit-and-test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4.1.7 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10.12' + + - name: Install dev dependencies + run: | + python -m pip install --upgrade pip + pip install -r sharktank/sharktank/tools/tuner/requirements-dev.txt + + - name: Run pre-commit test + run: pre-commit run --all-files --show-diff-on-failure --color=always + + - name: Install tuner dependencies + run: | + pip install -r sharktank/sharktank/tools/tuner/requirements-tuner.txt + python -m pip install \ + --find-links https://iree.dev/pip-release-links.html \ + --upgrade \ + iree-compiler iree-runtime + + - name: Run pytest + working-directory: sharktank/sharktank/tests/tuner + run: | + python -m pytest diff --git a/sharktank/sharktank/tools/tuner/requirements-dev.txt b/sharktank/sharktank/tools/tuner/requirements-dev.txt new file mode 100644 index 000000000..51d5b9ba0 --- /dev/null +++ b/sharktank/sharktank/tools/tuner/requirements-dev.txt @@ -0,0 +1,2 @@ +pre-commit==3.8.0 +virtualenv==20.13.0 diff --git a/tuner-requirements.txt b/sharktank/sharktank/tools/tuner/requirements-tuner.txt similarity index 81% rename from tuner-requirements.txt rename to sharktank/sharktank/tools/tuner/requirements-tuner.txt index 26ef8f6c7..f3484c921 100644 --- a/tuner-requirements.txt +++ b/sharktank/sharktank/tools/tuner/requirements-tuner.txt @@ -1,3 +1,4 @@ +pytest==8.2.2 tqdm==4.66.4 -types-tqdm==4.66.0.20240417 z3_solver==4.13.0.0 +types-tqdm==4.66.0.20240417 From 3b92d26aeb12980607caa87befaed28ee112f106 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 11:46:06 -0500 Subject: [PATCH 07/23] Edit code comments --- sharktank/sharktank/tools/tuner/candidate_gen.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index 02f8eb9d9..33214425f 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -7,6 +7,17 @@ # Given an input dispatch, this code modifies the hyperparameters # in the code and runs it. +""" +Generate candidates by tweaking op configuration for tuning. + +It can be invoked in two ways: + 1. From another python script, import and call `tune()` + 2. Run this script directly from the command + +Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk + +""" + import argparse import logging import math @@ -23,9 +34,6 @@ from iree.compiler import ir from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen -""" -Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk -""" tune_logger = logging.getLogger("tune") From d8bf23c71dad61b181640bc311c4aef4b2e32834 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 11:49:20 -0500 Subject: [PATCH 08/23] Fix dir path in ci --- .github/workflows/ci-tuner.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index 9a735965f..d83ffce83 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -44,6 +44,6 @@ jobs: iree-compiler iree-runtime - name: Run pytest - working-directory: sharktank/sharktank/tests/tuner + working-directory: sharktank/tests/tuner run: | python -m pytest From 152aa53f3a67cf04a7506916faa718ca7cfa2f2b Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 11:58:14 -0500 Subject: [PATCH 09/23] Edit comments and fix import --- sharktank/tests/tuner/candidate_gen_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sharktank/tests/tuner/candidate_gen_test.py b/sharktank/tests/tuner/candidate_gen_test.py index ad9b97e0a..3b86c0bf8 100644 --- a/sharktank/tests/tuner/candidate_gen_test.py +++ b/sharktank/tests/tuner/candidate_gen_test.py @@ -4,13 +4,13 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import pytest -import candidate_gen - """ -Usage: python -m pytest test_tune.py +Usage: python -m pytest candidate_gen_test.py """ +import pytest +from sharktank.tools.tuner import candidate_gen + def test_get_shaped_type_element_bitwidth(): assert ( From 47d1c784c78c2043f51848b846fb201337a2d8d0 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 17:35:57 -0500 Subject: [PATCH 10/23] Add dispatch kind registry --- .../sharktank/tools/tuner/candidate_gen.py | 202 ++++++++++++------ 1 file changed, 137 insertions(+), 65 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index 33214425f..aa61fc152 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -27,8 +27,9 @@ from dataclasses import asdict, dataclass from enum import Enum from os import mkdir, path, makedirs -from typing import Callable +from typing import Callable, Optional from textwrap import indent +from abc import ABC, abstractmethod import iree.compiler as ireec from iree.compiler import ir @@ -1126,47 +1127,138 @@ def parse_mlir(mlir_text: str) -> ir.Module: return mlir_module -def walk_callback_detect_type( - op: ir.Operation, walk_result: OpWalkResult -) -> ir.WalkResult: - if op.name == "linalg.conv_2d_nhwc_hwcf": - walk_result.was_interrupted = True - walk_result.dispatch_kind = DispatchKind.conv - return ir.WalkResult.INTERRUPT +@dataclass +class CandidateGenFn: + get_shapes_fn: Optional[Callable[[list[str]], ProblemSize]] = None + apply_params_fn: Optional[ + Callable[[ProblemSize, list[str], Configuration], tuple[str, str]] + ] = None + + +class DispatchTuner(ABC): + @abstractmethod + def supports(self, mlir: str) -> bool: + pass + + @abstractmethod + def get_candidate_gen_fn(self) -> CandidateGenFn: + pass + + +class DispatchTunerRegistry: + def __init__(self): + self.registry = set() + + def register(self, dispatch_tuners: list[DispatchTuner]) -> None: + for dispatch_tuner in dispatch_tuners: + self.registry.add(dispatch_tuner) + + def get_candidate_gen_fn(self, mlir: str) -> CandidateGenFn: + for dispatch_tuner in self.registry: + if dispatch_tuner.supports(mlir): + return dispatch_tuner.get_candidate_gen_fn() + + assert False, "Not supported" + + +class MmtTuner(DispatchTuner): + def supports(self, mlir: str) -> bool: + return "matmul_transpose_b" in mlir + + def get_candidate_gen_fn(self) -> CandidateGenFn: + return CandidateGenFn(get_shapes_mmt, apply_params_mmt) + + +class ConvTuner(DispatchTuner): + def supports(self, mlir: str) -> bool: + return "conv_2d_nhwc_hwcf" in mlir + + def get_candidate_gen_fn(self) -> CandidateGenFn: + return CandidateGenFn(get_shapes_conv, apply_params_conv) + + +class ContractionTuner(DispatchTuner): + def __init__( + self, lhs_dims: str, rhs_dims: str, tile_dims: str, mlir_template: str + ): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + self.mlir_template = mlir_template + + def supports(self, mlir: str) -> bool: + return "matmul_like" in mlir + + def get_candidate_gen_fn(self) -> CandidateGenFn: + if is_broadcast_rhs_mmt(self.mlir_template): + get_shapes_fn = get_shapes_broadcast_rhs_mmt + apply_params_fn = apply_params_broadcast_rhs_mmt + else: + get_shapes_fn = lambda template: get_shapes_contract( + template, self.lhs_dims, self.rhs_dims + ) + apply_params_fn = lambda ps, template, config: apply_params_contract( + ps, self.tile_dims, template, config + ) + return CandidateGenFn(get_shapes_fn, apply_params_fn) + + +class BatchMmtTuner(DispatchTuner): + def supports(self, mlir: str) -> bool: + return "batch_matmul_transpose_b" in mlir + + def get_candidate_gen_fn(self) -> CandidateGenFn: + return CandidateGenFn(get_shapes_batch_mmt, apply_params_batch_mmt) + +class BatchMatmulTuner(DispatchTuner): + def __init__(self, lhs_dims: str, rhs_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + + def supports(self, mlir: str) -> bool: + return "batch_matmul" in mlir + + def get_candidate_gen_fn(self) -> CandidateGenFn: + get_shapes_fn = lambda template: get_shapes_batch_matmul( + template, self.lhs_dims, self.rhs_dims + ) + apply_params_fn = lambda ps, template, config: apply_params_batch_matmul( + ps, self.tile_dims, template, config + ) + return CandidateGenFn(get_shapes_fn, apply_params_fn) + + +def walk_callback_get_fn( + op: ir.Operation, + candidate_gen_fn: CandidateGenFn, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> ir.WalkResult: if op.name == "util.func": func_name = str(op.opview.sym_name) - if "batch_matmul_transpose_b" in func_name: - walk_result.was_interrupted = True - walk_result.dispatch_kind = DispatchKind.batch_mmt - return ir.WalkResult.INTERRUPT - if "batch_matmul" in func_name: - walk_result.was_interrupted = True - walk_result.dispatch_kind = DispatchKind.batch_matmul - return ir.WalkResult.INTERRUPT - if "matmul_transpose_b" in func_name: - walk_result.was_interrupted = True - walk_result.dispatch_kind = DispatchKind.mmt - return ir.WalkResult.INTERRUPT - if "matmul_like" in func_name: - walk_result.was_interrupted = True - walk_result.dispatch_kind = DispatchKind.contraction + searched_fn = dispatch_tuner_registry.get_candidate_gen_fn(func_name) + candidate_gen_fn.get_shapes_fn = searched_fn.get_shapes_fn + candidate_gen_fn.apply_params_fn = searched_fn.apply_params_fn + if candidate_gen_fn.apply_params_fn and candidate_gen_fn.get_shapes_fn: return ir.WalkResult.INTERRUPT return ir.WalkResult.ADVANCE -def walk_mlir_op(mlir_module: ir.Module) -> OpWalkResult: - walk_result = OpWalkResult() +def walk_mlir_op( + mlir_module: ir.Module, + candidate_gen_fn: CandidateGenFn, + dispatch_tuner_registry: DispatchTunerRegistry, +): for op in mlir_module.body.operations: op.walk( - lambda op: walk_callback_detect_type(op, walk_result), + lambda op: walk_callback_get_fn( + op, candidate_gen_fn, dispatch_tuner_registry + ), ir.WalkOrder.POST_ORDER, ) - if walk_result.was_interrupted: + if candidate_gen_fn.apply_params_fn and candidate_gen_fn.get_shapes_fn: break - return walk_result - def tune( input: str, @@ -1191,46 +1283,26 @@ def tune( mlir_text = "".join(mlir_template) mlir_module = parse_mlir(mlir_text) - walk_result = walk_mlir_op(mlir_module) - assert walk_result.dispatch_kind != None - # Save the input file as the first candidate. with open(path.join(output, f"0.mlir"), "w") as f: f.write(mlir_text) - get_shapes_fn: Callable[[list[str]], ProblemSize] | None = None - apply_params_fn: ( - Callable[[ProblemSize, list[str], Configuration], tuple[str, str]] | None - ) = None - if walk_result.dispatch_kind == DispatchKind.conv: - get_shapes_fn = get_shapes_conv - apply_params_fn = apply_params_conv - elif walk_result.dispatch_kind == DispatchKind.mmt: - get_shapes_fn = get_shapes_mmt - apply_params_fn = apply_params_mmt - elif walk_result.dispatch_kind == DispatchKind.contraction: - if is_broadcast_rhs_mmt(mlir_template): - get_shapes_fn = get_shapes_broadcast_rhs_mmt - apply_params_fn = apply_params_broadcast_rhs_mmt - else: - get_shapes_fn = lambda template: get_shapes_contract( - template, lhs_dims, rhs_dims - ) - apply_params_fn = lambda ps, template, config: apply_params_contract( - ps, tile_dims, template, config - ) - elif walk_result.dispatch_kind == DispatchKind.batch_matmul: - get_shapes_fn = lambda template: get_shapes_batch_matmul( - template, lhs_dims, rhs_dims - ) - apply_params_fn = lambda ps, template, config: apply_params_batch_matmul( - ps, tile_dims, template, config - ) - elif walk_result.dispatch_kind == DispatchKind.batch_mmt: - get_shapes_fn = get_shapes_batch_mmt - apply_params_fn = apply_params_batch_mmt - else: - assert False, f"Unhandled dispatch kind: {walk_result.dispatch_kind}" + candidate_gen_fn = CandidateGenFn() + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims, mlir_template), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims), + ] + ) + + walk_mlir_op(mlir_module, candidate_gen_fn, dispatch_tuner_registry) + + get_shapes_fn = candidate_gen_fn.get_shapes_fn + apply_params_fn = candidate_gen_fn.apply_params_fn problem_size = get_shapes_fn(mlir_template) tune_logger.debug(str(problem_size)) From e498173a48fda3440231af72e2a2fb41cd314cf3 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 17:40:41 -0500 Subject: [PATCH 11/23] Remove unused class --- sharktank/sharktank/tools/tuner/candidate_gen.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index aa61fc152..a17eb0628 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -48,12 +48,6 @@ class DispatchKind(Enum): broadcast_rhs_mmt = 6 -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_kind: DispatchKind | None = None - - class ElementType(Enum): i8 = 1 i32 = 2 From 339bfb379a9b14843197f4754575fe335addc296 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Fri, 23 Aug 2024 17:53:50 -0500 Subject: [PATCH 12/23] Edit ci-tuner --- .github/workflows/ci-tuner.yml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index d83ffce83..b6552bfd5 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -15,7 +15,7 @@ permissions: contents: read jobs: - pre-commit-and-test: + test: runs-on: ubuntu-latest steps: @@ -32,9 +32,6 @@ jobs: python -m pip install --upgrade pip pip install -r sharktank/sharktank/tools/tuner/requirements-dev.txt - - name: Run pre-commit test - run: pre-commit run --all-files --show-diff-on-failure --color=always - - name: Install tuner dependencies run: | pip install -r sharktank/sharktank/tools/tuner/requirements-tuner.txt @@ -43,7 +40,5 @@ jobs: --upgrade \ iree-compiler iree-runtime - - name: Run pytest - working-directory: sharktank/tests/tuner - run: | - python -m pytest + - name: Run tuner tests + run: pytest sharktank/tests/tuner/ From 6d6a39fbe1a7192d6081f9d248fb8f9a50f700ee Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Mon, 26 Aug 2024 15:12:00 -0500 Subject: [PATCH 13/23] Move get_shapes_*() and apply_params_*() to DispatchTuner --- .../sharktank/tools/tuner/candidate_gen.py | 1009 +++++++++-------- 1 file changed, 509 insertions(+), 500 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index a17eb0628..ac05983cc 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -455,137 +455,6 @@ def apply_configuration( return new_mlir -def apply_params_mmt( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - M, N, K = problem_size.MNK - modified = indent( - get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) - ) - embeddable = indent( - get_transform_function_mmt(problem_size, f"match_op", configuration), " " - ) - return modified, embeddable - - -def apply_params_conv( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - conv_dims = ConvDimInfo.from_problem_size(problem_size) - modified = indent( - get_transform_function_conv( - problem_size, - f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_conv_tile_sizes(configuration) - ) - embeddable = indent( - get_transform_function_conv(problem_size, f"match_op", configuration), - " ", - ) - return modified, embeddable - - -def apply_params_contract( - problem_size: ProblemSize, - tile_dims: str, - template: list[str], - configuration: Configuration, -) -> tuple[str, str]: - # TODO: Generate transform function. - return ( - apply_configuration( - template, configuration, get_contract_tile_sizes(configuration, tile_dims) - ), - "", - ) - - -def apply_params_batch_matmul( - problem_size: ProblemSize, - tile_dims: str, - template: list[str], - configuration: Configuration, -) -> tuple[str, str]: - tune_logger.info(f"{configuration}") - M, N, K = problem_size.MNK - modified = indent( - get_transform_function_batch_matmul( - problem_size, - tile_dims, - f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_contract_tile_sizes(configuration, tile_dims) - ) - - embeddable = indent( - get_transform_function_batch_matmul( - problem_size, tile_dims, f"match_op", configuration - ), - " ", - ) - return modified, embeddable - - -def apply_params_batch_mmt( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - M, N, K = problem_size.MNK - B = problem_size.matmul_size.B - modified = indent( - get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - get_transform_function_batch_mmt(problem_size, f"match_op", configuration), - " ", - ) - return modified, embeddable - - -def apply_params_broadcast_rhs_mmt( - problem_size: ProblemSize, template: list[str], configuration: Configuration -) -> tuple[str, str]: - M, N, K = problem_size.MNK - modified = indent( - get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return modified, embeddable - - def parse_tensor_type(tensor_type: str) -> ShapedType: shape_match = re.search(MlirRegex.tensor_type, tensor_type) assert shape_match @@ -598,254 +467,6 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(dims, str_to_elem_ty[elem]) -def get_shapes_mmt(template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], rhs_shaped_type.shape[0], lhs_shaped_type.shape[1] - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - - assert False, "Shape not found" - - -def get_shapes_conv(template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - - -def get_shapes_contract( - template: list[str], lhs_dims: str, rhs_dims: str -) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - - -def get_shapes_batch_matmul( - template: list[str], lhs_dims: str, rhs_dims: str -) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, LHS)) - B0 = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, RHS)) - B1 = math.prod(val if dim == "b" else 1 for dim, val in zip(lhs_dims, RES)) - M = math.prod(val if dim == "m" else 1 for dim, val in zip(lhs_dims, LHS)) - N = math.prod(val if dim == "n" else 1 for dim, val in zip(rhs_dims, RHS)) - K0 = math.prod(val if dim == "k" else 1 for dim, val in zip(lhs_dims, LHS)) - K1 = math.prod(val if dim == "k" else 1 for dim, val in zip(rhs_dims, RHS)) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - - -def get_shapes_batch_mmt(template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - - def is_broadcast_rhs_mmt_op(line: str) -> bool: if "linalg.generic" not in line: return False @@ -862,51 +483,6 @@ def is_broadcast_rhs_mmt_op(line: str) -> bool: return True -def is_broadcast_rhs_mmt(template: list[str]) -> bool: - return any(is_broadcast_rhs_mmt_op(line) for line in template) - - -def get_shapes_broadcast_rhs_mmt(template: list[str]) -> ProblemSize: - for line in template: - if not is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.res_type.element_type != intrinsic.output_type: @@ -1122,22 +698,41 @@ def parse_mlir(mlir_text: str) -> ir.Module: @dataclass -class CandidateGenFn: - get_shapes_fn: Optional[Callable[[list[str]], ProblemSize]] = None - apply_params_fn: Optional[ - Callable[[ProblemSize, list[str], Configuration], tuple[str, str]] - ] = None +class TFMLIR: + """Transformation of MLIR context""" + + template: str + modified: str + embeddable: str class DispatchTuner(ABC): @abstractmethod - def supports(self, mlir: str) -> bool: + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" pass @abstractmethod - def get_candidate_gen_fn(self) -> CandidateGenFn: + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of thge operation.""" pass + @abstractmethod + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + """Apply parameter transformations to the operation.""" + pass + + +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + class DispatchTunerRegistry: def __init__(self): @@ -1147,111 +742,531 @@ def register(self, dispatch_tuners: list[DispatchTuner]) -> None: for dispatch_tuner in dispatch_tuners: self.registry.add(dispatch_tuner) - def get_candidate_gen_fn(self, mlir: str) -> CandidateGenFn: + def find_handler(self, op_name: str) -> DispatchTuner: for dispatch_tuner in self.registry: - if dispatch_tuner.supports(mlir): - return dispatch_tuner.get_candidate_gen_fn() - + if dispatch_tuner.supports(op_name): + return dispatch_tuner assert False, "Not supported" class MmtTuner(DispatchTuner): - def supports(self, mlir: str) -> bool: - return "matmul_transpose_b" in mlir + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) - def get_candidate_gen_fn(self) -> CandidateGenFn: - return CandidateGenFn(get_shapes_mmt, apply_params_mmt) + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_mmt( + problem_size, f"match_mmt_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_mmt_tile_sizes(configuration) + ) + embeddable = indent( + get_transform_function_mmt(problem_size, f"match_op", configuration), " " + ) + return TFMLIR(template, modified, embeddable) class ConvTuner(DispatchTuner): - def supports(self, mlir: str) -> bool: - return "conv_2d_nhwc_hwcf" in mlir + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) - def get_candidate_gen_fn(self) -> CandidateGenFn: - return CandidateGenFn(get_shapes_conv, apply_params_conv) + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + conv_dims = ConvDimInfo.from_problem_size(problem_size) + modified = indent( + get_transform_function_conv( + problem_size, + f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_conv_tile_sizes(configuration) + ) + embeddable = indent( + get_transform_function_conv(problem_size, f"match_op", configuration), + " ", + ) + return TFMLIR(template, modified, embeddable) class ContractionTuner(DispatchTuner): - def __init__( - self, lhs_dims: str, rhs_dims: str, tile_dims: str, mlir_template: str - ): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): self.lhs_dims = lhs_dims self.rhs_dims = rhs_dims self.tile_dims = tile_dims - self.mlir_template = mlir_template - - def supports(self, mlir: str) -> bool: - return "matmul_like" in mlir - - def get_candidate_gen_fn(self) -> CandidateGenFn: - if is_broadcast_rhs_mmt(self.mlir_template): - get_shapes_fn = get_shapes_broadcast_rhs_mmt - apply_params_fn = apply_params_broadcast_rhs_mmt - else: - get_shapes_fn = lambda template: get_shapes_contract( - template, self.lhs_dims, self.rhs_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" ) - apply_params_fn = lambda ps, template, config: apply_params_contract( - ps, self.tile_dims, template, config + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + def apply_params_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> tuple[str, str]: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return TFMLIR(template, modified, embeddable) + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + if self.is_broadcast_rhs_mmt(template): + return self.apply_params_broadcast_rhs_mmt( + problem_size, template, configuration ) - return CandidateGenFn(get_shapes_fn, apply_params_fn) + + # TODO: Generate transform function. + return TFMLIR( + template, + apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ), + "", + ) class BatchMmtTuner(DispatchTuner): - def supports(self, mlir: str) -> bool: - return "batch_matmul_transpose_b" in mlir + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) - def get_candidate_gen_fn(self) -> CandidateGenFn: - return CandidateGenFn(get_shapes_batch_mmt, apply_params_batch_mmt) + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + M, N, K = problem_size.MNK + B = problem_size.matmul_size.B + modified = indent( + get_transform_function_batch_mmt( + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + get_transform_function_batch_mmt(problem_size, f"match_op", configuration), + " ", + ) + return TFMLIR(template, modified, embeddable) class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): self.lhs_dims = lhs_dims self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) - def supports(self, mlir: str) -> bool: - return "batch_matmul" in mlir + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() - def get_candidate_gen_fn(self) -> CandidateGenFn: - get_shapes_fn = lambda template: get_shapes_batch_matmul( - template, self.lhs_dims, self.rhs_dims + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> TFMLIR: + M, N, K = problem_size.MNK + modified = indent( + get_transform_function_batch_matmul( + problem_size, + self.tile_dims, + f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), ) - apply_params_fn = lambda ps, template, config: apply_params_batch_matmul( - ps, self.tile_dims, template, config + + embeddable = indent( + get_transform_function_batch_matmul( + problem_size, self.tile_dims, f"match_op", configuration + ), + " ", ) - return CandidateGenFn(get_shapes_fn, apply_params_fn) + return TFMLIR(template, modified, embeddable) def walk_callback_get_fn( op: ir.Operation, - candidate_gen_fn: CandidateGenFn, + walk_result: OpWalkResult, dispatch_tuner_registry: DispatchTunerRegistry, ) -> ir.WalkResult: if op.name == "util.func": func_name = str(op.opview.sym_name) - searched_fn = dispatch_tuner_registry.get_candidate_gen_fn(func_name) - candidate_gen_fn.get_shapes_fn = searched_fn.get_shapes_fn - candidate_gen_fn.apply_params_fn = searched_fn.apply_params_fn - if candidate_gen_fn.apply_params_fn and candidate_gen_fn.get_shapes_fn: - return ir.WalkResult.INTERRUPT + walk_result.was_interrupted = True + walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) + return ir.WalkResult.INTERRUPT return ir.WalkResult.ADVANCE def walk_mlir_op( mlir_module: ir.Module, - candidate_gen_fn: CandidateGenFn, dispatch_tuner_registry: DispatchTunerRegistry, -): +) -> OpWalkResult: + walk_result = OpWalkResult() for op in mlir_module.body.operations: op.walk( - lambda op: walk_callback_get_fn( - op, candidate_gen_fn, dispatch_tuner_registry - ), + lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), ir.WalkOrder.POST_ORDER, ) - if candidate_gen_fn.apply_params_fn and candidate_gen_fn.get_shapes_fn: + if walk_result.was_interrupted: break + return walk_result def tune( @@ -1281,40 +1296,34 @@ def tune( with open(path.join(output, f"0.mlir"), "w") as f: f.write(mlir_text) - candidate_gen_fn = CandidateGenFn() dispatch_tuner_registry = DispatchTunerRegistry() dispatch_tuner_registry.register( [ MmtTuner(), ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims, mlir_template), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), ] ) - walk_mlir_op(mlir_module, candidate_gen_fn, dispatch_tuner_registry) - - get_shapes_fn = candidate_gen_fn.get_shapes_fn - apply_params_fn = candidate_gen_fn.apply_params_fn + walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - problem_size = get_shapes_fn(mlir_template) + dispatch_tuner = walk_result.dispatch_tuner + problem_size = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) - configs = [] for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): if i >= limit: break tune_logger.info(f"Solution #{i+1}: {config}") configs.append(config) - new_mlir, embeddable_tuning = apply_params_fn( - problem_size, mlir_template, config - ) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(new_mlir) + f.write(tf_mlir.modified) with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(embeddable_tuning) + f.write(tf_mlir.embeddable) with open(path.join(output, "configs.pkl"), "wb") as file: pickle.dump(configs, file) From 07e312b79d0f3477954a78c708ea7ede150eb80d Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Mon, 26 Aug 2024 16:35:16 -0500 Subject: [PATCH 14/23] Move more dispatch-specific func to its own class --- .../sharktank/tools/tuner/candidate_gen.py | 453 +++++++++--------- 1 file changed, 227 insertions(+), 226 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index ac05983cc..bdde59848 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -202,20 +202,6 @@ def from_problem_size(problem_size: ProblemSize): return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) -def get_conv_tile_sizes(configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: m, n, k = configuration.tile_sizes tile_size = [1] * len(tile_dims) @@ -240,189 +226,6 @@ def get_pipeline_config(configuration: Configuration) -> str: return extra_config -def get_transform_function_mmt( - problem_size: ProblemSize, functionName: str, configuration: Configuration -) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param -}} -""" - - -# int64_t n = outputShape[0]; -# int64_t oh = outputShape[1]; -# int64_t ow = outputShape[2]; -# int64_t oc = outputShape[3]; -# int64_t fh = filterShape[0]; -# int64_t fw = filterShape[1]; -# int64_t ic = filterShape[2]; -def get_transform_function_conv( - problem_size: ProblemSize, functionName: str, configuration: Configuration -) -> str: - dynamic_batch_input_ty = problem_size.lhs_type - dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() - dynamic_batch_input_ty.shape[0] = -1 - - dynamic_batch_output_ty = problem_size.res_type - dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() - dynamic_batch_output_ty.shape[0] - 1 - - input = f"tensor<{dynamic_batch_input_ty}>" - filter = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{dynamic_batch_output_ty}>" - - tile_sizes = ", ".join(map(str, get_conv_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ - ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): - %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} - ins(%lhs, %rhs : {input}, {filter}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param -}} -""" - - -def get_transform_function_batch_matmul( - problem_size: ProblemSize, - tile_dims: str, - functionName: str, - configuration: Configuration, -) -> str: - input0 = f"tensor<{problem_size.lhs_type}>" - input1 = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{problem_size.res_type}>" - - tile_sizes = ", ".join(map(str, get_contract_tile_sizes(configuration, tile_dims))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ - ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): - %13 = linalg.batch_matmul - ins(%lhs, %rhs : {input0}, {input1}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param -}} -""" - - -def get_transform_function_batch_mmt( - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, -) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - -def get_transform_function_broadcast_rhs_mmt( - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, -) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - lhs_dynamic_batch = problem_size.lhs_type - lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() - lhs_dynamic_batch.shape[0] = -1 - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_configuration( template: list[str], configuration: Configuration, tile_sizes: list[int] ) -> str: @@ -467,22 +270,6 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: return ShapedType(dims, str_to_elem_ty[elem]) -def is_broadcast_rhs_mmt_op(line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.res_type.element_type != intrinsic.output_type: @@ -801,6 +588,34 @@ def get_shapes(self, template: list[str]) -> ProblemSize: assert False, "Shape not found" + def get_transform_function_mmt( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + }} + """ + def apply_params( self, problem_size: ProblemSize, @@ -809,7 +624,7 @@ def apply_params( ) -> TFMLIR: M, N, K = problem_size.MNK modified = indent( - get_transform_function_mmt( + self.get_transform_function_mmt( problem_size, f"match_mmt_{M}x{N}x{K}", configuration ), "// ", @@ -818,7 +633,8 @@ def apply_params( template, configuration, get_mmt_tile_sizes(configuration) ) embeddable = indent( - get_transform_function_mmt(problem_size, f"match_op", configuration), " " + self.get_transform_function_mmt(problem_size, f"match_op", configuration), + " ", ) return TFMLIR(template, modified, embeddable) @@ -827,6 +643,19 @@ class ConvTuner(DispatchTuner): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + def get_shapes(self, template: list[str]) -> ProblemSize: for line in template: if "linalg.conv_2d_nhwc_hwcf" not in line: @@ -875,6 +704,55 @@ def get_shapes(self, template: list[str]) -> ProblemSize: assert False, "Shape not found" + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + def get_transform_function_conv( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + dynamic_batch_input_ty = problem_size.lhs_type + dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() + dynamic_batch_input_ty.shape[0] = -1 + + dynamic_batch_output_ty = problem_size.res_type + dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() + dynamic_batch_output_ty.shape[0] - 1 + + input = f"tensor<{dynamic_batch_input_ty}>" + filter = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{dynamic_batch_output_ty}>" + + tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ + ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): + %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} + ins(%lhs, %rhs : {input}, {filter}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + }} + """ + def apply_params( self, problem_size: ProblemSize, @@ -883,7 +761,7 @@ def apply_params( ) -> TFMLIR: conv_dims = ConvDimInfo.from_problem_size(problem_size) modified = indent( - get_transform_function_conv( + self.get_transform_function_conv( problem_size, f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", configuration, @@ -891,10 +769,10 @@ def apply_params( "// ", ) modified += apply_configuration( - template, configuration, get_conv_tile_sizes(configuration) + template, configuration, self.get_conv_tile_sizes(configuration) ) embeddable = indent( - get_transform_function_conv(problem_size, f"match_op", configuration), + self.get_transform_function_conv(problem_size, f"match_op", configuration), " ", ) return TFMLIR(template, modified, embeddable) @@ -909,12 +787,27 @@ def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): def supports(self, op_name: str) -> bool: return "matmul_like" in op_name + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(is_broadcast_rhs_mmt_op(line) for line in template) + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: for line in template: - if not is_broadcast_rhs_mmt_op(line): + if not self.is_broadcast_rhs_mmt_op(line): continue # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) @@ -1014,6 +907,41 @@ def get_shapes(self, template: list[str]) -> ProblemSize: assert False, "Shape not found" + def get_transform_function_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + lhs_dynamic_batch = problem_size.lhs_type + lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() + lhs_dynamic_batch.shape[0] = -1 + + return f""" + transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + }} + """ + def apply_params_broadcast_rhs_mmt( self, problem_size: ProblemSize, @@ -1022,7 +950,7 @@ def apply_params_broadcast_rhs_mmt( ) -> tuple[str, str]: M, N, K = problem_size.MNK modified = indent( - get_transform_function_broadcast_rhs_mmt( + self.get_transform_function_broadcast_rhs_mmt( problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration ), "// ", @@ -1032,7 +960,7 @@ def apply_params_broadcast_rhs_mmt( ) embeddable = indent( - get_transform_function_broadcast_rhs_mmt( + self.get_transform_function_broadcast_rhs_mmt( problem_size, f"match_op", configuration ), " ", @@ -1113,6 +1041,37 @@ def get_shapes(self, template: list[str]) -> ProblemSize: assert False, "Shape not found" + def get_transform_function_batch_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %generic, %config : !transform.any_op, !transform.any_param + }} + """ + def apply_params( self, problem_size: ProblemSize, @@ -1122,7 +1081,7 @@ def apply_params( M, N, K = problem_size.MNK B = problem_size.matmul_size.B modified = indent( - get_transform_function_batch_mmt( + self.get_transform_function_batch_mmt( problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration ), "// ", @@ -1132,7 +1091,9 @@ def apply_params( ) embeddable = indent( - get_transform_function_batch_mmt(problem_size, f"match_op", configuration), + self.get_transform_function_batch_mmt( + problem_size, f"match_op", configuration + ), " ", ) return TFMLIR(template, modified, embeddable) @@ -1210,6 +1171,46 @@ def get_shapes(self, template: list[str]) -> ProblemSize: assert False, "Shape not found" + def get_transform_function_batch_matmul( + self, + problem_size: ProblemSize, + tile_dims: str, + functionName: str, + configuration: Configuration, + ) -> str: + input0 = f"tensor<{problem_size.lhs_type}>" + input1 = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{problem_size.res_type}>" + + tile_sizes = ", ".join( + map(str, get_contract_tile_sizes(configuration, tile_dims)) + ) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ + ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : {input0}, {input1}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + }} + """ + def apply_params( self, problem_size: ProblemSize, @@ -1218,7 +1219,7 @@ def apply_params( ) -> TFMLIR: M, N, K = problem_size.MNK modified = indent( - get_transform_function_batch_matmul( + self.get_transform_function_batch_matmul( problem_size, self.tile_dims, f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", @@ -1233,7 +1234,7 @@ def apply_params( ) embeddable = indent( - get_transform_function_batch_matmul( + self.get_transform_function_batch_matmul( problem_size, self.tile_dims, f"match_op", configuration ), " ", From eaf12fec0f5369ed528b2f4e53fe494e852b507c Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Mon, 26 Aug 2024 17:32:57 -0500 Subject: [PATCH 15/23] Fix mlir indent err and update pytest --- .../sharktank/tools/tuner/candidate_gen.py | 74 +++++++++---------- sharktank/tests/tuner/candidate_gen_test.py | 74 +++++++++++++------ 2 files changed, 89 insertions(+), 59 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index bdde59848..c86922205 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -923,31 +923,31 @@ def get_transform_function_broadcast_rhs_mmt( lhs_dynamic_batch.shape[0] = -1 return f""" - transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %generic, %config : !transform.any_op, !transform.any_param - }} - """ +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" def apply_params_broadcast_rhs_mmt( self, problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> tuple[str, str]: + ) -> TFMLIR: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_broadcast_rhs_mmt( @@ -1053,24 +1053,24 @@ def get_transform_function_batch_mmt( extra_config = get_pipeline_config(configuration) return f""" - transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %generic, %config : !transform.any_op, !transform.any_param - }} - """ +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" def apply_params( self, diff --git a/sharktank/tests/tuner/candidate_gen_test.py b/sharktank/tests/tuner/candidate_gen_test.py index 3b86c0bf8..4fc21aa63 100644 --- a/sharktank/tests/tuner/candidate_gen_test.py +++ b/sharktank/tests/tuner/candidate_gen_test.py @@ -82,7 +82,15 @@ def test_get_conv_tile_sizes(): subgroup_n_count=4, waves_per_eu=1, ) - assert candidate_gen.get_conv_tile_sizes(config) == [1, 1, 464, 320, 1, 1, 16] + assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] def test_get_contract_tile_sizes(): @@ -138,7 +146,7 @@ def test_get_shapes_mmt(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"^bb0(%in: f16, %in_0: f16, %out: f32):", ] - assert candidate_gen.get_shapes_mmt(template) == candidate_gen.ProblemSize( + assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(2048, 1280, 1280), candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), @@ -153,7 +161,7 @@ def test_get_shapes_conv(): r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", ] - assert candidate_gen.get_shapes_conv(template) == candidate_gen.ProblemSize( + assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(32, 256, 11520), candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), @@ -169,8 +177,8 @@ def test_get_shapes_contract(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"^bb0(%in: f16, %in_0: f16, %out: f32):", ] - assert candidate_gen.get_shapes_contract( - template, "mk", "nk" + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( + template ) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(2048, 1280, 1280), candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), @@ -186,8 +194,8 @@ def test_get_shapes_batch_matmul(): "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", ] - assert candidate_gen.get_shapes_batch_matmul( - template, "bmk", "bkn" + assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( + template ) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(32, 32, 1024, 1), candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), @@ -203,7 +211,9 @@ def test_get_shapes_batch_mmt(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", ] - assert candidate_gen.get_shapes_batch_mmt(template) == candidate_gen.ProblemSize( + assert candidate_gen.BatchMmtTuner().get_shapes( + template + ) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(4096, 640, 640, 2), candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), @@ -426,9 +436,10 @@ def test_apply_params_mmt(): candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), candidate_gen.DispatchKind.mmt, ) - modified, embeddable = candidate_gen.apply_params_mmt( - problem_size, mlir_template, config - ) + tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable assert modified assert embeddable @@ -473,10 +484,13 @@ def test_apply_params_conv(): candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), candidate_gen.DispatchKind.conv, ) - modified, embeddable = candidate_gen.apply_params_conv( + tf_mlir = candidate_gen.ConvTuner().apply_params( problem_size, mlir_template, config ) + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + assert modified assert embeddable assert ( @@ -518,10 +532,12 @@ def test_apply_params_contract(): waves_per_eu=2, ) - new_mlir, _embeddable = candidate_gen.apply_params_contract( - problem_size, tile_dims, mlir_template, config + tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config ) + new_mlir = tf_mlir.modified + assert new_mlir assert ( "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" @@ -562,10 +578,13 @@ def test_apply_params_batch_matmul(): waves_per_eu=2, ) - modified, embeddable = candidate_gen.apply_params_batch_matmul( - problem_size, tile_dims, mlir_template, config + tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config ) + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + assert modified assert embeddable assert ( @@ -606,10 +625,13 @@ def test_apply_params_batch_mmt_float(): waves_per_eu=2, ) - modified, embeddable = candidate_gen.apply_params_batch_mmt( + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( problem_size, mlir_template, config ) + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + assert embeddable assert modified assert ( @@ -650,10 +672,13 @@ def test_apply_params_batch_mmt_int(): waves_per_eu=4, ) - modified, embeddable = candidate_gen.apply_params_batch_mmt( + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( problem_size, mlir_template, config ) + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + assert modified assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified assert ( @@ -715,9 +740,12 @@ def test_apply_params_broadcast_rhs_mmt(): waves_per_eu=4, ) - modified, embeddable = candidate_gen.apply_params_broadcast_rhs_mmt( - problem_size, mlir_template, config - ) + tf_mlir = candidate_gen.ContractionTuner( + "mk", "nk", "mnk" + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable assert modified assert ( @@ -764,7 +792,9 @@ def test_detect_broadcast_rhs_mmt(): r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', ] - assert candidate_gen.is_broadcast_rhs_mmt(mlir_lines) + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( + mlir_lines + ) def test_parse_mlir(): From 96355519bc8475d46e7c500bfb765de60d99ca7f Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Mon, 26 Aug 2024 17:54:27 -0500 Subject: [PATCH 16/23] Add docstrings --- .../sharktank/tools/tuner/candidate_gen.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index c86922205..be31f8914 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -541,6 +541,8 @@ def supports(self, op_name: str) -> bool: return "matmul_transpose_b" in op_name def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None for line in template: if "linalg.generic" not in line: continue @@ -585,8 +587,8 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_type=res_shaped_type, dispatch_kind=DispatchKind.mmt, ) - - assert False, "Shape not found" + assert mmt_re + assert dps, f"'{mmt_re}' not found in given context" def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration @@ -1271,13 +1273,13 @@ def walk_mlir_op( def tune( - input: str, - output: str = "", - limit: int = 4096, - num_subgroups: int = 4, - lhs_dims: str = "mk", - rhs_dims: str = "nk", - tile_dims: str = "mnk", + input: str, # Path to the mlir file to be tuned + output: str = "", # Path to the output directory, auto creates one if not given + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints + lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations + rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations + tile_dims: str = "mnk", # Dimensions for the tile size ): input_file = str(input) From 789614fc8bc7bf9d38be08a637adf0b33ec8c950 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Tue, 27 Aug 2024 11:42:39 -0500 Subject: [PATCH 17/23] Add LLVMGPUVectorDistribute check for dispatch registry --- sharktank/sharktank/tools/tuner/candidate_gen.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index be31f8914..d19551621 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -529,11 +529,19 @@ def register(self, dispatch_tuners: list[DispatchTuner]) -> None: for dispatch_tuner in dispatch_tuners: self.registry.add(dispatch_tuner) + def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: + for attr in attrs: + if (attr.name == "translation_info") and ( + "LLVMGPUVectorDistribute" in str(attr.attr) + ): + return True + assert False, "Translation info not supported" + def find_handler(self, op_name: str) -> DispatchTuner: for dispatch_tuner in self.registry: if dispatch_tuner.supports(op_name): return dispatch_tuner - assert False, "Not supported" + assert False, "Dispatch kind not supported" class MmtTuner(DispatchTuner): @@ -1249,6 +1257,8 @@ def walk_callback_get_fn( walk_result: OpWalkResult, dispatch_tuner_registry: DispatchTunerRegistry, ) -> ir.WalkResult: + if op.name == "func.func": + dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) if op.name == "util.func": func_name = str(op.opview.sym_name) walk_result.was_interrupted = True From 7b4c68736c124e164bda5d105877cebc2197bde7 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Tue, 27 Aug 2024 13:11:43 -0500 Subject: [PATCH 18/23] Rename TFMLIR class --- .../sharktank/tools/tuner/candidate_gen.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index d19551621..8a8315afb 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -485,7 +485,7 @@ def parse_mlir(mlir_text: str) -> ir.Module: @dataclass -class TFMLIR: +class MLIRTransformation: """Transformation of MLIR context""" template: str @@ -510,7 +510,7 @@ def apply_params( problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: """Apply parameter transformations to the operation.""" pass @@ -631,7 +631,7 @@ def apply_params( problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_mmt( @@ -646,7 +646,7 @@ def apply_params( self.get_transform_function_mmt(problem_size, f"match_op", configuration), " ", ) - return TFMLIR(template, modified, embeddable) + return MLIRTransformation(template, modified, embeddable) class ConvTuner(DispatchTuner): @@ -768,7 +768,7 @@ def apply_params( problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: conv_dims = ConvDimInfo.from_problem_size(problem_size) modified = indent( self.get_transform_function_conv( @@ -785,7 +785,7 @@ def apply_params( self.get_transform_function_conv(problem_size, f"match_op", configuration), " ", ) - return TFMLIR(template, modified, embeddable) + return MLIRTransformation(template, modified, embeddable) class ContractionTuner(DispatchTuner): @@ -957,7 +957,7 @@ def apply_params_broadcast_rhs_mmt( problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_broadcast_rhs_mmt( @@ -975,21 +975,21 @@ def apply_params_broadcast_rhs_mmt( ), " ", ) - return TFMLIR(template, modified, embeddable) + return MLIRTransformation(template, modified, embeddable) def apply_params( self, problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: if self.is_broadcast_rhs_mmt(template): return self.apply_params_broadcast_rhs_mmt( problem_size, template, configuration ) # TODO: Generate transform function. - return TFMLIR( + return MLIRTransformation( template, apply_configuration( template, @@ -1087,7 +1087,7 @@ def apply_params( problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: M, N, K = problem_size.MNK B = problem_size.matmul_size.B modified = indent( @@ -1106,7 +1106,7 @@ def apply_params( ), " ", ) - return TFMLIR(template, modified, embeddable) + return MLIRTransformation(template, modified, embeddable) class BatchMatmulTuner(DispatchTuner): @@ -1226,7 +1226,7 @@ def apply_params( problem_size: ProblemSize, template: list[str], configuration: Configuration, - ) -> TFMLIR: + ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_batch_matmul( @@ -1249,7 +1249,7 @@ def apply_params( ), " ", ) - return TFMLIR(template, modified, embeddable) + return MLIRTransformation(template, modified, embeddable) def walk_callback_get_fn( From 6cb0c80029a47f9de3f5e9fba345ade194eca5cd Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Tue, 27 Aug 2024 14:13:38 -0500 Subject: [PATCH 19/23] Add README --- sharktank/sharktank/tools/tuner/README.md | 67 +++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 sharktank/sharktank/tools/tuner/README.md diff --git a/sharktank/sharktank/tools/tuner/README.md b/sharktank/sharktank/tools/tuner/README.md new file mode 100644 index 000000000..a68512cd6 --- /dev/null +++ b/sharktank/sharktank/tools/tuner/README.md @@ -0,0 +1,67 @@ +# IREE dispatch auto-tuning scripts +`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. + +## Prerequisites +[Optional] Using virtual environments: +```shell +cd tuning +python -m venv .venv +source .venv/bin/activate +``` +Install python dependencies: +```shell +pip install -r ./requirements-tuner.txt +``` +Using the IREE's Python bindings: + - Building with CMake + ```shell + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python)" + ``` + - Set environment + ```shell + source ../iree-build/.env && export PYTHONPATH + ``` +For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) + +### Overall flow + +1. Symlink all scripts and mlir/irpa files in your build dir. + - Symlink `iree-build-dir/tools` inside `tuning`. + - Symlink ML model MLIR and weights based on `unet.sh`. + +2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. + +3. Temporarily comment out all the existing configs in `config.mlir`. + - Example: + ```mlir + // , @match_mmt_2048x10240x1280 -> @apply_op_config + // , @match_mmt_2048x1280x5120 -> @apply_op_config + // , @match_mmt_2048x1280x1280 -> @apply_op_config + ``` + +4. Compile a baseline unet +```shell +./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd +``` + +5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. +```shell +cp dump-winograd/*_141_*benchmark.mlir ./141.mlir +``` + +6. Run the tuning script. + - Example: + ```shell + python punet_autotune.py 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 + ``` + +7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. + +8. Paste the transform spec into the `config.mlir` and uncomment them. + +9. Add the match function to the entry point in `config.mlir` + - Example: + ```mlir + @match_something -> @apply_op_config + ``` \ No newline at end of file From 3b2fba32603f305cc9c1f973b8154d4420136308 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Tue, 27 Aug 2024 14:16:09 -0500 Subject: [PATCH 20/23] Fix lint err --- sharktank/sharktank/tools/tuner/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/tools/tuner/README.md b/sharktank/sharktank/tools/tuner/README.md index a68512cd6..69821496e 100644 --- a/sharktank/sharktank/tools/tuner/README.md +++ b/sharktank/sharktank/tools/tuner/README.md @@ -64,4 +64,4 @@ cp dump-winograd/*_141_*benchmark.mlir ./141.mlir - Example: ```mlir @match_something -> @apply_op_config - ``` \ No newline at end of file + ``` From 4a42b5b961971e51e139c9b360116353e7fbf6a4 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Wed, 28 Aug 2024 09:27:48 -0500 Subject: [PATCH 21/23] Move tuner dir --- .github/workflows/ci-tuner.yml | 6 +- tuner/README.md | 67 ++ tuner/candidate_gen.py | 1408 ++++++++++++++++++++++++++++++++ tuner/candidate_gen_test.py | 814 ++++++++++++++++++ tuner/requirements-dev.txt | 2 + tuner/requirements-tuner.txt | 4 + 6 files changed, 2298 insertions(+), 3 deletions(-) create mode 100644 tuner/README.md create mode 100755 tuner/candidate_gen.py create mode 100644 tuner/candidate_gen_test.py create mode 100644 tuner/requirements-dev.txt create mode 100644 tuner/requirements-tuner.txt diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index b6552bfd5..9855a2779 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -30,15 +30,15 @@ jobs: - name: Install dev dependencies run: | python -m pip install --upgrade pip - pip install -r sharktank/sharktank/tools/tuner/requirements-dev.txt + pip install -r sharktank/tuner/requirements-dev.txt - name: Install tuner dependencies run: | - pip install -r sharktank/sharktank/tools/tuner/requirements-tuner.txt + pip install -r sharktank/tuner/requirements-tuner.txt python -m pip install \ --find-links https://iree.dev/pip-release-links.html \ --upgrade \ iree-compiler iree-runtime - name: Run tuner tests - run: pytest sharktank/tests/tuner/ + run: pytest sharktank/tuner/ diff --git a/tuner/README.md b/tuner/README.md new file mode 100644 index 000000000..69821496e --- /dev/null +++ b/tuner/README.md @@ -0,0 +1,67 @@ +# IREE dispatch auto-tuning scripts +`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. + +## Prerequisites +[Optional] Using virtual environments: +```shell +cd tuning +python -m venv .venv +source .venv/bin/activate +``` +Install python dependencies: +```shell +pip install -r ./requirements-tuner.txt +``` +Using the IREE's Python bindings: + - Building with CMake + ```shell + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python)" + ``` + - Set environment + ```shell + source ../iree-build/.env && export PYTHONPATH + ``` +For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) + +### Overall flow + +1. Symlink all scripts and mlir/irpa files in your build dir. + - Symlink `iree-build-dir/tools` inside `tuning`. + - Symlink ML model MLIR and weights based on `unet.sh`. + +2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. + +3. Temporarily comment out all the existing configs in `config.mlir`. + - Example: + ```mlir + // , @match_mmt_2048x10240x1280 -> @apply_op_config + // , @match_mmt_2048x1280x5120 -> @apply_op_config + // , @match_mmt_2048x1280x1280 -> @apply_op_config + ``` + +4. Compile a baseline unet +```shell +./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd +``` + +5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. +```shell +cp dump-winograd/*_141_*benchmark.mlir ./141.mlir +``` + +6. Run the tuning script. + - Example: + ```shell + python punet_autotune.py 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 + ``` + +7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. + +8. Paste the transform spec into the `config.mlir` and uncomment them. + +9. Add the match function to the entry point in `config.mlir` + - Example: + ```mlir + @match_something -> @apply_op_config + ``` diff --git a/tuner/candidate_gen.py b/tuner/candidate_gen.py new file mode 100755 index 000000000..8a8315afb --- /dev/null +++ b/tuner/candidate_gen.py @@ -0,0 +1,1408 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +""" +Generate candidates by tweaking op configuration for tuning. + +It can be invoked in two ways: + 1. From another python script, import and call `tune()` + 2. Run this script directly from the command + +Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk + +""" + +import argparse +import logging +import math +import pickle +import re +import z3 +from dataclasses import asdict, dataclass +from enum import Enum +from os import mkdir, path, makedirs +from typing import Callable, Optional +from textwrap import indent +from abc import ABC, abstractmethod + +import iree.compiler as ireec +from iree.compiler import ir +from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen + + +tune_logger = logging.getLogger("tune") + + +class DispatchKind(Enum): + conv = 1 + mmt = 2 + contraction = 3 + batch_mmt = 4 + batch_matmul = 5 + broadcast_rhs_mmt = 6 + + +class ElementType(Enum): + i8 = 1 + i32 = 2 + f8 = 3 + f16 = 4 + f32 = 5 + + @property + def bitwidth(self) -> int: + match self: + case ElementType.i8 | ElementType.f8: + return 8 + case ElementType.f16: + return 16 + case ElementType.i32 | ElementType.f32: + return 32 + case _: + assert False, "unhandled case" + + def __str__(self) -> str: + return self.name + + +@dataclass +class ShapedType: + shape: list[int] + element_type: ElementType + + def rank(self) -> int: + return len(self.shape) + + @property + def bitwidth(self) -> int: + return self.element_type.bitwidth + + def __str__(self) -> str: + dim_to_str = lambda dim: str(dim) if dim != -1 else "?" + return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) + + +@dataclass +class MatmulSize: + M: int + N: int + K: int + B: int = 1 + + +@dataclass +class ProblemSize: + matmul_size: MatmulSize + lhs_type: ShapedType + rhs_type: ShapedType + res_type: ShapedType + dispatch_kind: DispatchKind + + @property + def MNK(self) -> tuple[int, int, int]: + return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) + + +@dataclass +class MfmaIntrinsic: + input_type: ElementType + m: int + n: int + k: int + output_type: ElementType + + def __str__(self) -> str: + input = str(self.input_type).upper() + output = str(self.output_type).upper() + return f"MFMA_{input}_{self.m}x{self.n}x{self.k}_{output}" + + @staticmethod + def mfma_f16_16x16x16_f32(): + return MfmaIntrinsic(ElementType.f16, 16, 16, 16, ElementType.f32) + + @staticmethod + def mfma_f16_32x32x8_f32(): + return MfmaIntrinsic(ElementType.f16, 32, 32, 8, ElementType.f32) + + @staticmethod + def mfma_i8_16x16x32_i32(): + return MfmaIntrinsic(ElementType.i8, 16, 16, 32, ElementType.i32) + + @staticmethod + def mfma_i8_32x32x16_i32(): + return MfmaIntrinsic(ElementType.i8, 32, 32, 16, ElementType.i32) + + @staticmethod + def all(): + return [ + MfmaIntrinsic.mfma_f16_16x16x16_f32(), + MfmaIntrinsic.mfma_f16_32x32x8_f32(), + MfmaIntrinsic.mfma_i8_16x16x32_i32(), + MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + +@dataclass +class Configuration: + subgroup_size: int + workgroup_size: list[int] + intrinsic: MfmaIntrinsic + tile_sizes: list[int] + subgroup_m_count: int + subgroup_n_count: int + waves_per_eu: int + + +class MlirRegex(str, Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + +def read_input_mlir(filename: str) -> list[str]: + with open(filename, "r") as f: + return f.readlines() + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +@dataclass +class ConvDimInfo: + n: int + oh: int + ow: int + oc: int + fh: int + fw: int + ic: int + + @staticmethod + def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): + n, oh, ow, oc = res_shaped_type.shape + fh, fw, ic, _ = rhs_shaped_type.shape + return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) + + @staticmethod + def from_problem_size(problem_size: ProblemSize): + return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +def get_pipeline_config(configuration: Configuration) -> str: + extra_config = ", prefetch_shared_memory" + if configuration.waves_per_eu != 2: + extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' + return extra_config + + +def apply_configuration( + template: list[str], configuration: Configuration, tile_sizes: list[int] +) -> str: + tune_logger.info(f"Applying: {configuration}") + expr0 = re.compile( + r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" + ) + expr1 = re.compile( + r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," + ) + expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' + repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' + repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + + new_mlir = "" + for line in template: + if "intrinsic =" in line: + line = re.sub(expr0, repl0, line) + if "LLVMGPUVectorDistribute " in line: + line = re.sub(expr1, repl1, line) + if "tile_sizes" in line: + line = re.sub(expr2, repl2, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr3, repl3, line) + new_mlir += line + + return new_mlir + + +def parse_tensor_type(tensor_type: str) -> ShapedType: + shape_match = re.search(MlirRegex.tensor_type, tensor_type) + assert shape_match + + shape_str = shape_match.group(1) + dims_and_elem = shape_str.split("x") + dims = [int(x) for x in dims_and_elem[:-1]] + elem = dims_and_elem[-1] + str_to_elem_ty = {x.name: x for x in ElementType} + return ShapedType(dims, str_to_elem_ty[elem]) + + +def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: + if problem_size.res_type.element_type != intrinsic.output_type: + return False + if problem_size.dispatch_kind != DispatchKind.batch_matmul: + if problem_size.lhs_type.element_type != intrinsic.input_type: + return False + if problem_size.rhs_type.element_type != intrinsic.input_type: + return False + return True + + return list(filter(is_compatible, MfmaIntrinsic.all())) + + +def get_mfma_intrinsic_constraints( + problem_size: ProblemSize, + intrinsic_m: z3.ArithRef, + intrinsic_n: z3.ArithRef, + intrinsic_k: z3.ArithRef, +) -> z3.BoolRef: + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + return z3.Or( + *( + z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) + for mfma in compatible_intrinsics + ) + ) + + +def get_dispatch_constraints( + problem_size: ProblemSize, + tile_m: z3.ArithRef, + tile_n: z3.ArithRef, + tile_k: z3.ArithRef, +) -> list[z3.BoolRef]: + if problem_size.dispatch_kind != DispatchKind.conv: + return [] + + dim_info = ConvDimInfo.from_problem_size(problem_size) + conv_constraints = [] + # WARNING: This sometimes makes the constraints UNSAT for some reason. + conv_constraints += [tile_m <= dim_info.ow] + conv_constraints += [tile_n <= dim_info.oc] + conv_constraints += [tile_k <= dim_info.ic] + return conv_constraints + + +def calculate_shared_memory_usage_in_bytes( + problem_size: ProblemSize, + m: int | z3.ArithRef, + n: int | z3.ArithRef, + k: int | z3.ArithRef, +) -> int | z3.ArithRef: + lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) + rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + return lhs_memory + rhs_memory + + +def generate_constraints( + problem_size: ProblemSize, + tile_sizes, + num_subgroups, + subgroup_size, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, + waves_per_eu, +): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +def generate_solutions(problem_size: ProblemSize, num_subgrups: int): + M, N, K = problem_size.MNK + tune_logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = generate_constraints( + problem_size, + [m, n, k], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + solver.add(z3.simplify(z3.And(constraints))) + tune_logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + MfmaIntrinsic( + problem_size.lhs_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.res_type.element_type, + ), + [lookup(m), lookup(n), lookup(k)], + lookup(sg_m_cnt), + lookup(sg_n_cnt), + lookup(waves_per_eu), + ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) + i += 1 + yield config + + +def get_default_output_dir() -> str: + from datetime import datetime + + return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") + + +def parse_mlir(mlir_text: str) -> ir.Module: + mlir_module = None + with ireec.ir.Context() as context: + try: + mlir_module = ireec.ir.Module.parse(mlir_text) + tune_logger.info("MLIR parsing successful!") + except ireec.ir.MLIRError as e: + tune_logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +@dataclass +class MLIRTransformation: + """Transformation of MLIR context""" + + template: str + modified: str + embeddable: str + + +class DispatchTuner(ABC): + @abstractmethod + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" + pass + + @abstractmethod + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of thge operation.""" + pass + + @abstractmethod + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + """Apply parameter transformations to the operation.""" + pass + + +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + + +class DispatchTunerRegistry: + def __init__(self): + self.registry = set() + + def register(self, dispatch_tuners: list[DispatchTuner]) -> None: + for dispatch_tuner in dispatch_tuners: + self.registry.add(dispatch_tuner) + + def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: + for attr in attrs: + if (attr.name == "translation_info") and ( + "LLVMGPUVectorDistribute" in str(attr.attr) + ): + return True + assert False, "Translation info not supported" + + def find_handler(self, op_name: str) -> DispatchTuner: + for dispatch_tuner in self.registry: + if dispatch_tuner.supports(op_name): + return dispatch_tuner + assert False, "Dispatch kind not supported" + + +class MmtTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + assert mmt_re + assert dps, f"'{mmt_re}' not found in given context" + + def get_transform_function_mmt( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_mmt( + problem_size, f"match_mmt_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_mmt_tile_sizes(configuration) + ) + embeddable = indent( + self.get_transform_function_mmt(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class ConvTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + def get_transform_function_conv( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + dynamic_batch_input_ty = problem_size.lhs_type + dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() + dynamic_batch_input_ty.shape[0] = -1 + + dynamic_batch_output_ty = problem_size.res_type + dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() + dynamic_batch_output_ty.shape[0] - 1 + + input = f"tensor<{dynamic_batch_input_ty}>" + filter = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{dynamic_batch_output_ty}>" + + tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ + ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): + %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} + ins(%lhs, %rhs : {input}, {filter}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + conv_dims = ConvDimInfo.from_problem_size(problem_size) + modified = indent( + self.get_transform_function_conv( + problem_size, + f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, self.get_conv_tile_sizes(configuration) + ) + embeddable = indent( + self.get_transform_function_conv(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class ContractionTuner(DispatchTuner): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not self.is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + def get_transform_function_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + lhs_dynamic_batch = problem_size.lhs_type + lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() + lhs_dynamic_batch.shape[0] = -1 + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + def apply_params_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + self.get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + if self.is_broadcast_rhs_mmt(template): + return self.apply_params_broadcast_rhs_mmt( + problem_size, template, configuration + ) + + # TODO: Generate transform function. + return MLIRTransformation( + template, + apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ), + "", + ) + + +class BatchMmtTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + def get_transform_function_batch_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + B = problem_size.matmul_size.B + modified = indent( + self.get_transform_function_batch_mmt( + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + self.get_transform_function_batch_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class BatchMatmulTuner(DispatchTuner): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" + + def get_transform_function_batch_matmul( + self, + problem_size: ProblemSize, + tile_dims: str, + functionName: str, + configuration: Configuration, + ) -> str: + input0 = f"tensor<{problem_size.lhs_type}>" + input1 = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{problem_size.res_type}>" + + tile_sizes = ", ".join( + map(str, get_contract_tile_sizes(configuration, tile_dims)) + ) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ + ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : {input0}, {input1}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_batch_matmul( + problem_size, + self.tile_dims, + f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ) + + embeddable = indent( + self.get_transform_function_batch_matmul( + problem_size, self.tile_dims, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +def walk_callback_get_fn( + op: ir.Operation, + walk_result: OpWalkResult, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> ir.WalkResult: + if op.name == "func.func": + dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) + if op.name == "util.func": + func_name = str(op.opview.sym_name) + walk_result.was_interrupted = True + walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) + return ir.WalkResult.INTERRUPT + return ir.WalkResult.ADVANCE + + +def walk_mlir_op( + mlir_module: ir.Module, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> OpWalkResult: + walk_result = OpWalkResult() + for op in mlir_module.body.operations: + op.walk( + lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), + ir.WalkOrder.POST_ORDER, + ) + if walk_result.was_interrupted: + break + return walk_result + + +def tune( + input: str, # Path to the mlir file to be tuned + output: str = "", # Path to the output directory, auto creates one if not given + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints + lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations + rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations + tile_dims: str = "mnk", # Dimensions for the tile size +): + input_file = str(input) + + if not output: + output = get_default_output_dir() + + # Create the directory if it does not exist + makedirs(str(output), exist_ok=True) + + tune_logger.debug(f"Output directory {output}") + tune_logger.debug(f"Processing {input_file}") + mlir_template = read_input_mlir(input_file) + mlir_text = "".join(mlir_template) + + mlir_module = parse_mlir(mlir_text) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) + + walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + problem_size = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input", help="Input mlir file", type=str) + parser.add_argument( + "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() + ) + parser.add_argument( + "-l", + "--limit", + help="Max number of candidates generated", + type=int, + default=4096, + ) + parser.add_argument( + "--num-subgroups", + help="Number of subgroups per workgroup to use. (-1 == unconstrained)", + type=int, + default=-1, + ) + parser.add_argument( + "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" + ) + parser.add_argument( + "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" + ) + parser.add_argument( + "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" + ) + + args = parser.parse_args() + tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) + + # Create printing formatter for logging info + formatter = logging.Formatter("%(message)s") + + # Create a handler to print to console + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + tune_logger.addHandler(console_handler) + + # # Optionally, add a file handler to log to a file + # file_handler = logging.FileHandler("tune.log") + # file_handler.setFormatter(formatter) + # tune_logger.addHandler(file_handler) + + tune( + args.input, + args.output, + args.limit, + args.num_subgroups, + args.lhs_dims, + args.rhs_dims, + args.tile_dims, + ) + + +if __name__ == "__main__": + args = main() diff --git a/tuner/candidate_gen_test.py b/tuner/candidate_gen_test.py new file mode 100644 index 000000000..52b01869b --- /dev/null +++ b/tuner/candidate_gen_test.py @@ -0,0 +1,814 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +import candidate_gen + + +def test_get_shaped_type_element_bitwidth(): + assert ( + candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 + ) + assert ( + candidate_gen.ShapedType( + [2048, 512, 384], candidate_gen.ElementType.f8 + ).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + ) + + +def test_get_shaped_type_to_str(): + assert ( + str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) + == "1024x2048xi8" + ) + assert ( + str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) + == "1024xf32" + ) + assert ( + str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) + == "1x2x3xf16" + ) + assert ( + str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) + == "?x2x3xf16" + ) + + +def test_parse_tensor_type(): + assert candidate_gen.parse_tensor_type( + "tensor<1x2x3xf32>" + ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) + assert candidate_gen.parse_tensor_type( + "tensor<123xi8>" + ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) + + +def test_get_mmt_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic="", + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + waves_per_eu=0, + ) + assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic="#iree_gpu.mma_layout", + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=1, + ) + assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] + + +def test_get_contract_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + 16, + 16, + 16, + ] + + +def test_get_pipeline_config(): + config1 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + config2 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=4, + ) + assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" + assert ( + candidate_gen.get_pipeline_config(config2) + == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_get_shapes_mmt(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + + +def test_get_shapes_conv(): + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 256, 11520), + candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + + +def test_get_shapes_contract(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( + template + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul(): + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( + template + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 32, 1024, 1), + candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt(): + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.BatchMmtTuner().get_shapes( + template + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + +def test_mfma_intrinsic_to_str(): + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) + == "MFMA_F16_16x16x16_F32" + ) + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) + == "MFMA_I8_32x32x16_I32" + ) + + +def test_get_compatible_mfma_intrinsics(): + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), + candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + +def test_generate_solutions(): + matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) + lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + configs = candidate_gen.generate_solutions(problem_size, 4) + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 147456 + ) + + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 81920 + ) + + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) + == 12288 + ) + + +def test_generate_constraints_valid_input(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == candidate_gen.z3.sat + + +def test_generate_constraints_invalid_input(): + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == candidate_gen.z3.unsat + + +def test_apply_params_mmt(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.Configuration( + subgroup_size=16, + workgroup_size=[16, 16, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[8, 8, 8], + subgroup_m_count=16, + subgroup_n_count=16, + waves_per_eu=8, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" + in modified + ) + assert "tile_sizes = [[8, 8, 8]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified + + +def test_apply_params_conv(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), + candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + tf_mlir = candidate_gen.ConvTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_contract(): + mlir_template = [ + ", subgroup_m_count = 2, subgroup_n_count = 2>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "*mnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 3840, 1280), + candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[480, 384, 32], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + new_mlir = tf_mlir.modified + + assert new_mlir + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in new_mlir + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in new_mlir + ) + assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir + + +def test_apply_params_batch_matmul(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "bmnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[416, 320, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_float(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert embeddable + assert modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_int(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_apply_params_broadcast_rhs_mmt(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.broadcast_rhs_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + tf_mlir = candidate_gen.ContractionTuner( + "mk", "nk", "mnk" + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert ( + "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" + in modified + ) + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_detect_broadcast_rhs_mmt(): + mlir_lines = [ + r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + ] + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( + mlir_lines + ) + + +def test_parse_mlir(): + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + mlir_module = candidate_gen.parse_mlir(mlir_str) + assert mlir_module != None + assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) + assert isinstance( + mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp + ) diff --git a/tuner/requirements-dev.txt b/tuner/requirements-dev.txt new file mode 100644 index 000000000..51d5b9ba0 --- /dev/null +++ b/tuner/requirements-dev.txt @@ -0,0 +1,2 @@ +pre-commit==3.8.0 +virtualenv==20.13.0 diff --git a/tuner/requirements-tuner.txt b/tuner/requirements-tuner.txt new file mode 100644 index 000000000..f3484c921 --- /dev/null +++ b/tuner/requirements-tuner.txt @@ -0,0 +1,4 @@ +pytest==8.2.2 +tqdm==4.66.4 +z3_solver==4.13.0.0 +types-tqdm==4.66.0.20240417 From 28f2a96ecca93c2c49ab55a726f5818bce92346c Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Wed, 28 Aug 2024 09:31:02 -0500 Subject: [PATCH 22/23] Update path in tuner-ci --- .github/workflows/ci-tuner.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index 9855a2779..5de7d4182 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -30,15 +30,15 @@ jobs: - name: Install dev dependencies run: | python -m pip install --upgrade pip - pip install -r sharktank/tuner/requirements-dev.txt + pip install -r tuner/requirements-dev.txt - name: Install tuner dependencies run: | - pip install -r sharktank/tuner/requirements-tuner.txt + pip install -r tuner/requirements-tuner.txt python -m pip install \ --find-links https://iree.dev/pip-release-links.html \ --upgrade \ iree-compiler iree-runtime - name: Run tuner tests - run: pytest sharktank/tuner/ + run: pytest tuner/ From 213e6bec5135736c325a7a71c3f6228fdcae5038 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Wed, 28 Aug 2024 09:35:00 -0500 Subject: [PATCH 23/23] Remove old tuner dir --- sharktank/sharktank/tools/tuner/README.md | 67 - .../sharktank/tools/tuner/candidate_gen.py | 1408 ----------------- .../tools/tuner/requirements-dev.txt | 2 - .../tools/tuner/requirements-tuner.txt | 4 - sharktank/tests/tuner/candidate_gen_test.py | 814 ---------- 5 files changed, 2295 deletions(-) delete mode 100644 sharktank/sharktank/tools/tuner/README.md delete mode 100755 sharktank/sharktank/tools/tuner/candidate_gen.py delete mode 100644 sharktank/sharktank/tools/tuner/requirements-dev.txt delete mode 100644 sharktank/sharktank/tools/tuner/requirements-tuner.txt delete mode 100644 sharktank/tests/tuner/candidate_gen_test.py diff --git a/sharktank/sharktank/tools/tuner/README.md b/sharktank/sharktank/tools/tuner/README.md deleted file mode 100644 index 69821496e..000000000 --- a/sharktank/sharktank/tools/tuner/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# IREE dispatch auto-tuning scripts -`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. - -## Prerequisites -[Optional] Using virtual environments: -```shell -cd tuning -python -m venv .venv -source .venv/bin/activate -``` -Install python dependencies: -```shell -pip install -r ./requirements-tuner.txt -``` -Using the IREE's Python bindings: - - Building with CMake - ```shell - -DIREE_BUILD_PYTHON_BINDINGS=ON \ - -DPython3_EXECUTABLE="$(which python)" - ``` - - Set environment - ```shell - source ../iree-build/.env && export PYTHONPATH - ``` -For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) - -### Overall flow - -1. Symlink all scripts and mlir/irpa files in your build dir. - - Symlink `iree-build-dir/tools` inside `tuning`. - - Symlink ML model MLIR and weights based on `unet.sh`. - -2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. - -3. Temporarily comment out all the existing configs in `config.mlir`. - - Example: - ```mlir - // , @match_mmt_2048x10240x1280 -> @apply_op_config - // , @match_mmt_2048x1280x5120 -> @apply_op_config - // , @match_mmt_2048x1280x1280 -> @apply_op_config - ``` - -4. Compile a baseline unet -```shell -./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd -``` - -5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. -```shell -cp dump-winograd/*_141_*benchmark.mlir ./141.mlir -``` - -6. Run the tuning script. - - Example: - ```shell - python punet_autotune.py 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 - ``` - -7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. - -8. Paste the transform spec into the `config.mlir` and uncomment them. - -9. Add the match function to the entry point in `config.mlir` - - Example: - ```mlir - @match_something -> @apply_op_config - ``` diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py deleted file mode 100755 index 8a8315afb..000000000 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ /dev/null @@ -1,1408 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Given an input dispatch, this code modifies the hyperparameters -# in the code and runs it. - -""" -Generate candidates by tweaking op configuration for tuning. - -It can be invoked in two ways: - 1. From another python script, import and call `tune()` - 2. Run this script directly from the command - -Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk - -""" - -import argparse -import logging -import math -import pickle -import re -import z3 -from dataclasses import asdict, dataclass -from enum import Enum -from os import mkdir, path, makedirs -from typing import Callable, Optional -from textwrap import indent -from abc import ABC, abstractmethod - -import iree.compiler as ireec -from iree.compiler import ir -from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen - - -tune_logger = logging.getLogger("tune") - - -class DispatchKind(Enum): - conv = 1 - mmt = 2 - contraction = 3 - batch_mmt = 4 - batch_matmul = 5 - broadcast_rhs_mmt = 6 - - -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - -@dataclass -class ShapedType: - shape: list[int] - element_type: ElementType - - def rank(self) -> int: - return len(self.shape) - - @property - def bitwidth(self) -> int: - return self.element_type.bitwidth - - def __str__(self) -> str: - dim_to_str = lambda dim: str(dim) if dim != -1 else "?" - return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) - - -@dataclass -class MatmulSize: - M: int - N: int - K: int - B: int = 1 - - -@dataclass -class ProblemSize: - matmul_size: MatmulSize - lhs_type: ShapedType - rhs_type: ShapedType - res_type: ShapedType - dispatch_kind: DispatchKind - - @property - def MNK(self) -> tuple[int, int, int]: - return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) - - -@dataclass -class MfmaIntrinsic: - input_type: ElementType - m: int - n: int - k: int - output_type: ElementType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{input}_{self.m}x{self.n}x{self.k}_{output}" - - @staticmethod - def mfma_f16_16x16x16_f32(): - return MfmaIntrinsic(ElementType.f16, 16, 16, 16, ElementType.f32) - - @staticmethod - def mfma_f16_32x32x8_f32(): - return MfmaIntrinsic(ElementType.f16, 32, 32, 8, ElementType.f32) - - @staticmethod - def mfma_i8_16x16x32_i32(): - return MfmaIntrinsic(ElementType.i8, 16, 16, 32, ElementType.i32) - - @staticmethod - def mfma_i8_32x32x16_i32(): - return MfmaIntrinsic(ElementType.i8, 32, 32, 16, ElementType.i32) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f16_16x16x16_f32(), - MfmaIntrinsic.mfma_f16_32x32x8_f32(), - MfmaIntrinsic.mfma_i8_16x16x32_i32(), - MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - -@dataclass -class Configuration: - subgroup_size: int - workgroup_size: list[int] - intrinsic: MfmaIntrinsic - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int - waves_per_eu: int - - -class MlirRegex(str, Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - -def read_input_mlir(filename: str) -> list[str]: - with open(filename, "r") as f: - return f.readlines() - - -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - -@dataclass -class ConvDimInfo: - n: int - oh: int - ow: int - oc: int - fh: int - fw: int - ic: int - - @staticmethod - def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): - n, oh, ow, oc = res_shaped_type.shape - fh, fw, ic, _ = rhs_shaped_type.shape - return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) - - @staticmethod - def from_problem_size(problem_size: ProblemSize): - return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) - - -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = ", prefetch_shared_memory" - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config - - -def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] -) -> str: - tune_logger.info(f"Applying: {configuration}") - expr0 = re.compile( - r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" - ) - expr1 = re.compile( - r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," - ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' - - new_mlir = "" - for line in template: - if "intrinsic =" in line: - line = re.sub(expr0, repl0, line) - if "LLVMGPUVectorDistribute " in line: - line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: - line = re.sub(expr2, repl2, line) - if "amdgpu-waves-per-eu" in line: - line = re.sub(expr3, repl3, line) - new_mlir += line - - return new_mlir - - -def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(MlirRegex.tensor_type, tensor_type) - assert shape_match - - shape_str = shape_match.group(1) - dims_and_elem = shape_str.split("x") - dims = [int(x) for x in dims_and_elem[:-1]] - elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) - - -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: - return False - if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: - return False - return True - - return list(filter(is_compatible, MfmaIntrinsic.all())) - - -def get_mfma_intrinsic_constraints( - problem_size: ProblemSize, - intrinsic_m: z3.ArithRef, - intrinsic_n: z3.ArithRef, - intrinsic_k: z3.ArithRef, -) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) - assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" - return z3.Or( - *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics - ) - ) - - -def get_dispatch_constraints( - problem_size: ProblemSize, - tile_m: z3.ArithRef, - tile_n: z3.ArithRef, - tile_k: z3.ArithRef, -) -> list[z3.BoolRef]: - if problem_size.dispatch_kind != DispatchKind.conv: - return [] - - dim_info = ConvDimInfo.from_problem_size(problem_size) - conv_constraints = [] - # WARNING: This sometimes makes the constraints UNSAT for some reason. - conv_constraints += [tile_m <= dim_info.ow] - conv_constraints += [tile_n <= dim_info.oc] - conv_constraints += [tile_k <= dim_info.ic] - return conv_constraints - - -def calculate_shared_memory_usage_in_bytes( - problem_size: ProblemSize, - m: int | z3.ArithRef, - n: int | z3.ArithRef, - k: int | z3.ArithRef, -) -> int | z3.ArithRef: - lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) - rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) - return lhs_memory + rhs_memory - - -def generate_constraints( - problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, -): - M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, - ) - m, n, k = tile_sizes - intrinsic_mn, intrinsic_k = intrinsic_size - wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_threads == wg_x * wg_y * wg_z] - constraints += [subgroup_size == 64, wg_threads <= 1024] - constraints += [ - get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k - ) - ] - subgroup_k_count = 1 - constraints += [ - m >= intrinsic_mn, - m <= 512, - m <= M, - ] - constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] - constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] - for x in (subgroup_m_count, subgroup_n_count): - constraints += [x >= 1, x <= 32] - - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - - constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] - constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] - constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] - constraints += [wg_x == subgroup_size * subgroup_n_count] - constraints += [wg_y == subgroup_m_count] - constraints += [wg_z == subgroup_k_count] - constraints += [z3.Or(wg_x <= n, wg_x <= m)] - constraints += [k % intrinsic_mn == 0] - constraints += [(k * n) % wg_threads == 0] - constraints += [(k * m) % wg_threads == 0] - subgroups = subgroup_m_count * subgroup_n_count - if num_subgroups > 0: - constraints += [subgroups == num_subgroups] - else: - constraints += [subgroups >= 1, subgroups <= 10] - - constraints += [waves_per_eu == 2] - # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) - constraints += [shared_memory <= 65536] - - constraints += get_dispatch_constraints(problem_size, m, n, k) - - return constraints - - -def generate_solutions(problem_size: ProblemSize, num_subgrups: int): - M, N, K = problem_size.MNK - tune_logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") - subgroup_size = z3.Int("subgroup_size") - intrinsic_mn = z3.Int("intrinsic_mn") - intrinsic_k = z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") - sg_m_cnt = z3.Int("sg_m_cnt") - sg_n_cnt = z3.Int("sg_n_cnt") - waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] - - solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - solver.add(z3.simplify(z3.And(constraints))) - tune_logger.debug(f"Initial constraints: {solver}") - i = 0 - while solver.check() == z3.sat: - model = solver.model() - lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( - problem_size.lhs_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.res_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), - lookup(waves_per_eu), - ) - solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) - i += 1 - yield config - - -def get_default_output_dir() -> str: - from datetime import datetime - - return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") - - -def parse_mlir(mlir_text: str) -> ir.Module: - mlir_module = None - with ireec.ir.Context() as context: - try: - mlir_module = ireec.ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ireec.ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") - - return mlir_module - - -@dataclass -class MLIRTransformation: - """Transformation of MLIR context""" - - template: str - modified: str - embeddable: str - - -class DispatchTuner(ABC): - @abstractmethod - def supports(self, op_name: str) -> bool: - """Check if the tuner can handle the type of operation represented by the input string.""" - pass - - @abstractmethod - def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of thge operation.""" - pass - - @abstractmethod - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - """Apply parameter transformations to the operation.""" - pass - - -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_tuner: Optional[DispatchTuner] = None - - -class DispatchTunerRegistry: - def __init__(self): - self.registry = set() - - def register(self, dispatch_tuners: list[DispatchTuner]) -> None: - for dispatch_tuner in dispatch_tuners: - self.registry.add(dispatch_tuner) - - def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: - for attr in attrs: - if (attr.name == "translation_info") and ( - "LLVMGPUVectorDistribute" in str(attr.attr) - ): - return True - assert False, "Translation info not supported" - - def find_handler(self, op_name: str) -> DispatchTuner: - for dispatch_tuner in self.registry: - if dispatch_tuner.supports(op_name): - return dispatch_tuner - assert False, "Dispatch kind not supported" - - -class MmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert dps, f"'{mmt_re}' not found in given context" - - def get_transform_function_mmt( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration - ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) - ) - embeddable = indent( - self.get_transform_function_mmt(problem_size, f"match_op", configuration), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class ConvTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - def get_transform_function_conv( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration - ) -> str: - dynamic_batch_input_ty = problem_size.lhs_type - dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() - dynamic_batch_input_ty.shape[0] = -1 - - dynamic_batch_output_ty = problem_size.res_type - dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() - dynamic_batch_output_ty.shape[0] - 1 - - input = f"tensor<{dynamic_batch_input_ty}>" - filter = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{dynamic_batch_output_ty}>" - - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ - ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): - %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} - ins(%lhs, %rhs : {input}, {filter}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - conv_dims = ConvDimInfo.from_problem_size(problem_size) - modified = indent( - self.get_transform_function_conv( - problem_size, - f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, self.get_conv_tile_sizes(configuration) - ) - embeddable = indent( - self.get_transform_function_conv(problem_size, f"match_op", configuration), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class ContractionTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - - def get_transform_function_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, - ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - lhs_dynamic_batch = problem_size.lhs_type - lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() - lhs_dynamic_batch.shape[0] = -1 - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - if self.is_broadcast_rhs_mmt(template): - return self.apply_params_broadcast_rhs_mmt( - problem_size, template, configuration - ) - - # TODO: Generate transform function. - return MLIRTransformation( - template, - apply_configuration( - template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), - ), - "", - ) - - -class BatchMmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - - def get_transform_function_batch_mmt( - self, - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, - ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - B = problem_size.matmul_size.B - modified = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - - def get_transform_function_batch_matmul( - self, - problem_size: ProblemSize, - tile_dims: str, - functionName: str, - configuration: Configuration, - ) -> str: - input0 = f"tensor<{problem_size.lhs_type}>" - input1 = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{problem_size.res_type}>" - - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) - ) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ - ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): - %13 = linalg.batch_matmul - ins(%lhs, %rhs : {input0}, {input1}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_batch_matmul( - problem_size, - self.tile_dims, - f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), - ) - - embeddable = indent( - self.get_transform_function_batch_matmul( - problem_size, self.tile_dims, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -def walk_callback_get_fn( - op: ir.Operation, - walk_result: OpWalkResult, - dispatch_tuner_registry: DispatchTunerRegistry, -) -> ir.WalkResult: - if op.name == "func.func": - dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) - if op.name == "util.func": - func_name = str(op.opview.sym_name) - walk_result.was_interrupted = True - walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) - return ir.WalkResult.INTERRUPT - return ir.WalkResult.ADVANCE - - -def walk_mlir_op( - mlir_module: ir.Module, - dispatch_tuner_registry: DispatchTunerRegistry, -) -> OpWalkResult: - walk_result = OpWalkResult() - for op in mlir_module.body.operations: - op.walk( - lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), - ir.WalkOrder.POST_ORDER, - ) - if walk_result.was_interrupted: - break - return walk_result - - -def tune( - input: str, # Path to the mlir file to be tuned - output: str = "", # Path to the output directory, auto creates one if not given - limit: int = 4096, # Max candidates to be generated - num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints - lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations - rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations - tile_dims: str = "mnk", # Dimensions for the tile size -): - input_file = str(input) - - if not output: - output = get_default_output_dir() - - # Create the directory if it does not exist - makedirs(str(output), exist_ok=True) - - tune_logger.debug(f"Output directory {output}") - tune_logger.debug(f"Processing {input_file}") - mlir_template = read_input_mlir(input_file) - mlir_text = "".join(mlir_template) - - mlir_module = parse_mlir(mlir_text) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - problem_size = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input", help="Input mlir file", type=str) - parser.add_argument( - "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() - ) - parser.add_argument( - "-l", - "--limit", - help="Max number of candidates generated", - type=int, - default=4096, - ) - parser.add_argument( - "--num-subgroups", - help="Number of subgroups per workgroup to use. (-1 == unconstrained)", - type=int, - default=-1, - ) - parser.add_argument( - "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" - ) - parser.add_argument( - "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" - ) - parser.add_argument( - "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" - ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" - ) - - args = parser.parse_args() - tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) - - # Create printing formatter for logging info - formatter = logging.Formatter("%(message)s") - - # Create a handler to print to console - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - tune_logger.addHandler(console_handler) - - # # Optionally, add a file handler to log to a file - # file_handler = logging.FileHandler("tune.log") - # file_handler.setFormatter(formatter) - # tune_logger.addHandler(file_handler) - - tune( - args.input, - args.output, - args.limit, - args.num_subgroups, - args.lhs_dims, - args.rhs_dims, - args.tile_dims, - ) - - -if __name__ == "__main__": - args = main() diff --git a/sharktank/sharktank/tools/tuner/requirements-dev.txt b/sharktank/sharktank/tools/tuner/requirements-dev.txt deleted file mode 100644 index 51d5b9ba0..000000000 --- a/sharktank/sharktank/tools/tuner/requirements-dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -pre-commit==3.8.0 -virtualenv==20.13.0 diff --git a/sharktank/sharktank/tools/tuner/requirements-tuner.txt b/sharktank/sharktank/tools/tuner/requirements-tuner.txt deleted file mode 100644 index f3484c921..000000000 --- a/sharktank/sharktank/tools/tuner/requirements-tuner.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest==8.2.2 -tqdm==4.66.4 -z3_solver==4.13.0.0 -types-tqdm==4.66.0.20240417 diff --git a/sharktank/tests/tuner/candidate_gen_test.py b/sharktank/tests/tuner/candidate_gen_test.py deleted file mode 100644 index 4fc21aa63..000000000 --- a/sharktank/tests/tuner/candidate_gen_test.py +++ /dev/null @@ -1,814 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Usage: python -m pytest candidate_gen_test.py -""" - -import pytest -from sharktank.tools.tuner import candidate_gen - - -def test_get_shaped_type_element_bitwidth(): - assert ( - candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - ) - assert ( - candidate_gen.ShapedType( - [2048, 512, 384], candidate_gen.ElementType.f8 - ).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 - ) - - -def test_get_shaped_type_to_str(): - assert ( - str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) - == "1024x2048xi8" - ) - assert ( - str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) - == "1024xf32" - ) - assert ( - str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) - == "1x2x3xf16" - ) - assert ( - str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) - == "?x2x3xf16" - ) - - -def test_parse_tensor_type(): - assert candidate_gen.parse_tensor_type( - "tensor<1x2x3xf32>" - ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) - assert candidate_gen.parse_tensor_type( - "tensor<123xi8>" - ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) - - -def test_get_mmt_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic="", - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=1, - ) - assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 16, - ] - - -def test_get_contract_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ - 16, - 16, - 16, - ] - - -def test_get_pipeline_config(): - config1 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - config2 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=4, - ) - assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" - assert ( - candidate_gen.get_pipeline_config(config2) - == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - ) - - -def test_get_shapes_mmt(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv(): - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - - -def test_get_shapes_contract(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul(): - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt(): - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMmtTuner().get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - -def test_mfma_intrinsic_to_str(): - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) - == "MFMA_F16_16x16x16_F32" - ) - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) - == "MFMA_I8_32x32x16_I32" - ) - - -def test_get_compatible_mfma_intrinsics(): - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), - candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - -def test_generate_solutions(): - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - configs = candidate_gen.generate_solutions(problem_size, 4) - assert configs is not None - - -def test_calculate_shared_memory_usage_in_bytes(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 147456 - ) - - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 81920 - ) - - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) - == 12288 - ) - - -def test_generate_constraints_valid_input(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - # Define input parameters as z3 Ints - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are satisfiable - assert solver.check() == candidate_gen.z3.sat - - -def test_generate_constraints_invalid_input(): - # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are unsatisfiable - assert solver.check() == candidate_gen.z3.unsat - - -def test_apply_params_mmt(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - M, N, K = 2048, 1280, 1280 - - config = candidate_gen.Configuration( - subgroup_size=16, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, - waves_per_eu=8, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" - in modified - ) - assert "tile_sizes = [[8, 8, 8]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified - - -def test_apply_params_conv(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - tf_mlir = candidate_gen.ConvTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_contract(): - mlir_template = [ - ", subgroup_m_count = 2, subgroup_n_count = 2>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config - ) - - new_mlir = tf_mlir.modified - - assert new_mlir - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in new_mlir - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in new_mlir - ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir - - -def test_apply_params_batch_matmul(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_float(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert embeddable - assert modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_int(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_apply_params_broadcast_rhs_mmt(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - tf_mlir = candidate_gen.ContractionTuner( - "mk", "nk", "mnk" - ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert ( - "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" - in modified - ) - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_detect_broadcast_rhs_mmt(): - mlir_lines = [ - r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( - mlir_lines - ) - - -def test_parse_mlir(): - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str) - assert mlir_module != None - assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance( - mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp - )