From c6996a580b75d1449ab4b7efe8d3391f62f2768d Mon Sep 17 00:00:00 2001 From: Harsh Menon <harsh@nod-labs.com> Date: Tue, 3 Dec 2024 21:36:02 -0800 Subject: [PATCH] Address comments #2 --- iree/turbine/kernel/lang/__init__.py | 1 + iree/turbine/kernel/lang/global_symbols.py | 18 +++--------------- tests/kernel/wave/wave_evoformer_test.py | 3 +-- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/iree/turbine/kernel/lang/__init__.py b/iree/turbine/kernel/lang/__init__.py index bc03f196..e95c2084 100644 --- a/iree/turbine/kernel/lang/__init__.py +++ b/iree/turbine/kernel/lang/__init__.py @@ -13,6 +13,7 @@ ) from .._support.dtype import ( + DataType, bf16, bool, i4, diff --git a/iree/turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py index 652a714d..9e9432c8 100644 --- a/iree/turbine/kernel/lang/global_symbols.py +++ b/iree/turbine/kernel/lang/global_symbols.py @@ -1,5 +1,4 @@ from .._support.indexing import index_symbol -import sys # Global symbols used throughout the code. @@ -14,25 +13,14 @@ WORKGROUP_2 = index_symbol("$WG2") -def create_additional_workgroup_symbols(): - """ - Since we can have a large number of workgroups, we create - symbols for them dynamically. However, only WORKGROUP_0, - WORKGROUP_1, and WORKGROUP_2 will persist during code generation, - so we generate those symbols statically. - """ - max_workgroups = 5 - for i in range(3, max_workgroups): - globals()[f"WORKGROUP_{i}"] = index_symbol(f"$WG{i}") - - def get_workgroup_symbol(i: int): assert i >= 0, "Workgroup index must be non-negative." + symbol_name = f"WORKGROUP_{i}" + if symbol_name not in globals(): + globals()[symbol_name] = index_symbol(f"$WG{i}") return index_symbol(f"$WG{i}") -create_additional_workgroup_symbols() - THREAD_0 = index_symbol("$T0") THREAD_1 = index_symbol("$T1") THREAD_2 = index_symbol("$T2") diff --git a/tests/kernel/wave/wave_evoformer_test.py b/tests/kernel/wave/wave_evoformer_test.py index d3e33b2e..2ecb44c5 100644 --- a/tests/kernel/wave/wave_evoformer_test.py +++ b/tests/kernel/wave/wave_evoformer_test.py @@ -22,7 +22,7 @@ ) from iree.turbine.kernel.wave.constraints import MMAType from iree.turbine.kernel.wave.templates.evoformer import get_evoformer_kernel -from iree.turbine.kernel._support.dtype import DataType +from iree.turbine.kernel.lang import DataType import os _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) @@ -151,7 +151,6 @@ def testEvoformerAttentionForward( f.write(mb.module_op.get_asm()) eps = 1e-2 if output.dtype == torch.float16 else 5e-2 - print(f"Max diff: {torch.max(torch.abs(torch_ref - output)).item()}") assert ( torch.max(torch.abs(torch_ref - output)).item() < eps ), f"out eps: {torch.max(torch.abs(torch_ref - output))}"