Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine MISA and Winograd configs #27

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions benchmark-scheduled-unet-misa.sh

This file was deleted.

2 changes: 1 addition & 1 deletion compile-scheduled-unet-tk.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

# Usage: PATH=/path/to/iree/build/tools:$PATH ./compile-scheduled-unet.sh <target-chip> <default|winograd|misa> [extra flags]
# Usage: PATH=/path/to/iree/build/tools:$PATH ./compile-scheduled-unet.sh <target-chip> <default|winograd|misa|hybrid> [extra flags]

set -euo pipefail

Expand Down
13 changes: 13 additions & 0 deletions compile-unet-base.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ readonly CHIP="$2"
readonly MODE="$3"
USE_WINOGRAD=0
USE_MISA=0
USE_WINOGRAD_AND_MISA=0
if [[ $MODE =~ "winograd" ]] ; then
USE_WINOGRAD=1
elif [[ $MODE =~ "misa" ]] ; then
USE_MISA=1
elif [[ $MODE =~ "hybrid" ]] ; then
USE_WINOGRAD_AND_MISA=1
fi

readonly ATTENTION_SPEC="$(realpath "$4")"
Expand Down Expand Up @@ -67,11 +70,21 @@ readonly MISA_FLAGS=(
"--iree-preprocessing-transform-spec-filename=${SPEC_DIR}/misa_unet_spec.mlir"
)

readonly WINOGRAD_AND_MISA_FLAGS=(
"--iree-opt-const-expr-max-size-increase-threshold=1000000000000000"
"--iree-preprocessing-pass-pipeline=${WINOGRAD_PIPELINE[*]}"
"--iree-hal-executable-object-search-path=${SPEC_DIR}"
"--iree-preprocessing-transform-spec-filename=${SPEC_DIR}/misa_unet_spec.mlir"
"--iree-preprocessing-pass-pipeline=${WINOGRAD_PIPELINE[*]}"
)

declare -a FLAGS=("${DEFAULT_FLAGS[*]}")
if [ "$USE_WINOGRAD" = 1 ] ; then
FLAGS=("${WINOGRAD_FLAGS[@]}")
elif [ "$USE_MISA" = 1 ] ; then
FLAGS=("${MISA_FLAGS[@]}")
elif [ "$USE_WINOGRAD_AND_MISA" = 1 ] ; then
FLAGS=("${WINOGRAD_AND_MISA_FLAGS[@]}")
fi

