From f57284a93291064aa465e29b07d29f45473dfde3 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 21 Aug 2024 23:30:49 -0700 Subject: [PATCH] [Triton SMEM] Add not-yet-landed usage of Triton SMEM feature with autotuning The following change requires a private patchset that is not yet available outside of https://github.com/plotfi/triton/pull/4 This patch adds usage of shared memory using the tl.local_copy and tl.gather operations for the TW (time bias) and PW (position bias) tensors for the forward pass kernel. Autotuning is also hooked up to the usage of these shared memory operators --- ops/triton/triton_ragged_hstu_attention.py | 263 ++++++++++++++++++--- 1 file changed, 226 insertions(+), 37 deletions(-) diff --git a/ops/triton/triton_ragged_hstu_attention.py b/ops/triton/triton_ragged_hstu_attention.py index a10ae7a..461aa9d 100644 --- a/ops/triton/triton_ragged_hstu_attention.py +++ b/ops/triton/triton_ragged_hstu_attention.py @@ -58,148 +58,296 @@ def _get_fw_configs() -> List[triton.Config]: # noqa: C901 ) else: configs = [ + ## Comment these out to get the baseline: triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 32}, + {"BLOCK_M": 16, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=2, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32}, + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 64}, + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 64}, + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 64}, + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 128}, + {"BLOCK_M": 32, "BLOCK_N": 128, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 128}, + {"BLOCK_M": 32, "BLOCK_N": 128, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=2, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32}, + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64}, + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=2, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=2, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32}, + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=2, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64}, + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=8, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, + {"BLOCK_M": 128, "BLOCK_N": 128, "enable_tw_preload": True, "enable_pw_preload": True}, num_stages=4, num_warps=4, ), triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128}, + {"BLOCK_M": 128, "BLOCK_N": 128, "enable_tw_preload": True, "enable_pw_preload": True}, + num_stages=2, + num_warps=8, + ), + + ## Keep these to get the baseline: + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "enable_tw_preload": False, "enable_pw_preload": False}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "enable_tw_preload": False, "enable_pw_preload": False}, num_stages=2, num_warps=8, ), @@ -223,6 +371,8 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 ts_0, TW, PW, + TW_PRELOAD, + PW_PRELOAD, alpha, MAX_SEQ_LEN, num_buckets, @@ -247,6 +397,8 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 ALLOW_TF32: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + enable_tw_preload: tl.constexpr, + enable_pw_preload: tl.constexpr, ): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- @@ -302,10 +454,17 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 ts = ts.to(tl.int32) ts = tl.where(ts > 0, ts, 0) ts = tl.where(ts < num_buckets, ts, num_buckets) - ts_w = tl.load( - TW + ts, - mask=mask_m[:, None] and mask_n[None, :], - ) + if enable_tw_preload: + ts_w = tl.gather( + TW_PRELOAD, ts, + # There is a flaky test that sometimes fails when mask is provided + # mask=mask_m[:, None] and mask_n[None, :], + ) + else: + ts_w = tl.load( + TW + ts, + mask=mask_m[:, None] and mask_n[None, :], + ) attn_bias = attn_bias + ts_w if USE_POS_BIAS: if HAS_MAX_POS_IND: @@ -318,10 +477,17 @@ def _ragged_hstu_attn_fwd_one_block( # noqa: C901 ) else: offs_pos_w = offs_n_minus_m + MAX_SEQ_LEN - 1 - pos_w = tl.load( - PW + offs_pos_w, - mask=mask_m[:, None] and mask_n[None, :], - ) + if enable_pw_preload: + pos_w = tl.gather( + PW_PRELOAD, offs_pos_w, + # There is a flaky test that sometimes fails when mask is provided + # mask=mask_m[:, None] and mask_n[None, :], + ) + else: + pos_w = tl.load( + PW + offs_pos_w, + mask=mask_m[:, None] and mask_n[None, :], + ) attn_bias = attn_bias + pos_w qk = qk + attn_bias elif ATTN_BIAS_TYPE == "separate": @@ -409,6 +575,8 @@ def _ragged_hstu_attn_fwd( # noqa C901 BLOCK_N: tl.constexpr, max_attn_len, HAS_MAX_ATTN_LEN: tl.constexpr, + enable_tw_preload: tl.constexpr, + enable_pw_preload: tl.constexpr, ): # M_CTX == N_CTX off_hz = tl.program_id(1) @@ -517,6 +685,19 @@ def _ragged_hstu_attn_fwd( # noqa C901 K_block_ptr = tl.advance(K_block_ptr, (0, low)) # pyre-ignore[61] V_block_ptr = tl.advance(V_block_ptr, (low, 0)) + + TW_PRELOAD = None + if USE_TIME_BIAS and enable_tw_preload: + tw_bucket_range = tl.arange(0, 2048) + TW_PRELOAD = tl.load(TW + tw_bucket_range) + TW_PRELOAD = tl.local_copy(TW_PRELOAD) + + PW_PRELOAD = None + if USE_POS_BIAS and enable_pw_preload: + pw_bucket_range = tl.arange(0, 4096) + PW_PRELOAD = tl.load(PW + pw_bucket_range) + PW_PRELOAD = tl.local_copy(PW_PRELOAD) + # pyre-ignore[61] for start_n in range(low, high, BLOCK_N): cur_offs_n = offs_n + start_n @@ -542,6 +723,8 @@ def _ragged_hstu_attn_fwd( # noqa C901 ts_0=ts_0 if ATTN_BIAS_TYPE == "fused" and USE_TIME_BIAS else None, TW=TW, PW=PW, + TW_PRELOAD=TW_PRELOAD, + PW_PRELOAD=PW_PRELOAD, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, num_buckets=num_buckets, @@ -568,6 +751,8 @@ def _ragged_hstu_attn_fwd( # noqa C901 ALLOW_TF32=ALLOW_TF32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + enable_tw_preload=enable_tw_preload, + enable_pw_preload=enable_pw_preload, ) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -608,6 +793,8 @@ def _ragged_hstu_attn_fwd( # noqa C901 ), TW=TW, PW=PW, + TW_PRELOAD=TW_PRELOAD, + PW_PRELOAD=PW_PRELOAD, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, num_buckets=num_buckets, @@ -634,6 +821,8 @@ def _ragged_hstu_attn_fwd( # noqa C901 ALLOW_TF32=ALLOW_TF32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + enable_tw_preload=enable_tw_preload, + enable_pw_preload=enable_pw_preload, ) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))