Skip to content

Commit

Permalink
[TKW] Add create_vmfb_file option (#300)
Browse files Browse the repository at this point in the history
Add a `create_vmfb_file` flag to run full IREE pipeline to get vmfb file
without actually invoking anything.

This is a quick hack to get iree-kernel-benchmark going until we have a
proper AOT compilation API.

---------

Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 authored Dec 3, 2024
1 parent a366891 commit 6c36741
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 8 additions & 3 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .assumptions import Assumption
import torch.fx as fx
import iree.turbine.kernel.lang as tkl
from pathlib import Path


import tempfile
Expand Down Expand Up @@ -541,6 +542,7 @@ def compile_and_invoke(
kernel_dynamic_dims: list[int] = [],
run: bool = False,
run_bench: bool = False,
create_vmfb_file: Optional[Path] = None,
inplace: bool = False,
):
backend = config["backend"]
Expand Down Expand Up @@ -588,9 +590,12 @@ def compile_and_invoke(

res = compile_str(asm, target_backends=[backend], extra_args=flags)

dump_vmfb_file = config.get("dump_vmfb_file", None)
if dump_vmfb_file is not None:
_write_file(dump_vmfb_file, "wb", res)
if create_vmfb_file is not None:
_write_file(create_vmfb_file, "wb", res)

if not (run or run_bench):
return

if inplace:
# Select device as the GPU, where input tensors are coming from.
device_uuid = get_device_uuid(kernel_inputs + kernel_outputs)
Expand Down
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def test_execute(self, args, kwargs):

run = kwargs.get("run", False)
run_bench = kwargs.get("run_bench", False)
if run or run_bench:
create_vmfb_file = kwargs.get("create_vmfb_file", None)
if run or run_bench or create_vmfb_file:
# TODO: cache compiled code
dynamic_symbols = kwargs.get("dynamic_symbols", [])
host_codegen.isolated_test_call(
Expand Down Expand Up @@ -372,6 +373,7 @@ def test_execute(self, args, kwargs):
kernel_dynamic_dims,
run,
run_bench,
create_vmfb_file=create_vmfb_file,
inplace=True,
)

Expand Down

0 comments on commit 6c36741

Please sign in to comment.