Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner] Add tuner files #158

Merged
merged 22 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2 changes: 1 addition & 1 deletion tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import pytest
import candidate_gen
from . import candidate_gen


def test_get_shaped_type_element_bitwidth():
Expand Down
5 changes: 5 additions & 0 deletions tuner/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3 changes: 3 additions & 0 deletions tuner/examples/punet/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Test files/dirs recommended by README.md.
dump-mmt
test-benchmark.mlir
46 changes: 46 additions & 0 deletions tuner/examples/punet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Punet Tuner

## Environments
Follow instructions in [`/tuner/README.md`](../README.md)

## Shell Scripts

The required shell scripts can be downloaded from:
[sdxl-scripts](https://github.com/nod-ai/sdxl-scripts).

These scripts include:
1. `compile-punet-base.sh` - Used for compiling model candidates.
2. `compile_candidate.sh` - Used for compiling dispatch candidates.
3. `punet.sh` - Invoked by `compile_candidate.sh`.

Add the parent directories of these scripts to your `PATH` environment variable,
so that they can be picked up by `punet_autotune.py`.

## Running the Tuner

### [Optional] Generate a tunable mlir
Use
[`punet.sh`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/punet.sh)
to compile the sample matmul `mmt.mlir` (can also find here:
[`mmt_unet.mlir`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/mmt_unet.mlir)):
```shell
punet.sh mmt.mlir -o mmt.vmfb --iree-hal-dump-executable-files-to=dump-mmt
cp ./dump-mmt/module_main_0_dispatch_0_rocm_hsaco_fb_benchmark.mlir test-benchmark.mlir
```

### Recommended Trial Run
For an initial trial to test the tuning loop, use:
```shell
python -m tuner.examples.punet.punet_autotune test-benchmark.mlir --num-candidates=10
```

### Dry Run Test
RattataKing marked this conversation as resolved.
Show resolved Hide resolved
To perform a dry run (no GPU required), use:
```shell
python -m tuner.examples.punet.punet_autotune test-benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run
```

### Basic Usage
```shell
python -m tuner.examples.punet.punet_autotune test-benchmark.mlir
```
5 changes: 5 additions & 0 deletions tuner/examples/punet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11 changes: 11 additions & 0 deletions tuner/examples/punet/mmt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
!matA_0 = tensor<2048x1280xf16>
!matB_0 = tensor<10240x1280xf16>
!matC_0 = tensor<2048x10240xf32>

func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 {
%cst = arith.constant 0.000000e+00 : f16
%5 = tensor.empty() : !matC_0
%6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0
%8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0
return %8 : !matC_0
}
191 changes: 191 additions & 0 deletions tuner/examples/punet/punet_autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Sample Usage:

python -m tuner.examples.punet.punet_autotune benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64


Recommended Trial Run:

python -m tuner.examples.punet.punet_autotune benchmark.mlir --num-candidates=1


Dry Run Test (no gpu requried):

python -m tuner.examples.punet.punet_autotune benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run

"""

from ... import libtuner
from pathlib import Path


class PunetClient(libtuner.TuningClient):
def get_dispatch_compile_timeout_s(self) -> int:
return 4

def get_dispatch_compile_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
mlir_path = candidate_tracker.dispatch_mlir_path
assert mlir_path is not None
command = [
"compile_candidate.sh",
mlir_path.as_posix(),
]
return command

def get_dispatch_benchmark_timeout_s(self) -> int:
return 15

def get_dispatch_benchmark_command(
self,
candidate_tracker: libtuner.CandidateTracker,
) -> list[str]:
compiled_vmfb_path = candidate_tracker.compiled_dispatch_path
assert compiled_vmfb_path is not None

command = [
"iree-benchmark-module",
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
f"--module={compiled_vmfb_path.resolve()}",
"--hip_use_streams=true",
"--hip_allow_inline_execution=true",
"--batch_size=1000",
"--benchmark_repetitions=3",
f"--benchmark_out=dispatch_{candidate_tracker.candidate_id}_bm.json",
"--benchmark_out_format=json",
]

return command

def get_model_compile_timeout_s(self) -> int:
return 300

def get_model_compile_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
mlir_spec_path = candidate_tracker.spec_path
assert mlir_spec_path is not None
target_dir = mlir_spec_path.resolve().parent.parent.parent
output_name = f"unet_candidate_{candidate_tracker.candidate_id}.vmfb"
command = [
"compile-punet-base.sh",
"iree-compile",
"gfx942",
f"{mlir_spec_path.resolve()}",
"./punet.mlir",
"-o",
(target_dir / output_name).as_posix(),
]
return command

def get_model_benchmark_timeout_s(self) -> int:
return 180

def get_model_benchmark_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
unet_candidate_path = candidate_tracker.compiled_model_path
assert unet_candidate_path is not None

command = [
"iree-benchmark-module",
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
"--hip_use_streams=true",
"--hip_allow_inline_execution=true",
"--device_allocator=caching",
f"--module={unet_candidate_path.resolve()}",
"--parameters=model=punet.irpa",
"--function=main",
"--input=1x4x128x128xf16",
"--input=1xsi32",
"--input=2x64x2048xf16",
"--input=2x1280xf16",
"--input=2x6xf16",
"--input=1xf16",
"--benchmark_repetitions=5",
f"--benchmark_out=model_{candidate_tracker.candidate_id}_bm.json",
"--benchmark_out_format=json",
]
return command


def main():
args = libtuner.parse_arguments()
path_config = libtuner.PathConfig()
path_config.base_dir.mkdir(parents=True, exist_ok=True)
path_config.output_unilog.touch()
candidate_trackers: list[libtuner.CandidateTracker] = []
punet_client = PunetClient()
stop_after_phase: str = args.stop_after

print("Setup logging")
libtuner.setup_logging(args, path_config)
print(path_config.run_log, end="\n\n")

if not args.dry_run:
print("Validating devices")
libtuner.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
candidates = libtuner.generate_candidates(args, path_config, candidate_trackers)
print(f"Stored candidates in {path_config.candidates_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return
RattataKing marked this conversation as resolved.
Show resolved Hide resolved

print("Compiling candidates...")
compiled_candidates = libtuner.compile_dispatches(
args, path_config, candidates, candidate_trackers, punet_client
)
print(f"Compiled files are stored in {path_config.compiled_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled candidates...")
top_candidates = libtuner.benchmark_dispatches(
args, path_config, compiled_candidates, candidate_trackers, punet_client
)
print(f"Stored results in {path_config.output_unilog}\n")
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return
kuhar marked this conversation as resolved.
Show resolved Hide resolved

print(f"Compiling top model candidates...")
punet_candidates = libtuner.compile_models(
args, path_config, top_candidates, candidate_trackers, punet_client
)
print(f"Model candidates compiled in {path_config.base_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.compile_models:
return

print("Benchmarking model candidates...")
libtuner.benchmark_models(
args, path_config, punet_candidates, candidate_trackers, punet_client
)
print(f"Stored results in {path_config.output_unilog}")
if stop_after_phase == libtuner.ExecutionPhases.benchmark_models:
return

libtuner.summerize_top_candidates(path_config, candidate_trackers)
print(f"Stored top candidates info in {path_config.result_summary_log}\n")

libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n")

print("Check the detailed execution logs in:")
print(path_config.run_log)

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
if args.verbose:
print(candidate)


if __name__ == "__main__":
main()
Loading
Loading