diff --git a/example/layers/flash_triton.py b/example/layers/flash_triton.py deleted file mode 100644 index 1d687378..00000000 --- a/example/layers/flash_triton.py +++ /dev/null @@ -1,830 +0,0 @@ -""" -*Experimental* implementation of FlashAttention in Triton. - -We use the FlashAttention implementation from Phil Tillet a starting point. -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -Changes: -- Implement both causal and non-causal attention. -- Implement both self-attention and cross-attention. -- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. -- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. -- Support attention bias. -- Speed up the forward pass a bit, and only store the LSE instead of m and l. -- Make the backward for d=128 much faster by reducing register spilling. -- Optionally parallelize the backward pass across seqlen_k, to deal with the case of -small batch size * nheads. - -Caution: -- This is an *experimental* implementation. The forward pass should be quite robust but -I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). -- This implementation has only been tested on A100. -- If you plan to use headdim other than 64 and 128, you should test for race conditions -(due to the Triton compiler), as done in tests/test_flash_attn.py -"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions -for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident -that there are none left for other head dimensions. - -Differences between this Triton version and the CUDA version: -- Triton version doesn't support dropout. -- Triton forward is generally faster than CUDA forward, while Triton backward is -generally slower than CUDA backward. Overall Triton forward + backward is slightly slower -than CUDA forward + backward. -- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). -- Triton version supports attention bias, while CUDA version doesn't. -""" - -import math - -import torch - -import triton -import triton.language as tl - - -# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), -# # This config has a race condition when EVEN_M == False, disabling it for now. -# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), -# ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] -# ) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _fwd_kernel( - Q, K, V, Bias, Out, - Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - softmax_scale, - stride_qb, stride_qh, stride_qm, - stride_kb, stride_kh, stride_kn, - stride_vb, stride_vh, stride_vn, - stride_bb, stride_bh, stride_bm, - stride_ob, stride_oh, stride_om, - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, - CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # off_b = tl.program_id(1) - # off_h = tl.program_id(2) - # off_hb = off_b * nheads + off_h - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # Initialize pointers to Q, K, V - # Adding parenthesis around indexing might use int32 math instead of int64 math? - # https://github.com/openai/triton/issues/741 - # I'm seeing a tiny bit of difference (5-7us) - q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) - k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) - v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) - if BIAS_TYPE == 'vector': - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n - elif BIAS_TYPE == 'matrix': - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) - # initialize pointer to m and l - t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - # load q: it will stay in SRAM throughout - # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call - # tl.load(q_ptrs), we get the wrong output! - if EVEN_M & EVEN_N: - if EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) - else: - q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0) - # loop over k, v and update accumulator - end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0) - else: - k = tl.load(k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - if BIAS_TYPE != 'none': - if BIAS_TYPE == 'vector': - if EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == 'matrix': - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load(b_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) - & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0).to(tl.float32) - # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler - # can then fuse the mult and add into an fma instruction. But if we have bias we need to - # to multiply with softmax_scale here. - qk = qk * softmax_scale + bias - m_ij = tl.maximum(tl.max(qk, 1), lse_i) - - m_ij = tl.where(m_ij==float("-inf"),0,m_ij) - p = tl.exp(qk - m_ij[:, None]) - else: - m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) - p = tl.exp(qk * softmax_scale - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # p = tl.where(p==float("-inf"), 0, p) - # l_ij = tl.maximum(tl.sum(p, 1),-1e16) - # scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - # mask_sum = tl.sum(bias == float("-inf"), axis=1) == BLOCK_M - # acc_o_scale = tl.where(mask_sum, 0, acc_o_scale) - # # -- update output accumulator -- - # BUG: have to store and immediately load - tl.store(t_ptrs, acc_o_scale) - acc_o_scale = tl.load(t_ptrs) - acc_o = acc_o * acc_o_scale[:, None] - # update acc_o - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0) - else: - v = tl.load(v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - p = p.to(v.dtype) - acc_o += tl.dot(p, v) - - # -- update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) - lse_i = tl.where(lse_i == float("-inf"), 0, lse_i) - o_scale = tl.exp(m_i - lse_i) - # BUG: have to store and immediately load - tl.store(t_ptrs, o_scale) - o_scale = tl.load(t_ptrs) - acc_o = acc_o * o_scale[:, None] - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m - lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m - tl.store(lse_ptrs, lse_i) - # initialize pointers to output - offs_d = tl.arange(0, BLOCK_HEADDIM) - out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) - if EVEN_M: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o) - else: - tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) - else: - tl.store(out_ptrs, acc_o, - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) - - -@triton.jit -def _bwd_preprocess_do_o_dot( - Out, DO, Delta, - stride_ob, stride_oh, stride_om, - stride_dob, stride_doh, stride_dom, - nheads, seqlen_q, seqlen_q_rounded, headdim, - BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # load - o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) - do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) - - -@triton.jit -def _bwd_store_dk_dv( - dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, -): - # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.store(dv_ptrs), there's a race condition - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - - -@triton.jit -def _bwd_kernel_one_col_block( - start_n, - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qm, stride_kn, stride_vn, stride_bm, - stride_dom, stride_dqm, stride_dkn, stride_dvn, - seqlen_q, seqlen_k, headdim, - ATOMIC_ADD: tl.constexpr, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, -): - # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) - begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M - # initialize row/col offsets - offs_qm = begin_m + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) - do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) - dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) - if BIAS_TYPE == 'vector': - b_ptrs = Bias + offs_n - elif BIAS_TYPE == 'matrix': - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) - # initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - # There seems to be some problem with Triton pipelining that makes results wrong for - # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop - # may have zero step, and pipelining with the bias matrix could screw it up. - # So we just exit early. - if begin_m >= seqlen_q: - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) - return - # k and v stay in SRAM throughout - # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.load(k_ptrs), we get the wrong output! - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - else: - k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - else: - k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0) - # loop over rows - num_block_m = tl.cdiv(seqlen_q, BLOCK_M) - for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_d[None, :] < headdim), other=0.0) - # recompute p = softmax(qk, dim=-1).T - qk = tl.dot(q, k, trans_b=True) - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) - if IS_CAUSAL: - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - if BIAS_TYPE != 'none': - tl.debug_barrier() # Race condition otherwise - if BIAS_TYPE == 'vector': - if EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == 'matrix': - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load(b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_n[None, :] < seqlen_k), - other=0.0).to(tl.float32) - qk = qk * softmax_scale + bias - # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. - # Also wrong for headdim=64. - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - lse_i = tl.load(LSE + offs_m_curr) - if BIAS_TYPE == 'none': - p = tl.exp(qk * softmax_scale - lse_i[:, None]) - else: - p = tl.exp(qk - lse_i[:, None]) - # compute dv - # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs - # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, - # the output is correct. - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. - do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_d[None, :] < headdim), other=0.0) - # if EVEN_M: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs) - # else: - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - # else: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - # else: - # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - # & (offs_d[None, :] < headdim), other=0.0) - dv += tl.dot(p.to(do.dtype), do, trans_a=True) - # compute dp = dot(v, do) - # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. - # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True - # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - dp = tl.dot(do, v, trans_b=True) - # There's a race condition for headdim=48 - if not EVEN_HEADDIM: - tl.debug_barrier() - # compute ds = p * (dp - delta[:, None]) - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) - # compute dk = dot(ds.T, q) - dk += tl.dot(ds, q, trans_a=True) - # compute dq - if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' - tl.debug_barrier() - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - else: - if EVEN_HEADDIM: - dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, - eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last") - else: - dq = tl.load(dq_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last") - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) - else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) - else: - tl.atomic_add(dq_ptrs, dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) - # increment pointers - dq_ptrs += BLOCK_M * stride_dqm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_dom - if BIAS_TYPE == 'matrix': - b_ptrs += BLOCK_M * stride_bm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now - # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - ], - key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _bwd_kernel( - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qb, stride_qh, stride_qm, - stride_kb, stride_kh, stride_kn, - stride_vb, stride_vh, stride_vn, - stride_bb, stride_bh, stride_bm, - stride_dob, stride_doh, stride_dom, - stride_dqb, stride_dqh, stride_dqm, - stride_dkb, stride_dkh, stride_dkn, - stride_dvb, stride_dvh, stride_dvn, - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, - CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, -): - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # offset pointers for batch/head - Q += off_b * stride_qb + off_h * stride_qh - K += off_b * stride_kb + off_h * stride_kh - V += off_b * stride_vb + off_h * stride_vh - DO += off_b * stride_dob + off_h * stride_doh - DQ += off_b * stride_dqb + off_h * stride_dqh - DK += off_b * stride_dkb + off_h * stride_dkh - DV += off_b * stride_dvb + off_h * stride_dvh - if BIAS_TYPE != 'none': - Bias += off_b * stride_bb + off_h * stride_bh - # pointer to row-wise quantities in value-like data - D += off_hb * seqlen_q_rounded - LSE += off_hb * seqlen_q_rounded - if not SEQUENCE_PARALLEL: - num_block_n = tl.cdiv(seqlen_k, BLOCK_N) - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - start_n, - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qm, stride_kn, stride_vn, stride_bm, - stride_dom, stride_dqm, stride_dkn, stride_dvn, - seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=False, - BIAS_TYPE=BIAS_TYPE, - IS_CAUSAL=IS_CAUSAL, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - else: - start_n = tl.program_id(0) - _bwd_kernel_one_col_block( - start_n, - Q, K, V, Bias, - DO, DQ, DK, DV, - LSE, D, - softmax_scale, - stride_qm, stride_kn, stride_vn, stride_bm, - stride_dom, stride_dqm, stride_dkn, stride_dvn, - seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=True, - BIAS_TYPE=BIAS_TYPE, - IS_CAUSAL=IS_CAUSAL, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N - ) - - -def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): - # shape constraints - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, 'FlashAttention only support head dimensions up to 128' - assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' - assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' - assert q.is_cuda and k.is_cuda and v.is_cuda - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - - has_bias = bias is not None - bias_type = 'none' - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - if bias.stride(-1) != 1: - bias = bias.contiguous() - if bias.shape[2:] == (1, seqlen_k): - bias_type = 'vector' - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = 'matrix' - else: - raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' - ' or (seqlen_q, seqlen_k)') - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK = 128 - num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _fwd_kernel[grid]( - q, k, v, bias, o, - lse, tmp, - softmax_scale, - q.stride(0), q.stride(2), q.stride(1), - k.stride(0), k.stride(2), k.stride(1), - v.stride(0), v.stride(2), v.stride(1), - *bias_strides, - o.stride(0), o.stride(2), o.stride(1), - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, - seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, causal, BLOCK_HEADDIM, - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return o, lse, softmax_scale # softmax_scale could have been updated - - -def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): - # Make sure that the last dimension is contiguous - if do.stride(-1) != 1: - do = do.contiguous() - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - # assert d in {16, 32, 64, 128} - assert d <= 128 - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - assert lse.shape == (batch, nheads, seqlen_q_rounded) - assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 - assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - # dq_accum = torch.zeros_like(q, dtype=torch.float32) - dq_accum = torch.empty_like(q, dtype=torch.float32) - delta = torch.empty_like(lse) - # delta = torch.zeros_like(lse) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _bwd_preprocess_do_o_dot[grid]( - o, do, delta, - o.stride(0), o.stride(2), o.stride(1), - do.stride(0), do.stride(2), do.stride(1), - nheads, seqlen_q, seqlen_q_rounded, d, - BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, - ) - - has_bias = bias is not None - bias_type = 'none' - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.stride(-1) == 1 - if bias.shape[2:] == (1, seqlen_k): - bias_type = 'vector' - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = 'matrix' - else: - raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' - ' or (seqlen_q, seqlen_k)') - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - # BLOCK_M = 128 - # BLOCK_N = 64 - # num_warps = 4 - grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads) - _bwd_kernel[grid]( - q, k, v, bias, - do, dq_accum, dk, dv, - lse, delta, - softmax_scale, - q.stride(0), q.stride(2), q.stride(1), - k.stride(0), k.stride(2), k.stride(1), - v.stride(0), v.stride(2), v.stride(1), - *bias_strides, - do.stride(0), do.stride(2), do.stride(1), - dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), - dk.stride(0), dk.stride(2), dk.stride(1), - dv.stride(0), dv.stride(2), dv.stride(1), - nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, - seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, causal, BLOCK_HEADDIM, - # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - dq.copy_(dq_accum) - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): - """ - qkv: (batch, seqlen, 3, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). - ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) - """ - # Make sure that the last dimension is contiguous - if qkv.stride(-1) != 1: - qkv = qkv.contiguous() - o, lse, ctx.softmax_scale = _flash_attn_forward( - qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, - softmax_scale=softmax_scale - ) - ctx.save_for_backward(qkv, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - qkv, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dqkv = torch.empty_like(qkv) - _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, - dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], - bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) - return dqkv, None, None, None - - -flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): - """ - q: (batch, seqlen_q, nheads, headdim) - kv: (batch, seqlen_k, 2, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - # Make sure that the last dimension is contiguous - q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] - o, lse, ctx.softmax_scale = _flash_attn_forward( - q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale - ) - ctx.save_for_backward(q, kv, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, kv, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, - dq, dkv[:, :, 0], dkv[:, :, 1], - bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) - return dq, dkv, None, None, None - - -flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply - - -class FlashAttnFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): - """ - q: (batch_size, seqlen_q, nheads, headdim) - k, v: (batch_size, seqlen_k, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - # Make sure that the last dimension is contiguous - q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] - o, lse, ctx.softmax_scale = _flash_attn_forward( - q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale - ) - ctx.save_for_backward(q, k, v, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, - bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) - return dq, dk, dv, None, None, None - - -flash_attn_func = FlashAttnFunc.apply diff --git a/example/layers/test_linear.py b/example/layers/test_linear.py deleted file mode 100644 index 27568e12..00000000 --- a/example/layers/test_linear.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import torch.nn.functional as F -import bmtrain as bmt -from bmtrain.global_var import config -from . import TransformerEncoder - - -gb = 1024.0 * 1024.0 * 1024.0 - -class CustomLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, bias=None): - ctx.save_for_backward(x, weight, bias) - return F.linear(x, weight, bias) - - @staticmethod - def backward(ctx, grad_output): - x, weight, bias = ctx.saved_tensors - grad_x = grad_weight = grad_bias = None - if x.requires_grad: - grad_x = grad_output.matmul(weight) - if weight.requires_grad: - dim = grad_output.dim() - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) - if bias is not None and bias.requires_grad: - grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) - return grad_x, grad_weight, grad_bias - - -class LinearFunctionForZeroStage3(torch.autograd.Function): - # Note that both forward and backward are @staticmethods - @staticmethod - #@autocast_custom_fwd - # bias is an optional argument - def forward(ctx, input, weight, bias=None): - - ctx.save_for_backward(input, weight, bias) - - if input.dim() == 2 and bias is not None: - # fused op is marginally faster - ret = torch.addmm(bias, input, weight.t()) - else: - output = input.matmul(weight.t()) - if bias is not None: - output += bias - ret = output - - return ret - - # This function has only a single output, so it gets only one gradient - @staticmethod - #@autocast_custom_bwd - def backward(ctx, grad_output): - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - input, weight, bias = ctx.saved_tensors - - grad_input = grad_weight = grad_bias = None - - #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - if ctx.needs_input_grad[0]: - #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") - grad_input = grad_output.matmul(weight) - #print(f"Computed grad input {grad_input.shape}") - if ctx.needs_input_grad[1]: - #print("Computing grad weight") - dim = grad_output.dim() - if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) - else: - grad_weight = grad_output.t().matmul(input) - #print(f"Computed grad weight grad_weight {grad_weight.shape}") - if bias is not None and ctx.needs_input_grad[2]: - #print("Computing grad bias") - grad_bias = grad_output.sum(0) - #print("Done computing grad bias") - #print("needs bias") - #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") - return grad_input, grad_weight, grad_bias - - -class Linear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = False, dtype = torch.float16) -> None: - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_) - if bias: - self.bias = bmt.DistributedParameter(torch.empty((1, out_features), dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_) - else: - self.register_parameter('bias', None) - - def forward(self, input): - #return CustomLinear.apply(input, self.weight, self.bias) - return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias) - - def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, self.out_features, self.bias is not None - ) - -class Feedforward(bmt.DistributedModule): - def __init__(self, dim_model : int, dim_ff : int, bias : bool = False, dtype = torch.float16) -> None: - super().__init__() - - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) - self.gate = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - - self.relu = torch.nn.ReLU() - - def forward(self, input : torch.Tensor) -> torch.Tensor: - gate_out = self.relu(self.gate(input)) - return self.w_out(self.w_in(input) * gate_out) - -bmt.init_distributed(zero_level=2) - -linears = [] -for i in range(10): - linears.append(bmt.CheckpointBlock(TransformerEncoder(8192, 20480), use_checkpoint=False)) - -linears = bmt.TransformerBlockList(linears) - -device = torch.device('cuda') -bmt.synchronize() -if config['rank'] == 0: - print('before forward', torch.cuda.memory_allocated(device) / gb) - -x = torch.randn(4096, 8192, dtype=torch.float16, device=device).requires_grad_() -bmt.synchronize() -if config['rank'] == 0: - print('init input', torch.cuda.memory_allocated(device) / gb) - -y = linears(x) -bmt.synchronize() -if config['rank'] == 0: - print('after forward', torch.cuda.memory_allocated(device) / gb) - -y.sum().backward() -bmt.synchronize() -if config['rank'] == 0: - print('after backward', torch.cuda.memory_allocated(device) / gb) diff --git a/example/test_attn.py b/example/test_attn.py deleted file mode 100644 index 642f2a8b..00000000 --- a/example/test_attn.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn.functional as F -import bmtrain as bmt -from bmtrain.global_var import config -from layers import Attention - - -gb = 1024.0 * 1024.0 * 1024.0 - -bmt.init_distributed(zero_level=3) - -linears = [] -for i in range(10), : - linears.append(bmt.CheckpointBlock(Attention( - dim_model=8192, - dim_head=128, - num_heads=64, - dropout_p=0.0, - use_flash_attn=True, - dtype=torch.half - ), - use_checkpoint=False) - ) - -linears = bmt.TransformerBlockList(linears) - -device = torch.device('cuda') -bmt.synchronize() -if config['rank'] == 0: - print('before forward', torch.cuda.memory_allocated(device) / gb) -batch_size=1 -seq_len=4096 -x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() -bmt.synchronize() -if config['rank'] == 0: - print('init input', torch.cuda.memory_allocated(device) / gb) -enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() -mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) -mask = mask.unsqueeze(0).unsqueeze(0) -print(mask.shape) -y = linears(x,x,mask) -bmt.synchronize() -if config['rank'] == 0: - print('after forward', torch.cuda.memory_allocated(device) / gb) - -y.sum().backward() -bmt.synchronize() -if config['rank'] == 0: - print('after backward', torch.cuda.memory_allocated(device) / gb) diff --git a/example/test_block.py b/example/test_block.py deleted file mode 100644 index 90f57182..00000000 --- a/example/test_block.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch.nn.functional as F -import bmtrain as bmt -from bmtrain.global_var import config -from layers import TransformerEncoder - - -gb = 1024.0 * 1024.0 * 1024.0 -def reserved(device): - return torch.cuda.memory_reserved(device) / gb -def allocated(device): - return torch.cuda.memory_allocated(device) / gb -def max_allocated(device): - return torch.cuda.max_memory_allocated(device) / gb - -bmt.init_distributed(zero_level=3) - -linears = [] -for i in range(10), : - linears.append(TransformerEncoder( - dim_model=8192, - dim_head=128, - num_heads=64, - dim_ff=20480, - bias=False, - dtype=torch.half - ) - ) - -linears = bmt.TransformerBlockList(linears) -# linears = torch.nn.ModuleList(linears) - -optimizer = bmt.optim.AdamOffloadOptimizer(linears.parameters(), weight_decay=1e-2) -lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) - -optim_manager = bmt.optim.OptimManager(loss_scale=2**20) -optim_manager.add_optimizer(optimizer, lr_scheduler) - -bmt.synchronize() - -device = torch.device('cuda') -bmt.synchronize() -if config['rank'] == 0: - print('before init input', allocated(device), reserved(device)) -batch_size=1 -seq_len=4096 - -for i in range(4): - x = torch.randn(batch_size, seq_len, 8192, dtype=torch.float16, device=device).requires_grad_() - bmt.synchronize() - if config['rank'] == 0: - print('init input', allocated(device), reserved(device)) - enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() - mask = torch.arange(seq_len).unsqueeze(0) <= torch.arange(seq_len).unsqueeze(1) - mask = mask.unsqueeze(0).unsqueeze(0).to(device) -#y = linears(x,mask) - y = x - for encoder in linears: - y = encoder(y, mask) - bmt.synchronize() - if config['rank'] == 0: - print('after forward', allocated(device), reserved(device), max_allocated(device)) - - y.sum().backward() - bmt.synchronize() - if config['rank'] == 0: - print('after backward', allocated(device), reserved(device), max_allocated(device)) - optim_manager.step() - if config['rank'] == 0: - print('after optimizer', allocated(device), reserved(device)) -#torch.cuda.empty_cache() - optim_manager.zero_grad()