Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add evoformer example #307

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion iree/turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = [
"DataType",
"bf16",
"bool",
"i4",
"i8",
Expand All @@ -17,7 +18,16 @@
]

_INT_TYPES = ["i1", "i4", "i8", "i16", "i32", "i64"]
_FLOAT_TYPES = ["f16", "f32", "f64", "f8E5M2", "f8E5M2FNUZ", "f8E4M3FN", "f8E4M3FNUZ"]
_FLOAT_TYPES = [
"bf16",
"f16",
"f32",
"f64",
"f8E5M2",
"f8E5M2FNUZ",
"f8E4M3FN",
"f8E4M3FNUZ",
]
_INDEX_TYPES = ["index"]


Expand Down Expand Up @@ -55,9 +65,12 @@ def bitwidth(self):
return 64
if "f8" in self._name:
return 8
if "bf16" in self._name:
return 16
return int(self._name[1:])


bf16 = DataType("bf16")
bool = DataType("bool", "i1")
i4 = DataType("i4")
i8 = DataType("i8")
Expand Down
2 changes: 2 additions & 0 deletions iree/turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
)

from .._support.dtype import (
DataType,
bf16,
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
bool,
i4,
i8,
Expand Down
10 changes: 10 additions & 0 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,21 @@
GLOBAL_ADDRESS_SPACE = index_symbol("$GLOBAL_ADDRESS_SPACE")
SHARED_ADDRESS_SPACE = index_symbol("$SHARED_ADDRESS_SPACE")


# Distribution symbols.
WORKGROUP_0 = index_symbol("$WG0")
WORKGROUP_1 = index_symbol("$WG1")
WORKGROUP_2 = index_symbol("$WG2")


def get_workgroup_symbol(i: int):
assert i >= 0, "Workgroup index must be non-negative."
symbol_name = f"WORKGROUP_{i}"
if symbol_name not in globals():
globals()[symbol_name] = index_symbol(f"$WG{i}")
return index_symbol(f"$WG{i}")


THREAD_0 = index_symbol("$T0")
THREAD_1 = index_symbol("$T1")
THREAD_2 = index_symbol("$T2")
Expand Down
45 changes: 31 additions & 14 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,17 @@ def mma_matrix_shapes(self, mma_type: Optional[MMAType]) -> tuple[int]:
return (16, 16, 16)
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return (32, 32, 8)
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8 | MMAType.I32_16x16x32_I8:
case (
MMAType.F32_16x16x32_F8
| MMAType.F32_16x16x32_K4_F8
| MMAType.I32_16x16x32_I8
):
return (16, 16, 32)
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8 | MMAType.I32_32x32x16_I8:
case (
MMAType.F32_32x32x16_F8
| MMAType.F32_32x32x16_K4_F8
| MMAType.I32_32x32x16_I8
):
return (32, 32, 16)
case _:
return ()
Expand Down Expand Up @@ -234,7 +242,11 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8 | MMAType.I32_16x16x32_I8:
case (
MMAType.F32_16x16x32_F8
| MMAType.F32_16x16x32_K4_F8
| MMAType.I32_16x16x32_I8
):
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
Expand Down Expand Up @@ -262,7 +274,11 @@ def apply(
+ 4 * floor(lane / 16)
+ (GPR_NUM % 4), # K
]
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8 | MMAType.I32_32x32x16_I8:
case (
MMAType.F32_32x32x16_F8
| MMAType.F32_32x32x16_K4_F8
| MMAType.I32_32x32x16_I8
):
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand Down Expand Up @@ -328,6 +344,16 @@ class WorkgroupConstraint(Constraint):
tile_size: IndexExpr
workgroup_dim: int

def __post_init__(self):
self.wg_dim = None
match self.workgroup_dim:
case 0 | 1 | 2 | 3 | 4:
self.wg_dim = get_workgroup_symbol(self.workgroup_dim)
case _:
raise ValueError(
"Invalid workgroup dimension. Expected 0, 1, 2, 3 or 4."
)

@property
def count(self) -> IndexExpr:
"""
Expand All @@ -336,16 +362,7 @@ def count(self) -> IndexExpr:
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexSequence:
match self.workgroup_dim:
case 0:
wg_dim = WORKGROUP_0
case 1:
wg_dim = WORKGROUP_1
case 2:
wg_dim = WORKGROUP_2
case _:
raise ValueError("Invalid workgroup dimension. Expected 0, 1 or 2.")
return IndexSequence(wg_dim * self.tile_size, 1)
return IndexSequence(self.wg_dim * self.tile_size, 1)


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
Expand Down
23 changes: 19 additions & 4 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import itertools
import torch.fx as fx
from typing import Any, TypeAlias
from typing import Any, TypeAlias, Sequence, Type, Callable
from functools import partial

from .constraints import (
Expand All @@ -15,8 +15,23 @@
WorkgroupConstraint,
TilingConstraint,
)
from ..ops.wave_ops import *
from .._support.indexing import IndexingContext, IndexSequence
from ..ops.wave_ops import (
Allocate,
CustomOp,
GetResult,
Getitem,
IterArg,
MMA,
Output,
Placeholder,
Read,
ReduceOp,
Reduction,
Reshape,
Write,
get_custom,
)
from .._support.indexing import IndexingContext, IndexSymbol
from ...support.logging import get_logger
from .._support.tracing import CapturedTrace
from .utils import (
Expand Down Expand Up @@ -440,7 +455,7 @@ def _expand_mma_reduction(
for dim in mma.indexing_dims:
if dim not in dim_scaling and mma.vector_shapes[dim] > 0:
tile_size = idxc.get_static_value(dim)
dim_scaling[dim] = tile_size // mma.vector_shapes[dim]
dim_scaling[dim] = max(tile_size // mma.vector_shapes[dim], 1)

# Store the original mma node and accumulator value for expansion.
# When we begin expansion, we have a single mma node with the correct accumulator.
Expand Down
197 changes: 197 additions & 0 deletions iree/turbine/kernel/wave/templates/evoformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel._support.dtype import DataType
from iree.turbine.kernel.wave.utils import (
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
)


def get_evoformer_kernel(
batch: tuple[int, int],
n: tuple[int, int],
kv_seq_len: tuple[int, int],
heads: tuple[int, int],
head_dim: tuple[int, int],
q_seq_len: tuple[int, int],
v_dim: tuple[int, int],
mfma_variant: MMAType,
Hardcode84 marked this conversation as resolved.
Show resolved Hide resolved
datatype: DataType,
):
assert datatype in [tkl.f16, tkl.bf16], f"Unsupported datatype: {datatype}"

# Input sizes
B = tkl.sym.B
BN = tkl.sym.BN
M = tkl.sym.M
H = tkl.sym.H
N = tkl.sym.N
K1 = tkl.sym.K1
K2 = tkl.sym.K2
# Workgroup tile sizes
BLOCK_B = tkl.sym.BLOCK_B
BLOCK_BN = tkl.sym.BLOCK_BN
BLOCK_H = tkl.sym.BLOCK_H
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K2 = tkl.sym.BLOCK_K2
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
ratio_m = 2
ratio_n = 1
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.WorkgroupConstraint(BN, BLOCK_BN, 3)]
constraints += [tkw.WorkgroupConstraint(H, BLOCK_H, 4)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / ratio_n)]

if mfma_variant == MMAType.F32_16x16x16_F16:
Mvec = 16
Nvec = 16
if mfma_variant == MMAType.F32_32x32x8_F16:
Mvec = 32
Nvec = 32

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(ratio_m, ratio_n, 1),
mma_type=mfma_variant,
vector_shapes={B: 0, BN: 0, H: 0, M: Mvec, N: Nvec},
)
]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
l = tkw.IndexMapping.iterator(3)
m = tkw.IndexMapping.iterator(4)
# [B, BN, M, H, K1] -> [B, BN, H, M, K1]
q_mapping = tkw.IndexMapping(
num_iterators=5,
inputs={B: i, BN: j, H: k, M: l, K1: m},
outputs={B: i, BN: j, H: k, M: l, K1: m},
)
# [B, BN, K2, H, K1] -> [B, BN, H, K2, K1]
k_mapping = tkw.IndexMapping(
num_iterators=5,
inputs={B: i, BN: j, H: k, K2: l, K1: m},
outputs={B: i, BN: j, H: k, K2: l, K1: m},
)
# [B, BN, N, H, K2] -> [B, BN, H, N, K2]
v_mapping = tkw.IndexMapping(
num_iterators=5,
inputs={B: i, BN: j, H: k, N: l, K2: m},
outputs={B: i, BN: j, H: k, N: l, K2: m},
)
# [B, BN, H, N, M] -> [B, BN, M, H, N]
o_mapping = tkw.IndexMapping(
num_iterators=5,
inputs={B: i, BN: j, H: k, N: l, M: m},
outputs={B: i, BN: j, H: k, N: l, M: m},
)

@tkw.wave(constraints)
def evoformer_fwd(
q: tkl.Memory[B, BN, M, H, K1, GLOBAL_ADDRESS_SPACE, datatype],
k: tkl.Memory[B, BN, K2, H, K1, ADDRESS_SPACE, datatype],
v: tkl.Memory[B, BN, N, H, K2, ADDRESS_SPACE, datatype],
mask: tkl.Memory[B, BN, K2, GLOBAL_ADDRESS_SPACE, datatype],
bias: tkl.Memory[B, H, M, K2, GLOBAL_ADDRESS_SPACE, datatype],
c: tkl.Memory[B, BN, M, H, N, GLOBAL_ADDRESS_SPACE, datatype],
):
c_reg = tkl.Register[B, BN, H, N, M, tkl.f32](0.0)
init_sum = tkl.Register[B, BN, H, M, tkl.f32](0.0)
init_max = tkl.Register[B, BN, H, M, tkl.f32](-1e6)

@tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])
def repeat(
partial_max: tkl.Register[B, BN, H, M, tkl.f32],
partial_sum: tkl.Register[B, BN, H, M, tkl.f32],
acc: tkl.Register[B, BN, H, N, M, tkl.f32],
) -> (
tkl.Register[B, BN, H, M, tkl.f32],
tkl.Register[B, BN, H, M, tkl.f32],
tkl.Register[B, BN, H, N, M, tkl.f32],
):
imm_reg = tkl.Register[B, BN, H, K2, M, tkl.f32](0.0)
q_reg = tkw.read(
q, mapping=q_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD
)
if datatype == tkl.bf16:
q_reg = tkw.cast(tkw.cast(q_reg, tkl.f32), tkl.f16)
k_reg = tkw.read(
k, mapping=k_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD
)
if datatype == tkl.bf16:
k_reg = tkw.cast(tkw.cast(k_reg, tkl.f32), tkl.f16)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg)
x_j = tkw.permute(inner_acc, target_shape=[B, BN, H, M, K2])
mask_reg = tkw.read(mask, elements_per_thread=STORE_ELEMS_PER_THREAD)
casted_mask_reg = tkw.cast(mask_reg, tkl.f32)
y_j = x_j + casted_mask_reg
bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD)
casted_bias_reg = tkw.cast(bias_reg, tkl.f32)
z_j = y_j + casted_bias_reg
m_j = tkw.max(z_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
e_delta = tkw.exp2(z_j - m_j)
e_init = partial_sum * e_delta_max
d_j = tkw.sum(e_delta, e_init, dim=K2)
imm_f16 = tkw.cast(e_delta, tkl.f16)
v_reg = tkw.read(
v, mapping=v_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD
)
if datatype == tkl.bf16:
v_reg = tkw.cast(tkw.cast(v_reg, tkl.f32), tkl.f16)
new_acc = acc * e_delta_max
acc = tkw.mma(v_reg, imm_f16, new_acc)
return m_j, d_j, acc

# repeat represents the results of the loop
res_max, res_sum, res_mm = repeat
res = res_mm / res_sum
casted = tkw.cast(res, datatype)
tkw.write(
casted, c, mapping=o_mapping, elements_per_thread=STORE_ELEMS_PER_THREAD
)

SHAPE = 0
TILE_SIZE = 1

symbols = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
B: batch[SHAPE],
BN: n[SHAPE],
K2: kv_seq_len[SHAPE],
H: heads[SHAPE],
K1: head_dim[SHAPE],
M: q_seq_len[SHAPE],
N: v_dim[SHAPE],
BLOCK_B: batch[TILE_SIZE],
BLOCK_BN: n[TILE_SIZE],
BLOCK_H: heads[TILE_SIZE],
BLOCK_M: q_seq_len[TILE_SIZE],
BLOCK_N: v_dim[TILE_SIZE],
BLOCK_K2: kv_seq_len[TILE_SIZE],
}

return evoformer_fwd, symbols
Loading
Loading