Skip to content

Commit

Permalink
run pre-commit
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 12, 2024
1 parent 119833b commit abd2a23
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 44 deletions.
32 changes: 32 additions & 0 deletions tuner/examples/test/README_.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Example Tuner Test

Example of tuning a dispatch and full model.

## Environments
Follow instructions in [`/tuner/README.md`](../README.md)

## Running the Tuner

### Choose a model to tune
This example uses the simple `double_mmt.mlir` file.

### Generate a benchmark file
Use the usual `iree-compile` command for your model and add
`--iree-hal-dump-executable-files-to=dump`. For example:
```shell
iree-compile double_mmt.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump -o /dev/null
```

Next, copy the `*_benchmark.mlir` file to some temporary directory of choice.
This will be the input to the dispatch tuner. In the example, the `mmt_benchmark.mlir` example file (from double_mmt.mlir) can be used.

### Recommended Trial Run
For an initial trial to test the tuning loop, use:
```shell
python -m examples.test double_mmt.mlir mmt_benchmark.mlir --num-candidates=20
```

### Basic Usage
```shell
python -m examples.test double_mmt.mlir mmt_benchmark.mlir
```
6 changes: 3 additions & 3 deletions tuner/examples/test/conv_benchmark.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) {
hal.executable.export public @main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
Expand Down Expand Up @@ -52,8 +52,8 @@ module {
%ordinal = hal.executable.export.ordinal target(@main_0_dispatch_0::@rocm_hsaco_fb::@main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32) : index
scf.for %arg1 = %c0 to %1 step %c1 {
hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z]) bindings([
(%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c0, %c2959360],
(%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c2959360, %c14745600],
(%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c0, %c2959360],
(%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c2959360, %c14745600],
(%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c17704960, %c10485760]
]) flags("None")
hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
Expand Down
6 changes: 3 additions & 3 deletions tuner/examples/test/mmt_benchmark.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) {
hal.executable.export public @main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
Expand Down Expand Up @@ -57,8 +57,8 @@ module {
%ordinal = hal.executable.export.ordinal target(@main_dispatch_0::@rocm_hsaco_fb::@main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32) : index
scf.for %arg1 = %c0 to %1 step %c1 {
hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z]) bindings([
(%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c0, %c8388608],
(%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c8388608, %c8388608],
(%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c0, %c8388608],
(%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c8388608, %c8388608],
(%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c16777216, %c16777216]
]) flags("None")
hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
Expand Down
13 changes: 7 additions & 6 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def find_handler(self, op_name: str) -> DispatchTuner:
assert False, "Dispatch kind not supported"



class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser):
def apply_params(
self,
Expand Down Expand Up @@ -184,7 +183,9 @@ def get_td_spec(
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
conv_op: ir.Operation = self.get_conv_operation(ir_module)
assert conv_op.name == "linalg.conv_2d_nhwc_hwcf", "expected linalg.conv_2d_nhwc_hwcf"
assert (
conv_op.name == "linalg.conv_2d_nhwc_hwcf"
), "expected linalg.conv_2d_nhwc_hwcf"
lhs_type = ir.ShapedType(conv_op.operands[0].type)
rhs_type = ir.ShapedType(conv_op.operands[1].type)
acc_type = ir.ShapedType(conv_op.operands[2].type)
Expand All @@ -198,9 +199,7 @@ def get_td_spec(
conv_type = conv_op.name.split(".")[-1]
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(
ir_module.context, conv_op, compilation_info, func_name
)
return build_td_spec(ir_module.context, conv_op, compilation_info, func_name)


class MmtTuner(DispatchTuner, MmtParser):
Expand Down Expand Up @@ -654,7 +653,9 @@ def generate_configs_and_td_specs(

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(str(input_module).splitlines())
problem_size: ProblemSize = dispatch_tuner.get_shapes(
str(input_module).splitlines()
)
tune_logger.debug(str(problem_size))

# Index 0 is reserved for default config, so it gets no td spec.
Expand Down
26 changes: 15 additions & 11 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,20 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
tuner = candidate_gen.ContractionOpInterfaceTuner()
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
assert td_spec_module

named_sequence_ops: list[transform.NamedSequenceOp] = op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op : isinstance(op.opview, transform.NamedSequenceOp)

named_sequence_ops: list[transform.NamedSequenceOp] = (
op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
)
)
apply_config_sequence = None
matcher_sequence = None
entry_point = None
for op in named_sequence_ops:
if str(op.opview.sym_name) == "\"apply_op_config\"":
if str(op.opview.sym_name) == '"apply_op_config"':
apply_config_sequence = op
elif str(op.opview.sym_name) == "\"__kernel_config\"":
elif str(op.opview.sym_name) == '"__kernel_config"':
entry_point = op
else:
matcher_sequence = op
Expand Down Expand Up @@ -174,17 +176,19 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None:
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
assert td_spec_module

named_sequence_ops: list[transform.NamedSequenceOp] = op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op : isinstance(op.opview, transform.NamedSequenceOp)
named_sequence_ops: list[transform.NamedSequenceOp] = (
op_matchers.get_ops_from_module(
module=td_spec_module,
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
)
)
apply_config_sequence = None
matcher_sequence = None
entry_point = None
for op in named_sequence_ops:
if str(op.opview.sym_name) == "\"apply_op_config\"":
if str(op.opview.sym_name) == '"apply_op_config"':
apply_config_sequence = op
elif str(op.opview.sym_name) == "\"__kernel_config\"":
elif str(op.opview.sym_name) == '"__kernel_config"':
entry_point = op
else:
matcher_sequence = op
Expand Down
7 changes: 3 additions & 4 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
if not isinstance(problem_size.res_type.element_type, type(c_type)):
return False
if problem_size.dispatch_kind != DispatchKind.batch_matmul:
if (
not isinstance(problem_size.lhs_type.element_type, type(a_type))
or not isinstance(problem_size.rhs_type.element_type, type(b_type))
):
if not isinstance(
problem_size.lhs_type.element_type, type(a_type)
) or not isinstance(problem_size.rhs_type.element_type, type(b_type)):
return False
return True

Expand Down
12 changes: 5 additions & 7 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
class ContractionOpInterfaceParser(DispatchParser):
def supports(self, op_name: str) -> bool:
return (
"matmul_like" in op_name or
"batch_matmul" in op_name or
"batch_matmul_transpose_b" in op_name or
"matmul_transpose_b" in op_name
"matmul_like" in op_name
or "batch_matmul" in op_name
or "batch_matmul_transpose_b" in op_name
or "matmul_transpose_b" in op_name
)

def get_contraction_operation(
Expand Down Expand Up @@ -140,9 +140,7 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
# TODO(Max191): Support more convolution types. Only NHWC convs are supported.
class ConvolutionOpInterfaceParser(DispatchParser):
def __init__(self):
self.supported_ops = [
"linalg.conv_2d_nhwc_hwcf"
]
self.supported_ops = ["linalg.conv_2d_nhwc_hwcf"]

def supports(self, op_name: str) -> bool:
for supported_op_name in self.supported_ops:
Expand Down
19 changes: 9 additions & 10 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,10 +1017,9 @@ def parse_dispatch_benchmark_results(
benchmark_time = res.get_mean_time_us()
assert benchmark_time is not None
candidate_trackers[candidate_id].first_benchmark_time = benchmark_time
candidate_trackers[
candidate_id
].spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename(
candidate_id
candidate_trackers[candidate_id].spec_path = (
path_config.specs_dir
/ path_config.get_candidate_spec_filename(candidate_id)
)
mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path
spec_path = candidate_trackers[candidate_id].spec_path
Expand Down Expand Up @@ -1284,9 +1283,9 @@ def parse_model_benchmark_results(
]

dump_list = []
incomplete_list: list[
tuple[int, Optional[str]]
] = [] # format: [(candidate_id, device_id)]
incomplete_list: list[tuple[int, Optional[str]]] = (
[]
) # format: [(candidate_id, device_id)]

baseline_time = None
for same_device_results in grouped_benchmark_results:
Expand Down Expand Up @@ -1338,9 +1337,9 @@ def parse_model_benchmark_results(
calibrated_benchmark_diff = (
benchmark_time - baseline_time
) / baseline_time
candidate_trackers[
candidate_id
].calibrated_benchmark_diff = calibrated_benchmark_diff
candidate_trackers[candidate_id].calibrated_benchmark_diff = (
calibrated_benchmark_diff
)
else:
calibrated_benchmark_diff = None

Expand Down

0 comments on commit abd2a23

Please sign in to comment.