Skip to content

Commit

Permalink
[WIP] TMA Version of HSTU (Autotuned)
Browse files Browse the repository at this point in the history
Based on #57, this version uses the autotuned to toggle use of TMA.
  • Loading branch information
plotfi committed Sep 7, 2024
1 parent 2050a7c commit 5368955
Showing 1 changed file with 184 additions and 39 deletions.
223 changes: 184 additions & 39 deletions ops/triton/triton_ragged_hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _get_fw_configs() -> List[triton.Config]: # noqa: C901
)
)
else:
configs = [
base_configs = [
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32},
num_stages=2,
Expand Down Expand Up @@ -204,6 +204,20 @@ def _get_fw_configs() -> List[triton.Config]: # noqa: C901
num_warps=8,
),
]

for config in base_configs:
for tma_config in [False, True]:
configs.append(
triton.Config(
{
"BLOCK_M": config.kwargs["BLOCK_M"], "BLOCK_N": config.kwargs["BLOCK_N"],
"enable_tma": tma_config
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)
)

return configs


Expand All @@ -218,6 +232,10 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901
q,
K_block_ptr,
V_block_ptr,
K_desc_ptr,
V_desc_ptr,
offset_h,
seq_start,
n_targets,
ts_1_ptrs,
ts_0,
Expand Down Expand Up @@ -245,13 +263,27 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901
HAS_MAX_ATTN_LEN: tl.constexpr,
IS_DELTA_Q: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_D_Q: tl.constexpr,
BLOCK_D_V: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
enable_tma: tl.constexpr,
):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha
k = None
qk = None
if enable_tma:
k = tl._experimental_descriptor_load(
K_desc_ptr,
[(seq_start + start_n).to(tl.int32), offset_h.to(tl.int32)],
[BLOCK_N, BLOCK_D_Q],
tl.bfloat16,
)
qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * alpha
else:
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha
invalid_mask = offs_m[:, None] == offs_n[None, :]
if HAS_MULTIPLE_TARGETS:
if INVALID_MASK_TYPE == "lower_triangular":
Expand Down Expand Up @@ -335,7 +367,17 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901
silu = tl.where(invalid_mask, silu, 0)
if HAS_ATTN_SCALE:
silu = silu * attn_scale[:, None]
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")

v = None
if enable_tma:
v = tl._experimental_descriptor_load(
V_desc_ptr,
[(seq_start + start_n).to(tl.int32), offset_h.to(tl.int32)],
[BLOCK_N, BLOCK_D_V],
tl.bfloat16,
)
else:
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
silu = silu.to(v.dtype)
return tl.dot(silu, v, allow_tf32=ALLOW_TF32)

