forked from pytorch-labs/attention-gym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark.py
256 lines (210 loc) · 8.1 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from functools import lru_cache
from typing import Optional, List
import torch
import torch.nn.functional as F
from tabulate import tabulate
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
create_mask,
flex_attention,
_score_mod_signature,
_mask_mod_signature,
)
from triton.testing import do_bench
from attn_gym.masks.document_mask import length_to_offsets
from attn_gym.masks import (
causal_mask,
generate_sliding_window,
generate_prefix_lm_mask,
generate_doc_mask_mod,
)
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap
torch.set_default_device("cuda")
torch.manual_seed(0)
torch._dynamo.config.cache_size_limit = 1000
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)
# For better performance, you can use:
# flex_attention = torch.compile(_flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
data_type = torch.float16
# The kernels will utilize block sparsity to increase performance
print(f"Using the default sparsity block size: {_DEFAULT_SPARSE_BLOCK_SIZE}")
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
return block_mask
def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
return multiplier * flops * (1e3 / time_ms) / 1e12
def print_header(text):
width = 91
print("╔" + "═" * (width - 2) + "╗")
print(f"║ {text.center(width - 4)} ║")
print("╚" + "═" * (width - 2) + "╝")
def test_mask(
score_mod: Optional[_score_mod_signature] = None,
mask_mod: Optional[_mask_mod_signature] = None,
B: int = 16,
H: int = 16,
S: int = 8192,
D: int = 64,
skip_correctness: bool = False,
print_mask: bool = True,
device: str = "cuda",
):
assert score_mod is not None or mask_mod is not None, "Must provide a score_mod or mask_mod"
if mask_mod is not None:
block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device)
else:
block_mask = None
sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod
mask = create_mask(sdpa_mask_fn, 1, 1, S, S, device=device)
qkv = [
torch.randn(B, H, S, D, device=device, dtype=data_type, requires_grad=True)
for _ in range(3)
]
gradOut = torch.randn(B, H, S, D, device=device, dtype=torch.float16)
causal_fa2 = lambda: F.scaled_dot_product_attention(*qkv, is_causal=True)
sdpa_mask = lambda: F.scaled_dot_product_attention(*qkv, attn_mask=mask)
flex_attention_call = lambda: flex_attention(*qkv, score_mod=score_mod, block_mask=block_mask)
results = []
if block_mask is not None:
density = (100 - block_mask.sparsity()) / 100
else:
density = 1.0
causal_fav2_flops = 0.5 * B * H * D * S * S
flops = density * B * H * D * S * S
# Forward pass
causal_fa2_time = do_bench(causal_fa2)
sdpa_mask_time = do_bench(sdpa_mask)
flex_ms = do_bench(flex_attention_call)
# Backward pass
causal_fa2_out = causal_fa2()
sdpa_mask_out = sdpa_mask()
flex_out = flex_attention_call()
causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True))
sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True))
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))
print_header(
f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace(
"_", " "
).title()
)
# Inline correctness check
if not skip_correctness:
sdpa_mask_outs = []
flex_outs = []
for tensor in qkv:
tensor.grad = None
out1 = sdpa_mask()
sdpa_mask_outs.append(out1)
out1.backward(gradOut)
sdpa_mask_outs += [tensor.grad for tensor in qkv]
for tensor in qkv:
tensor.grad = None
out2 = flex_attention_call()
flex_outs.append(out2)
out2.backward(gradOut)
flex_outs += [tensor.grad for tensor in qkv]
for flex, sdpa_mask in zip(flex_outs, sdpa_mask_outs):
torch.testing.assert_close(flex, sdpa_mask, atol=1e-1, rtol=1e-2)
print("Correctness check passed ✅")
# Usage in your results formatting:
results = [
[
"causal FA2",
f"{causal_fa2_time:.4f}",
f"{calculate_tflops(causal_fav2_flops, causal_fa2_time, 4):.2f}",
f"{causal_fa2_bw_time:.4f}",
f"{calculate_tflops(causal_fav2_flops, causal_fa2_bw_time, 10):.2f}",
],
[
"F.sdpa + mask",
f"{sdpa_mask_time:.4f}",
f"{calculate_tflops(flops, sdpa_mask_time, 4):.2f}",
f"{sdpa_mask_bw_time:.4f}",
f"{calculate_tflops(flops, sdpa_mask_bw_time, 10):.2f}",
],
[
"flexattention",
f"{flex_ms:.4f}",
f"{calculate_tflops(flops, flex_ms, 4):.2f}",
f"{flex_bw_ms:.4f}",
f"{calculate_tflops(flops, flex_bw_ms, 10):.2f}",
],
]
print(
tabulate(
results,
headers=[
"Operation",
"FW Time (ms)",
"FW FLOPS (TF/s)",
"BW Time (ms)",
"BW FLOPS (TF/s)",
],
tablefmt="grid",
)
)
if print_mask:
print(f"\nBlock Mask:\n{block_mask}")
def run_document_masking(max_seq_len: int, num_docs: int):
import random
random.seed(0)
def generate_random_lengths(total_length, num_documents):
# Initialize all lengths to 1 to ensure each document has at least one token
lengths = [1] * num_documents
remaining_length = total_length - num_documents
# Randomly distribute the remaining length
for _ in range(remaining_length):
index = random.randint(0, num_documents - 1)
lengths[index] += 1
return lengths
lengths = generate_random_lengths(max_seq_len, num_docs)
offsets = length_to_offsets(lengths, "cuda")
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
test_mask(mask_mod=document_causal_mask, S=32768)
def main(examples: List[str] = ["all"]):
"""Run the benchmark with the given examples.
Args:
examples: List of examples to run. If "all" is specified, all examples will be run.
"""
available_examples = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}
if "all" in examples:
ex_to_run = list(available_examples.keys())
else:
ex_to_run = examples
for ex in ex_to_run:
if ex in available_examples:
available_examples[ex]()
else:
print(f"Warning: Unknown example key '{ex}'. Skipping.")
if __name__ == "__main__":
try:
from jsonargparse import ArgumentParser
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
parser = ArgumentParser(description="Run specific examples or all examples.")
parser.add_argument(
"--examples",
type=str,
nargs="+",
default=["all"],
help="List of examples to run. Use space to separate multiple examples. "
"Available options: causal, alibi, sliding_window, prefix_lm, "
"document, softcap, softcap_approx, or 'all' to run all examples.",
)
args = parser.parse_args()
main(**vars(args))