diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index e47100059..06ccae0e3 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -20,7 +20,6 @@ import argparse import logging -import math import pickle import re import z3 # type: ignore @@ -28,36 +27,16 @@ from os import path, makedirs from typing import Optional from textwrap import indent -from abc import ABC, abstractmethod +from abc import abstractmethod from iree.compiler import ir # type: ignore from .common import * +from .dispatch_parser import * tune_logger = logging.getLogger("tune") -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - def apply_configuration( template: list[str], configuration: Configuration, tile_sizes: list[int] ) -> str: @@ -282,29 +261,8 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -def parse_mlir(mlir_text: str, ctx: ir.Context) -> ir.Module: - mlir_module = None - try: - mlir_module = ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") - - return mlir_module - - -class DispatchTuner(ABC): - @abstractmethod - def supports(self, op_name: str) -> bool: - """Check if the tuner can handle the type of operation represented by the input string.""" - pass - - @abstractmethod - def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of the operation.""" - pass - +class DispatchTuner(DispatchParser): + # TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod def apply_params( self, @@ -316,12 +274,6 @@ def apply_params( pass -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_tuner: Optional[DispatchTuner] = None - - class DispatchTunerRegistry: def __init__(self): self.registry = set() @@ -345,60 +297,7 @@ def find_handler(self, op_name: str) -> DispatchTuner: assert False, "Dispatch kind not supported" -class MmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert False, f"'{mmt_re}' not found in given context" - +class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: @@ -450,71 +349,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class ConvTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - +class ConvTuner(DispatchTuner, ConvParser): # int64_t n = outputShape[0]; # int64_t oh = outputShape[1]; # int64_t ow = outputShape[2]; @@ -589,135 +424,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class ContractionTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - +class ContractionTuner(DispatchTuner, ContractionParser): def get_transform_function_broadcast_rhs_mmt( self, problem_size: ProblemSize, @@ -801,57 +508,7 @@ def apply_params( ) -class BatchMmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - +class BatchMmtTuner(DispatchTuner, BatchMmtParser): def get_transform_function_batch_mmt( self, problem_size: ProblemSize, @@ -910,78 +567,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - +class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): def get_transform_function_batch_matmul( self, problem_size: ProblemSize, @@ -1053,6 +639,12 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + + def walk_callback_get_fn( op: ir.Operation, walk_result: OpWalkResult, @@ -1106,7 +698,8 @@ def tune( mlir_text = "".join(mlir_template) with ir.Context() as ctx: - mlir_module: ir.Module = parse_mlir(mlir_text, ctx) + tuner_context = TunerContext(ctx, tune_logger) + mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context) # Save the input file as the first candidate. with open(path.join(output, f"0.mlir"), "w") as f: f.write(mlir_text) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 5f7b9c848..63e8441d1 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -9,189 +9,48 @@ """ import pytest -from . import candidate_gen - -from iree.compiler import ir # type: ignore -from iree.compiler.dialects import func # type: ignore - - -def test_get_mmt_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=1, - ) - assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 16, - ] - - -def test_get_contract_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, "knm") == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, "kkk") == [ - 16, - 16, - 16, - ] - - -def test_get_shapes_mmt() -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv() -> None: - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - -def test_get_shapes_contract() -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul() -> None: - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt() -> None: - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMmtTuner().get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) +from . import candidate_gen +from . import common def test_generate_solutions() -> None: - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(2048, 3840, 1280) + lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) + rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) + res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) configs = candidate_gen.generate_solutions(problem_size, 4) assert configs is not None def test_calculate_shared_memory_usage_in_bytes() -> None: - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 147456 ) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 81920 ) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) @@ -200,12 +59,12 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: def test_generate_constraints_valid_input() -> None: - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) # Define input parameters as z3 Ints m, n, k = ( @@ -246,12 +105,12 @@ def test_generate_constraints_valid_input() -> None: def test_generate_constraints_invalid_input() -> None: # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) m, n, k = ( candidate_gen.z3.Int("m"), @@ -307,25 +166,23 @@ def test_apply_params_mmt() -> None: M, N, K = 2048, 1280, 1280 - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions( - prefetch_shared_memory=True - ), + gpu_pipeline_options=common.GpuPipelineOptions(prefetch_shared_memory=True), waves_per_eu=8, ) - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, + problem_size = common.ProblemSize( + common.MatmulSize(M, N, K), + common.ShapedType([M, K], common.ElementType.f16), + common.ShapedType([N, K], common.ElementType.f16), + common.ShapedType([M, N], common.ElementType.f32), + common.DispatchKind.mmt, ) tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) @@ -361,27 +218,25 @@ def test_apply_params_conv() -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions( - reorder_workgroups_strategy=candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE + gpu_pipeline_options=common.GpuPipelineOptions( + reorder_workgroups_strategy=common.ReorderWorkgroupsStrategy.TRANSPOSE ), waves_per_eu=2, ) - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, + problem_size = common.ProblemSize( + common.MatmulSize(oh * ow, oc, fh * fw * ic), + common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16), + common.ShapedType([fh, fw, ic, oc], common.ElementType.f16), + common.ShapedType([n, oh, ow, oc], common.ElementType.f32), + common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( problem_size, mlir_template, config @@ -419,22 +274,22 @@ def test_apply_params_contract() -> None: ] tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, + problem_size = common.ProblemSize( + common.MatmulSize(2048, 3840, 1280), + common.ShapedType([2, 1024, 1280], common.ElementType.f16), + common.ShapedType([3, 20, 64, 1280], common.ElementType.f16), + common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32), + common.DispatchKind.contraction, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -466,22 +321,22 @@ def test_apply_params_batch_matmul() -> None: ] tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, + problem_size = common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], common.ElementType.f16), + common.ShapedType([64, 640, 320], common.ElementType.f16), + common.ShapedType([64, 968, 320], common.ElementType.f32), + common.DispatchKind.batch_matmul, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -516,22 +371,22 @@ def test_apply_params_batch_mmt_float() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.f16), + common.ShapedType([2, 640, 640], common.ElementType.f16), + common.ShapedType([2, 4096, 640], common.ElementType.f32), + common.DispatchKind.batch_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -564,22 +419,22 @@ def test_apply_params_batch_mmt_int() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([2, 640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.DispatchKind.batch_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=4, ) @@ -635,22 +490,22 @@ def test_apply_params_broadcast_rhs_mmt() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.DispatchKind.broadcast_rhs_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=4, ) @@ -711,19 +566,3 @@ def test_detect_broadcast_rhs_mmt() -> None: assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( mlir_lines ) - - -def test_parse_mlir() -> None: - with ir.Context() as ctx: - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str, ctx) - assert mlir_module is not None - assert isinstance(mlir_module, ir.Module) - assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 1fbeb9910..7b295cdb0 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -5,10 +5,19 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import re +import logging from dataclasses import astuple, dataclass from enum import Enum from typing import Optional +from iree.compiler import ir # type: ignore + + +class TunerContext: + def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): + self.mlir_ctx = mlir_ctx + self.logger = logger + class DispatchKind(Enum): conv = 1 diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py new file mode 100644 index 000000000..670f8c3f7 --- /dev/null +++ b/tuner/tuner/dispatch_parser.py @@ -0,0 +1,435 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import math +import re +from abc import ABCMeta, abstractmethod + +from .common import * + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: + mlir_module = None + try: + mlir_module = ir.Module.parse(mlir_text, ctx.mlir_ctx) + ctx.logger.info("MLIR parsing successful!") + except ir.MLIRError as e: + ctx.logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +class DispatchParser(metaclass=ABCMeta): + @abstractmethod + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" + pass + + @abstractmethod + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of the operation.""" + pass + + +class MmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + assert mmt_re + assert False, f"'{mmt_re}' not found in given context" + + +class ConvParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + +class ContractionParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not self.is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + +class BatchMmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + +class BatchMatmulParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py new file mode 100644 index 000000000..bcdee240c --- /dev/null +++ b/tuner/tuner/dispatch_parser_test.py @@ -0,0 +1,176 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest + +from logging import Logger +from unittest.mock import MagicMock + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import func # type: ignore + +from . import common +from . import dispatch_parser + + +def test_get_mmt_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=0, + ) + assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=1, + ) + assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] + + +def test_get_contract_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] + assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ + 16, + 16, + 16, + ] + + +def test_get_shapes_mmt() -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + dispatch_parser.DispatchKind.mmt, + ) + + +def test_get_shapes_conv() -> None: + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(32, 256, 11520), + common.ShapedType([1, 3, 34, 1280], common.ElementType.f16), + common.ShapedType([3, 3, 1280, 256], common.ElementType.f16), + common.ShapedType([1, 1, 32, 256], common.ElementType.f32), + dispatch_parser.DispatchKind.conv, + ) + + +def test_get_shapes_contract() -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.ContractionParser("mk", "nk", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + dispatch_parser.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul() -> None: + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMatmulParser("bmk", "bkn", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(32, 32, 1024, 1), + common.ShapedType([1, 32, 1024], common.ElementType.f32), + common.ShapedType([1, 1024, 32], common.ElementType.f32), + common.ShapedType([1, 32, 32], common.ElementType.f32), + dispatch_parser.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt() -> None: + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([2, 640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + dispatch_parser.DispatchKind.batch_mmt, + ) + + +def test_parse_mlir() -> None: + with ir.Context() as ctx: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + logger: Logger = MagicMock(spec=Logger) + tuner_context = common.TunerContext(ctx, logger) + mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_context) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp)