Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] TMA Version of HSTU (Autotuned) #71

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 184 additions & 39 deletions ops/triton/triton_ragged_hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _get_fw_configs() -> List[triton.Config]: # noqa: C901
)
)
else:
configs = [
base_configs = [
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32},
num_stages=2,
Expand Down Expand Up @@ -204,6 +204,20 @@ def _get_fw_configs() -> List[triton.Config]: # noqa: C901
num_warps=8,
),
]

for config in base_configs:
for tma_config in [False, True]:
configs.append(
triton.Config(
{
"BLOCK_M": config.kwargs["BLOCK_M"], "BLOCK_N": config.kwargs["BLOCK_N"],
"enable_tma": tma_config
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)
)

return configs


Expand All @@ -218,6 +232,10 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901
q,
K_block_ptr,
V_block_ptr,
K_desc_ptr,
V_desc_ptr,
offset_h,
seq_start,
n_targets,
ts_1_ptrs,
ts_0,
Expand Down Expand Up @@ -245,13 +263,27 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901
HAS_MAX_ATTN_LEN: tl.constexpr,
IS_DELTA_Q: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_D_Q: tl.constexpr,
BLOCK_D_V: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
enable_tma: tl.constexpr,
):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha
k = None
qk = None
if enable_tma:
k = tl._experimental_descriptor_load(
K_desc_ptr,
[(seq_start + start_n).to(tl.int32), offset_h.to(tl.int32)],
[BLOCK_N, BLOCK_D_Q],
tl.bfloat16,
)
qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * alpha
else:
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha
invalid_mask = offs_m[:, None] == offs_n[None, :]
plotfi marked this conversation as resolved.
Show resolved Hide resolved
if HAS_MULTIPLE_TARGETS:
if INVALID_MASK_TYPE == "lower_triangular":
Expand Down Expand Up @@ -335,7 +367,17 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901
silu = tl.where(invalid_mask, silu, 0)
if HAS_ATTN_SCALE:
silu = silu * attn_scale[:, None]
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")

v = None
if enable_tma:
v = tl._experimental_descriptor_load(
V_desc_ptr,
[(seq_start + start_n).to(tl.int32), offset_h.to(tl.int32)],
[BLOCK_N, BLOCK_D_V],
tl.bfloat16,
)
else:
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
silu = silu.to(v.dtype)
return tl.dot(silu, v, allow_tf32=ALLOW_TF32)

Expand All @@ -359,6 +401,8 @@ def _ragged_hstu_attn_fwd( # noqa C901
Q,
K,
V,
desc_k,
desc_v,
seq_offsets,
TS,
TW,
Expand Down Expand Up @@ -409,6 +453,7 @@ def _ragged_hstu_attn_fwd( # noqa C901
BLOCK_N: tl.constexpr,
max_attn_len: tl.constexpr,
HAS_MAX_ATTN_LEN: tl.constexpr,
enable_tma: tl.constexpr,
plotfi marked this conversation as resolved.
Show resolved Hide resolved
):
# M_CTX == N_CTX
off_hz = tl.program_id(1)
Expand Down Expand Up @@ -452,22 +497,25 @@ def _ragged_hstu_attn_fwd( # noqa C901
block_shape=(BLOCK_M, BLOCK_D_Q),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + off_h * stride_kh + seq_start * stride_kn,
shape=(BLOCK_D_Q, seq_len),
strides=(1, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_D_Q, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + off_h * stride_vh + seq_start * stride_vn,
shape=(seq_len, BLOCK_D_V),
strides=(stride_vn, 1),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_D_V),
order=(1, 0),
)
K_block_ptr = None
V_block_ptr = None
if not enable_tma:
K_block_ptr = tl.make_block_ptr(
base=K + off_h * stride_kh + seq_start * stride_kn,
shape=(BLOCK_D_Q, seq_len),
strides=(1, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_D_Q, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + off_h * stride_vh + seq_start * stride_vn,
shape=(seq_len, BLOCK_D_V),
strides=(stride_vn, 1),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_D_V),
order=(1, 0),
)
mask_m = offs_m < seq_len
if ATTN_BIAS_TYPE == "fused" and USE_TIME_BIAS:
ts_0_ptrs = TS + off_z * stride_ts + offs_m
Expand All @@ -486,6 +534,7 @@ def _ragged_hstu_attn_fwd( # noqa C901
scale_ptrs = Scale + off_z * stride_sz
attn_scale = tl.load(scale_ptrs + offs_m * stride_sm, mask=offs_m < seq_len)

# convert q to tma
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32)
if INVALID_MASK_TYPE == "lower_triangular":
Expand All @@ -511,8 +560,13 @@ def _ragged_hstu_attn_fwd( # noqa C901
elif INVALID_MASK_TYPE == "upper_triangular":
low = start_m
high = seq_len
if enable_tma:
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[desc_k], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[desc_v], dtype=tl.int32, is_pure=False, pack=1)
# pyre-ignore[61]
if low > 0:
if low > 0 and not enable_tma:
# pyre-ignore[61]
K_block_ptr = tl.advance(K_block_ptr, (0, low))
# pyre-ignore[61]
Expand All @@ -531,6 +585,11 @@ def _ragged_hstu_attn_fwd( # noqa C901
q=q,
K_block_ptr=K_block_ptr,
V_block_ptr=V_block_ptr,
K_desc_ptr=desc_k,
V_desc_ptr=desc_v,
offset_h=off_h * stride_vh,
seq_start=seq_start,
# pyre-ignore[61]
n_targets=n_targets if HAS_MULTIPLE_TARGETS else None,
ts_1_ptrs=(
# pyre-ignore[61]
Expand Down Expand Up @@ -566,20 +625,25 @@ def _ragged_hstu_attn_fwd( # noqa C901
HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN,
IS_DELTA_Q=IS_DELTA_Q,
ALLOW_TF32=ALLOW_TF32,
plotfi marked this conversation as resolved.
Show resolved Hide resolved
BLOCK_D_Q=BLOCK_D_Q,
BLOCK_D_V=BLOCK_D_V,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
enable_tma=enable_tma,
)
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if not enable_tma:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))

if HAS_MULTIPLE_TARGETS and INVALID_MASK_TYPE == "lower_triangular":
# pyre-ignore[61]
if uih_end < start_m:
low_delta = start_m
high_delta = start_m + BLOCK_M
offset = (low_delta - uih_end).to(tl.int32) # pyre-ignore [61]
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
if not enable_tma:
offset = (low_delta - uih_end).to(tl.int32) # pyre-ignore [61]
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
for start_delta in range(low_delta, high_delta, BLOCK_N):
cur_offs_n = offs_n + start_delta
mask_n = cur_offs_n < seq_len
Expand All @@ -593,6 +657,10 @@ def _ragged_hstu_attn_fwd( # noqa C901
q=q,
K_block_ptr=K_block_ptr,
V_block_ptr=V_block_ptr,
K_desc_ptr=desc_k,
V_desc_ptr=desc_v,
offset_h=off_h * stride_vh,
seq_start=seq_start,
n_targets=n_targets if HAS_MULTIPLE_TARGETS else None,
ts_1_ptrs=(
# pyre-ignore[61]
Expand Down Expand Up @@ -632,11 +700,15 @@ def _ragged_hstu_attn_fwd( # noqa C901
HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN,
IS_DELTA_Q=IS_DELTA_Q,
ALLOW_TF32=ALLOW_TF32,
plotfi marked this conversation as resolved.
Show resolved Hide resolved
BLOCK_D_Q=BLOCK_D_Q,
BLOCK_D_V=BLOCK_D_V,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
enable_tma=enable_tma,
)
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if not enable_tma:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))

if IS_DELTA_Q:
start_m_delta = tl.program_id(0) * BLOCK_M
Expand All @@ -648,6 +720,7 @@ def _ragged_hstu_attn_fwd( # noqa C901
+ offs_v_d[None, :]
)
out_ptrs = Out + off_o
# todo: convert out to tma
tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None])
else:
# rematerialize offsets to save registers
Expand Down Expand Up @@ -691,10 +764,43 @@ def triton_ragged_attention(
has_attn_scale = attn_scale is not None
has_max_attn_len = max_attn_len is not None

grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)
TMA_SIZE = 128
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)

