Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner]: use property function from iree lowering config python binding #662

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 39 additions & 45 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,19 @@

tune_logger = logging.getLogger("tune")

# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.

def apply_configuration(
template: list[str],
configuration: Configuration,
workgroup_sizes: list[int],
reduction_sizes: list[int],
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn
workgroup_sizes = lowering_config.workgroup_tile_sizes
reduction_sizes = lowering_config.reduction_tile_sizes
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
Expand Down Expand Up @@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -167,8 +173,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_mmt_workgroup_sizes(configuration),
get_mmt_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
Expand All @@ -193,15 +197,12 @@ def get_transform_function_conv(
filter = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{dynamic_batch_output_ty}>"

workgroup_sizes = ", ".join(
map(str, self.get_conv_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, self.get_conv_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -246,8 +247,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
self.get_conv_workgroup_sizes(configuration),
self.get_conv_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_conv(problem_size, f"match_op", configuration),
Expand All @@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt(
functionName: str,
configuration: Configuration,
) -> str:
workgroup_sizes = ", ".join(
map(str, get_batch_mmt_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, get_batch_mmt_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt(
modified += apply_configuration(
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand Down Expand Up @@ -345,8 +339,6 @@ def apply_params(
apply_configuration(
template,
configuration,
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
),
"",
)
Expand All @@ -359,9 +351,12 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -403,8 +398,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand All @@ -428,9 +421,12 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand Down Expand Up @@ -476,8 +472,6 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
)

embeddable = indent(
Expand Down
36 changes: 18 additions & 18 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
'gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}',
]

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640
n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 16

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[464, 320, 0],
reduction=[0, 0, 16],
workgroup=[n, oh, ow, oc, fh, fw, 0],
reduction=[0, 0, 0, 0, 0, 0, ic],
subgroup_m_count=1,
subgroup_n_count=4,
)
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
"LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64"
in modified
)
assert "workgroup = [1, 1, 464, 320, 1, 1, 0]" in modified
assert "workgroup = [2, 64, 64, 640, 3, 3, 0]" in modified
assert "reduction = [0, 0, 0, 0, 0, 0, 16]" in modified
assert (
"gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>"
Expand Down Expand Up @@ -186,8 +186,8 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[480, 384, 0],
reduction=[0, 0, 32],
workgroup=[1, 480, 384, 0],
reduction=[0, 0, 0, 32],
subgroup_m_count=1,
subgroup_n_count=4,
)
Expand Down Expand Up @@ -241,8 +241,8 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[416, 320, 0],
reduction=[0, 0, 128],
workgroup=[1, 416, 320, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -299,8 +299,8 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -408,8 +408,8 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
"%config = transform.param.constant #iree_codegen.compilation_info<"
in embeddable
)
assert "workgroup = [128, 64, 0]" in embeddable
assert "reduction = [0, 0, 128]" in embeddable
assert "workgroup = [1, 128, 64, 0]" in embeddable
assert "reduction = [0, 0, 0, 128]" in embeddable
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable

Expand All @@ -435,8 +435,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
lowering_config = common.get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[128, 64, 0],
reduction=[0, 0, 128],
workgroup=[1, 128, 64, 0],
reduction=[0, 0, 0, 128],
subgroup_m_count=2,
subgroup_n_count=2,
)
Expand Down Expand Up @@ -492,8 +492,8 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
"%config = transform.param.constant #iree_codegen.compilation_info<"
in embeddable
)
assert "workgroup = [128, 64, 0]" in embeddable
assert "reduction = [0, 0, 128]" in embeddable
assert "workgroup = [1, 128, 64, 0]" in embeddable
assert "reduction = [0, 0, 0, 128]" in embeddable
assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable

Expand Down
34 changes: 0 additions & 34 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,40 +119,6 @@ class Configuration:
waves_per_eu: int


def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]:
if "mma_kind" in config.lowering_config.attributes:
return config.lowering_config.attributes["mma_kind"]
return None


def get_workgroup_tile_sizes(config: Configuration) -> list[int]:
if "workgroup" in config.lowering_config.attributes:
workgroup_attrs = config.lowering_config.attributes["workgroup"]
return [attr.value for attr in workgroup_attrs]
return []


def get_reduction_tile_sizes(config: Configuration) -> list[int]:
if "reduction" in config.lowering_config.attributes:
reduction_attrs = config.lowering_config.attributes["reduction"]
return [attr.value for attr in reduction_attrs]
return []


def get_subgroup_m_count(config: Configuration) -> Optional[int]:
if "subgroup_m_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_m_count"]
return attr.value
return None


def get_subgroup_n_count(config: Configuration) -> Optional[int]:
if "subgroup_n_count" in config.lowering_config.attributes:
attr = config.lowering_config.attributes["subgroup_n_count"]
return attr.value
return None


def get_lowering_config(
tuner_ctx: TunerContext,
**kwargs: Any,
Expand Down
5 changes: 2 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,5 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None:
waves_per_eu=2,
)

assert common.get_intrinsic(config) is None
assert common.get_subgroup_m_count(config) == 1
assert common.get_subgroup_n_count(config) == 1
assert config.lowering_config.mma_kind is None
assert config.lowering_config.subgroup_count_mn == (1, 1)
36 changes: 2 additions & 34 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,10 @@ def parse_tensor_type(tensor_type: str) -> ShapedType:
return ShapedType(shaped_ty.shape, shaped_ty.element_type)


def get_mmt_workgroup_sizes(configuration: Configuration):
return get_workgroup_tile_sizes(configuration)


def get_mmt_reduction_sizes(configuration: Configuration):
return get_reduction_tile_sizes(configuration)


def get_contract_workgroup_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
m, n, _k = get_workgroup_tile_sizes(configuration)
m, n, _k = configuration.lowering_config.workgroup_tile_sizes

workgroup_size = [1] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
Expand All @@ -48,7 +40,7 @@ def get_contract_workgroup_sizes(
def get_contract_reduction_sizes(
configuration: Configuration, tile_dims: str
) -> list[int]:
_m, _n, k = get_reduction_tile_sizes(configuration)
_m, _n, k = configuration.lowering_config.reduction_tile_sizes
reduction_size = [0] * len(tile_dims)
for idx, dim in enumerate(tile_dims):
if dim == "k":
Expand All @@ -57,14 +49,6 @@ def get_contract_reduction_sizes(
return reduction_size


def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]:
return [1] + get_workgroup_tile_sizes(configuration)


def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]:
return [0] + get_reduction_tile_sizes(configuration)


class MlirRegex(Enum):
ssa_value = r"%[a-zA-Z0-9-_]+"
tensor_type = r"tensor<([^>]+)>"
Expand Down Expand Up @@ -164,22 +148,6 @@ class ConvParser(DispatchParser):
def supports(self, op_name: str) -> bool:
return "conv_2d_nhwc_hwcf" in op_name

def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]:
batch = 1
fh = 1
fw = 1

oh = 1

ow, oc, _ic = get_workgroup_tile_sizes(configuration)

return [batch, oh, ow, oc, fh, fw, 0]

def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]:
_ow, _oc, ic = get_reduction_tile_sizes(configuration)

return [0, 0, 0, 0, 0, 0, ic]

def get_shapes(self, template: list[str]) -> ProblemSize:
for line in template:
if "linalg.conv_2d_nhwc_hwcf" not in line:
Expand Down
Loading
Loading