Skip to content

Commit

Permalink
Add flag to insert host codegen call
Browse files Browse the repository at this point in the history
This is useful for autotuning applications where
we need the IR with the host codegen call.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Oct 5, 2024
1 parent da3436d commit 9865cbc
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions shark_turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..compiler import builder, dispatch_codegen, kernel_codegen, host_codegen
from ..compiler.ir import Context, Operation
from ..compiler.builder import ModuleBuilder
from .codegen import WaveEmitter
from .constraints import (
Constraint,
Expand Down Expand Up @@ -288,8 +289,24 @@ def _trace_and_get_kernel_signature(
if kwargs.get("canonicalize", False):
canonicalize_module(mb.module_op)

if kwargs.get("add_host_codegen_call", False):
self.add_host_codegen_call(mb, exe, kernel_sig, entrypoint_name, kwargs)

return mb, graph, exe, kernel_sig, entrypoint_name

def add_host_codegen_call(
self,
mb: ModuleBuilder,
exe: dispatch_codegen.StreamExecutable,
kernel_sig: kernel_codegen.KernelSignature,
entrypoint_name: str,
kwargs,
):
dynamic_symbols = kwargs.get("dynamic_symbols", [])
host_codegen.isolated_test_call(
mb, exe, kernel_sig, entrypoint_name, dynamic_symbols
)

def test_execute(self, args, kwargs):
(
mb,
Expand All @@ -303,10 +320,7 @@ def test_execute(self, args, kwargs):
run_bench = kwargs.get("run_bench", False)
if run or run_bench:
# TODO: cache compiled code
dynamic_symbols = kwargs.get("dynamic_symbols", [])
host_codegen.isolated_test_call(
mb, exe, kernel_sig, entrypoint_name, dynamic_symbols
)
self.add_host_codegen_call(mb, exe, kernel_sig, entrypoint_name, kwargs)
asm = mb.module_op.get_asm()

kernel_inputs = []
Expand Down

0 comments on commit 9865cbc

Please sign in to comment.