def grid_tma(META):
if META['enable_tma'] == False:
return lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

nonlocal desc_k
nonlocal desc_v
k_buf = torch.empty_like(desc_k, device="cpu", pin_memory=True)
v_buf = torch.empty_like(desc_v, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
)
plotfi marked this conversation as resolved.
Show resolved Hide resolved
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1)

stride_sz = 0
stride_sm = 0
Expand All @@ -705,10 +811,12 @@ def triton_ragged_attention(
stride_sz = attn_scale.stride(0)
stride_sm = attn_scale.stride(1)

_ragged_hstu_attn_fwd[grid](
_ragged_hstu_attn_fwd[grid_tma](
Q=q,
K=k,
V=v,
desc_k=desc_k,
desc_v=desc_v,
seq_offsets=seq_offsets,
TS=None,
TW=None,
Expand Down Expand Up @@ -789,13 +897,48 @@ def triton_ragged_attention_relative_bias(
has_multiple_targets = num_targets is not None
has_max_pos_id = max_pos_ind is not None
has_max_attn_len = max_attn_len is not None
_, H, DimQ = q.shape
L, H, DimQ = q.shape
_, _, DimV = v.shape
out = torch.empty_like(v)
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

TMA_SIZE = 128
BLOCK_D_V, BLOCK_D_Q = DimV, DimQ
desc_k = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)
desc_v = torch.empty((TMA_SIZE), device="cuda", dtype=torch.int8)

def grid_tma(META):
if META['enable_tma'] == False:
return lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
)

nonlocal desc_k
nonlocal desc_v
k_buf = torch.empty_like(desc_k, device="cpu", pin_memory=True)
v_buf = torch.empty_like(desc_v, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
k.data_ptr(),
L,
H * DimQ,
META['BLOCK_N'],
BLOCK_D_Q,
k.element_size(),
k_buf.numpy()
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
v.data_ptr(),
L,
H * DimV,
META['BLOCK_N'],
BLOCK_D_V,
v.element_size(),
v_buf.numpy()
)
desc_k.copy_(k_buf, non_blocking=True)
desc_v.copy_(v_buf, non_blocking=True)
return (triton.cdiv(N, META["BLOCK_M"]), Z * H, 1)

stride_sz = 0
stride_sm = 0
if attn_scale is not None:
Expand All @@ -807,10 +950,12 @@ def triton_ragged_attention_relative_bias(
use_time_bias = relative_bias_type == "TIME" or relative_bias_type == "ALL"
use_pos_bias = relative_bias_type == "POSITION" or relative_bias_type == "ALL"

_ragged_hstu_attn_fwd[grid](
_ragged_hstu_attn_fwd[grid_tma](
Q=q,
K=k,
V=v,
desc_k=desc_k,
desc_v=desc_v,
seq_offsets=seq_offsets,
TS=timestamps,
TW=ts_weights,
Expand Down