Skip to content

Commit

Permalink
pytorch ring attention implementation (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Dec 26, 2024
1 parent 28f8c97 commit ada84fe
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 49 deletions.
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,16 @@ Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, wh

### 1. Installation

FlashAttention is the most important external dependency and is often the cause of errors when installing and using yunchang. Yunchang supports flash_attn 2.6.x and 2.7.x, both v3 and v2 versions. Additionally, yunchang supports using torch's SDPA for sequence parallelism without installing flash_attn.
FlashAttention is the most important external dependency and is often the cause of errors when installing and using yunchang.
Yunchang supports flash_attn 2.6.x and 2.7.x, both v3 and v2 versions. Additionally, yunchang supports runs without flash_attn, which is suitable for NPUs.

As shown in the figure below, there are three usage methods based on the flash_attn situation:

1. For H100, B100, hardware that supports FA v3, ring_flash_attn uses FA v3.

2. For A100, L40, hardware that supports FA v2, ring_flash_attn uses FA v2.

3. For hardware such as NPUs that does not support FA, use torch's SDPA. In this case, there is no need to install `flash_attn`, and you should apply `UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.TORCH)`.

<p align="center">
<img src="./media/usp_fa.png">
</p>
3. For hardware such as NPUs that does not support FA, use torch to implement attention computation. In this case, there is no need to install `flash_attn`, and you should apply `LongContextAttention(ring_impl_type="basic", attn_type=FlashAttentionImpl.TORCH)`.

Option 1: pip install

Expand Down Expand Up @@ -99,8 +96,8 @@ set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
# attn_type could be FA, FA3, TORCH.
longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.FA)

# if you use Ulysses, where no flash_attn is supported, you can use the following code.
# UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.TORCH)
# if you use NPUs, where no flash_attn is supported, you can use the following code.
# LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.TORCH)

# extract a local shard for the global Q, K, V.
local_q = EXTRACT_FUNC_DICT["zigzag"](
Expand All @@ -126,7 +123,8 @@ local_out = usp_attn(
### 3.Test

```bash
torchrun --nproc_per_node=4 --master_port=12346 test/test_hybrid_attn.py --sp_ulysses_degree 2 --seqlen 1024 --use_bwd
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --use_bwd --ring_impl_type "zigzag" --causal --attn_impl fa
torchrun --nproc_per_node=4 ./test/test_hybrid_attn.py --sp_ulysses_degree 2 --use_bwd --ring_impl_type "zigzag" --causal --attn_impl torch
torchrun --nproc_per_node 8 test/test_hybrid_qkvpacked_attn.py
```

Expand Down
22 changes: 19 additions & 3 deletions test/test_hybrid_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def parse_args():
help='whether to test backward pass (default: False)')
parser.add_argument('--sp_ulysses_degree', type=int, default=None,
help='sp_ulysses_degree (default: world_size)')
parser.add_argument('--ring_impl_type', type=str, default='basic',
choices=['basic', 'zigzag'],
help='ring implementation type (default: basic)')
parser.add_argument('--causal', action='store_true',
help='whether to use causal attention (default: False)')
parser.add_argument('--attn_impl', type=str, default='torch',
choices=['torch', 'fa', 'fa3'],
help='attention implementation type (default: torch)')
return parser.parse_args()

def log(msg, a, rank0_only=False):
Expand Down Expand Up @@ -66,15 +74,15 @@ def log(msg, a, rank0_only=False):
nheads = 32
d = 1280 // 32
dropout_p = 0
causal = True
causal = args.causal
deterministic = False

use_bwd = args.use_bwd

assert seqlen % world_size == 0
assert d % 8 == 0

ring_impl_type = "basic" # You can change this to "basic" or "zigzag" if needed
ring_impl_type = args.ring_impl_type

# Prepare inputs
q = torch.randn(
Expand Down Expand Up @@ -125,7 +133,15 @@ def log(msg, a, rank0_only=False):
local_k.requires_grad = True
local_v.requires_grad = True

usp_attn = LongContextAttention(ring_impl_type=ring_impl_type, attn_type=FlashAttentionImpl.FA)
# Map argument to FlashAttentionImpl enum
attn_impl_map = {
'torch': FlashAttentionImpl.TORCH,
'fa': FlashAttentionImpl.FA,
'fa3': FlashAttentionImpl.FA3
}

usp_attn = LongContextAttention(ring_impl_type=ring_impl_type,
attn_type=attn_impl_map[args.attn_impl])

if rank == 0:
print("#" * 30)
Expand Down
1 change: 1 addition & 0 deletions yunchang/comm/extract_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def zigzag_extract_local(value, rank, world_size, rd, ud, dim=1, *args, **kwargs
"basic": basic_extract_local,
"strip": stripe_extract_local,
"zigzag": zigzag_extract_local,
"basic_pytorch": basic_extract_local,
}
2 changes: 1 addition & 1 deletion yunchang/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(
value_layer = SeqAllToAll4D.apply(
self.ulysses_pg, value, self.scatter_idx, self.gather_idx, self.use_sync
)

out = self.ring_attn_fn(
query_layer,
key_layer,
Expand Down
2 changes: 2 additions & 0 deletions yunchang/hybrid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
zigzag_ring_flash_attn_qkvpacked_func,
stripe_flash_attn_func,
stripe_flash_attn_qkvpacked_func,
ring_pytorch_attn_func,
)

RING_IMPL_DICT = {
"basic": ring_flash_attn_func,
"zigzag": zigzag_ring_flash_attn_func,
"strip": stripe_flash_attn_func,
"basic_pytorch": ring_pytorch_attn_func,
}

RING_IMPL_QKVPACKED_DICT = {
Expand Down
14 changes: 10 additions & 4 deletions yunchang/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
flash_attn_backward,
flash_attn3_func_forward,
flash_attn3_func_backward,
torch_attn,
pytorch_attn_forward,
pytorch_attn_backward,
HAS_FLASH_ATTN_HOPPER
)
from enum import Enum, auto
Expand Down Expand Up @@ -64,10 +65,15 @@ def fn(q,
raise ValueError(f"Unknown stage: {stage}")

elif impl_type == FlashAttentionImpl.TORCH:
if stage == "fwd-bwd" or stage == "fwd-only":
return torch_attn
if stage == "fwd-only":
return pytorch_attn_forward
elif stage == "bwd-only":
return pytorch_attn_backward
elif stage == "fwd-bwd":
from yunchang.ring.ring_pytorch_attn import pytorch_attn_func
return pytorch_attn_func
else:
raise ValueError(f"FlashAttentionImpl.TORCH: bwd-only is not supported")
raise ValueError(f"Unknown stage: {stage}")
else:
raise ValueError(f"Unknown flash attention implementation: {impl_type}")

Expand Down
128 changes: 108 additions & 20 deletions yunchang/kernels/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Tuple
from yunchang.globals import HAS_FLASH_ATTN, HAS_FLASH_ATTN_HOPPER

import math
import torch
if HAS_FLASH_ATTN:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
Expand All @@ -16,26 +18,112 @@

import torch.nn.functional as F

def torch_attn(q, k, v, dropout_p = 0.0,
softmax_scale = None,
causal=False,
*args, **kwargs):
batch_size, seq_len, hs, hd = q.size()
query = q.view(batch_size, -1, hs, hd).transpose(1, 2)
key = k.view(batch_size, -1, hs, hd).transpose(1, 2)
value = v.view(batch_size, -1, hs, hd).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=causal
)

hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, hs, hd
)
hidden_states = hidden_states.to(query.dtype)
return hidden_states,
import torch
aten = torch.ops.aten


def pytorch_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p=0.0,
softmax_scale=None,
causal=True,
window_size=(-1, -1),
softcap=None,
alibi_slopes=None,
return_softmax=False,
):
"""
q shape (bs, seqlen, nhead, hs)
k shape (bs, seqlen, nhead, hs)
v shape (bs, seqlen, nhead, hs)
"""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out, lse = aten._scaled_dot_product_efficient_attention(
q,
k,
v,
attn_bias=None,
compute_log_sumexp=True,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)[:2]
out = out.transpose(1, 2)
lse = lse.to(q.dtype)
return out, lse

def pytorch_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
block_dq_buffer=None, # Add new parameters with default values
block_dk_buffer=None,
block_dv_buffer=None,
dropout_p=0.0,
softmax_scale=None,
bwd_causal=None, # This will replace the original causal parameter
window_size=None,
softcap=None,
alibi_slopes=None,
deterministic=True,
rng_state=None,
*args,
**kwargs,
):
# TODO(optim): use pytorch _scaled_dot_product_efficient_attention_backward
# Use efficient attention backward
# https://github.com/pytorch/pytorch/blob/main/tools/autograd/derivatives.yaml#L2874

