diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index b6552bfd5..9855a2779 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -30,15 +30,15 @@ jobs: - name: Install dev dependencies run: | python -m pip install --upgrade pip - pip install -r sharktank/sharktank/tools/tuner/requirements-dev.txt + pip install -r sharktank/tuner/requirements-dev.txt - name: Install tuner dependencies run: | - pip install -r sharktank/sharktank/tools/tuner/requirements-tuner.txt + pip install -r sharktank/tuner/requirements-tuner.txt python -m pip install \ --find-links https://iree.dev/pip-release-links.html \ --upgrade \ iree-compiler iree-runtime - name: Run tuner tests - run: pytest sharktank/tests/tuner/ + run: pytest sharktank/tuner/ diff --git a/tuner/README.md b/tuner/README.md new file mode 100644 index 000000000..69821496e --- /dev/null +++ b/tuner/README.md @@ -0,0 +1,67 @@ +# IREE dispatch auto-tuning scripts +`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. + +## Prerequisites +[Optional] Using virtual environments: +```shell +cd tuning +python -m venv .venv +source .venv/bin/activate +``` +Install python dependencies: +```shell +pip install -r ./requirements-tuner.txt +``` +Using the IREE's Python bindings: + - Building with CMake + ```shell + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python)" + ``` + - Set environment + ```shell + source ../iree-build/.env && export PYTHONPATH + ``` +For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) + +### Overall flow + +1. Symlink all scripts and mlir/irpa files in your build dir. + - Symlink `iree-build-dir/tools` inside `tuning`. + - Symlink ML model MLIR and weights based on `unet.sh`. + +2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. + +3. Temporarily comment out all the existing configs in `config.mlir`. + - Example: + ```mlir + // , @match_mmt_2048x10240x1280 -> @apply_op_config + // , @match_mmt_2048x1280x5120 -> @apply_op_config + // , @match_mmt_2048x1280x1280 -> @apply_op_config + ``` + +4. Compile a baseline unet +```shell +./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd +``` + +5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. +```shell +cp dump-winograd/*_141_*benchmark.mlir ./141.mlir +``` + +6. Run the tuning script. + - Example: + ```shell + python punet_autotune.py 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 + ``` + +7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. + +8. Paste the transform spec into the `config.mlir` and uncomment them. + +9. Add the match function to the entry point in `config.mlir` + - Example: + ```mlir + @match_something -> @apply_op_config + ``` diff --git a/tuner/candidate_gen.py b/tuner/candidate_gen.py new file mode 100755 index 000000000..8a8315afb --- /dev/null +++ b/tuner/candidate_gen.py @@ -0,0 +1,1408 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +""" +Generate candidates by tweaking op configuration for tuning. + +It can be invoked in two ways: + 1. From another python script, import and call `tune()` + 2. Run this script directly from the command + +Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk + +""" + +import argparse +import logging +import math +import pickle +import re +import z3 +from dataclasses import asdict, dataclass +from enum import Enum +from os import mkdir, path, makedirs +from typing import Callable, Optional +from textwrap import indent +from abc import ABC, abstractmethod + +import iree.compiler as ireec +from iree.compiler import ir +from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen + + +tune_logger = logging.getLogger("tune") + + +class DispatchKind(Enum): + conv = 1 + mmt = 2 + contraction = 3 + batch_mmt = 4 + batch_matmul = 5 + broadcast_rhs_mmt = 6 + + +class ElementType(Enum): + i8 = 1 + i32 = 2 + f8 = 3 + f16 = 4 + f32 = 5 + + @property + def bitwidth(self) -> int: + match self: + case ElementType.i8 | ElementType.f8: + return 8 + case ElementType.f16: + return 16 + case ElementType.i32 | ElementType.f32: + return 32 + case _: + assert False, "unhandled case" + + def __str__(self) -> str: + return self.name + + +@dataclass +class ShapedType: + shape: list[int] + element_type: ElementType + + def rank(self) -> int: + return len(self.shape) + + @property + def bitwidth(self) -> int: + return self.element_type.bitwidth + + def __str__(self) -> str: + dim_to_str = lambda dim: str(dim) if dim != -1 else "?" + return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) + + +@dataclass +class MatmulSize: + M: int + N: int + K: int + B: int = 1 + + +@dataclass +class ProblemSize: + matmul_size: MatmulSize + lhs_type: ShapedType + rhs_type: ShapedType + res_type: ShapedType + dispatch_kind: DispatchKind + + @property + def MNK(self) -> tuple[int, int, int]: + return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) + + +@dataclass +class MfmaIntrinsic: + input_type: ElementType + m: int + n: int + k: int + output_type: ElementType + + def __str__(self) -> str: + input = str(self.input_type).upper() + output = str(self.output_type).upper() + return f"MFMA_{input}_{self.m}x{self.n}x{self.k}_{output}" + + @staticmethod + def mfma_f16_16x16x16_f32(): + return MfmaIntrinsic(ElementType.f16, 16, 16, 16, ElementType.f32) + + @staticmethod + def mfma_f16_32x32x8_f32(): + return MfmaIntrinsic(ElementType.f16, 32, 32, 8, ElementType.f32) + + @staticmethod + def mfma_i8_16x16x32_i32(): + return MfmaIntrinsic(ElementType.i8, 16, 16, 32, ElementType.i32) + + @staticmethod + def mfma_i8_32x32x16_i32(): + return MfmaIntrinsic(ElementType.i8, 32, 32, 16, ElementType.i32) + + @staticmethod + def all(): + return [ + MfmaIntrinsic.mfma_f16_16x16x16_f32(), + MfmaIntrinsic.mfma_f16_32x32x8_f32(), + MfmaIntrinsic.mfma_i8_16x16x32_i32(), + MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + +@dataclass +class Configuration: + subgroup_size: int + workgroup_size: list[int] + intrinsic: MfmaIntrinsic + tile_sizes: list[int] + subgroup_m_count: int + subgroup_n_count: int + waves_per_eu: int + + +class MlirRegex(str, Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + +def read_input_mlir(filename: str) -> list[str]: + with open(filename, "r") as f: + return f.readlines() + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +@dataclass +class ConvDimInfo: + n: int + oh: int + ow: int + oc: int + fh: int + fw: int + ic: int + + @staticmethod + def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): + n, oh, ow, oc = res_shaped_type.shape + fh, fw, ic, _ = rhs_shaped_type.shape + return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) + + @staticmethod + def from_problem_size(problem_size: ProblemSize): + return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +def get_pipeline_config(configuration: Configuration) -> str: + extra_config = ", prefetch_shared_memory" + if configuration.waves_per_eu != 2: + extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' + return extra_config + + +def apply_configuration( + template: list[str], configuration: Configuration, tile_sizes: list[int] +) -> str: + tune_logger.info(f"Applying: {configuration}") + expr0 = re.compile( + r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" + ) + expr1 = re.compile( + r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," + ) + expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' + repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' + repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + + new_mlir = "" + for line in template: + if "intrinsic =" in line: + line = re.sub(expr0, repl0, line) + if "LLVMGPUVectorDistribute " in line: + line = re.sub(expr1, repl1, line) + if "tile_sizes" in line: + line = re.sub(expr2, repl2, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr3, repl3, line) + new_mlir += line + + return new_mlir + + +def parse_tensor_type(tensor_type: str) -> ShapedType: + shape_match = re.search(MlirRegex.tensor_type, tensor_type) + assert shape_match + + shape_str = shape_match.group(1) + dims_and_elem = shape_str.split("x") + dims = [int(x) for x in dims_and_elem[:-1]] + elem = dims_and_elem[-1] + str_to_elem_ty = {x.name: x for x in ElementType} + return ShapedType(dims, str_to_elem_ty[elem]) + + +def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: + if problem_size.res_type.element_type != intrinsic.output_type: + return False + if problem_size.dispatch_kind != DispatchKind.batch_matmul: + if problem_size.lhs_type.element_type != intrinsic.input_type: + return False + if problem_size.rhs_type.element_type != intrinsic.input_type: + return False + return True + + return list(filter(is_compatible, MfmaIntrinsic.all())) + + +def get_mfma_intrinsic_constraints( + problem_size: ProblemSize, + intrinsic_m: z3.ArithRef, + intrinsic_n: z3.ArithRef, + intrinsic_k: z3.ArithRef, +) -> z3.BoolRef: + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + return z3.Or( + *( + z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) + for mfma in compatible_intrinsics + ) + ) + + +def get_dispatch_constraints( + problem_size: ProblemSize, + tile_m: z3.ArithRef, + tile_n: z3.ArithRef, + tile_k: z3.ArithRef, +) -> list[z3.BoolRef]: + if problem_size.dispatch_kind != DispatchKind.conv: + return [] + + dim_info = ConvDimInfo.from_problem_size(problem_size) + conv_constraints = [] + # WARNING: This sometimes makes the constraints UNSAT for some reason. + conv_constraints += [tile_m <= dim_info.ow] + conv_constraints += [tile_n <= dim_info.oc] + conv_constraints += [tile_k <= dim_info.ic] + return conv_constraints + + +def calculate_shared_memory_usage_in_bytes( + problem_size: ProblemSize, + m: int | z3.ArithRef, + n: int | z3.ArithRef, + k: int | z3.ArithRef, +) -> int | z3.ArithRef: + lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) + rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + return lhs_memory + rhs_memory + + +def generate_constraints( + problem_size: ProblemSize, + tile_sizes, + num_subgroups, + subgroup_size, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, + waves_per_eu, +): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +def generate_solutions(problem_size: ProblemSize, num_subgrups: int): + M, N, K = problem_size.MNK + tune_logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = generate_constraints( + problem_size, + [m, n, k], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + solver.add(z3.simplify(z3.And(constraints))) + tune_logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + MfmaIntrinsic( + problem_size.lhs_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.res_type.element_type, + ), + [lookup(m), lookup(n), lookup(k)], + lookup(sg_m_cnt), + lookup(sg_n_cnt), + lookup(waves_per_eu), + ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) + i += 1 + yield config + + +def get_default_output_dir() -> str: + from datetime import datetime + + return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") + + +def parse_mlir(mlir_text: str) -> ir.Module: + mlir_module = None + with ireec.ir.Context() as context: + try: + mlir_module = ireec.ir.Module.parse(mlir_text) + tune_logger.info("MLIR parsing successful!") + except ireec.ir.MLIRError as e: + tune_logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +@dataclass +class MLIRTransformation: + """Transformation of MLIR context""" + + template: str + modified: str + embeddable: str + + +class DispatchTuner(ABC): + @abstractmethod + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" + pass + + @abstractmethod + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of thge operation.""" + pass + + @abstractmethod + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + """Apply parameter transformations to the operation.""" + pass + + +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + + +class DispatchTunerRegistry: + def __init__(self): + self.registry = set() + + def register(self, dispatch_tuners: list[DispatchTuner]) -> None: + for dispatch_tuner in dispatch_tuners: + self.registry.add(dispatch_tuner) + + def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: + for attr in attrs: + if (attr.name == "translation_info") and ( + "LLVMGPUVectorDistribute" in str(attr.attr) + ): + return True + assert False, "Translation info not supported" + + def find_handler(self, op_name: str) -> DispatchTuner: + for dispatch_tuner in self.registry: + if dispatch_tuner.supports(op_name): + return dispatch_tuner + assert False, "Dispatch kind not supported" + + +class MmtTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + assert mmt_re + assert dps, f"'{mmt_re}' not found in given context" + + def get_transform_function_mmt( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_mmt( + problem_size, f"match_mmt_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_mmt_tile_sizes(configuration) + ) + embeddable = indent( + self.get_transform_function_mmt(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class ConvTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + def get_transform_function_conv( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + dynamic_batch_input_ty = problem_size.lhs_type + dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() + dynamic_batch_input_ty.shape[0] = -1 + + dynamic_batch_output_ty = problem_size.res_type + dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() + dynamic_batch_output_ty.shape[0] - 1 + + input = f"tensor<{dynamic_batch_input_ty}>" + filter = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{dynamic_batch_output_ty}>" + + tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ + ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): + %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} + ins(%lhs, %rhs : {input}, {filter}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + conv_dims = ConvDimInfo.from_problem_size(problem_size) + modified = indent( + self.get_transform_function_conv( + problem_size, + f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, self.get_conv_tile_sizes(configuration) + ) + embeddable = indent( + self.get_transform_function_conv(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class ContractionTuner(DispatchTuner): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not self.is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + def get_transform_function_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + lhs_dynamic_batch = problem_size.lhs_type + lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() + lhs_dynamic_batch.shape[0] = -1 + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + def apply_params_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + self.get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + if self.is_broadcast_rhs_mmt(template): + return self.apply_params_broadcast_rhs_mmt( + problem_size, template, configuration + ) + + # TODO: Generate transform function. + return MLIRTransformation( + template, + apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ), + "", + ) + + +class BatchMmtTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + def get_transform_function_batch_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + B = problem_size.matmul_size.B + modified = indent( + self.get_transform_function_batch_mmt( + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + self.get_transform_function_batch_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class BatchMatmulTuner(DispatchTuner): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" + + def get_transform_function_batch_matmul( + self, + problem_size: ProblemSize, + tile_dims: str, + functionName: str, + configuration: Configuration, + ) -> str: + input0 = f"tensor<{problem_size.lhs_type}>" + input1 = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{problem_size.res_type}>" + + tile_sizes = ", ".join( + map(str, get_contract_tile_sizes(configuration, tile_dims)) + ) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ + ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : {input0}, {input1}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_batch_matmul( + problem_size, + self.tile_dims, + f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ) + + embeddable = indent( + self.get_transform_function_batch_matmul( + problem_size, self.tile_dims, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +def walk_callback_get_fn( + op: ir.Operation, + walk_result: OpWalkResult, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> ir.WalkResult: + if op.name == "func.func": + dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) + if op.name == "util.func": + func_name = str(op.opview.sym_name) + walk_result.was_interrupted = True + walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) + return ir.WalkResult.INTERRUPT + return ir.WalkResult.ADVANCE + + +def walk_mlir_op( + mlir_module: ir.Module, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> OpWalkResult: + walk_result = OpWalkResult() + for op in mlir_module.body.operations: + op.walk( + lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), + ir.WalkOrder.POST_ORDER, + ) + if walk_result.was_interrupted: + break + return walk_result + + +def tune( + input: str, # Path to the mlir file to be tuned + output: str = "", # Path to the output directory, auto creates one if not given + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints + lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations + rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations + tile_dims: str = "mnk", # Dimensions for the tile size +): + input_file = str(input) + + if not output: + output = get_default_output_dir() + + # Create the directory if it does not exist + makedirs(str(output), exist_ok=True) + + tune_logger.debug(f"Output directory {output}") + tune_logger.debug(f"Processing {input_file}") + mlir_template = read_input_mlir(input_file) + mlir_text = "".join(mlir_template) + + mlir_module = parse_mlir(mlir_text) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) + + walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + problem_size = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input", help="Input mlir file", type=str) + parser.add_argument( + "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() + ) + parser.add_argument( + "-l", + "--limit", + help="Max number of candidates generated", + type=int, + default=4096, + ) + parser.add_argument( + "--num-subgroups", + help="Number of subgroups per workgroup to use. (-1 == unconstrained)", + type=int, + default=-1, + ) + parser.add_argument( + "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" + ) + parser.add_argument( + "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" + ) + parser.add_argument( + "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" + ) + + args = parser.parse_args() + tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) + + # Create printing formatter for logging info + formatter = logging.Formatter("%(message)s") + + # Create a handler to print to console + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + tune_logger.addHandler(console_handler) + + # # Optionally, add a file handler to log to a file + # file_handler = logging.FileHandler("tune.log") + # file_handler.setFormatter(formatter) + # tune_logger.addHandler(file_handler) + + tune( + args.input, + args.output, + args.limit, + args.num_subgroups, + args.lhs_dims, + args.rhs_dims, + args.tile_dims, + ) + + +if __name__ == "__main__": + args = main() diff --git a/tuner/candidate_gen_test.py b/tuner/candidate_gen_test.py new file mode 100644 index 000000000..52b01869b --- /dev/null +++ b/tuner/candidate_gen_test.py @@ -0,0 +1,814 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +import candidate_gen + + +def test_get_shaped_type_element_bitwidth(): + assert ( + candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 + ) + assert ( + candidate_gen.ShapedType( + [2048, 512, 384], candidate_gen.ElementType.f8 + ).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + ) + + +def test_get_shaped_type_to_str(): + assert ( + str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) + == "1024x2048xi8" + ) + assert ( + str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) + == "1024xf32" + ) + assert ( + str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) + == "1x2x3xf16" + ) + assert ( + str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) + == "?x2x3xf16" + ) + + +def test_parse_tensor_type(): + assert candidate_gen.parse_tensor_type( + "tensor<1x2x3xf32>" + ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) + assert candidate_gen.parse_tensor_type( + "tensor<123xi8>" + ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) + + +def test_get_mmt_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic="", + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + waves_per_eu=0, + ) + assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic="#iree_gpu.mma_layout", + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=1, + ) + assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] + + +def test_get_contract_tile_sizes(): + config = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + 16, + 16, + 16, + ] + + +def test_get_pipeline_config(): + config1 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=2, + ) + config2 = candidate_gen.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic="", + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + waves_per_eu=4, + ) + assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" + assert ( + candidate_gen.get_pipeline_config(config2) + == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_get_shapes_mmt(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + + +def test_get_shapes_conv(): + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 256, 11520), + candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + + +def test_get_shapes_contract(): + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( + template + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul(): + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( + template + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(32, 32, 1024, 1), + candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt(): + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert candidate_gen.BatchMmtTuner().get_shapes( + template + ) == candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + +def test_mfma_intrinsic_to_str(): + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) + == "MFMA_F16_16x16x16_F32" + ) + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) + == "MFMA_I8_32x32x16_I32" + ) + + +def test_get_compatible_mfma_intrinsics(): + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 1280, 1280), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.mmt, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), + candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + ] + + assert candidate_gen.get_compatible_mfma_intrinsics( + candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + ) == [ + candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + ] + + +def test_generate_solutions(): + matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) + lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + configs = candidate_gen.generate_solutions(problem_size, 4) + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 147456 + ) + + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 81920 + ) + + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + assert ( + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) + == 12288 + ) + + +def test_generate_constraints_valid_input(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == candidate_gen.z3.sat + + +def test_generate_constraints_invalid_input(): + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + subgroup_size = candidate_gen.z3.Int("subgroup_size") + intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") + intrinsic_k = candidate_gen.z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) + sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") + sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") + waves_per_eu = candidate_gen.z3.Int("waves_per_eu") + + constraints = candidate_gen.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == candidate_gen.z3.unsat + + +def test_apply_params_mmt(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.Configuration( + subgroup_size=16, + workgroup_size=[16, 16, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[8, 8, 8], + subgroup_m_count=16, + subgroup_n_count=16, + waves_per_eu=8, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" + in modified + ) + assert "tile_sizes = [[8, 8, 8]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified + + +def test_apply_params_conv(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), + candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + tf_mlir = candidate_gen.ConvTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_contract(): + mlir_template = [ + ", subgroup_m_count = 2, subgroup_n_count = 2>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "*mnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 3840, 1280), + candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[480, 384, 32], + subgroup_m_count=1, + subgroup_n_count=4, + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + new_mlir = tf_mlir.modified + + assert new_mlir + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in new_mlir + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in new_mlir + ) + assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir + + +def test_apply_params_batch_matmul(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "bmnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), + tile_sizes=[416, 320, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_float(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert embeddable + assert modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_int(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_apply_params_broadcast_rhs_mmt(): + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.broadcast_rhs_mmt, + ) + + config = candidate_gen.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + waves_per_eu=4, + ) + + tf_mlir = candidate_gen.ContractionTuner( + "mk", "nk", "mnk" + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert ( + "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" + in modified + ) + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_detect_broadcast_rhs_mmt(): + mlir_lines = [ + r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + ] + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( + mlir_lines + ) + + +def test_parse_mlir(): + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + mlir_module = candidate_gen.parse_mlir(mlir_str) + assert mlir_module != None + assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) + assert isinstance( + mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp + ) diff --git a/tuner/requirements-dev.txt b/tuner/requirements-dev.txt new file mode 100644 index 000000000..51d5b9ba0 --- /dev/null +++ b/tuner/requirements-dev.txt @@ -0,0 +1,2 @@ +pre-commit==3.8.0 +virtualenv==20.13.0 diff --git a/tuner/requirements-tuner.txt b/tuner/requirements-tuner.txt new file mode 100644 index 000000000..f3484c921 --- /dev/null +++ b/tuner/requirements-tuner.txt @@ -0,0 +1,4 @@ +pytest==8.2.2 +tqdm==4.66.4 +z3_solver==4.13.0.0 +types-tqdm==4.66.0.20240417