Skip to content

Commit

Permalink
[TKW] Fix broken create_vmfb_file option (#332)
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 authored Dec 13, 2024
1 parent ebc0fa4 commit 7586479
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
remove_chained_extractslice,
subs_idxc,
delinearize_index,
_write_file,
)
from .minimize_global_loads import minimize_global_loads
from .decompose_reduce_ops import decompose_reduce_ops
Expand Down Expand Up @@ -435,6 +436,9 @@ def test_execute(self, args, kwargs):
raise ValueError("no config provided")

compiled_wave_vmfb = compile_to_vmfb(asm, config, run_bench)
if create_vmfb_file is not None:
_write_file(create_vmfb_file, "wb", compiled_wave_vmfb)

kernel_usages = [
binding.kernel_buffer_type.usage
for binding in kernel_sig.kernel_buffer_bindings
Expand Down
55 changes: 55 additions & 0 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,61 @@ def wrapper(shape):
return wrapper


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_copy")[:1])
def test_dump_vmfb(shape, tmp_path, request):
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Each workgroup works on single row of input data, and rows are further
# split into blocks of size up to 256. We have single wave per WG,
# and with default wave size of 64, each thread is operating on up to 4
# elements.
wave_size = 64
BLOCK_M = 1
# Tile size cannot be dynamic, so we use a fixed value here.
BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size)
ELEMS_PER_THREAD = BLOCK_N / wave_size

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=wave_size,
waves_per_block=(1, 1, 1),
vector_shapes={M: BLOCK_M, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD)

config = get_default_run_config()

vmfb_file = tmp_path / "test.vmfb"
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
create_vmfb_file=vmfb_file,
run_config=config,
):
assert not os.path.exists(vmfb_file)
test()
assert os.path.exists(vmfb_file)


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_copy"))
def test_copy(shape, request):
Expand Down

0 comments on commit 7586479

Please sign in to comment.