From 72163ac240689574395ccaf79b95a0550f712861 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 21:13:35 -0500 Subject: [PATCH 1/2] [tuner] Move common utilities to their own file Signed-off-by: Jakub Kuderski --- tuner/tuner/candidate_gen.py | 250 +---------------------------- tuner/tuner/candidate_gen_test.py | 151 ------------------ tuner/tuner/common.py | 255 ++++++++++++++++++++++++++++++ tuner/tuner/common_test.py | 131 +++++++++++++++ 4 files changed, 389 insertions(+), 398 deletions(-) create mode 100644 tuner/tuner/common.py create mode 100644 tuner/tuner/common_test.py diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 96bfc7146..e47100059 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -24,8 +24,7 @@ import pickle import re import z3 # type: ignore -from dataclasses import astuple, dataclass -from enum import Enum +from dataclasses import dataclass from os import path, makedirs from typing import Optional from textwrap import indent @@ -33,214 +32,15 @@ from iree.compiler import ir # type: ignore -tune_logger = logging.getLogger("tune") - - -class DispatchKind(Enum): - conv = 1 - mmt = 2 - contraction = 3 - batch_mmt = 4 - batch_matmul = 5 - broadcast_rhs_mmt = 6 - - -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - -@dataclass -class ShapedType: - shape: list[int] - element_type: ElementType - - def rank(self) -> int: - return len(self.shape) - - @property - def bitwidth(self) -> int: - return self.element_type.bitwidth - - def __str__(self) -> str: - dim_to_str = lambda dim: str(dim) if dim != -1 else "?" - return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) - - -@dataclass -class MatmulSize: - M: int - N: int - K: int - B: int = 1 - - -@dataclass -class ProblemSize: - matmul_size: MatmulSize - lhs_type: ShapedType - rhs_type: ShapedType - res_type: ShapedType - dispatch_kind: DispatchKind - - @property - def MNK(self) -> tuple[int, int, int]: - return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) - - -@dataclass -class MfmaIntrinsic: - output_type: ElementType - m: int - n: int - k: int - input_type: ElementType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" - - @staticmethod - def mfma_f32_16x16x16_f16(): - return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16) - - @staticmethod - def mfma_f32_32x32x8_f16(): - return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16) - - @staticmethod - def mfma_i32_16x16x32_i8(): - return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8) - - @staticmethod - def mfma_i32_32x32x16_i8(): - return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f32_16x16x16_f16(), - MfmaIntrinsic.mfma_f32_32x32x8_f16(), - MfmaIntrinsic.mfma_i32_16x16x32_i8(), - MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - -class ReorderWorkgroupsStrategy(Enum): - NONE = 0 - SWIZZLE = 1 - TRANSPOSE = 2 - - def __str__(self) -> str: - return self.name.title() - - -@dataclass -class GpuPipelineOptions: - """Represents the `iree_gpu.pipeline_options` attribute""" +from .common import * - prefetch_shared_memory: Optional[bool] = None - no_reduce_shared_memory_bank_conflicts: Optional[bool] = None - reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None - - def all_default(self) -> bool: - return all(x is None for x in astuple(self)) - - def __str__(self) -> str: - options: list[str] = [] - if self.prefetch_shared_memory is not None: - options.append( - f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}" - ) - if self.no_reduce_shared_memory_bank_conflicts is not None: - options.append( - f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}" - ) - if self.reorder_workgroups_strategy is not None: - options.append( - f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}" - ) - - return f"#iree_gpu.pipeline_options<{', '.join(options)}>" - - -@dataclass -class Configuration: - subgroup_size: int - workgroup_size: list[int] - intrinsic: MfmaIntrinsic - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int - gpu_pipeline_options: GpuPipelineOptions - waves_per_eu: int - - -class MlirRegex(Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" - - def __str__(self) -> str: - return self.value - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - -def read_input_mlir(filename: str) -> list[str]: - with open(filename, "r") as f: - return f.readlines() +tune_logger = logging.getLogger("tune") def get_mmt_tile_sizes(configuration: Configuration): return configuration.tile_sizes -@dataclass -class ConvDimInfo: - n: int - oh: int - ow: int - oc: int - fh: int - fw: int - ic: int - - @staticmethod - def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): - n, oh, ow, oc = res_shaped_type.shape - fh, fw, ic, _ = rhs_shaped_type.shape - return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) - - @staticmethod - def from_problem_size(problem_size: ProblemSize): - return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) - - def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: m, n, k = configuration.tile_sizes tile_size = [1] * len(tile_dims) @@ -258,15 +58,6 @@ def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: return [1] + configuration.tile_sizes -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = "" - if not configuration.gpu_pipeline_options.all_default(): - extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config - - def apply_configuration( template: list[str], configuration: Configuration, tile_sizes: list[int] ) -> str: @@ -303,32 +94,6 @@ def apply_configuration( return new_mlir -def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) - assert shape_match - - shape_str = shape_match.group(1) - dims_and_elem = shape_str.split("x") - dims = [int(x) for x in dims_and_elem[:-1]] - elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) - - -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: - return False - if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: - return False - return True - - return list(filter(is_compatible, MfmaIntrinsic.all())) - - def get_mfma_intrinsic_constraints( problem_size: ProblemSize, intrinsic_m: z3.ArithRef, @@ -529,15 +294,6 @@ def parse_mlir(mlir_text: str, ctx: ir.Context) -> ir.Module: return mlir_module -@dataclass -class MLIRTransformation: - """Transformation of MLIR context""" - - template: list[str] - modified: str - embeddable: str - - class DispatchTuner(ABC): @abstractmethod def supports(self, op_name: str) -> bool: diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index a1a3a3e49..5f7b9c848 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -15,53 +15,6 @@ from iree.compiler.dialects import func # type: ignore -def test_get_shaped_type_element_bitwidth() -> None: - assert ( - candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - ) - assert ( - candidate_gen.ShapedType( - [2048, 512, 384], candidate_gen.ElementType.f8 - ).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 - ) - - -def test_get_shaped_type_to_str() -> None: - assert ( - str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) - == "1024x2048xi8" - ) - assert ( - str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) - == "1024xf32" - ) - assert ( - str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) - == "1x2x3xf16" - ) - assert ( - str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) - == "?x2x3xf16" - ) - - -def test_parse_tensor_type() -> None: - assert candidate_gen.parse_tensor_type( - "tensor<1x2x3xf32>" - ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) - assert candidate_gen.parse_tensor_type( - "tensor<123xi8>" - ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) - - def test_get_mmt_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=0, @@ -98,32 +51,6 @@ def test_get_conv_tile_sizes() -> None: ] -def test_gpu_pipeline_options() -> None: - options = candidate_gen.GpuPipelineOptions() - assert options.all_default() - assert str(options) == "#iree_gpu.pipeline_options<>" - - options.prefetch_shared_memory = True - assert not options.all_default() - assert str(options) == "#iree_gpu.pipeline_options" - - options.no_reduce_shared_memory_bank_conflicts = False - assert ( - str(options) - == "#iree_gpu.pipeline_options" - ) - - options = candidate_gen.GpuPipelineOptions() - options.reorder_workgroups_strategy = ( - candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE - ) - assert not options.all_default() - assert ( - str(options) - == "#iree_gpu.pipeline_options" - ) - - def test_get_contract_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=32, @@ -145,32 +72,6 @@ def test_get_contract_tile_sizes() -> None: ] -def test_get_pipeline_config() -> None: - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=2, - ) - config1_str: str = candidate_gen.get_pipeline_config(config) - assert config1_str == "" - - config.waves_per_eu = 4 - config2_str: str = candidate_gen.get_pipeline_config(config) - assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - - config.gpu_pipeline_options.prefetch_shared_memory = True - config3_str = candidate_gen.get_pipeline_config(config) - assert ( - config3_str - == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - ) - - def test_get_shapes_mmt() -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", @@ -254,58 +155,6 @@ def test_get_shapes_batch_mmt() -> None: ) -def test_mfma_intrinsic_to_str() -> None: - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16()) - == "MFMA_F32_16x16x16_F16" - ) - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8()) - == "MFMA_I32_32x32x16_I8" - ) - - -def test_get_compatible_mfma_intrinsics() -> None: - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_i32_16x16x32_i8(), - candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), - ] - - def test_generate_solutions() -> None: matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py new file mode 100644 index 000000000..1fbeb9910 --- /dev/null +++ b/tuner/tuner/common.py @@ -0,0 +1,255 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import re +from dataclasses import astuple, dataclass +from enum import Enum +from typing import Optional + + +class DispatchKind(Enum): + conv = 1 + mmt = 2 + contraction = 3 + batch_mmt = 4 + batch_matmul = 5 + broadcast_rhs_mmt = 6 + + +class ElementType(Enum): + i8 = 1 + i32 = 2 + f8 = 3 + f16 = 4 + f32 = 5 + + @property + def bitwidth(self) -> int: + match self: + case ElementType.i8 | ElementType.f8: + return 8 + case ElementType.f16: + return 16 + case ElementType.i32 | ElementType.f32: + return 32 + case _: + assert False, "unhandled case" + + def __str__(self) -> str: + return self.name + + +@dataclass +class ShapedType: + shape: list[int] + element_type: ElementType + + def rank(self) -> int: + return len(self.shape) + + @property + def bitwidth(self) -> int: + return self.element_type.bitwidth + + def __str__(self) -> str: + dim_to_str = lambda dim: str(dim) if dim != -1 else "?" + return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) + + +@dataclass +class MatmulSize: + M: int + N: int + K: int + B: int = 1 + + +@dataclass +class ProblemSize: + matmul_size: MatmulSize + lhs_type: ShapedType + rhs_type: ShapedType + res_type: ShapedType + dispatch_kind: DispatchKind + + @property + def MNK(self) -> tuple[int, int, int]: + return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) + + +@dataclass +class MfmaIntrinsic: + output_type: ElementType + m: int + n: int + k: int + input_type: ElementType + + def __str__(self) -> str: + input = str(self.input_type).upper() + output = str(self.output_type).upper() + return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" + + @staticmethod + def mfma_f32_16x16x16_f16(): + return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16) + + @staticmethod + def mfma_f32_32x32x8_f16(): + return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16) + + @staticmethod + def mfma_i32_16x16x32_i8(): + return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8) + + @staticmethod + def mfma_i32_32x32x16_i8(): + return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8) + + @staticmethod + def all(): + return [ + MfmaIntrinsic.mfma_f32_16x16x16_f16(), + MfmaIntrinsic.mfma_f32_32x32x8_f16(), + MfmaIntrinsic.mfma_i32_16x16x32_i8(), + MfmaIntrinsic.mfma_i32_32x32x16_i8(), + ] + + +def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: + if problem_size.res_type.element_type != intrinsic.output_type: + return False + if problem_size.dispatch_kind != DispatchKind.batch_matmul: + if problem_size.lhs_type.element_type != intrinsic.input_type: + return False + if problem_size.rhs_type.element_type != intrinsic.input_type: + return False + return True + + return list(filter(is_compatible, MfmaIntrinsic.all())) + + +class ReorderWorkgroupsStrategy(Enum): + NONE = 0 + SWIZZLE = 1 + TRANSPOSE = 2 + + def __str__(self) -> str: + return self.name.title() + + +@dataclass +class GpuPipelineOptions: + """Represents the `iree_gpu.pipeline_options` attribute""" + + prefetch_shared_memory: Optional[bool] = None + no_reduce_shared_memory_bank_conflicts: Optional[bool] = None + reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None + + def all_default(self) -> bool: + return all(x is None for x in astuple(self)) + + def __str__(self) -> str: + options: list[str] = [] + if self.prefetch_shared_memory is not None: + options.append( + f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}" + ) + if self.no_reduce_shared_memory_bank_conflicts is not None: + options.append( + f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}" + ) + if self.reorder_workgroups_strategy is not None: + options.append( + f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}" + ) + + return f"#iree_gpu.pipeline_options<{', '.join(options)}>" + + +@dataclass +class Configuration: + subgroup_size: int + workgroup_size: list[int] + intrinsic: MfmaIntrinsic + tile_sizes: list[int] + subgroup_m_count: int + subgroup_n_count: int + gpu_pipeline_options: GpuPipelineOptions + waves_per_eu: int + + +def get_pipeline_config(configuration: Configuration) -> str: + extra_config = "" + if not configuration.gpu_pipeline_options.all_default(): + extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" + if configuration.waves_per_eu != 2: + extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' + return extra_config + + +class MlirRegex(Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + + def __str__(self) -> str: + return self.value + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + +def read_input_mlir(filename: str) -> list[str]: + with open(filename, "r") as f: + return f.readlines() + + +@dataclass +class ConvDimInfo: + n: int + oh: int + ow: int + oc: int + fh: int + fw: int + ic: int + + @staticmethod + def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): + n, oh, ow, oc = res_shaped_type.shape + fh, fw, ic, _ = rhs_shaped_type.shape + return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) + + @staticmethod + def from_problem_size(problem_size: ProblemSize): + return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) + + +def parse_tensor_type(tensor_type: str) -> ShapedType: + shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) + assert shape_match + + shape_str = shape_match.group(1) + dims_and_elem = shape_str.split("x") + dims = [int(x) for x in dims_and_elem[:-1]] + elem = dims_and_elem[-1] + str_to_elem_ty = {x.name: x for x in ElementType} + return ShapedType(dims, str_to_elem_ty[elem]) + + +@dataclass +class MLIRTransformation: + """Transformation of MLIR context""" + + template: list[str] + modified: str + embeddable: str diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py new file mode 100644 index 000000000..858d593c9 --- /dev/null +++ b/tuner/tuner/common_test.py @@ -0,0 +1,131 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +from . import common + + +def test_get_shaped_type_element_bitwidth() -> None: + assert common.ShapedType([1024, 2048], common.ElementType.i8).bitwidth == 8 + assert common.ShapedType([2048], common.ElementType.i32).bitwidth == 32 + assert common.ShapedType([2048, 512, 384], common.ElementType.f8).bitwidth == 8 + assert common.ShapedType([1, 1], common.ElementType.f16).bitwidth == 16 + + +def test_get_shaped_type_to_str() -> None: + assert str(common.ShapedType([1024, 2048], common.ElementType.i8)) == "1024x2048xi8" + assert str(common.ShapedType([1024], common.ElementType.f32)) == "1024xf32" + assert str(common.ShapedType([1, 2, 3], common.ElementType.f16)) == "1x2x3xf16" + assert str(common.ShapedType([-1, 2, 3], common.ElementType.f16)) == "?x2x3xf16" + + +def test_parse_tensor_type() -> None: + assert common.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( + [1, 2, 3], common.ElementType.f32 + ) + assert common.parse_tensor_type("tensor<123xi8>") == common.ShapedType( + [123], common.ElementType.i8 + ) + + +def test_gpu_pipeline_options() -> None: + options = common.GpuPipelineOptions() + assert options.all_default() + assert str(options) == "#iree_gpu.pipeline_options<>" + + options.prefetch_shared_memory = True + assert not options.all_default() + assert str(options) == "#iree_gpu.pipeline_options" + + options.no_reduce_shared_memory_bank_conflicts = False + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + options = common.GpuPipelineOptions() + options.reorder_workgroups_strategy = common.ReorderWorkgroupsStrategy.TRANSPOSE + assert not options.all_default() + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + +def test_get_pipeline_config() -> None: + config = common.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + config1_str: str = common.get_pipeline_config(config) + assert config1_str == "" + + config.waves_per_eu = 4 + config2_str: str = common.get_pipeline_config(config) + assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + + config.gpu_pipeline_options.prefetch_shared_memory = True + config3_str = common.get_pipeline_config(config) + assert ( + config3_str + == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_mfma_intrinsic_to_str() -> None: + assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" + assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" + + +def test_get_compatible_mfma_intrinsics() -> None: + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + common.DispatchKind.mmt, + ) + ) == [ + common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.i8), + common.ShapedType([1280, 1280], common.ElementType.i8), + common.ShapedType([2048, 1280], common.ElementType.i32), + common.DispatchKind.mmt, + ) + ) == [ + common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), + common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], common.ElementType.f32), + common.ShapedType([64, 640, 320], common.ElementType.f32), + common.ShapedType([64, 968, 320], common.ElementType.f32), + common.DispatchKind.batch_matmul, + ) + ) == [ + common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] From db718e79cdb087d849dd0906a7e689a7a1e2bac7 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 22:31:02 -0500 Subject: [PATCH 2/2] [tuner] Move parsing out of candidate_gen Signed-off-by: Jakub Kuderski --- tuner/tuner/candidate_gen.py | 441 ++-------------------------- tuner/tuner/candidate_gen_test.py | 353 ++++++---------------- tuner/tuner/common.py | 9 + tuner/tuner/dispatch_parser.py | 435 +++++++++++++++++++++++++++ tuner/tuner/dispatch_parser_test.py | 176 +++++++++++ 5 files changed, 733 insertions(+), 681 deletions(-) create mode 100644 tuner/tuner/dispatch_parser.py create mode 100644 tuner/tuner/dispatch_parser_test.py diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index e47100059..06ccae0e3 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -20,7 +20,6 @@ import argparse import logging -import math import pickle import re import z3 # type: ignore @@ -28,36 +27,16 @@ from os import path, makedirs from typing import Optional from textwrap import indent -from abc import ABC, abstractmethod +from abc import abstractmethod from iree.compiler import ir # type: ignore from .common import * +from .dispatch_parser import * tune_logger = logging.getLogger("tune") -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - def apply_configuration( template: list[str], configuration: Configuration, tile_sizes: list[int] ) -> str: @@ -282,29 +261,8 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -def parse_mlir(mlir_text: str, ctx: ir.Context) -> ir.Module: - mlir_module = None - try: - mlir_module = ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") - - return mlir_module - - -class DispatchTuner(ABC): - @abstractmethod - def supports(self, op_name: str) -> bool: - """Check if the tuner can handle the type of operation represented by the input string.""" - pass - - @abstractmethod - def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of the operation.""" - pass - +class DispatchTuner(DispatchParser): + # TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod def apply_params( self, @@ -316,12 +274,6 @@ def apply_params( pass -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_tuner: Optional[DispatchTuner] = None - - class DispatchTunerRegistry: def __init__(self): self.registry = set() @@ -345,60 +297,7 @@ def find_handler(self, op_name: str) -> DispatchTuner: assert False, "Dispatch kind not supported" -class MmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert False, f"'{mmt_re}' not found in given context" - +class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: @@ -450,71 +349,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class ConvTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - +class ConvTuner(DispatchTuner, ConvParser): # int64_t n = outputShape[0]; # int64_t oh = outputShape[1]; # int64_t ow = outputShape[2]; @@ -589,135 +424,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class ContractionTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - +class ContractionTuner(DispatchTuner, ContractionParser): def get_transform_function_broadcast_rhs_mmt( self, problem_size: ProblemSize, @@ -801,57 +508,7 @@ def apply_params( ) -class BatchMmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - +class BatchMmtTuner(DispatchTuner, BatchMmtParser): def get_transform_function_batch_mmt( self, problem_size: ProblemSize, @@ -910,78 +567,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - +class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): def get_transform_function_batch_matmul( self, problem_size: ProblemSize, @@ -1053,6 +639,12 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + + def walk_callback_get_fn( op: ir.Operation, walk_result: OpWalkResult, @@ -1106,7 +698,8 @@ def tune( mlir_text = "".join(mlir_template) with ir.Context() as ctx: - mlir_module: ir.Module = parse_mlir(mlir_text, ctx) + tuner_context = TunerContext(ctx, tune_logger) + mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context) # Save the input file as the first candidate. with open(path.join(output, f"0.mlir"), "w") as f: f.write(mlir_text) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 5f7b9c848..63e8441d1 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -9,189 +9,48 @@ """ import pytest -from . import candidate_gen - -from iree.compiler import ir # type: ignore -from iree.compiler.dialects import func # type: ignore - - -def test_get_mmt_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=1, - ) - assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 16, - ] - - -def test_get_contract_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, "knm") == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, "kkk") == [ - 16, - 16, - 16, - ] - - -def test_get_shapes_mmt() -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv() -> None: - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - -def test_get_shapes_contract() -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul() -> None: - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt() -> None: - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMmtTuner().get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) +from . import candidate_gen +from . import common def test_generate_solutions() -> None: - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(2048, 3840, 1280) + lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) + rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) + res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) configs = candidate_gen.generate_solutions(problem_size, 4) assert configs is not None def test_calculate_shared_memory_usage_in_bytes() -> None: - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 147456 ) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 81920 ) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) @@ -200,12 +59,12 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: def test_generate_constraints_valid_input() -> None: - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) # Define input parameters as z3 Ints m, n, k = ( @@ -246,12 +105,12 @@ def test_generate_constraints_valid_input() -> None: def test_generate_constraints_invalid_input() -> None: # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) m, n, k = ( candidate_gen.z3.Int("m"), @@ -307,25 +166,23 @@ def test_apply_params_mmt() -> None: M, N, K = 2048, 1280, 1280 - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions( - prefetch_shared_memory=True - ), + gpu_pipeline_options=common.GpuPipelineOptions(prefetch_shared_memory=True), waves_per_eu=8, ) - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, + problem_size = common.ProblemSize( + common.MatmulSize(M, N, K), + common.ShapedType([M, K], common.ElementType.f16), + common.ShapedType([N, K], common.ElementType.f16), + common.ShapedType([M, N], common.ElementType.f32), + common.DispatchKind.mmt, ) tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) @@ -361,27 +218,25 @@ def test_apply_params_conv() -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions( - reorder_workgroups_strategy=candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE + gpu_pipeline_options=common.GpuPipelineOptions( + reorder_workgroups_strategy=common.ReorderWorkgroupsStrategy.TRANSPOSE ), waves_per_eu=2, ) - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, + problem_size = common.ProblemSize( + common.MatmulSize(oh * ow, oc, fh * fw * ic), + common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16), + common.ShapedType([fh, fw, ic, oc], common.ElementType.f16), + common.ShapedType([n, oh, ow, oc], common.ElementType.f32), + common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( problem_size, mlir_template, config @@ -419,22 +274,22 @@ def test_apply_params_contract() -> None: ] tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, + problem_size = common.ProblemSize( + common.MatmulSize(2048, 3840, 1280), + common.ShapedType([2, 1024, 1280], common.ElementType.f16), + common.ShapedType([3, 20, 64, 1280], common.ElementType.f16), + common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32), + common.DispatchKind.contraction, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -466,22 +321,22 @@ def test_apply_params_batch_matmul() -> None: ] tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, + problem_size = common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], common.ElementType.f16), + common.ShapedType([64, 640, 320], common.ElementType.f16), + common.ShapedType([64, 968, 320], common.ElementType.f32), + common.DispatchKind.batch_matmul, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -516,22 +371,22 @@ def test_apply_params_batch_mmt_float() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.f16), + common.ShapedType([2, 640, 640], common.ElementType.f16), + common.ShapedType([2, 4096, 640], common.ElementType.f32), + common.DispatchKind.batch_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -564,22 +419,22 @@ def test_apply_params_batch_mmt_int() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([2, 640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.DispatchKind.batch_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=4, ) @@ -635,22 +490,22 @@ def test_apply_params_broadcast_rhs_mmt() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.DispatchKind.broadcast_rhs_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=4, ) @@ -711,19 +566,3 @@ def test_detect_broadcast_rhs_mmt() -> None: assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( mlir_lines ) - - -def test_parse_mlir() -> None: - with ir.Context() as ctx: - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str, ctx) - assert mlir_module is not None - assert isinstance(mlir_module, ir.Module) - assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 1fbeb9910..7b295cdb0 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -5,10 +5,19 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import re +import logging from dataclasses import astuple, dataclass from enum import Enum from typing import Optional +from iree.compiler import ir # type: ignore + + +class TunerContext: + def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): + self.mlir_ctx = mlir_ctx + self.logger = logger + class DispatchKind(Enum): conv = 1 diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py new file mode 100644 index 000000000..670f8c3f7 --- /dev/null +++ b/tuner/tuner/dispatch_parser.py @@ -0,0 +1,435 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import math +import re +from abc import ABCMeta, abstractmethod + +from .common import * + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: + mlir_module = None + try: + mlir_module = ir.Module.parse(mlir_text, ctx.mlir_ctx) + ctx.logger.info("MLIR parsing successful!") + except ir.MLIRError as e: + ctx.logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +class DispatchParser(metaclass=ABCMeta): + @abstractmethod + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" + pass + + @abstractmethod + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of the operation.""" + pass + + +class MmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + assert mmt_re + assert False, f"'{mmt_re}' not found in given context" + + +class ConvParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + +class ContractionParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not self.is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + +class BatchMmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + +class BatchMatmulParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py new file mode 100644 index 000000000..bcdee240c --- /dev/null +++ b/tuner/tuner/dispatch_parser_test.py @@ -0,0 +1,176 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest + +from logging import Logger +from unittest.mock import MagicMock + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import func # type: ignore + +from . import common +from . import dispatch_parser + + +def test_get_mmt_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=0, + ) + assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=1, + ) + assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] + + +def test_get_contract_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] + assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ + 16, + 16, + 16, + ] + + +def test_get_shapes_mmt() -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + dispatch_parser.DispatchKind.mmt, + ) + + +def test_get_shapes_conv() -> None: + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(32, 256, 11520), + common.ShapedType([1, 3, 34, 1280], common.ElementType.f16), + common.ShapedType([3, 3, 1280, 256], common.ElementType.f16), + common.ShapedType([1, 1, 32, 256], common.ElementType.f32), + dispatch_parser.DispatchKind.conv, + ) + + +def test_get_shapes_contract() -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.ContractionParser("mk", "nk", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + dispatch_parser.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul() -> None: + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMatmulParser("bmk", "bkn", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(32, 32, 1024, 1), + common.ShapedType([1, 32, 1024], common.ElementType.f32), + common.ShapedType([1, 1024, 32], common.ElementType.f32), + common.ShapedType([1, 32, 32], common.ElementType.f32), + dispatch_parser.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt() -> None: + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([2, 640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + dispatch_parser.DispatchKind.batch_mmt, + ) + + +def test_parse_mlir() -> None: + with ir.Context() as ctx: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + logger: Logger = MagicMock(spec=Logger) + tuner_context = common.TunerContext(ctx, logger) + mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_context) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp)