Skip to content

Commit

Permalink
[tuner]: use property function from iree lowering config python bindi…
Browse files Browse the repository at this point in the history
…ng directly

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 9, 2024
1 parent 7f6de06 commit aa4f2d5
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 140 deletions.
84 changes: 39 additions & 45 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,19 @@

tune_logger = logging.getLogger("tune")

# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.

def apply_configuration(
template: list[str],
configuration: Configuration,
workgroup_sizes: list[int],
reduction_sizes: list[int],
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn
workgroup_sizes = lowering_config.workgroup_tile_sizes
reduction_sizes = lowering_config.reduction_tile_sizes
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
Expand Down Expand Up @@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -167,8 +173,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_mmt_workgroup_sizes(configuration),
get_mmt_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
Expand All @@ -193,15 +197,12 @@ def get_transform_function_conv(
filter = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{dynamic_batch_output_ty}>"

workgroup_sizes = ", ".join(
map(str, self.get_conv_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, self.get_conv_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -246,8 +247,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
self.get_conv_workgroup_sizes(configuration),
self.get_conv_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_conv(problem_size, f"match_op", configuration),
Expand All @@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt(
functionName: str,
configuration: Configuration,
) -> str:
workgroup_sizes = ", ".join(
map(str, get_batch_mmt_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, get_batch_mmt_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt(
modified += apply_configuration(
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand Down Expand Up @@ -345,8 +339,6 @@ def apply_params(
apply_configuration(
template,
configuration,
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
),
"",
)
Expand All @@ -359,9 +351,12 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -403,8 +398,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand All @@ -428,9 +421,12 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -476,8 +472,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
)

embeddable = indent(
Expand Down
36 changes: 18 additions & 18 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
'gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}',
]

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640
n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[464, 320, 0],
reduction=[0, 0, 16],
workgroup=[n, oh, ow, oc, fh, fw, 0],
reduction=[0, 0, 0, 0, 0, 0, ic],
subgroup_m_count=1,
subgroup_n_count=4,
)
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
"LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64"
in modified
)
assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified
assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified
assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified
assert (
"gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>"
Expand Down Expand Up @@ -186,8 +186,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[480, 384, 0],
reduction=[0, 0, 32],
workgroup=[1, 480, 384, 0],
reduction=[0, 0, 0, 32],
subgroup_m_count=1,
subgroup_n_count=4,
)
Expand Down Expand Up @@ -241,8 +241,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[416, 320, 0],
reduction=[0, 0, 128],
workgroup=[1, 416, 320, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -299,8 +299,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
"%config = transform.param.constant #iree_codegen.compilation_info<"
in embeddable
)
assert "workgroup = [128, 64, 0]" in embeddable
assert "reduction = [0, 0, 128]" in embeddable
assert "workgroup = [1, 128, 64, 0]" in embeddable
assert "reduction = [0, 0, 0, 128]" in embeddable
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable

Expand All @@ -435,8 +435,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
"%config = transform.param.constant #iree_codegen.compilation_info<"
in embeddable
)
assert "workgroup = [128, 64, 0]" in embeddable
assert "reduction = [0, 0, 128]" in embeddable
assert "workgroup = [1, 128, 64, 0]" in embeddable
assert "reduction = [0, 0, 0, 128]" in embeddable
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable

Expand Down
34 changes: 0 additions & 34 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,40 +119,6 @@ class Configuration:
waves_per_eu: int


def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]:
if "mma_kind" in config.lowering_config.attributes:
return config.lowering_config.attributes["mma_kind"]
return None


def get_workgroup_tile_sizes(config: Configuration) -> list[int]:
if "workgroup" in config.lowering_config.attributes:
workgroup_attrs = config.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []


def get_reduction_tile_sizes(config: Configuration) -> list[int]:
if "reduction" in config.lowering_config.attributes:
reduction_attrs = config.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
return []


def get_subgroup_m_count(config: Configuration) -> Optional[int]:
if "subgroup_m_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_m_count"]
return attr.value
return None


def get_subgroup_n_count(config: Configuration) -> Optional[int]:
if "subgroup_n_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_n_count"]
return attr.value
return None


def get_lowering_config(
tuner_ctx: TunerContext,
**kwargs: Any,
Expand Down
5 changes: 2 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,5 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None:
waves_per_eu=2,
)

assert common.get_intrinsic(config) is None
assert common.get_subgroup_m_count(config) == 1
assert common.get_subgroup_n_count(config) == 1
assert config.lowering_config.mma_kind is None
assert config.lowering_config.subgroup_count_mn == (1, 1)
36 changes: 2 additions & 34 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,10 @@ def parse_tensor_type(tensor_type: str) -> ShapedType:
return ShapedType(shaped_ty.shape, shaped_ty.element_type)


def get_mmt_workgroup_sizes(configuration: Configuration):
return get_workgroup_tile_sizes(configuration)


def get_mmt_reduction_sizes(configuration: Configuration):
return get_reduction_tile_sizes(configuration)


def get_contract_workgroup_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
m, n, _k = get_workgroup_tile_sizes(configuration)
m, n, _k = configuration.lowering_config.workgroup_tile_sizes

workgroup_size = [1] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
Expand All @@ -48,7 +40,7 @@ def get_contract_workgroup_sizes(
def get_contract_reduction_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
_m, _n, k = get_reduction_tile_sizes(configuration)
_m, _n, k = configuration.lowering_config.reduction_tile_sizes
reduction_size = [0] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
if dim == "k":
Expand All @@ -57,14 +49,6 @@ def get_contract_reduction_sizes(
return reduction_size


def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]:
return [1] + get_workgroup_tile_sizes(configuration)


def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]:
return [0] + get_reduction_tile_sizes(configuration)


class MlirRegex(Enum):
ssa_value = r"%[a-zA-Z0-9-_]+"
tensor_type = r"tensor<([^>]+)>"
Expand Down Expand Up @@ -164,22 +148,6 @@ class ConvParser(DispatchParser):
def supports(self, op_name: str) -> bool:
return "conv_2d_nhwc_hwcf" in op_name

def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:
batch = 1
fh = 1
fw = 1

oh = 1

ow, oc, _ic = get_workgroup_tile_sizes(configuration)

return [batch, oh, ow, oc, fh, fw, 0]

def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
_ow, _oc, ic = get_reduction_tile_sizes(configuration)

return [0, 0, 0, 0, 0, 0, ic]

def get_shapes(self, template: list[str]) -> ProblemSize:
for line in template:
if "linalg.conv_2d_nhwc_hwcf" not in line:
Expand Down
Loading

0 comments on commit aa4f2d5

Please sign in to comment.