Skip to content

Commit

Permalink
[tunner]: address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Nov 25, 2024
1 parent 3a07b21 commit 34757e9
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 29 deletions.
14 changes: 7 additions & 7 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:

M, N, K = 2048, 1280, 1280

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, f"MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=16,
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.contraction,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_matmul,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
Expand Down Expand Up @@ -390,7 +390,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.broadcast_rhs_mmt,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_compatible_mfma_intrinsics(
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[iree_gpu.MMAIntrinsic]:
def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma
a_type, b_type, c_type = mma_attr.abc_element_types
if problem_size.res_type.element_type != c_type:
return False
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_gpu_pipeline_options() -> None:


def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=32,
Expand Down
40 changes: 23 additions & 17 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ def get_mfma_intrinsic_constraints(
return z3.Or(
*(
z3.And(
intrinsic_m == mma_attr.mnk_shape[0],
intrinsic_n == mma_attr.mnk_shape[1],
intrinsic_k == mma_attr.mnk_shape[2],
intrinsic_m == mnk[0],
intrinsic_n == mnk[1],
intrinsic_k == mnk[2],
)
for mma_attr in (
iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics
)
for mnk in [mma_attr.mnk_shape]
)
)

Expand Down Expand Up @@ -148,20 +149,25 @@ def getMMAAttr(
lhs_type: ir.IntegerType | ir.FloatType,
rhs_type: ir.IntegerType | ir.FloatType,
) -> iree_gpu.MMAAttr:
mma_str = ""
if lhs_type == rhs_type:
input = str(lhs_type).upper()
output = str(output_type).upper()
mma_str = f"MFMA_{output}_{m}x{n}x{k}_{input}"
else:
lhs = str(lhs_type).upper()
rhs = str(rhs_type).upper()
output = str(output_type).upper()
mma_str = f"MFMA_{output}_{m}x{n}x{k}_{lhs}_{rhs}"

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, mma_str)
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
return mma_attr
for mma_intrinsic in iree_gpu.MMAIntrinsic:
mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma
a_type, b_type, c_type = mma_attr.abc_element_types
mnk = mma_attr.mnk_shape
if (
a_type == lhs_type
and b_type == rhs_type
and c_type == output_type
and m == mnk[0]
and n == mnk[1]
and k == mnk[2]
):
return mma_attr
# If no matching intrinsic is found, raise an exception
raise ValueError(
f"No matching MMA intrinsic found for "
f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, "
f"m={m}, n={n}, k={k}."
)


def generate_solutions(
Expand Down
6 changes: 3 additions & 3 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None:


def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:
mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = dispatch_parser.Configuration(
subgroup_size=0,
Expand All @@ -56,7 +56,7 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:


def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None:
mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = dispatch_parser.Configuration(
subgroup_size=64,
Expand All @@ -80,7 +80,7 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None:


def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None:
mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = dispatch_parser.Configuration(
subgroup_size=32,
Expand Down

0 comments on commit 34757e9

Please sign in to comment.