Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Dec 3, 2024
1 parent 8dfb73c commit 6ad378b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 28 deletions.
24 changes: 22 additions & 2 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
from .._support.indexing import index_symbol
import sys

# Global symbols used throughout the code.

# Address spaces.
GLOBAL_ADDRESS_SPACE = index_symbol("$GLOBAL_ADDRESS_SPACE")
SHARED_ADDRESS_SPACE = index_symbol("$SHARED_ADDRESS_SPACE")


# Distribution symbols.
WORKGROUP_0 = index_symbol("$WG0")
WORKGROUP_1 = index_symbol("$WG1")
WORKGROUP_2 = index_symbol("$WG2")
WORKGROUP_3 = index_symbol("$WG3")
WORKGROUP_4 = index_symbol("$WG4")


def create_additional_workgroup_symbols():
"""
Since we can have a large number of workgroups, we create
symbols for them dynamically. However, only WORKGROUP_0,
WORKGROUP_1, and WORKGROUP_2 will persist during code generation,
so we generate those symbols statically.
"""
max_workgroups = 5
for i in range(3, max_workgroups):
globals()[f"WORKGROUP_{i}"] = index_symbol(f"$WG{i}")


def get_workgroup_symbol(i: int):
assert i >= 0, "Workgroup index must be non-negative."
return index_symbol(f"$WG{i}")


create_additional_workgroup_symbols()

THREAD_0 = index_symbol("$T0")
THREAD_1 = index_symbol("$T1")
Expand Down
36 changes: 22 additions & 14 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,17 @@ def mma_matrix_shapes(self, mma_type: Optional[MMAType]) -> tuple[int]:
return (16, 16, 16)
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return (32, 32, 8)
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8 | MMAType.I32_16x16x32_I8:
case (
MMAType.F32_16x16x32_F8
| MMAType.F32_16x16x32_K4_F8
| MMAType.I32_16x16x32_I8
):
return (16, 16, 32)
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8 | MMAType.I32_32x32x16_I8:
case (
MMAType.F32_32x32x16_F8
| MMAType.F32_32x32x16_K4_F8
| MMAType.I32_32x32x16_I8
):
return (32, 32, 16)
case _:
return ()
Expand Down Expand Up @@ -234,7 +242,11 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8 | MMAType.I32_16x16x32_I8:
case (
MMAType.F32_16x16x32_F8
| MMAType.F32_16x16x32_K4_F8
| MMAType.I32_16x16x32_I8
):
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
Expand Down Expand Up @@ -262,7 +274,11 @@ def apply(
+ 4 * floor(lane / 16)
+ (GPR_NUM % 4), # K
]
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8 | MMAType.I32_32x32x16_I8:
case (
MMAType.F32_32x32x16_F8
| MMAType.F32_32x32x16_K4_F8
| MMAType.I32_32x32x16_I8
):
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand Down Expand Up @@ -331,16 +347,8 @@ class WorkgroupConstraint(Constraint):
def __post_init__(self):
self.wg_dim = None
match self.workgroup_dim:
case 0:
self.wg_dim = WORKGROUP_0
case 1:
self.wg_dim = WORKGROUP_1
case 2:
self.wg_dim = WORKGROUP_2
case 3:
self.wg_dim = WORKGROUP_3
case 4:
self.wg_dim = WORKGROUP_4
case 0 | 1 | 2 | 3 | 4:
self.wg_dim = get_workgroup_symbol(self.workgroup_dim)
case _:
raise ValueError(
"Invalid workgroup dimension. Expected 0, 1, 2, 3 or 4."
Expand Down
12 changes: 0 additions & 12 deletions iree/turbine/kernel/wave/templates/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,6 @@ def repeat(
BLOCK_M: q_seq_len[TILE_SIZE],
BLOCK_N: v_dim[TILE_SIZE],
BLOCK_K2: kv_seq_len[TILE_SIZE],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 2,
SHUFFLE_UNITS: 2,
}

return evoformer_fwd, symbols
3 changes: 3 additions & 0 deletions tests/kernel/wave/wave_evoformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
device_randn,
device_zeros,
device_randint,
get_default_scheduling_params,
)
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.templates.evoformer import get_evoformer_kernel
Expand Down Expand Up @@ -93,6 +94,8 @@ def testEvoformerAttentionForward(
*shapes_and_tile_sizes, mfma_variant, dtype
)

symbols.update(get_default_scheduling_params())

config = get_default_run_config()
if run_bench:
config["benchmark_batch_size"] = 1000
Expand Down

0 comments on commit 6ad378b

Please sign in to comment.