Skip to content

Commit

Permalink
Address comments #2
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Dec 4, 2024
1 parent 807635c commit c6996a5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 17 deletions.
1 change: 1 addition & 0 deletions iree/turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from .._support.dtype import (
DataType,
bf16,
bool,
i4,
Expand Down
18 changes: 3 additions & 15 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .._support.indexing import index_symbol
import sys

# Global symbols used throughout the code.

Expand All @@ -14,25 +13,14 @@
WORKGROUP_2 = index_symbol("$WG2")


def create_additional_workgroup_symbols():
"""
Since we can have a large number of workgroups, we create
symbols for them dynamically. However, only WORKGROUP_0,
WORKGROUP_1, and WORKGROUP_2 will persist during code generation,
so we generate those symbols statically.
"""
max_workgroups = 5
for i in range(3, max_workgroups):
globals()[f"WORKGROUP_{i}"] = index_symbol(f"$WG{i}")


def get_workgroup_symbol(i: int):
assert i >= 0, "Workgroup index must be non-negative."
symbol_name = f"WORKGROUP_{i}"
if symbol_name not in globals():
globals()[symbol_name] = index_symbol(f"$WG{i}")
return index_symbol(f"$WG{i}")


create_additional_workgroup_symbols()

THREAD_0 = index_symbol("$T0")
THREAD_1 = index_symbol("$T1")
THREAD_2 = index_symbol("$T2")
Expand Down
3 changes: 1 addition & 2 deletions tests/kernel/wave/wave_evoformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.templates.evoformer import get_evoformer_kernel
from iree.turbine.kernel._support.dtype import DataType
from iree.turbine.kernel.lang import DataType
import os

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
Expand Down Expand Up @@ -151,7 +151,6 @@ def testEvoformerAttentionForward(
f.write(mb.module_op.get_asm())

eps = 1e-2 if output.dtype == torch.float16 else 5e-2
print(f"Max diff: {torch.max(torch.abs(torch_ref - output)).item()}")
assert (
torch.max(torch.abs(torch_ref - output)).item() < eps
), f"out eps: {torch.max(torch.abs(torch_ref - output))}"

0 comments on commit c6996a5

Please sign in to comment.