From c6996a580b75d1449ab4b7efe8d3391f62f2768d Mon Sep 17 00:00:00 2001
From: Harsh Menon <harsh@nod-labs.com>
Date: Tue, 3 Dec 2024 21:36:02 -0800
Subject: [PATCH] Address comments #2

---
 iree/turbine/kernel/lang/__init__.py       |  1 +
 iree/turbine/kernel/lang/global_symbols.py | 18 +++---------------
 tests/kernel/wave/wave_evoformer_test.py   |  3 +--
 3 files changed, 5 insertions(+), 17 deletions(-)

diff --git a/iree/turbine/kernel/lang/__init__.py b/iree/turbine/kernel/lang/__init__.py
index bc03f196..e95c2084 100644
--- a/iree/turbine/kernel/lang/__init__.py
+++ b/iree/turbine/kernel/lang/__init__.py
@@ -13,6 +13,7 @@
 )
 
 from .._support.dtype import (
+    DataType,
     bf16,
     bool,
     i4,
diff --git a/iree/turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py
index 652a714d..9e9432c8 100644
--- a/iree/turbine/kernel/lang/global_symbols.py
+++ b/iree/turbine/kernel/lang/global_symbols.py
@@ -1,5 +1,4 @@
 from .._support.indexing import index_symbol
-import sys
 
 # Global symbols used throughout the code.
 
@@ -14,25 +13,14 @@
 WORKGROUP_2 = index_symbol("$WG2")
 
 
-def create_additional_workgroup_symbols():
-    """
-    Since we can have a large number of workgroups, we create
-    symbols for them dynamically. However, only WORKGROUP_0,
-    WORKGROUP_1, and WORKGROUP_2 will persist during code generation,
-    so we generate those symbols statically.
-    """
-    max_workgroups = 5
-    for i in range(3, max_workgroups):
-        globals()[f"WORKGROUP_{i}"] = index_symbol(f"$WG{i}")
-
-
 def get_workgroup_symbol(i: int):
     assert i >= 0, "Workgroup index must be non-negative."
+    symbol_name = f"WORKGROUP_{i}"
+    if symbol_name not in globals():
+        globals()[symbol_name] = index_symbol(f"$WG{i}")
     return index_symbol(f"$WG{i}")
 
 
-create_additional_workgroup_symbols()
-
 THREAD_0 = index_symbol("$T0")
 THREAD_1 = index_symbol("$T1")
 THREAD_2 = index_symbol("$T2")
diff --git a/tests/kernel/wave/wave_evoformer_test.py b/tests/kernel/wave/wave_evoformer_test.py
index d3e33b2e..2ecb44c5 100644
--- a/tests/kernel/wave/wave_evoformer_test.py
+++ b/tests/kernel/wave/wave_evoformer_test.py
@@ -22,7 +22,7 @@
 )
 from iree.turbine.kernel.wave.constraints import MMAType
 from iree.turbine.kernel.wave.templates.evoformer import get_evoformer_kernel
-from iree.turbine.kernel._support.dtype import DataType
+from iree.turbine.kernel.lang import DataType
 import os
 
 _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
@@ -151,7 +151,6 @@ def testEvoformerAttentionForward(
                 f.write(mb.module_op.get_asm())
 
         eps = 1e-2 if output.dtype == torch.float16 else 5e-2
-        print(f"Max diff: {torch.max(torch.abs(torch_ref - output)).item()}")
         assert (
             torch.max(torch.abs(torch_ref - output)).item() < eps
         ), f"out eps: {torch.max(torch.abs(torch_ref - output))}"