diff --git a/benchmark-scheduled-unet-misa.sh b/benchmark-scheduled-unet-misa.sh deleted file mode 100755 index 189592f..0000000 --- a/benchmark-scheduled-unet-misa.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Usage: PATH=/path/to/iree/build/tools:$PATH ./benchmark-unet.sh N - -set -xeu - -IRPA_PATH_PREFIX="${2:-/data/shark}" - -iree-benchmark-module \ - --device=rocm://$1 \ - --device_allocator=caching \ - --module="$PWD"/tmp/scheduled_unet_misa.vmfb \ - --parameters=model=${IRPA_PATH_PREFIX}/scheduled_unet.irpa \ - --function=run_forward \ - --input=1x4x128x128xf16 \ - --input=2x64x2048xf16 \ - --input=2x1280xf16 \ - --input=2x6xf16 \ - --input=1xf16 \ - --input=1xi64 \ - --benchmark_repetitions=3 diff --git a/compile-scheduled-unet-tk.sh b/compile-scheduled-unet-tk.sh index 7c5b116..e275966 100755 --- a/compile-scheduled-unet-tk.sh +++ b/compile-scheduled-unet-tk.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Usage: PATH=/path/to/iree/build/tools:$PATH ./compile-scheduled-unet.sh [extra flags] +# Usage: PATH=/path/to/iree/build/tools:$PATH ./compile-scheduled-unet.sh [extra flags] set -euo pipefail diff --git a/compile-unet-base.sh b/compile-unet-base.sh index 9ded04b..9e3d315 100755 --- a/compile-unet-base.sh +++ b/compile-unet-base.sh @@ -19,10 +19,13 @@ readonly CHIP="$2" readonly MODE="$3" USE_WINOGRAD=0 USE_MISA=0 +USE_WINOGRAD_AND_MISA=0 if [[ $MODE =~ "winograd" ]] ; then USE_WINOGRAD=1 elif [[ $MODE =~ "misa" ]] ; then USE_MISA=1 +elif [[ $MODE =~ "hybrid" ]] ; then + USE_WINOGRAD_AND_MISA=1 fi readonly ATTENTION_SPEC="$(realpath "$4")" @@ -67,11 +70,21 @@ readonly MISA_FLAGS=( "--iree-preprocessing-transform-spec-filename=${SPEC_DIR}/misa_unet_spec.mlir" ) +readonly WINOGRAD_AND_MISA_FLAGS=( + "--iree-opt-const-expr-max-size-increase-threshold=1000000000000000" + "--iree-preprocessing-pass-pipeline=${WINOGRAD_PIPELINE[*]}" + "--iree-hal-executable-object-search-path=${SPEC_DIR}" + "--iree-preprocessing-transform-spec-filename=${SPEC_DIR}/misa_unet_spec.mlir" + "--iree-preprocessing-pass-pipeline=${WINOGRAD_PIPELINE[*]}" +) + declare -a FLAGS=("${DEFAULT_FLAGS[*]}") if [ "$USE_WINOGRAD" = 1 ] ; then FLAGS=("${WINOGRAD_FLAGS[@]}") elif [ "$USE_MISA" = 1 ] ; then FLAGS=("${MISA_FLAGS[@]}") +elif [ "$USE_WINOGRAD_AND_MISA" = 1 ] ; then + FLAGS=("${WINOGRAD_AND_MISA_FLAGS[@]}") fi set -x diff --git a/specs/misa_unet_spec.mlir b/specs/misa_unet_spec.mlir index e883d76..968481f 100644 --- a/specs/misa_unet_spec.mlir +++ b/specs/misa_unet_spec.mlir @@ -857,15 +857,16 @@ module attributes {transform.with_named_sequence} { transform.foreach %funcs : !transform.any_op { ^bb1(%func: !transform.any_op): transform.foreach_match in %func - @match_conv_k1 -> @cast_and_call_dag_k1, - @match_conv_k2 -> @cast_and_call_dag_k2, - @match_conv_k3 -> @cast_and_call_dag_k3, - @match_conv_k4 -> @cast_and_call_dag_k4, - @match_conv_k5 -> @cast_and_call_dag_k5, - @match_conv_k6 -> @cast_and_call_dag_k6, - @match_conv_k7 -> @cast_and_call_dag_k7, - @match_conv_k8 -> @cast_and_call_dag_k8, - @match_conv_k9 -> @cast_and_call_dag_k9 + // base - 63.5 ms + @match_conv_k1 -> @cast_and_call_dag_k1, // -1.4 ms + // @match_conv_k2 -> @cast_and_call_dag_k2, // -0.1 ms + // @match_conv_k3 -> @cast_and_call_dag_k3, // -0.1 ms + // @match_conv_k4 -> @cast_and_call_dag_k4, // 0 ms + // @match_conv_k5 -> @cast_and_call_dag_k5, // -0.1 ms + // @match_conv_k6 -> @cast_and_call_dag_k6, // -0.1 ms + @match_conv_k7 -> @cast_and_call_dag_k7 // -0.5 ms + // @match_conv_k8 -> @cast_and_call_dag_k8, // -0.1 ms + // @match_conv_k9 -> @cast_and_call_dag_k9 // -0.2 ms : (!transform.any_op) -> (!transform.any_op) } transform.apply_dce to %module : !transform.any_op diff --git a/specs/winograd_conv_spec.mlir b/specs/winograd_conv_spec.mlir index e8a7024..e555c45 100644 --- a/specs/winograd_conv_spec.mlir +++ b/specs/winograd_conv_spec.mlir @@ -453,24 +453,24 @@ module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%func: !transform.any_op {transform.consumed}) { - transform.foreach_match in %func // Base: 69.4ms // Best: 67.3ms + transform.foreach_match in %func // Base: 63.5ms // Best: 60.5ms // @match_conv2x4x128x128x3x3x320 -> @annotate_op, // fail to compile - @match_conv2x1280x64x64x3x3x1280 -> @annotate_op, // 68.8ms -0.6 - @match_conv2x640x128x128x3x3x320 -> @annotate_op, // 68.9ms -0.5 - @match_conv2x1920x64x64x3x3x640 -> @annotate_op, // 69.0ms -0.4 - @match_conv2x960x128x128x3x3x320 -> @annotate_op, // 69.1ms -0.3 - @match_conv2x640x128x128x3x3x640 -> @annotate_op, // 69.1ms -0.3 - @match_conv2x1280x64x64x3x3x640 -> @annotate_op, // 69.2ms -0.2 - @match_conv2x960x64x64x3x3x640 -> @annotate_op, // 69.2ms -0.2 - // @match_conv2x320x64x64x3x3x640 -> @annotate_op, // 69.5ms -0.1 - // @match_conv2x640x64x64x3x3x640 -> @annotate_op, // 69.3ms -0.1 - // @match_conv2x2560x32x32x3x3x1280 -> @annotate_op, // 69.3ms -0.1 - // @match_conv2x1920x32x32x3x3x1280 -> @annotate_op, // 69.4ms +0.0 - // @match_conv2x640x32x32x3x3x1280 -> @annotate_op, // 69.5ms +0.1 - // @match_conv2x320x128x128x3x3x320 -> @annotate_op, // 69.6ms +0.2 - // @match_conv2x1280x32x32x3x3x1280 -> @annotate_op, // 69.7ms +0.3 - // @match_conv2x320x128x128x3x3x4 -> @annotate_op, // 69.7ms +0.3 + @match_conv2x1280x64x64x3x3x1280 -> @annotate_op, // 63.0ms -0.5 + @match_conv2x640x128x128x3x3x320 -> @annotate_op, // 63.0ms -0.5 + @match_conv2x1920x64x64x3x3x640 -> @annotate_op, // 63.2ms -0.3 + @match_conv2x960x128x128x3x3x320 -> @annotate_op, // 63.1ms -0.4 + @match_conv2x640x128x128x3x3x640 -> @annotate_op, // 63.0ms -0.5 + @match_conv2x1280x64x64x3x3x640 -> @annotate_op, // 63.2ms -0.3 + @match_conv2x960x64x64x3x3x640 -> @annotate_op, // 63.2ms -0.3 + // @match_conv2x320x64x64x3x3x640 -> @annotate_op, + // @match_conv2x640x64x64x3x3x640 -> @annotate_op, + // @match_conv2x2560x32x32x3x3x1280 -> @annotate_op, + // @match_conv2x1920x32x32x3x3x1280 -> @annotate_op, + // @match_conv2x640x32x32x3x3x1280 -> @annotate_op, + // @match_conv2x320x128x128x3x3x320 -> @annotate_op, + // @match_conv2x1280x32x32x3x3x1280 -> @annotate_op, + // @match_conv2x320x128x128x3x3x4 -> @annotate_op, @placeholder -> @annotate_op : (!transform.any_op) -> (!transform.any_op) transform.yield diff --git a/tuning/autotune.py b/tuning/autotune.py index c57b66d..a7ef63a 100644 --- a/tuning/autotune.py +++ b/tuning/autotune.py @@ -66,7 +66,7 @@ def parse_arguments() -> argparse.Namespace: # Required arguments parser.add_argument( - "mode", choices=["default", "winograd"], help="Compilation mode" + "mode", choices=["default", "winograd", "misa", "hybrid"], help="Compilation mode" ) parser.add_argument( "input_file", type=Path, help="Path to the input benchmark file (.mlir)" @@ -543,7 +543,7 @@ def compile_unet_candidates( f"Hash value '{hash_val}' collided at candidate {indices}." ) unique_unet_candidates.append( - candidate_trackers[indices[0]].unet_candidate_path + candidate_trackers[indices[0]].unet_candidate_path # If collision occurs, use the first candidate index in the list ) return unique_unet_candidates if collision_detected else unet_candidates @@ -626,7 +626,7 @@ def main(): best_log = benchmark_top_candidates( args, base_dir, candidates_dir, compiled_files, candidate_trackers ) - print(f"Top candidates results are stored in {best_log}\n") + print(f"Top20 candidates results are stored in {best_log}\n") print("Compiling unet candidates...") unet_candidates = compile_unet_candidates( @@ -634,7 +634,7 @@ def main(): ) print(f"Unet candidates compiled in {base_dir}\n") - print("Bnechmarking unet candidates...") + print(f"Bnechmarking [{len(unet_candidates)}] unet candidates...") unet_result_log = benchmark_unet( args, base_dir, unet_candidates, candidate_trackers ) diff --git a/tuning/tune.py b/tuning/tune.py index 6517a6f..7756542 100755 --- a/tuning/tune.py +++ b/tuning/tune.py @@ -557,7 +557,9 @@ def generate_constraints( constraints += [k == intrinsic_mn * z3.FreshInt()] constraints += [k * n % (wg_x * wg_y * wg_z) == 0] constraints += [k * m % (wg_x * wg_y * wg_z) == 0] - constraints += [subgroup_m_count * subgroup_n_count == 4] + # constraints += [subgroup_m_count * subgroup_n_count == 4] # splat + # constraints += [subgroup_m_count * subgroup_n_count == 2] # real_weights + constraints += [subgroup_m_count * subgroup_n_count == 5] # conv constraints += [z3.Or(waves_per_eu == 1, waves_per_eu == 2, waves_per_eu == 4)]