From 7586479aedcb9d81accd896946a15dda959c906e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Dec 2024 19:26:46 +0100 Subject: [PATCH] [TKW] Fix broken `create_vmfb_file` option (#332) Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/wave.py | 4 +++ tests/kernel/wave/wave_e2e_test.py | 55 ++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index b8fae83d..15ffe7f2 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -32,6 +32,7 @@ remove_chained_extractslice, subs_idxc, delinearize_index, + _write_file, ) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops @@ -435,6 +436,9 @@ def test_execute(self, args, kwargs): raise ValueError("no config provided") compiled_wave_vmfb = compile_to_vmfb(asm, config, run_bench) + if create_vmfb_file is not None: + _write_file(create_vmfb_file, "wb", compiled_wave_vmfb) + kernel_usages = [ binding.kernel_buffer_type.usage for binding in kernel_sig.kernel_buffer_bindings diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index cea5308d..7f5f82aa 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -75,6 +75,61 @@ def wrapper(shape): return wrapper +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")[:1]) +def test_dump_vmfb(shape, tmp_path, request): + M = tkl.sym.M + N = tkl.sym.N + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + ELEMS_PER_THREAD = BLOCK_N / wave_size + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + waves_per_block=(1, 1, 1), + vector_shapes={M: BLOCK_M, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) + + config = get_default_run_config() + + vmfb_file = tmp_path / "test.vmfb" + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + create_vmfb_file=vmfb_file, + run_config=config, + ): + assert not os.path.exists(vmfb_file) + test() + assert os.path.exists(vmfb_file) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) def test_copy(shape, request):