Skip to content

Commit

Permalink
Add flag to insert host codegen call
Browse files Browse the repository at this point in the history
This is useful for autotuning applications where
we need the IR with the host codegen call.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Oct 5, 2024
1 parent b0ef345 commit d395865
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..compiler import builder, dispatch_codegen, kernel_codegen, host_codegen
from ..compiler.ir import Context, Operation
from ..compiler.builder import ModuleBuilder
from .codegen import WaveEmitter
from .constraints import (
Constraint,
Expand Down Expand Up @@ -41,7 +42,7 @@
from .shared_memory_indexing import apply_shared_memory_indexing_corrections
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
from .._support.indexing import IndexingContext, IndexExpr, IndexSymbol
import iree.turbine.kernel.lang as tkl
from .._support.tracing import (
CapturedTrace,
Expand Down Expand Up @@ -288,8 +289,24 @@ def _trace_and_get_kernel_signature(
if kwargs.get("canonicalize", False):
canonicalize_module(mb.module_op)

if kwargs.get("add_host_codegen_call", False):
dynamic_symbols = kwargs.get("dynamic_symbols", [])
self.add_host_codegen_call(mb, exe, kernel_sig, entrypoint_name, kwargs)

return mb, graph, exe, kernel_sig, entrypoint_name

def add_host_codegen_call(
self,
mb: ModuleBuilder,
exe: dispatch_codegen.StreamExecutable,
kernel_sig: kernel_codegen.KernelSignature,
entrypoint_name: str,
dynamic_symbols: list[IndexSymbol],
):
host_codegen.isolated_test_call(
mb, exe, kernel_sig, entrypoint_name, dynamic_symbols
)

def test_execute(self, args, kwargs):
(
mb,
Expand All @@ -303,8 +320,13 @@ def test_execute(self, args, kwargs):
run_bench = kwargs.get("run_bench", False)
if run or run_bench:
# TODO: cache compiled code
if kwargs.get("add_host_codegen_call", False):
raise ValueError(
"Cannot set add_host_codegen_call when attempting to run or benchmark the kernel."
)

dynamic_symbols = kwargs.get("dynamic_symbols", [])
host_codegen.isolated_test_call(
self.add_host_codegen_call(
mb, exe, kernel_sig, entrypoint_name, dynamic_symbols
)
asm = mb.module_op.get_asm()
Expand Down

0 comments on commit d395865

Please sign in to comment.