set -x
Expand Down
19 changes: 10 additions & 9 deletions specs/misa_unet_spec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -857,15 +857,16 @@ module attributes {transform.with_named_sequence} {
transform.foreach %funcs : !transform.any_op {
^bb1(%func: !transform.any_op):
transform.foreach_match in %func
@match_conv_k1 -> @cast_and_call_dag_k1,
@match_conv_k2 -> @cast_and_call_dag_k2,
@match_conv_k3 -> @cast_and_call_dag_k3,
@match_conv_k4 -> @cast_and_call_dag_k4,
@match_conv_k5 -> @cast_and_call_dag_k5,
@match_conv_k6 -> @cast_and_call_dag_k6,
@match_conv_k7 -> @cast_and_call_dag_k7,
@match_conv_k8 -> @cast_and_call_dag_k8,
@match_conv_k9 -> @cast_and_call_dag_k9
// base - 63.5 ms
@match_conv_k1 -> @cast_and_call_dag_k1, // -1.4 ms
// @match_conv_k2 -> @cast_and_call_dag_k2, // -0.1 ms
// @match_conv_k3 -> @cast_and_call_dag_k3, // -0.1 ms
// @match_conv_k4 -> @cast_and_call_dag_k4, // 0 ms
// @match_conv_k5 -> @cast_and_call_dag_k5, // -0.1 ms
// @match_conv_k6 -> @cast_and_call_dag_k6, // -0.1 ms
@match_conv_k7 -> @cast_and_call_dag_k7 // -0.5 ms
// @match_conv_k8 -> @cast_and_call_dag_k8, // -0.1 ms
// @match_conv_k9 -> @cast_and_call_dag_k9 // -0.2 ms
: (!transform.any_op) -> (!transform.any_op)
}
transform.apply_dce to %module : !transform.any_op
Expand Down
32 changes: 16 additions & 16 deletions specs/winograd_conv_spec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -453,24 +453,24 @@ module attributes { transform.with_named_sequence } {


transform.named_sequence @__transform_main(%func: !transform.any_op {transform.consumed}) {
transform.foreach_match in %func // Base: 69.4ms // Best: 67.3ms
transform.foreach_match in %func // Base: 63.5ms // Best: 60.5ms
// @match_conv2x4x128x128x3x3x320 -> @annotate_op, // fail to compile

@match_conv2x1280x64x64x3x3x1280 -> @annotate_op, // 68.8ms -0.6
@match_conv2x640x128x128x3x3x320 -> @annotate_op, // 68.9ms -0.5
@match_conv2x1920x64x64x3x3x640 -> @annotate_op, // 69.0ms -0.4
@match_conv2x960x128x128x3x3x320 -> @annotate_op, // 69.1ms -0.3
@match_conv2x640x128x128x3x3x640 -> @annotate_op, // 69.1ms -0.3
@match_conv2x1280x64x64x3x3x640 -> @annotate_op, // 69.2ms -0.2
@match_conv2x960x64x64x3x3x640 -> @annotate_op, // 69.2ms -0.2
// @match_conv2x320x64x64x3x3x640 -> @annotate_op, // 69.5ms -0.1
// @match_conv2x640x64x64x3x3x640 -> @annotate_op, // 69.3ms -0.1
// @match_conv2x2560x32x32x3x3x1280 -> @annotate_op, // 69.3ms -0.1
// @match_conv2x1920x32x32x3x3x1280 -> @annotate_op, // 69.4ms +0.0
// @match_conv2x640x32x32x3x3x1280 -> @annotate_op, // 69.5ms +0.1
// @match_conv2x320x128x128x3x3x320 -> @annotate_op, // 69.6ms +0.2
// @match_conv2x1280x32x32x3x3x1280 -> @annotate_op, // 69.7ms +0.3
// @match_conv2x320x128x128x3x3x4 -> @annotate_op, // 69.7ms +0.3
@match_conv2x1280x64x64x3x3x1280 -> @annotate_op, // 63.0ms -0.5
@match_conv2x640x128x128x3x3x320 -> @annotate_op, // 63.0ms -0.5
@match_conv2x1920x64x64x3x3x640 -> @annotate_op, // 63.2ms -0.3
@match_conv2x960x128x128x3x3x320 -> @annotate_op, // 63.1ms -0.4
@match_conv2x640x128x128x3x3x640 -> @annotate_op, // 63.0ms -0.5
@match_conv2x1280x64x64x3x3x640 -> @annotate_op, // 63.2ms -0.3
@match_conv2x960x64x64x3x3x640 -> @annotate_op, // 63.2ms -0.3
// @match_conv2x320x64x64x3x3x640 -> @annotate_op,
// @match_conv2x640x64x64x3x3x640 -> @annotate_op,
// @match_conv2x2560x32x32x3x3x1280 -> @annotate_op,
// @match_conv2x1920x32x32x3x3x1280 -> @annotate_op,
// @match_conv2x640x32x32x3x3x1280 -> @annotate_op,
// @match_conv2x320x128x128x3x3x320 -> @annotate_op,
// @match_conv2x1280x32x32x3x3x1280 -> @annotate_op,
// @match_conv2x320x128x128x3x3x4 -> @annotate_op,
@placeholder -> @annotate_op
: (!transform.any_op) -> (!transform.any_op)
transform.yield
Expand Down
8 changes: 4 additions & 4 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def parse_arguments() -> argparse.Namespace:

# Required arguments
parser.add_argument(
"mode", choices=["default", "winograd"], help="Compilation mode"
"mode", choices=["default", "winograd", "misa", "hybrid"], help="Compilation mode"
)
parser.add_argument(
"input_file", type=Path, help="Path to the input benchmark file (.mlir)"
Expand Down Expand Up @@ -543,7 +543,7 @@ def compile_unet_candidates(
f"Hash value '{hash_val}' collided at candidate {indices}."
)
unique_unet_candidates.append(
candidate_trackers[indices[0]].unet_candidate_path
candidate_trackers[indices[0]].unet_candidate_path # If collision occurs, use the first candidate index in the list
)

return unique_unet_candidates if collision_detected else unet_candidates
Expand Down Expand Up @@ -626,15 +626,15 @@ def main():
best_log = benchmark_top_candidates(
args, base_dir, candidates_dir, compiled_files, candidate_trackers
)
print(f"Top candidates results are stored in {best_log}\n")
print(f"Top20 candidates results are stored in {best_log}\n")

print("Compiling unet candidates...")
unet_candidates = compile_unet_candidates(
args, base_dir, best_log, candidate_trackers
)
print(f"Unet candidates compiled in {base_dir}\n")

print("Bnechmarking unet candidates...")
print(f"Bnechmarking [{len(unet_candidates)}] unet candidates...")
unet_result_log = benchmark_unet(
args, base_dir, unet_candidates, candidate_trackers
)
Expand Down
4 changes: 3 additions & 1 deletion tuning/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def generate_constraints(
constraints += [k == intrinsic_mn * z3.FreshInt()]
constraints += [k * n % (wg_x * wg_y * wg_z) == 0]
constraints += [k * m % (wg_x * wg_y * wg_z) == 0]
constraints += [subgroup_m_count * subgroup_n_count == 4]
# constraints += [subgroup_m_count * subgroup_n_count == 4] # splat
# constraints += [subgroup_m_count * subgroup_n_count == 2] # real_weights
constraints += [subgroup_m_count * subgroup_n_count == 5] # conv

constraints += [z3.Or(waves_per_eu == 1, waves_per_eu == 2, waves_per_eu == 4)]

Expand Down