diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 40eb27a82..ce1a86d9b 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -27,7 +27,7 @@ from dataclasses import astuple, dataclass from enum import Enum from os import mkdir, path, makedirs -from typing import Optional +from typing import Optional, Generator from textwrap import indent from abc import ABC, abstractmethod @@ -46,6 +46,7 @@ class DispatchKind(Enum): batch_mmt = 4 batch_matmul = 5 broadcast_rhs_mmt = 6 + mmt4d = 7 class ElementType(Enum): @@ -187,20 +188,169 @@ def __str__(self) -> str: @dataclass -class Configuration: +class BaseConfiguration: + tile_sizes: list[int] + + @abstractmethod + def get_mmt_tile_sizes(self) -> list[int]: + return self.tile_sizes + + @abstractmethod + def get_batch_mmt_tile_sizes(self) -> list[int]: + return [1] + self.tile_sizes + + def get_contract_tile_sizes(self, tile_dims: list[str]) -> list[int]: + m, n, k = self.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + @abstractmethod + def get_pipeline_config(self) -> str: + pass + + @abstractmethod + def get_intrinsic_config(self) -> str: + pass + + def get_base_mlir_config(self, tile_sizes: list[int]) -> str: + tile_sizes_str = ", ".join(map(str, tile_sizes)) + return f""" + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + """ + + @abstractmethod + def get_mlir_config(self, tile_sizes: list[int]) -> str: + base_config = self.get_base_mlir_config(tile_sizes) + full_config = f""" + translation_info = #iree_codegen.translation_info + > -> !transform.any_param + """ + + return base_config + full_config + + @abstractmethod + def apply_configuration(self, template: list[str], tile_sizes: list[int]) -> str: + pass + + +@dataclass +class LLVMGPUConfiguration(BaseConfiguration): subgroup_size: int workgroup_size: list[int] intrinsic: MfmaIntrinsic - tile_sizes: list[int] subgroup_m_count: int subgroup_n_count: int gpu_pipeline_options: GpuPipelineOptions waves_per_eu: int + def get_pipeline_config(self) -> str: + extra_config = "" + if not self.gpu_pipeline_options.all_default(): + extra_config += f", gpu_pipeline_options = {self.gpu_pipeline_options}" + if self.waves_per_eu != 2: + extra_config += ( + f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{self.waves_per_eu}"}}' + ) + + return extra_config + + def get_intrinsic_config(self) -> str: + return str(self.intrinsic) + + def get_mlir_config(self, tile_sizes: list[int]) -> str: + base_config = self.get_base_mlir_config(tile_sizes) + wg_x, wg_y, wg_z = self.workgroup_size + + backend_config = f""" + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {self.subgroup_m_count}, subgroup_n_count = {self.subgroup_n_count}> + {self.get_pipeline_config()}}}> + > -> !transform.any_param + """ + + return base_config + backend_config + + def apply_configuration(self, template: list[str], tile_sizes: list[int]) -> str: + tune_logger.info(f"Applying: {self}") + + expr0 = re.compile( + r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" + ) + expr1 = re.compile( + r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," + ) + expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile( + r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>" + ) + expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + + repl0 = f", subgroup_m_count = {self.subgroup_m_count}, subgroup_n_count = {self.subgroup_n_count}>" + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, self.workgroup_size))}] subgroup_size = {self.subgroup_size},' + repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' + repl3 = f"gpu_pipeline_options = {self.gpu_pipeline_options}" + repl4 = f'"amdgpu-waves-per-eu" = "{self.waves_per_eu}"' + + new_mlir = "" + for line in template: + if "intrinsic =" in line: + line = re.sub(expr0, repl0, line) + if "LLVMGPUVectorDistribute " in line: + line = re.sub(expr1, repl1, line) + if "tile_sizes" in line: + line = re.sub(expr2, repl2, line) + if "gpu_pipeline_options =" in line: + line = re.sub(expr3, repl3, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr4, repl4, line) + new_mlir += line + + return new_mlir + + +@dataclass +class LLVMCPUConfiguration(BaseConfiguration): + def get_mlir_config(self, tile_sizes: list[int]) -> str: + base_config = self.get_base_mlir_config(tile_sizes) + + backend_config = f""" + translation_info = #iree_codegen.translation_info + > -> !transform.any_param + """ + + return base_config + backend_config + + def apply_configuration(self, template: list[str], tile_sizes: list[int]) -> str: + tune_logger.info(f"Applying: {self}") + + expr0 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") + + repl0 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' + + new_mlir = "" + for line in template: + if "tile_sizes" in line: + line = re.sub(expr0, repl0, line) + new_mlir += line + + return new_mlir + class MlirRegex(Enum): ssa_value = r"%[a-zA-Z0-9-_]+" tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + device_target = r'#hal\.device\.target<"(?P[a-zA-Z0-9-_]+)"' def __str__(self) -> str: return self.value @@ -219,10 +369,6 @@ def read_input_mlir(filename: str) -> list[str]: return f.readlines() -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - @dataclass class ConvDimInfo: n: int @@ -244,68 +390,6 @@ def from_problem_size(problem_size: ProblemSize): return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = "" - if not configuration.gpu_pipeline_options.all_default(): - extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config - - -def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] -) -> str: - tune_logger.info(f"Applying: {configuration}") - expr0 = re.compile( - r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" - ) - expr1 = re.compile( - r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," - ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") - expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" - repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' - - new_mlir = "" - for line in template: - if "intrinsic =" in line: - line = re.sub(expr0, repl0, line) - if "LLVMGPUVectorDistribute " in line: - line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: - line = re.sub(expr2, repl2, line) - if "gpu_pipeline_options =" in line: - line = re.sub(expr3, repl3, line) - if "amdgpu-waves-per-eu" in line: - line = re.sub(expr4, repl4, line) - new_mlir += line - - return new_mlir - - def parse_tensor_type(tensor_type: str) -> ShapedType: shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) assert shape_match @@ -377,141 +461,239 @@ def calculate_shared_memory_usage_in_bytes( return lhs_memory + rhs_memory -def generate_constraints( - problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, -): - M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, - ) - m, n, k = tile_sizes - intrinsic_mn, intrinsic_k = intrinsic_size - wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_threads == wg_x * wg_y * wg_z] - constraints += [subgroup_size == 64, wg_threads <= 1024] - constraints += [ - get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k +class SolutionGenerationStrategy(ABC): + @abstractmethod + def generate_solutions( + self, problem_size: ProblemSize, num_subgroups: int + ) -> Generator[BaseConfiguration, None, None]: + pass + + +class LLVMGPUSolutionStrategy(SolutionGenerationStrategy): + def generate_solutions( + self, + problem_size: ProblemSize, + num_subgroups: int, + ): + M, N, K = problem_size.MNK + tune_logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = self.generate_constraints( + problem_size, + [m, n, k], + num_subgroups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, ) - ] - subgroup_k_count = 1 - constraints += [ - m >= intrinsic_mn, - m <= 512, - m <= M, - ] - constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] - constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] - for x in (subgroup_m_count, subgroup_n_count): - constraints += [x >= 1, x <= 32] - - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - - constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] - constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] - constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] - constraints += [wg_x == subgroup_size * subgroup_n_count] - constraints += [wg_y == subgroup_m_count] - constraints += [wg_z == subgroup_k_count] - constraints += [z3.Or(wg_x <= n, wg_x <= m)] - constraints += [k % intrinsic_mn == 0] - constraints += [(k * n) % wg_threads == 0] - constraints += [(k * m) % wg_threads == 0] - subgroups = subgroup_m_count * subgroup_n_count - if num_subgroups > 0: - constraints += [subgroups == num_subgroups] - else: - constraints += [subgroups >= 1, subgroups <= 10] - - constraints += [waves_per_eu == 2] - # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) - constraints += [shared_memory <= 65536] - - constraints += get_dispatch_constraints(problem_size, m, n, k) - - return constraints - - -def generate_solutions(problem_size: ProblemSize, num_subgrups: int): - M, N, K = problem_size.MNK - tune_logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") - subgroup_size = z3.Int("subgroup_size") - intrinsic_mn = z3.Int("intrinsic_mn") - intrinsic_k = z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") - sg_m_cnt = z3.Int("sg_m_cnt") - sg_n_cnt = z3.Int("sg_n_cnt") - waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] + solver.add(z3.simplify(z3.And(constraints))) + tune_logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = LLVMGPUConfiguration( + subgroup_size=lookup(subgroup_size), + workgroup_size=[lookup(wg_x), lookup(wg_y), lookup(wg_z)], + tile_sizes=[lookup(m), lookup(n), lookup(k)], + subgroup_m_count=lookup(sg_m_cnt), + subgroup_n_count=lookup(sg_n_cnt), + gpu_pipeline_options=GpuPipelineOptions(), + waves_per_eu=lookup(waves_per_eu), + intrinsic=MfmaIntrinsic( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + ), + ) + + solver.add( + z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))) + ) + i += 1 + yield config - solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, + def generate_constraints( + self, + problem_size: ProblemSize, + tile_sizes, + num_subgroups, subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, waves_per_eu, - ) - solver.add(z3.simplify(z3.And(constraints))) - tune_logger.debug(f"Initial constraints: {solver}") - i = 0 - while solver.check() == z3.sat: - model = solver.model() - lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( - problem_size.res_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.lhs_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), - GpuPipelineOptions(), - lookup(waves_per_eu), + ): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +class LLVMCPUSolutionStrategy(SolutionGenerationStrategy): + def generate_solutions( + self, + problem_size: ProblemSize, + num_subgroups: int, + ): + M, N, K = problem_size.MNK + tune_logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + all_vars = [ + m, + n, + k, + ] + + solver = z3.Solver() + constraints = self.generate_constraints( + problem_size, + [m, n, k], + ) + solver.add(z3.simplify(z3.And(constraints))) + tune_logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = LLVMCPUConfiguration( + tile_sizes=[lookup(m), lookup(n), lookup(k)], + ) + + solver.add( + z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))) + ) + i += 1 + yield config + + def generate_constraints( + self, + problem_size: ProblemSize, + tile_sizes, + ) -> list: + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, ) - solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) - i += 1 - yield config + m, n, k = tile_sizes + + constraints = [] + constraints += [m >= 1, m <= 512, m <= M] + constraints += [n >= 1, n <= 512, n <= N, N % n == 0] + constraints += [n >= 1, k <= 512, k <= K, K % k == 0] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +@dataclass +class SolutionStrategyFactory: + @staticmethod + def create_strategy(mlir_text: str) -> SolutionGenerationStrategy: + match = re.search(MlirRegex.device_target.value, mlir_text) + + if not match: + raise ValueError("No target found") + + target = match.group("target") + return SolutionStrategyFactory.get_strategy(target) + + @staticmethod + def get_strategy(target: str) -> SolutionGenerationStrategy: + if target == "local": + return LLVMCPUSolutionStrategy() + else: + return LLVMGPUSolutionStrategy() def get_default_output_dir() -> str: @@ -558,7 +740,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: """Apply parameter transformations to the operation.""" pass @@ -584,6 +766,12 @@ def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: "LLVMGPUVectorDistribute" in str(attr.attr) ): return True + + if (attr.name == "translation_info") and ( + "Mmt4dTilingExpert" in str(attr.attr) + ): + return True + assert False, "Translation info not supported" def find_handler(self, op_name: str) -> DispatchTuner: @@ -601,6 +789,7 @@ def get_shapes(self, template: list[str]) -> ProblemSize: mmt_re = None dps = None for line in template: + if "linalg.generic" not in line: continue if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: @@ -644,16 +833,18 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_type=res_shaped_type, dispatch_kind=DispatchKind.mmt, ) + assert mmt_re assert dps, f"'{mmt_re}' not found in given context" def get_transform_function_mmt( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration + self, + problem_size: ProblemSize, + functionName: str, + configuration: BaseConfiguration, ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) + tile_sizes = configuration.get_mmt_tile_sizes() + config_mlir = configuration.get_mlir_config(tile_sizes) return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -662,15 +853,7 @@ def get_transform_function_mmt( %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param + {config_mlir} transform.yield %matmul, %config : !transform.any_op, !transform.any_param }} """ @@ -679,7 +862,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( @@ -688,8 +871,110 @@ def apply_params( ), "// ", ) - modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) + modified += configuration.apply_configuration( + template, configuration.get_mmt_tile_sizes() + ) + embeddable = indent( + self.get_transform_function_mmt(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class Mmt4dTuner(DispatchTuner): + def supports(self, op_name: str) -> bool: + return "mmt4d" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + + if "linalg.mmt4d" not in line: + continue + # ins(%3, %4 : tensor<256x1280x8x1xf16>, tensor<1280x1280x8x1xf16>) outs(%6 : tensor<256x1280x8x8xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + lhs_M = lhs_shaped_type.shape[0] * lhs_shaped_type.shape[2] + lhs_K = lhs_shaped_type.shape[1] * lhs_shaped_type.shape[3] + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + rhs_N = rhs_shaped_type.shape[0] * rhs_shaped_type.shape[2] + rhs_K = rhs_shaped_type.shape[1] * rhs_shaped_type.shape[3] + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + res_M = res_shaped_type.shape[0] * res_shaped_type.shape[2] + res_N = res_shaped_type.shape[1] * res_shaped_type.shape[3] + + assert lhs_K == rhs_K + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_M, + rhs_N, + lhs_K, + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt4d, + ) + + assert mmt_re + assert dps, f"'{mmt_re}' not found in given context" + + def get_transform_function_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: BaseConfiguration, + ) -> str: + tile_sizes = configuration.get_mmt_tile_sizes() + config_mlir = configuration.get_mlir_config(tile_sizes) + + return f""" + transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_mmt4d_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + {config_mlir} + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: BaseConfiguration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_mmt( + problem_size, f"match_mmt4d_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += configuration.apply_configuration( + template, configuration.get_mmt_tile_sizes() ) embeddable = indent( self.get_transform_function_mmt(problem_size, f"match_op", configuration), @@ -702,7 +987,7 @@ class ConvTuner(DispatchTuner): def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + def get_conv_tile_sizes(self, configuration: BaseConfiguration) -> list[int]: m, n, k = configuration.tile_sizes batch = 1 fh = 1 @@ -771,7 +1056,10 @@ def get_shapes(self, template: list[str]) -> ProblemSize: # int64_t fw = filterShape[1]; # int64_t ic = filterShape[2]; def get_transform_function_conv( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration + self, + problem_size: ProblemSize, + functionName: str, + configuration: BaseConfiguration, ) -> str: dynamic_batch_input_ty = problem_size.lhs_type dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() @@ -785,10 +1073,8 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) + tile_sizes = self.get_conv_tile_sizes(configuration) + config_mlir = configuration.get_mlir_config(tile_sizes) return f""" transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) @@ -799,15 +1085,7 @@ def get_transform_function_conv( ins(%lhs, %rhs : {input}, {filter}) outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param + {config_mlir} transform.yield %conv, %config : !transform.any_op, !transform.any_param }} """ @@ -816,7 +1094,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: conv_dims = ConvDimInfo.from_problem_size(problem_size) modified = indent( @@ -827,8 +1105,8 @@ def apply_params( ), "// ", ) - modified += apply_configuration( - template, configuration, self.get_conv_tile_sizes(configuration) + modified += configuration.apply_configuration( + template, self.get_conv_tile_sizes(configuration) ) embeddable = indent( self.get_transform_function_conv(problem_size, f"match_op", configuration), @@ -970,12 +1248,10 @@ def get_transform_function_broadcast_rhs_mmt( self, problem_size: ProblemSize, functionName: str, - configuration: Configuration, + configuration: BaseConfiguration, ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) + tile_sizes = configuration.get_batch_mmt_tile_sizes() + config_mlir = configuration.get_mlir_config(tile_sizes) lhs_dynamic_batch = problem_size.lhs_type lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() @@ -988,15 +1264,7 @@ def get_transform_function_broadcast_rhs_mmt( %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param +{config_mlir} transform.yield %generic, %config : !transform.any_op, !transform.any_param }} """ @@ -1005,7 +1273,7 @@ def apply_params_broadcast_rhs_mmt( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( @@ -1014,8 +1282,8 @@ def apply_params_broadcast_rhs_mmt( ), "// ", ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) + modified += configuration.apply_configuration( + template, configuration.get_batch_mmt_tile_sizes() ) embeddable = indent( @@ -1030,7 +1298,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: if self.is_broadcast_rhs_mmt(template): return self.apply_params_broadcast_rhs_mmt( @@ -1040,10 +1308,9 @@ def apply_params( # TODO: Generate transform function. return MLIRTransformation( template, - apply_configuration( + configuration.apply_configuration( template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), + configuration.get_contract_tile_sizes(self.tile_dims), ), "", ) @@ -1104,12 +1371,10 @@ def get_transform_function_batch_mmt( self, problem_size: ProblemSize, functionName: str, - configuration: Configuration, + configuration: BaseConfiguration, ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) + tile_sizes = configuration.get_batch_mmt_tile_sizes() + config_mlir = configuration.get_mlir_config(tile_sizes) return f""" transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -1118,15 +1383,7 @@ def get_transform_function_batch_mmt( %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param +{config_mlir} transform.yield %generic, %config : !transform.any_op, !transform.any_param }} """ @@ -1135,7 +1392,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: M, N, K = problem_size.MNK B = problem_size.matmul_size.B @@ -1145,8 +1402,8 @@ def apply_params( ), "// ", ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) + modified += configuration.apply_configuration( + template, configuration.get_batch_mmt_tile_sizes() ) embeddable = indent( @@ -1235,18 +1492,14 @@ def get_transform_function_batch_matmul( problem_size: ProblemSize, tile_dims: str, functionName: str, - configuration: Configuration, + configuration: BaseConfiguration, ) -> str: input0 = f"tensor<{problem_size.lhs_type}>" input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) - ) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) + tile_sizes = configuration.get_contract_tile_sizes(tile_dims) + config_mlir = configuration.get_mlir_config(tile_sizes) return f""" transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) @@ -1257,15 +1510,7 @@ def get_transform_function_batch_matmul( ins(%lhs, %rhs : {input0}, {input1}) outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param + {config_mlir} transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param }} """ @@ -1274,7 +1519,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + configuration: BaseConfiguration, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( @@ -1286,10 +1531,9 @@ def apply_params( ), "// ", ) - modified += apply_configuration( + modified += configuration.apply_configuration( template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), + configuration.get_contract_tile_sizes(self.tile_dims), ) embeddable = indent( @@ -1362,6 +1606,7 @@ def tune( dispatch_tuner_registry.register( [ MmtTuner(), + Mmt4dTuner(), ConvTuner(), ContractionTuner(lhs_dims, rhs_dims, tile_dims), BatchMmtTuner(), @@ -1374,8 +1619,12 @@ def tune( dispatch_tuner = walk_result.dispatch_tuner problem_size = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) + + solutions_strategy = SolutionStrategyFactory.create_strategy(mlir_text) configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + for i, config in enumerate( + solutions_strategy.generate_solutions(problem_size, num_subgroups) + ): if i >= limit: break tune_logger.info(f"Solution #{i+1}: {config}") diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 2924db75b..c259a2096 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -60,29 +60,15 @@ def test_parse_tensor_type(): def test_get_mmt_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic="", + config = candidate_gen.BaseConfiguration( tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=0, ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] + assert config.get_mmt_tile_sizes() == [128, 320, 32] def test_get_conv_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", + config = candidate_gen.BaseConfiguration( tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=1, ) assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ 1, @@ -122,28 +108,26 @@ def test_gpu_pipeline_options(): def test_get_contract_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", + config = candidate_gen.BaseConfiguration( tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=2, ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ - 16, - 16, - 16, - ] + assert config.get_contract_tile_sizes(["m", "n", "k"]) == [4, 8, 16] + assert config.get_contract_tile_sizes(["n", "m", "k"]) == [8, 4, 16] + assert config.get_contract_tile_sizes(["k", "n", "m"]) == [16, 8, 4] + assert config.get_contract_tile_sizes(["k", "k", "k"]) == [16, 16, 16] -def test_get_pipeline_config(): - config = candidate_gen.Configuration( +def test_get_pipeline_config_base(): + config = candidate_gen.BaseConfiguration( + tile_sizes=[4, 8, 16], + ) + + config1_str: str = config.get_pipeline_config() + assert config1_str == None + + +def test_get_pipeline_config_llvmgpu(): + config = candidate_gen.LLVMGPUConfiguration( subgroup_size=32, workgroup_size=[16, 16, 1], intrinsic="", @@ -153,21 +137,29 @@ def test_get_pipeline_config(): gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) - config1_str: str = candidate_gen.get_pipeline_config(config) + config1_str: str = config.get_pipeline_config() assert config1_str == "" config.waves_per_eu = 4 - config2_str: str = candidate_gen.get_pipeline_config(config) + config2_str: str = config.get_pipeline_config() assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' config.gpu_pipeline_options.prefetch_shared_memory = True - config3_str = candidate_gen.get_pipeline_config(config) + config3_str = config.get_pipeline_config() assert ( config3_str == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' ) +def test_get_pipeline_config_llvmcpu(): + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[4, 8, 16], + ) + config1_str: str = config.get_pipeline_config() + assert config1_str == None + + def test_get_shapes_mmt(): template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", @@ -184,6 +176,27 @@ def test_get_shapes_mmt(): ) +def test_get_shapes_mmt4d(): + template = [ + r"%5 = tensor.empty() : tensor<256x1280x8x8xf32>", + r"%6 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f16) outs(%5 : tensor<256x1280x8x8xf32>) -> tensor<256x1280x8x8xf32>", + r"%7 = linalg.mmt4d {lowering_config = #iree_codegen.lowering_config} ins(%3, %4 : tensor<256x1280x8x1xf16>, tensor<1280x1280x8x1xf16>) outs(%6 : tensor<256x1280x8x8xf32>) -> tensor<256x1280x8x8xf32>", + ] + assert candidate_gen.Mmt4dTuner().get_shapes(template) == candidate_gen.ProblemSize( + matmul_size=candidate_gen.MatmulSize(2048, 10240, 1280), + lhs_type=candidate_gen.ShapedType( + [256, 1280, 8, 1], candidate_gen.ElementType.f16 + ), + rhs_type=candidate_gen.ShapedType( + [1280, 1280, 8, 1], candidate_gen.ElementType.f16 + ), + res_type=candidate_gen.ShapedType( + [256, 1280, 8, 8], candidate_gen.ElementType.f32 + ), + dispatch_kind=candidate_gen.DispatchKind.mmt4d, + ) + + def test_get_shapes_conv(): template = [ r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", @@ -303,7 +316,22 @@ def test_get_compatible_mfma_intrinsics(): ] -def test_generate_solutions(): +def test_generate_solutions_llvmgpu(): + matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) + lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + + gpu_strategy = candidate_gen.LLVMGPUSolutionStrategy() + configs = gpu_strategy.generate_solutions(problem_size, 4) + + assert configs is not None + + +def test_generate_solutions_llvmcpu(): matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) @@ -311,7 +339,10 @@ def test_generate_solutions(): problem_size = candidate_gen.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) - configs = candidate_gen.generate_solutions(problem_size, 4) + + cpu_strategy = candidate_gen.LLVMCPUSolutionStrategy() + configs = cpu_strategy.generate_solutions(problem_size, 4) + assert configs is not None @@ -347,7 +378,7 @@ def test_calculate_shared_memory_usage_in_bytes(): ) -def test_generate_constraints_valid_input(): +def test_generate_constraints_valid_input_llvmgpu(): matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -373,7 +404,8 @@ def test_generate_constraints_valid_input(): sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - constraints = candidate_gen.generate_constraints( + solution_strategy = candidate_gen.LLVMGPUSolutionStrategy() + constraints = solution_strategy.generate_constraints( problem_size, [m, n, k], 4, @@ -392,7 +424,33 @@ def test_generate_constraints_valid_input(): assert solver.check() == candidate_gen.z3.sat -def test_generate_constraints_invalid_input(): +def test_generate_constraints_valid_input_llvmcpu(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + + solution_strategy = candidate_gen.LLVMCPUSolutionStrategy() + constraints = solution_strategy.generate_constraints( + problem_size, + [m, n, k], + ) + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + assert solver.check() == candidate_gen.z3.sat + + +def test_generate_constraints_invalid_input_llvmgpu(): # Define input parameters that should lead to unsatisfiable constraints matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -418,7 +476,8 @@ def test_generate_constraints_invalid_input(): sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - constraints = candidate_gen.generate_constraints( + solution_strategy = candidate_gen.LLVMGPUSolutionStrategy() + constraints = solution_strategy.generate_constraints( problem_size, [m, n, k], 4, @@ -438,13 +497,42 @@ def test_generate_constraints_invalid_input(): assert solver.check() == candidate_gen.z3.unsat +def test_generate_constraints_invalid_input_llvmcpu(): + matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) + lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) + res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) + problem_size = candidate_gen.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + ) + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) + + solution_strategy = candidate_gen.LLVMCPUSolutionStrategy() + constraints = solution_strategy.generate_constraints( + problem_size, + [m, n, k], + ) + + constraints.append(m > 1000) # Tile size should never be negative + + solver = candidate_gen.z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == candidate_gen.z3.unsat + + def remove_comments(mlir: str) -> str: return "\n".join( filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) ) -def test_apply_params_mmt(): +def test_apply_params_mmt_llvmgpu(): mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[8, 8, 8], + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + assert embeddable + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[8, 8, 8]]" in modified + + +def test_apply_params_mmt4d_llvmgpu(): + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + "gpu_pipeline_options = #iree_gpu.pipeline_options", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.LLVMGPUConfiguration( + subgroup_size=16, + workgroup_size=[16, 16, 1], + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[8, 8, 8], + subgroup_m_count=16, + subgroup_n_count=16, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions( + prefetch_shared_memory=True + ), + waves_per_eu=8, + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + tf_mlir = candidate_gen.Mmt4dTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" + in modified + ) + assert "tile_sizes = [[8, 8, 8]]" in modified + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in modified + ) + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified + + +def test_apply_params_mmt4d_llvmcpu(): + mlir_template = [ + "", + ] + + M, N, K = 2048, 1280, 1280 + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[8, 8, 8], + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(M, N, K), + candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.mmt, + ) + tf_mlir = candidate_gen.Mmt4dTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + assert embeddable + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[8, 8, 8]]" in modified + + +def test_apply_params_conv_llvmgpu(): mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + ] + + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[464, 320, 16], + ) + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), + candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.conv, + ) + tf_mlir = candidate_gen.ConvTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + + assert embeddable + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + + +def test_apply_params_contract_llvmgpu(): mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", "", + ] + + tile_dims = "*mnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(2048, 3840, 1280), + candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.contraction, + ) + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[480, 384, 32], + ) + + tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + new_mlir = tf_mlir.modified + + assert new_mlir + assert "Mmt4dTilingExpert" in new_mlir + assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + + +def test_apply_params_batch_matmul_llvmgpu(): mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + ] + + tile_dims = "bmnk" + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(968, 320, 640, 64), + candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_matmul, + ) + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[416, 320, 128], + ) + + tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + + assert embeddable + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[1, 416, 320, 128]]" in modified + + +def test_apply_params_batch_mmt_float_llvmgpu(): mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[128, 64, 128], + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert embeddable + assert modified + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + + +def test_apply_params_batch_mmt_int_llvmgpu(): mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.batch_mmt, + ) + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[128, 64, 128], + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + modified = remove_comments(modified) + + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + + +def test_apply_params_broadcast_rhs_mmt_llvmgpu(): mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", "", + ] + + problem_size = candidate_gen.ProblemSize( + candidate_gen.MatmulSize(4096, 640, 640, 2), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), + candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), + candidate_gen.DispatchKind.broadcast_rhs_mmt, + ) + + config = candidate_gen.LLVMCPUConfiguration( + tile_sizes=[128, 64, 128], + ) + + tf_mlir = candidate_gen.ContractionTuner( + "mk", "nk", "mnk" + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert ( + "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" + in modified + ) + modified = remove_comments(modified) + + assert "Mmt4dTilingExpert" in modified + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + + def test_detect_broadcast_rhs_mmt(): mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 91c7b417a..022cfad2a 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -64,7 +64,7 @@ class CandidateTracker: candidate_id: int dispatch_mlir_path: Optional[Path] = None dispatch_config_path: Optional[Path] = None - configuration: Optional[candidate_gen.Configuration] = None + configuration: Optional[candidate_gen.BaseConfiguration] = None compilation_successful: Optional[bool] = None compiled_dispatch_path: Optional[Path] = None compiled_dispatch_hash: Optional[str] = None