From 34757e9f1a6be1a59e9d8dea16856dac5b8337b2 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 25 Nov 2024 15:34:29 -0600 Subject: [PATCH] [tunner]: address comments Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen_test.py | 14 +++++----- tuner/tuner/common.py | 2 +- tuner/tuner/common_test.py | 2 +- tuner/tuner/dispatch_constraints.py | 40 +++++++++++++++++------------ tuner/tuner/dispatch_parser_test.py | 6 ++--- 5 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index ed744d033..d81278e8c 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -46,7 +46,7 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: M, N, K = 2048, 1280, 1280 - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, f"MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=16, @@ -100,7 +100,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -166,7 +166,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.contraction, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -215,7 +215,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_matmul, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -267,7 +267,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -317,7 +317,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -390,7 +390,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.broadcast_rhs_mmt, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 8baa1594e..45ae48c22 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -90,7 +90,7 @@ def get_compatible_mfma_intrinsics( mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> list[iree_gpu.MMAIntrinsic]: def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma a_type, b_type, c_type = mma_attr.abc_element_types if problem_size.res_type.element_type != c_type: return False diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index fc44d80ff..ea0a4573d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -72,7 +72,7 @@ def test_gpu_pipeline_options() -> None: def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=32, diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 13bf1a5ce..3e8dc6c9f 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -28,13 +28,14 @@ def get_mfma_intrinsic_constraints( return z3.Or( *( z3.And( - intrinsic_m == mma_attr.mnk_shape[0], - intrinsic_n == mma_attr.mnk_shape[1], - intrinsic_k == mma_attr.mnk_shape[2], + intrinsic_m == mnk[0], + intrinsic_n == mnk[1], + intrinsic_k == mnk[2], ) for mma_attr in ( iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics ) + for mnk in [mma_attr.mnk_shape] ) ) @@ -148,20 +149,25 @@ def getMMAAttr( lhs_type: ir.IntegerType | ir.FloatType, rhs_type: ir.IntegerType | ir.FloatType, ) -> iree_gpu.MMAAttr: - mma_str = "" - if lhs_type == rhs_type: - input = str(lhs_type).upper() - output = str(output_type).upper() - mma_str = f"MFMA_{output}_{m}x{n}x{k}_{input}" - else: - lhs = str(lhs_type).upper() - rhs = str(rhs_type).upper() - output = str(output_type).upper() - mma_str = f"MFMA_{output}_{m}x{n}x{k}_{lhs}_{rhs}" - - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, mma_str) - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - return mma_attr + for mma_intrinsic in iree_gpu.MMAIntrinsic: + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma + a_type, b_type, c_type = mma_attr.abc_element_types + mnk = mma_attr.mnk_shape + if ( + a_type == lhs_type + and b_type == rhs_type + and c_type == output_type + and m == mnk[0] + and n == mnk[1] + and k == mnk[2] + ): + return mma_attr + # If no matching intrinsic is found, raise an exception + raise ValueError( + f"No matching MMA intrinsic found for " + f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, " + f"m={m}, n={n}, k={k}." + ) def generate_solutions( diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index cca9f3606..fb10b04bc 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -40,7 +40,7 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=0, @@ -56,7 +56,7 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=64, @@ -80,7 +80,7 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=32,