From d080ce1a1530f062f85cc1c7f5e5d1fb124b7e0e Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 10:23:27 -0500 Subject: [PATCH] [tuner] Clean up candidate generation code (#508) * Pass context explicitly to the parse function * Fix mypy typing violations * Add pyproject.toml --- tuner/pyproject.toml | 24 +++++++ tuner/setup.py | 35 ++++++++++ tuner/tuner/candidate_gen.py | 108 +++++++++++++++--------------- tuner/tuner/candidate_gen_test.py | 102 ++++++++++++++-------------- tuner/tuner/py.typed | 0 tuner/version.json | 3 + 6 files changed, 167 insertions(+), 105 deletions(-) create mode 100644 tuner/pyproject.toml create mode 100644 tuner/setup.py create mode 100644 tuner/tuner/py.typed create mode 100644 tuner/version.json diff --git a/tuner/pyproject.toml b/tuner/pyproject.toml new file mode 100644 index 000000000..1661a7744 --- /dev/null +++ b/tuner/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "SHARK Tuner" +authors = [ + {name = "SHARK Authors"}, +] +description = "IREE Dispatch Tuner" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">= 3.10" + +# Version is set via the `setup.py`. +dynamic = ["version"] + +[project.urls] +Repository = "https://github.com/nod-ai/SHARK-Platform" diff --git a/tuner/setup.py b/tuner/setup.py new file mode 100644 index 000000000..aa450eaee --- /dev/null +++ b/tuner/setup.py @@ -0,0 +1,35 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os + +from setuptools import setup + +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) + +# Setup and get version information. +VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json") + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +try: + version_info = load_version_info(VERSION_FILE_LOCAL) +except FileNotFoundError: + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) + +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") + +setup( + version=f"{PACKAGE_VERSION}", +) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 40eb27a82..96bfc7146 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -23,18 +23,15 @@ import math import pickle import re -import z3 +import z3 # type: ignore from dataclasses import astuple, dataclass from enum import Enum -from os import mkdir, path, makedirs +from os import path, makedirs from typing import Optional from textwrap import indent from abc import ABC, abstractmethod -import iree.compiler as ireec -from iree.compiler import ir -from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen - +from iree.compiler import ir # type: ignore tune_logger = logging.getLogger("tune") @@ -520,15 +517,14 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -def parse_mlir(mlir_text: str) -> ir.Module: +def parse_mlir(mlir_text: str, ctx: ir.Context) -> ir.Module: mlir_module = None - with ireec.ir.Context() as context: - try: - mlir_module = ireec.ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ireec.ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") + try: + mlir_module = ir.Module.parse(mlir_text) + tune_logger.info("MLIR parsing successful!") + except ir.MLIRError as e: + tune_logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") return mlir_module @@ -537,7 +533,7 @@ def parse_mlir(mlir_text: str) -> ir.Module: class MLIRTransformation: """Transformation of MLIR context""" - template: str + template: list[str] modified: str embeddable: str @@ -550,7 +546,7 @@ def supports(self, op_name: str) -> bool: @abstractmethod def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of thge operation.""" + """Extract problem size of the operation.""" pass @abstractmethod @@ -645,7 +641,7 @@ def get_shapes(self, template: list[str]) -> ProblemSize: dispatch_kind=DispatchKind.mmt, ) assert mmt_re - assert dps, f"'{mmt_re}' not found in given context" + assert False, f"'{mmt_re}' not found in given context" def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration @@ -1353,45 +1349,47 @@ def tune( mlir_template = read_input_mlir(input_file) mlir_text = "".join(mlir_template) - mlir_module = parse_mlir(mlir_text) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - problem_size = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) + with ir.Context() as ctx: + mlir_module: ir.Module = parse_mlir(mlir_text, ctx) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") def main(): diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 2924db75b..a1a3a3e49 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -11,8 +11,11 @@ import pytest from . import candidate_gen +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import func # type: ignore -def test_get_shaped_type_element_bitwidth(): + +def test_get_shaped_type_element_bitwidth() -> None: assert ( candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth == 8 @@ -31,7 +34,7 @@ def test_get_shaped_type_element_bitwidth(): ) -def test_get_shaped_type_to_str(): +def test_get_shaped_type_to_str() -> None: assert ( str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) == "1024x2048xi8" @@ -50,7 +53,7 @@ def test_get_shaped_type_to_str(): ) -def test_parse_tensor_type(): +def test_parse_tensor_type() -> None: assert candidate_gen.parse_tensor_type( "tensor<1x2x3xf32>" ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) @@ -59,11 +62,11 @@ def test_parse_tensor_type(): ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) -def test_get_mmt_tile_sizes(): +def test_get_mmt_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic="", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, @@ -73,11 +76,11 @@ def test_get_mmt_tile_sizes(): assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] -def test_get_conv_tile_sizes(): +def test_get_conv_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -95,7 +98,7 @@ def test_get_conv_tile_sizes(): ] -def test_gpu_pipeline_options(): +def test_gpu_pipeline_options() -> None: options = candidate_gen.GpuPipelineOptions() assert options.all_default() assert str(options) == "#iree_gpu.pipeline_options<>" @@ -121,32 +124,32 @@ def test_gpu_pipeline_options(): ) -def test_get_contract_tile_sizes(): +def test_get_contract_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic="", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + assert candidate_gen.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, "knm") == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, "kkk") == [ 16, 16, 16, ] -def test_get_pipeline_config(): +def test_get_pipeline_config() -> None: config = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic="", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, @@ -168,7 +171,7 @@ def test_get_pipeline_config(): ) -def test_get_shapes_mmt(): +def test_get_shapes_mmt() -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -184,7 +187,7 @@ def test_get_shapes_mmt(): ) -def test_get_shapes_conv(): +def test_get_shapes_conv() -> None: template = [ r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", @@ -199,7 +202,7 @@ def test_get_shapes_conv(): ) -def test_get_shapes_contract(): +def test_get_shapes_contract() -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -217,7 +220,7 @@ def test_get_shapes_contract(): ) -def test_get_shapes_batch_matmul(): +def test_get_shapes_batch_matmul() -> None: template = [ "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", @@ -234,7 +237,7 @@ def test_get_shapes_batch_matmul(): ) -def test_get_shapes_batch_mmt(): +def test_get_shapes_batch_mmt() -> None: template = [ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', @@ -251,7 +254,7 @@ def test_get_shapes_batch_mmt(): ) -def test_mfma_intrinsic_to_str(): +def test_mfma_intrinsic_to_str() -> None: assert ( str(candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" @@ -262,7 +265,7 @@ def test_mfma_intrinsic_to_str(): ) -def test_get_compatible_mfma_intrinsics(): +def test_get_compatible_mfma_intrinsics() -> None: assert candidate_gen.get_compatible_mfma_intrinsics( candidate_gen.ProblemSize( candidate_gen.MatmulSize(2048, 1280, 1280), @@ -303,7 +306,7 @@ def test_get_compatible_mfma_intrinsics(): ] -def test_generate_solutions(): +def test_generate_solutions() -> None: matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) @@ -315,7 +318,7 @@ def test_generate_solutions(): assert configs is not None -def test_calculate_shared_memory_usage_in_bytes(): +def test_calculate_shared_memory_usage_in_bytes() -> None: matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -347,7 +350,7 @@ def test_calculate_shared_memory_usage_in_bytes(): ) -def test_generate_constraints_valid_input(): +def test_generate_constraints_valid_input() -> None: matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -392,7 +395,7 @@ def test_generate_constraints_valid_input(): assert solver.check() == candidate_gen.z3.sat -def test_generate_constraints_invalid_input(): +def test_generate_constraints_invalid_input() -> None: # Define input parameters that should lead to unsatisfiable constraints matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -444,7 +447,7 @@ def remove_comments(mlir: str) -> str: ) -def test_apply_params_mmt(): +def test_apply_params_mmt() -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", @@ -861,18 +864,17 @@ def test_detect_broadcast_rhs_mmt(): ) -def test_parse_mlir(): - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str) - assert mlir_module != None - assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance( - mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp - ) +def test_parse_mlir() -> None: + with ir.Context() as ctx: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + mlir_module = candidate_gen.parse_mlir(mlir_str, ctx) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/tuner/py.typed b/tuner/tuner/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/tuner/version.json b/tuner/version.json new file mode 100644 index 000000000..794a2de28 --- /dev/null +++ b/tuner/version.json @@ -0,0 +1,3 @@ +{ + "package-version": "2.9.1.dev" +}