From ada84fe4623b0fd476d28af5e0b10166415c0de4 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 26 Dec 2024 14:23:40 +0800 Subject: [PATCH] pytorch ring attention implementation (#115) --- README.md | 16 ++-- test/test_hybrid_attn.py | 22 ++++- yunchang/comm/extract_local.py | 1 + yunchang/hybrid/attn_layer.py | 2 +- yunchang/hybrid/utils.py | 2 + yunchang/kernels/__init__.py | 14 ++- yunchang/kernels/attention.py | 128 +++++++++++++++++++++++----- yunchang/ring/__init__.py | 4 + yunchang/ring/ring_flash_attn.py | 12 --- yunchang/ring/ring_pytorch_attn.py | 132 +++++++++++++++++++++++++++++ yunchang/ring/utils.py | 1 + 11 files changed, 285 insertions(+), 49 deletions(-) create mode 100644 yunchang/ring/ring_pytorch_attn.py diff --git a/README.md b/README.md index 3548a17..bd5f06d 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,8 @@ Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, wh ### 1. Installation -FlashAttention is the most important external dependency and is often the cause of errors when installing and using yunchang. Yunchang supports flash_attn 2.6.x and 2.7.x, both v3 and v2 versions. Additionally, yunchang supports using torch's SDPA for sequence parallelism without installing flash_attn. +FlashAttention is the most important external dependency and is often the cause of errors when installing and using yunchang. +Yunchang supports flash_attn 2.6.x and 2.7.x, both v3 and v2 versions. Additionally, yunchang supports runs without flash_attn, which is suitable for NPUs. As shown in the figure below, there are three usage methods based on the flash_attn situation: @@ -45,11 +46,7 @@ As shown in the figure below, there are three usage methods based on the flash_a 2. For A100, L40, hardware that supports FA v2, ring_flash_attn uses FA v2. -3. For hardware such as NPUs that does not support FA, use torch's SDPA. In this case, there is no need to install `flash_attn`, and you should apply `UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.TORCH)`. - -

- -

+3. For hardware such as NPUs that does not support FA, use torch to implement attention computation. In this case, there is no need to install `flash_attn`, and you should apply `LongContextAttention(ring_impl_type="basic", attn_type=FlashAttentionImpl.TORCH)`. Option 1: pip install @@ -99,8 +96,8 @@ set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) # attn_type could be FA, FA3, TORCH. longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.FA) -# if you use Ulysses, where no flash_attn is supported, you can use the following code. -# UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.TORCH) +# if you use NPUs, where no flash_attn is supported, you can use the following code. +# LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.TORCH) # extract a local shard for the global Q, K, V. local_q = EXTRACT_FUNC_DICT["zigzag"]( @@ -126,7 +123,8 @@ local_out = usp_attn( ### 3.Test ```bash -torchrun --nproc_per_node=4 --master_port=12346 test/test_hybrid_attn.py --sp_ulysses_degree 2 --seqlen 1024 --use_bwd +torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --use_bwd --ring_impl_type "zigzag" --causal --attn_impl fa +torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --use_bwd --ring_impl_type "zigzag" --causal --attn_impl torch torchrun --nproc_per_node 8 test/test_hybrid_qkvpacked_attn.py ``` diff --git a/test/test_hybrid_attn.py b/test/test_hybrid_attn.py index 3d03be6..2568322 100644 --- a/test/test_hybrid_attn.py +++ b/test/test_hybrid_attn.py @@ -18,6 +18,14 @@ def parse_args(): help='whether to test backward pass (default: False)') parser.add_argument('--sp_ulysses_degree', type=int, default=None, help='sp_ulysses_degree (default: world_size)') + parser.add_argument('--ring_impl_type', type=str, default='basic', + choices=['basic', 'zigzag'], + help='ring implementation type (default: basic)') + parser.add_argument('--causal', action='store_true', + help='whether to use causal attention (default: False)') + parser.add_argument('--attn_impl', type=str, default='torch', + choices=['torch', 'fa', 'fa3'], + help='attention implementation type (default: torch)') return parser.parse_args() def log(msg, a, rank0_only=False): @@ -66,7 +74,7 @@ def log(msg, a, rank0_only=False): nheads = 32 d = 1280 // 32 dropout_p = 0 - causal = True + causal = args.causal deterministic = False use_bwd = args.use_bwd @@ -74,7 +82,7 @@ def log(msg, a, rank0_only=False): assert seqlen % world_size == 0 assert d % 8 == 0 - ring_impl_type = "basic" # You can change this to "basic" or "zigzag" if needed + ring_impl_type = args.ring_impl_type # Prepare inputs q = torch.randn( @@ -125,7 +133,15 @@ def log(msg, a, rank0_only=False): local_k.requires_grad = True local_v.requires_grad = True - usp_attn = LongContextAttention(ring_impl_type=ring_impl_type, attn_type=FlashAttentionImpl.FA) + # Map argument to FlashAttentionImpl enum + attn_impl_map = { + 'torch': FlashAttentionImpl.TORCH, + 'fa': FlashAttentionImpl.FA, + 'fa3': FlashAttentionImpl.FA3 + } + + usp_attn = LongContextAttention(ring_impl_type=ring_impl_type, + attn_type=attn_impl_map[args.attn_impl]) if rank == 0: print("#" * 30) diff --git a/yunchang/comm/extract_local.py b/yunchang/comm/extract_local.py index e4e8c9b..4396649 100644 --- a/yunchang/comm/extract_local.py +++ b/yunchang/comm/extract_local.py @@ -54,4 +54,5 @@ def zigzag_extract_local(value, rank, world_size, rd, ud, dim=1, *args, **kwargs "basic": basic_extract_local, "strip": stripe_extract_local, "zigzag": zigzag_extract_local, + "basic_pytorch": basic_extract_local, } diff --git a/yunchang/hybrid/attn_layer.py b/yunchang/hybrid/attn_layer.py index 68b2683..7cfcfa4 100644 --- a/yunchang/hybrid/attn_layer.py +++ b/yunchang/hybrid/attn_layer.py @@ -108,7 +108,7 @@ def forward( value_layer = SeqAllToAll4D.apply( self.ulysses_pg, value, self.scatter_idx, self.gather_idx, self.use_sync ) - + out = self.ring_attn_fn( query_layer, key_layer, diff --git a/yunchang/hybrid/utils.py b/yunchang/hybrid/utils.py index 52fb745..1c09aab 100644 --- a/yunchang/hybrid/utils.py +++ b/yunchang/hybrid/utils.py @@ -5,12 +5,14 @@ zigzag_ring_flash_attn_qkvpacked_func, stripe_flash_attn_func, stripe_flash_attn_qkvpacked_func, + ring_pytorch_attn_func, ) RING_IMPL_DICT = { "basic": ring_flash_attn_func, "zigzag": zigzag_ring_flash_attn_func, "strip": stripe_flash_attn_func, + "basic_pytorch": ring_pytorch_attn_func, } RING_IMPL_QKVPACKED_DICT = { diff --git a/yunchang/kernels/__init__.py b/yunchang/kernels/__init__.py index 871addd..4c40a82 100644 --- a/yunchang/kernels/__init__.py +++ b/yunchang/kernels/__init__.py @@ -3,7 +3,8 @@ flash_attn_backward, flash_attn3_func_forward, flash_attn3_func_backward, - torch_attn, + pytorch_attn_forward, + pytorch_attn_backward, HAS_FLASH_ATTN_HOPPER ) from enum import Enum, auto @@ -64,10 +65,15 @@ def fn(q, raise ValueError(f"Unknown stage: {stage}") elif impl_type == FlashAttentionImpl.TORCH: - if stage == "fwd-bwd" or stage == "fwd-only": - return torch_attn + if stage == "fwd-only": + return pytorch_attn_forward + elif stage == "bwd-only": + return pytorch_attn_backward + elif stage == "fwd-bwd": + from yunchang.ring.ring_pytorch_attn import pytorch_attn_func + return pytorch_attn_func else: - raise ValueError(f"FlashAttentionImpl.TORCH: bwd-only is not supported") + raise ValueError(f"Unknown stage: {stage}") else: raise ValueError(f"Unknown flash attention implementation: {impl_type}") diff --git a/yunchang/kernels/attention.py b/yunchang/kernels/attention.py index f8c9e95..7527302 100644 --- a/yunchang/kernels/attention.py +++ b/yunchang/kernels/attention.py @@ -1,5 +1,7 @@ +from typing import Optional, Tuple from yunchang.globals import HAS_FLASH_ATTN, HAS_FLASH_ATTN_HOPPER - +import math +import torch if HAS_FLASH_ATTN: import flash_attn from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward @@ -16,26 +18,112 @@ import torch.nn.functional as F -def torch_attn(q, k, v, dropout_p = 0.0, - softmax_scale = None, - causal=False, - *args, **kwargs): - batch_size, seq_len, hs, hd = q.size() - query = q.view(batch_size, -1, hs, hd).transpose(1, 2) - key = k.view(batch_size, -1, hs, hd).transpose(1, 2) - value = v.view(batch_size, -1, hs, hd).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=dropout_p, is_causal=causal - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, hs, hd - ) - hidden_states = hidden_states.to(query.dtype) - return hidden_states, +import torch +aten = torch.ops.aten + + +def pytorch_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), + softcap=None, + alibi_slopes=None, + return_softmax=False, +): + """ + q shape (bs, seqlen, nhead, hs) + k shape (bs, seqlen, nhead, hs) + v shape (bs, seqlen, nhead, hs) + """ + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out, lse = aten._scaled_dot_product_efficient_attention( + q, + k, + v, + attn_bias=None, + compute_log_sumexp=True, + dropout_p=dropout_p, + is_causal=causal, + scale=softmax_scale, + )[:2] + out = out.transpose(1, 2) + lse = lse.to(q.dtype) + return out, lse + +def pytorch_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer=None, # Add new parameters with default values + block_dk_buffer=None, + block_dv_buffer=None, + dropout_p=0.0, + softmax_scale=None, + bwd_causal=None, # This will replace the original causal parameter + window_size=None, + softcap=None, + alibi_slopes=None, + deterministic=True, + rng_state=None, + *args, + **kwargs, +): + # TODO(optim): use pytorch _scaled_dot_product_efficient_attention_backward + # Use efficient attention backward + # https://github.com/pytorch/pytorch/blob/main/tools/autograd/derivatives.yaml#L2874 + + # preprocess to reuse the original code + # from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = out.transpose(1, 2) + dout = dout.transpose(1, 2) + + batch_size, nheads, seqlen, d = q.shape + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + # Recreate S and P from log_sum_exp + S = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale + if bwd_causal: + causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=q.device, dtype=torch.bool), diagonal=1) + S = S.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), float('-inf')) + + P = torch.exp(S - softmax_lse.unsqueeze(-1)) + # Step 1: Compute dV + dV = torch.matmul(P.transpose(-2, -1), dout) + # Step 2: Compute dP + dP = torch.matmul(dout, v.transpose(-2, -1)) + # Step 3: Compute D + D = torch.sum(dout * out, dim=-1, keepdim=True) + # Step 4: Compute dS + dS = P * (dP - D) + # Apply causal mask to dS if is_causal is True + if bwd_causal: + dS = dS.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), 0) + # Step 5: Compute dQ + dQ = torch.matmul(dS, k) * softmax_scale + # Step 6: Compute dK + dK = torch.matmul(dS.transpose(-2, -1), q) * softmax_scale + + # TODO() post process to reuse origina; code + dQ = dQ.transpose(1, 2) + dK = dK.transpose(1, 2) + dV = dV.transpose(1, 2) + + return dQ, dK, dV def flash_attn_forward(q, k, v, dropout_p = 0.0, diff --git a/yunchang/ring/__init__.py b/yunchang/ring/__init__.py index ffe7608..d9c7643 100644 --- a/yunchang/ring/__init__.py +++ b/yunchang/ring/__init__.py @@ -23,3 +23,7 @@ stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func, ) + +from .ring_pytorch_attn import ( + ring_pytorch_attn_func, +) diff --git a/yunchang/ring/ring_flash_attn.py b/yunchang/ring/ring_flash_attn.py index 7f9ccec..bfd375c 100644 --- a/yunchang/ring/ring_flash_attn.py +++ b/yunchang/ring/ring_flash_attn.py @@ -32,18 +32,6 @@ def ring_flash_attn_forward( comm.commit() if not causal or step <= comm.rank: - # block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - # q, - # k, - # v, - # dropout_p, - # softmax_scale, - # causal=causal and step == 0, - # window_size=window_size, - # softcap=softcap, - # alibi_slopes=alibi_slopes, - # return_softmax=True and dropout_p > 0, - # ) fn = select_flash_attn_impl(attn_type, stage="fwd-only") block_out, block_lse = fn( q, diff --git a/yunchang/ring/ring_pytorch_attn.py b/yunchang/ring/ring_pytorch_attn.py new file mode 100644 index 0000000..9464b08 --- /dev/null +++ b/yunchang/ring/ring_pytorch_attn.py @@ -0,0 +1,132 @@ +# adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py +# Copyright 2024 The HuggingFace Inc. team and Jiarui Fang. + +import math +import torch +import torch.nn.functional as F +from typing import Any, Optional, Tuple +from yunchang.kernels import select_flash_attn_impl, FlashAttentionImpl +from .utils import RingComm, update_out_and_lse +from yunchang.kernels.attention import pytorch_attn_forward, pytorch_attn_backward + +def ring_pytorch_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: FlashAttentionImpl = FlashAttentionImpl.FA, +): + return RingAttentionFunc.apply(group, q, k, v, softmax_scale, causal) + +class RingAttentionFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, group, q, k, v, sm_scale, is_causal): + + comm = RingComm(group) + #TODO(fmom): add flex attention + #TODO(fmom): add flash attention + #TODO(fmom): Find a better to save these tensors without cloning + k_og = k.clone() + v_og = v.clone() + out, lse = None, None + next_k, next_v = None, None + + if sm_scale is None: + sm_scale = q.shape[-1] ** -0.5 + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k = comm.send_recv(k) + next_v = comm.send_recv(v) + comm.commit() + + if not is_causal or step <= comm.rank: + block_out, block_lse = pytorch_attn_forward( + q, k, v, softmax_scale = sm_scale, causal = is_causal and step == 0 + ) + print(f"block_out {block_out.shape} block_lse {block_lse.shape}") + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + + ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1)) + ctx.sm_scale = sm_scale + ctx.is_causal = is_causal + ctx.group = group + + return out + + @staticmethod + def backward(ctx, dout, *args): + + + q, k, v, out, softmax_lse = ctx.saved_tensors + sm_scale = ctx.sm_scale + is_causal = ctx.is_causal + + kv_comm = RingComm(ctx.group) + d_kv_comm = RingComm(ctx.group) + + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step <= kv_comm.rank or not is_causal: + bwd_causal = is_causal and step == 0 + + block_dq_buffer, block_dk_buffer, block_dv_buffer = pytorch_attn_backward( + dout, q, k, v, out, softmax_lse = softmax_lse, softmax_scale = sm_scale, causal = bwd_causal + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq, next_dk, next_dv, None, None diff --git a/yunchang/ring/utils.py b/yunchang/ring/utils.py index 8eb617c..f8fc446 100644 --- a/yunchang/ring/utils.py +++ b/yunchang/ring/utils.py @@ -92,6 +92,7 @@ def send_recv( ) -> torch.Tensor: if recv_tensor is None: res = torch.empty_like(to_send) + # print(f"send_recv: empty_like {to_send.shape}") else: res = recv_tensor