diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 5786a9fff..38696e6db 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -30,6 +30,8 @@ from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore + from .common import * from .dispatch_constraints import * from .dispatch_parser import * @@ -535,13 +537,19 @@ def tune( walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + # Get the MMA intrinisic intructions supported by the target. + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + dispatch_tuner = walk_result.dispatch_tuner assert dispatch_tuner, "No suitable dispatch tuner found" problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) configs = [] for i, config in enumerate( - generate_solutions(tune_logger, problem_size, num_subgroups) + generate_solutions(tune_logger, problem_size, num_subgroups, mma_list) ): if i >= limit: break diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a34f172eb..b6e31768e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -12,6 +12,8 @@ from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore + class CommonTypes: def __init__(self, ctx: ir.Context): @@ -130,7 +132,12 @@ def all(): ] -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: +def get_compatible_mfma_intrinsics( + problem_size: ProblemSize, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +) -> list[MfmaIntrinsic]: + available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.res_type.element_type != intrinsic.output_type: return False @@ -139,6 +146,10 @@ def is_compatible(intrinsic: MfmaIntrinsic) -> bool: return False if problem_size.rhs_type.element_type != intrinsic.input_type: return False + + if str(intrinsic) not in available_mma_intrinsics: + return False + return True return list(filter(is_compatible, MfmaIntrinsic.all())) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 891d703e2..297ac95a2 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -14,6 +14,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore @pytest.fixture @@ -109,7 +110,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.f16), common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.mmt, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], ) == [ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), @@ -122,7 +127,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.i8), common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.mmt, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) == [ common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), @@ -135,8 +144,44 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([64, 640, 320], tuner_ctx.type.f32), common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, - ) + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], ) == [ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], + ) == [ + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert ( + common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + == [] + ) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index edd7ccc38..85039a1e8 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -10,6 +10,9 @@ import z3 # type: ignore from typing import Iterator + +from iree.compiler.dialects import iree_gpu # type: ignore + from .common import * @@ -18,8 +21,9 @@ def get_mfma_intrinsic_constraints( intrinsic_m: z3.ArithRef, intrinsic_n: z3.ArithRef, intrinsic_k: z3.ArithRef, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" return z3.Or( *( @@ -68,6 +72,7 @@ def generate_constraints( subgroup_m_count, subgroup_n_count, waves_per_eu, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ): M, N, K = ( problem_size.matmul_size.M, @@ -82,7 +87,7 @@ def generate_constraints( constraints += [subgroup_size == 64, wg_threads <= 1024] constraints += [ get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics ) ] subgroup_k_count = 1 @@ -130,7 +135,10 @@ def generate_constraints( def generate_solutions( - logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int + logger: logging.Logger, + problem_size: ProblemSize, + num_subgrups: int, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> Iterator[Configuration]: M, N, K = problem_size.MNK logger.info(f"{M},{N},{K}") @@ -168,6 +176,7 @@ def generate_solutions( sg_m_cnt, sg_n_cnt, waves_per_eu, + mma_intrinsics, ) solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 7e1a5c55d..9de4beeee 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -14,6 +14,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_constraints @@ -37,7 +38,18 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) - configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4) + configs = dispatch_constraints.generate_solutions( + tuner_ctx.logger, + problem_size, + 4, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + assert configs is not None @@ -115,6 +127,12 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non sg_m_cnt, sg_n_cnt, waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) solver = z3.Solver() @@ -160,6 +178,12 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N sg_m_cnt, sg_n_cnt, waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], ) constraints.append(m > 1000) # Adding an additional unsatisfiable constraint