Skip to content

Commit

Permalink
[tuner]: remove MfmaIntrinsic and use iree_gpu.MMAAttr instead
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Nov 25, 2024
1 parent 0e74c39 commit 3a07b21
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 96 deletions.
13 changes: 6 additions & 7 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_configuration(
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl0 = f"<intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
Expand Down Expand Up @@ -119,7 +119,6 @@ def get_transform_function_mmt(

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

return f"""
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
Expand All @@ -132,7 +131,7 @@ def get_transform_function_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -205,7 +204,7 @@ def get_transform_function_conv(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -266,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -346,7 +345,7 @@ def get_transform_function_batch_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -414,7 +413,7 @@ def get_transform_function_batch_matmul(
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down
29 changes: 22 additions & 7 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore

from . import candidate_gen
from . import common
Expand Down Expand Up @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:

M, N, K = 2048, 1280, 1280

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, f"MFMA_F32_16x16x16_F16")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=16,
workgroup_size=[16, 16, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[8, 8, 8],
subgroup_m_count=16,
subgroup_n_count=16,
Expand Down Expand Up @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
Expand Down Expand Up @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.contraction,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
intrinsic=mma_attr,
tile_sizes=[480, 384, 32],
subgroup_m_count=1,
subgroup_n_count=4,
Expand Down Expand Up @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_matmul,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
intrinsic=mma_attr,
tile_sizes=[416, 320, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.broadcast_rhs_mmt,
)

mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down
72 changes: 11 additions & 61 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,74 +85,24 @@ def MNK(self) -> tuple[int, int, int]:
return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)


@dataclass
class MfmaIntrinsic:
output_type: ir.IntegerType | ir.FloatType
m: int
n: int
k: int
input_type: ir.IntegerType | ir.FloatType

def __str__(self) -> str:
input = str(self.input_type).upper()
output = str(self.output_type).upper()
return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}"

@staticmethod
def mfma_f32_16x16x16_f16():
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
return MfmaIntrinsic(f32, 16, 16, 16, f16)

@staticmethod
def mfma_f32_32x32x8_f16():
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
return MfmaIntrinsic(f32, 32, 32, 8, f16)

@staticmethod
def mfma_i32_16x16x32_i8():
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
return MfmaIntrinsic(i32, 16, 16, 32, i8)

@staticmethod
def mfma_i32_32x32x16_i8():
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
return MfmaIntrinsic(i32, 32, 32, 16, i8)

@staticmethod
def all():
return [
MfmaIntrinsic.mfma_f32_16x16x16_f16(),
MfmaIntrinsic.mfma_f32_32x32x8_f16(),
MfmaIntrinsic.mfma_i32_16x16x32_i8(),
MfmaIntrinsic.mfma_i32_32x32x16_i8(),
]


def get_compatible_mfma_intrinsics(
problem_size: ProblemSize,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[MfmaIntrinsic]:
available_mma_intrinsics = [str(mma) for mma in mma_intrinsics]

def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.res_type.element_type != intrinsic.output_type:
) -> list[iree_gpu.MMAIntrinsic]:
def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
a_type, b_type, c_type = mma_attr.abc_element_types
if problem_size.res_type.element_type != c_type:
return False
if problem_size.dispatch_kind != DispatchKind.batch_matmul:
if problem_size.lhs_type.element_type != intrinsic.input_type:
return False
if problem_size.rhs_type.element_type != intrinsic.input_type:
if (
problem_size.lhs_type.element_type != a_type
or problem_size.rhs_type.element_type != b_type
):
return False

if str(intrinsic) not in available_mma_intrinsics:
return False

return True

return list(filter(is_compatible, MfmaIntrinsic.all()))
return list(filter(is_comptible, mma_intrinsics))


class ReorderWorkgroupsStrategy(Enum):
Expand Down Expand Up @@ -197,7 +147,7 @@ def __str__(self) -> str:
class Configuration:
subgroup_size: int
workgroup_size: list[int]
intrinsic: MfmaIntrinsic
intrinsic: iree_gpu.MMAAttr
tile_sizes: list[int]
subgroup_m_count: int
subgroup_n_count: int
Expand Down
25 changes: 11 additions & 14 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Usage: python -m pytest candidate_gen_test.py
Usage: python -m pytest common_test.py
"""

import pytest
Expand Down Expand Up @@ -72,10 +72,12 @@ def test_gpu_pipeline_options() -> None:


def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16")
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=32,
workgroup_size=[16, 16, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
Expand All @@ -97,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
)


def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None:
assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16"
assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8"


def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
Expand All @@ -116,8 +113,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -133,8 +130,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
) == [
common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -150,8 +147,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -166,7 +163,7 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert (
Expand Down
Loading

0 comments on commit 3a07b21

Please sign in to comment.