-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathtest_hybrid_qkvpacked_attn.py
183 lines (148 loc) · 4.73 KB
/
test_hybrid_qkvpacked_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.distributed as dist
from yunchang import (
LongContextAttentionQKVPacked,
set_seq_parallel_pg,
EXTRACT_FUNC_DICT,
RING_IMPL_QKVPACKED_DICT
)
from yunchang.kernels import AttnType
def log(msg, a, rank0_only=False):
world_size = dist.get_world_size()
rank = dist.get_rank()
if rank0_only:
if rank == 0:
print(
f"{msg}: "
f"max {a.abs().max().item()}, "
f"mean {a.abs().mean().item()}",
flush=True,
)
return
for i in range(world_size):
if i == rank:
if rank == 0:
print(f"{msg}:")
print(
f"[{rank}] "
f"max {a.abs().max().item()}, "
f"mean {a.abs().mean().item()}",
flush=True,
)
dist.barrier()
import os
def get_local_rank():
local_rank = int(os.getenv('LOCAL_RANK', '0'))
return local_rank
def test(ring_impl_type="zigzag"):
rank = dist.get_rank()
local_rank = get_local_rank()
world_size = dist.get_world_size()
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
print(f"rank {rank} local_rank {local_rank} world_size {world_size}")
batch_size = 2
seqlen = 1024
nheads = 8
d = 32
dropout_p = 0.0
causal = True
deterministic = False
assert seqlen % world_size == 0
assert d % 8 == 0
sp_ulysses_degree = 2 # min(world_size, nheads)
sp_ring_degree = world_size // sp_ulysses_degree
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=ring_impl_type,
attn_type=AttnType.FA)
## prepare input and output tensors
# global tensors
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
with torch.no_grad():
dist.broadcast(qkv, src=0)
dist.broadcast(dout, src=0)
# sharded tensors for long context attn
local_qkv = (
EXTRACT_FUNC_DICT[ring_impl_type](
qkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
.detach()
.clone()
)
local_qkv.requires_grad = True
local_dout = (
EXTRACT_FUNC_DICT[ring_impl_type](
dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
.detach()
.clone()
)
# shared tensors for reference
local_qkv_ref = local_qkv.detach().clone()
local_qkv_ref.requires_grad = True
dist.barrier()
if rank == 0:
print("#" * 30)
print("# forward:")
print("#" * 30)
print(f"local_qkv shape {local_qkv.shape}")
local_out = longctx_attn(
local_qkv,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
)
from flash_attn import flash_attn_qkvpacked_func
# local_out = out.chunk(world_size, dim=1)[rank]
# local_lse = lse.chunk(world_size, dim=-1)[rank]
out, lse, _ = flash_attn_qkvpacked_func(
qkv,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=True,
)
local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type](
out, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("out_ref", local_out_ref, rank0_only=True)
log("out", local_out, rank0_only=True)
# log("lse", lse, rank0_only=True)
log("out diff", local_out - local_out_ref)
# log("lse diff", local_lse - ring_lse)
dist.barrier()
# if rank == 0:
# print(local_out_ref)
# print(local_out)
if rank == 0:
print("#" * 30)
print("# backward:")
print("#" * 30)
# long context attn backward
local_out.backward(local_dout)
local_dqkv = local_qkv.grad
# local ring backward
out.backward(dout)
dqkv = qkv.grad
local_dqkv_ref = EXTRACT_FUNC_DICT[ring_impl_type](
dqkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dq", local_dqkv_ref)
log("dq diff", local_dqkv - local_dqkv_ref)
if __name__ == "__main__":
dist.init_process_group("nccl")
for ring_impl_type in ["basic", "zigzag"]:
print(f"ring_impl_type: {ring_impl_type}")
test(ring_impl_type)
if dist.is_initialized():
dist.destroy_process_group()