diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 88952c996..762a0fcd6 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -39,6 +39,7 @@ from .assumptions import Assumption import torch.fx as fx import iree.turbine.kernel.lang as tkl +from pathlib import Path import tempfile @@ -541,6 +542,7 @@ def compile_and_invoke( kernel_dynamic_dims: list[int] = [], run: bool = False, run_bench: bool = False, + create_vmfb_file: Optional[Path] = None, inplace: bool = False, ): backend = config["backend"] @@ -588,9 +590,12 @@ def compile_and_invoke( res = compile_str(asm, target_backends=[backend], extra_args=flags) - dump_vmfb_file = config.get("dump_vmfb_file", None) - if dump_vmfb_file is not None: - _write_file(dump_vmfb_file, "wb", res) + if create_vmfb_file is not None: + _write_file(create_vmfb_file, "wb", res) + + if not (run or run_bench): + return + if inplace: # Select device as the GPU, where input tensors are coming from. device_uuid = get_device_uuid(kernel_inputs + kernel_outputs) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index aaa43e744..611686629 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -336,7 +336,8 @@ def test_execute(self, args, kwargs): run = kwargs.get("run", False) run_bench = kwargs.get("run_bench", False) - if run or run_bench: + create_vmfb_file = kwargs.get("create_vmfb_file", None) + if run or run_bench or create_vmfb_file: # TODO: cache compiled code dynamic_symbols = kwargs.get("dynamic_symbols", []) host_codegen.isolated_test_call( @@ -372,6 +373,7 @@ def test_execute(self, args, kwargs): kernel_dynamic_dims, run, run_bench, + create_vmfb_file=create_vmfb_file, inplace=True, )