# preprocess to reuse the original code
# from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = out.transpose(1, 2)
dout = dout.transpose(1, 2)

batch_size, nheads, seqlen, d = q.shape
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

# Recreate S and P from log_sum_exp
S = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale
if bwd_causal:
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=q.device, dtype=torch.bool), diagonal=1)
S = S.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), float('-inf'))

P = torch.exp(S - softmax_lse.unsqueeze(-1))
# Step 1: Compute dV
dV = torch.matmul(P.transpose(-2, -1), dout)
# Step 2: Compute dP
dP = torch.matmul(dout, v.transpose(-2, -1))
# Step 3: Compute D
D = torch.sum(dout * out, dim=-1, keepdim=True)
# Step 4: Compute dS
dS = P * (dP - D)
# Apply causal mask to dS if is_causal is True
if bwd_causal:
dS = dS.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), 0)
# Step 5: Compute dQ
dQ = torch.matmul(dS, k) * softmax_scale
# Step 6: Compute dK
dK = torch.matmul(dS.transpose(-2, -1), q) * softmax_scale

# TODO() post process to reuse origina; code
dQ = dQ.transpose(1, 2)
dK = dK.transpose(1, 2)
dV = dV.transpose(1, 2)

return dQ, dK, dV

def flash_attn_forward(q, k, v,
dropout_p = 0.0,
Expand Down
4 changes: 4 additions & 0 deletions yunchang/ring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
stripe_flash_attn_kvpacked_func,
stripe_flash_attn_qkvpacked_func,
)

from .ring_pytorch_attn import (
ring_pytorch_attn_func,
)
12 changes: 0 additions & 12 deletions yunchang/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,6 @@ def ring_flash_attn_forward(
comm.commit()

if not causal or step <= comm.rank:
# block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
# q,
# k,
# v,
# dropout_p,
# softmax_scale,
# causal=causal and step == 0,
# window_size=window_size,
# softcap=softcap,
# alibi_slopes=alibi_slopes,
# return_softmax=True and dropout_p > 0,
# )
fn = select_flash_attn_impl(attn_type, stage="fwd-only")
block_out, block_lse = fn(
q,
Expand Down
Loading

0 comments on commit ada84fe

Please sign in to comment.