Expand All @@ -359,6 +401,8 @@ def _ragged_hstu_attn_fwd( # noqa C901
Q,
K,
V,
desc_k,
desc_v,
seq_offsets,
TS,
TW,
Expand Down Expand Up @@ -409,6 +453,7 @@ def _ragged_hstu_attn_fwd( # noqa C901
BLOCK_N: tl.constexpr,
max_attn_len: tl.constexpr,
HAS_MAX_ATTN_LEN: tl.constexpr,
enable_tma: tl.constexpr,
):
# M_CTX == N_CTX
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -452,22 +497,25 @@ def _ragged_hstu_attn_fwd( # noqa C901
block_shape=(BLOCK_M, BLOCK_D_Q),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + off_h * stride_kh + seq_start * stride_kn,
shape=(BLOCK_D_Q, seq_len),
strides=(1, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_D_Q, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + off_h * stride_vh + seq_start * stride_vn,
shape=(seq_len, BLOCK_D_V),
strides=(stride_vn, 1),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_D_V),
order=(1, 0),
)
K_block_ptr = None
V_block_ptr = None
if not enable_tma:
K_block_ptr = tl.make_block_ptr(
base=K + off_h * stride_kh + seq_start * stride_kn,
shape=(BLOCK_D_Q, seq_len),
strides=(1, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_D_Q, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + off_h * stride_vh + seq_start * stride_vn,
shape=(seq_len, BLOCK_D_V),
strides=(stride_vn, 1),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_D_V),
order=(1, 0),
)
mask_m = offs_m < seq_len
if ATTN_BIAS_TYPE == "fused" and USE_TIME_BIAS:
ts_0_ptrs = TS + off_z * stride_ts + offs_m
Expand All @@ -486,6 +534,7 @@ def _ragged_hstu_attn_fwd( # noqa C901
scale_ptrs = Scale + off_z * stride_sz
attn_scale = tl.load(scale_ptrs + offs_m * stride_sm, mask=offs_m < seq_len)

# convert q to tma
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32)
if INVALID_MASK_TYPE == "lower_triangular":
Expand All @@ -511,8 +560,13 @@ def _ragged_hstu_attn_fwd( # noqa C901
elif INVALID_MASK_TYPE == "upper_triangular":
low = start_m
high = seq_len
if enable_tma:
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[desc_k], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[desc_v], dtype=tl.int32, is_pure=False, pack=1)
# pyre-ignore[61]
if low > 0:
if low > 0 and not enable_tma:
# pyre-ignore[61]
K_block_ptr = tl.advance(K_block_ptr, (0, low))
# pyre-ignore[61]
Expand All @@ -531,6 +585,11 @@ def _ragged_hstu_attn_fwd( # noqa C901
q=q,
K_block_ptr=K_block_ptr,
V_block_ptr=V_block_ptr,
K_desc_ptr=desc_k,
V_desc_ptr=desc_v,
offset_h=off_h * stride_vh,
seq_start=seq_start,
# pyre-ignore[61]
n_targets=n_targets if HAS_MULTIPLE_TARGETS else None,
ts_1_ptrs=(
# pyre-ignore[61]
Expand Down Expand Up @@ -566,20 +625,25 @@ def _ragged_hstu_attn_fwd( # noqa C901
HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN,
IS_DELTA_Q=IS_DELTA_Q,
ALLOW_TF32=ALLOW_TF32,
BLOCK_D_Q=BLOCK_D_Q,
BLOCK_D_V=BLOCK_D_V,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
enable_tma=enable_tma,
)
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if not enable_tma:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))

if HAS_MULTIPLE_TARGETS and INVALID_MASK_TYPE == "lower_triangular":
# pyre-ignore[61]
if uih_end < start_m:
low_delta = start_m
high_delta = start_m + BLOCK_M
offset = (low_delta - uih_end).to(tl.int32) # pyre-ignore [61]
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
if not enable_tma:
offset = (low_delta - uih_end).to(tl.int32) # pyre-ignore [61]
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
for start_delta in range(low_delta, high_delta, BLOCK_N):
cur_offs_n = offs_n + start_delta
mask_n = cur_offs_n < seq_len
Expand All @@ -593,6 +657,10 @@ def _ragged_hstu_attn_fwd( # noqa C901
q=q,
K_block_ptr=K_block_ptr,
V_block_ptr=V_block_ptr,
K_desc_ptr=desc_k,
V_desc_ptr=desc_v,
offset_h=off_h * stride_vh,
seq_start=seq_start,
n_targets=n_targets if HAS_MULTIPLE_TARGETS else None,
ts_1_ptrs=(
# pyre-ignore[61]
Expand Down Expand Up @@ -632,11 +700,15 @@ def _ragged_hstu_attn_fwd( # noqa C901
HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN,
IS_DELTA_Q=IS_DELTA_Q,
ALLOW_TF32=ALLOW_TF32,
BLOCK_D_Q=BLOCK_D_Q,
BLOCK_D_V=BLOCK_D_V,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
enable_tma=enable_tma,
)
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if not enable_tma:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))

if IS_DELTA_Q:
start_m_delta = tl.program_id(0) * BLOCK_M
Expand All @@ -648,6 +720,7 @@ def _ragged_hstu_attn_fwd( # noqa C901
+ offs_v_d[None, :]
)
out_ptrs = Out + off_o
# todo: convert out to tma
tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None])
else:
# rematerialize offsets to save registers
Expand Down Expand Up @@ -691,10 +764,43 @@ def triton_ragged_attention(
has_attn_scale = attn_scale is not None
has_max_attn_len = max_attn_len is not None

grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)
TMA_SIZE = 128
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)

def grid_tma(META):
if META['enable_tma'] == False:
return lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

nonlocal desc_k
nonlocal desc_v
k_buf = torch.empty_like(desc_k, device="cpu", pin_memory=True)
v_buf = torch.empty_like(desc_v, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1)

stride_sz = 0
stride_sm = 0
Expand All @@ -705,10 +811,12 @@ def triton_ragged_attention(
stride_sz = attn_scale.stride(0)
stride_sm = attn_scale.stride(1)

_ragged_hstu_attn_fwd[grid](
_ragged_hstu_attn_fwd[grid_tma](
Q=q,
K=k,
V=v,
desc_k=desc_k,
desc_v=desc_v,
seq_offsets=seq_offsets,
TS=None,
TW=None,
Expand Down Expand Up @@ -789,13 +897,48 @@ def triton_ragged_attention_relative_bias(
has_multiple_targets = num_targets is not None
has_max_pos_id = max_pos_ind is not None
has_max_attn_len = max_attn_len is not None
_, H, DimQ = q.shape
L, H, DimQ = q.shape
_, _, DimV = v.shape
out = torch.empty_like(v)
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

TMA_SIZE = 128
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)

def grid_tma(META):
if META['enable_tma'] == False:
return lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

nonlocal desc_k
nonlocal desc_v
k_buf = torch.empty_like(desc_k, device="cpu", pin_memory=True)
v_buf = torch.empty_like(desc_v, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1)

stride_sz = 0
stride_sm = 0
if attn_scale is not None:
Expand All @@ -807,10 +950,12 @@ def triton_ragged_attention_relative_bias(
use_time_bias = relative_bias_type == "TIME" or relative_bias_type == "ALL"
use_pos_bias = relative_bias_type == "POSITION" or relative_bias_type == "ALL"

_ragged_hstu_attn_fwd[grid](
_ragged_hstu_attn_fwd[grid_tma](
Q=q,
K=k,
V=v,
desc_k=desc_k,
desc_v=desc_v,
seq_offsets=seq_offsets,
TS=timestamps,
TW=ts_weights,
Expand Down

0 comments on commit 5368955

Please sign in to comment.