Skip to content

Commit

Permalink
[OPS] Add more perf-tests, new features to FA (triton-lang#1849)
Browse files Browse the repository at this point in the history
Adding new tests across the board for float32, bfloat16, non-powers-of-2
shapes (to test masks), and tests on sequence parallel for atomics. This
also adds the sequence parallel features from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py.
I am not sure about the best way to grab the baseline benchmarking
numbers. I have access to V100s and A100s, but I saw on the tests it
mentions " # A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k
float16". Current plan is to run CI here and use those numbers for
baseline, then match against my GPUs as a sanity check.

---------

Co-authored-by: Phil Tillet <[email protected]>
  • Loading branch information
IzzyPutterman and ptillet authored Jul 11, 2023
1 parent 73e18e9 commit d39d78f
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 182 deletions.
134 changes: 94 additions & 40 deletions python/test/regression/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,37 +55,52 @@ def nvsmi(attrs):
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
# Non pow 2 shapes
(1000, 200, 100): {'float16': 0.084},
(1000, 200, 700): {'float16': 0.084},
(994, 136, 402): {'float16': 0.084},
(995, 135, 409): {'float16': 0.084},
(99, 1357, 409): {'float16': 0.084},
},
# NOTE:
# A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k float16
'a100': {
(512, 512, 512): {'float16': 0.084, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.332, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.641, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.785, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.805, 'float32': 0.85, 'int8': 0.51},
# square
(512, 512, 512): {'float16': 0.084, 'float32': 0.12, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.332, 'float32': 0.352, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.635, 'float32': 0.522, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.750, 'float32': 0.810, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.760, 'float32': 0.760, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.044, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.028, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.163, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.285, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.033, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.254, 'float32': 0.230, 'int8': 0.177},
(16, 1024, 1024): {'float16': 0.008, 'float32': 0.009, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.036, 'float32': 0.038, 'int8': 0.026},
(16, 8192, 8192): {'float16': 0.056, 'float32': 0.061, 'int8': 0.043},
(64, 1024, 1024): {'float16': 0.020, 'float32': 0.030, 'int8': 0.017},
(64, 4096, 4096): {'float16': 0.160, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.280, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.040, 'float32': 0.050, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.160, 'float32': 0.200, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.250, 'float32': 0.23, 'int8': 0.177},
# Non pow 2 shapes
(1000, 200, 100): {'float16': 0.011, 'float32': 0.017, 'int8': 0.05},
(1000, 200, 700): {'float16': 0.027, 'float32': 0.047, 'int8': 0.05},
(994, 136, 402): {'float16': 0.015, 'float32': 0.024, 'int8': 0.05},
(995, 135, 409): {'float16': 0.015, 'float32': 0.025, 'int8': 0.05},
(99, 1357, 409): {'float16': 0.011, 'float32': 0.036, 'int8': 0.05}
}
}


@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
for dtype_str in ['float16', 'float32']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
if (M, N, K) in [(64, 4096, 4096), (64, 8192, 8192), (8192, 64, 8192)] and dtype_str == 'float32':
pytest.skip('Out of shared memory in float32')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
Expand Down Expand Up @@ -126,32 +141,44 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,

elementwise_data = {
'v100': {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.530,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
1024 * 16: {'float16': 0.0219, 'float32': 0.010},
1024 * 64: {'float16': 0.0791, 'float32': 0.010},
1024 * 256: {'float16': 0.243, 'float32': 0.010},
1024 * 1024: {'float16': 0.530, 'float32': 0.010},
1024 * 4096: {'float16': 0.796, 'float32': 0.010},
1024 * 16384: {'float16': 0.905, 'float32': 0.010},
1024 * 65536: {'float16': 0.939, 'float32': 0.010},
# Non pow 2
1020 * 100: {'float16': 0.010, 'float32': 0.010},
995 * 125: {'float16': 0.010, 'float32': 0.010},
10003 * 7007: {'float16': 0.010, 'float32': 0.010},
},
'a100': {
1024 * 16: 0.010,
1024 * 64: 0.040,
1024 * 256: 0.132,
1024 * 1024: 0.353,
1024 * 4096: 0.605,
1024 * 16384: 0.758,
1024 * 65536: 0.850,
1024 * 16: {'float16': 0.010, 'bfloat16': 0.010, 'float32': 0.020},
1024 * 64: {'float16': 0.040, 'bfloat16': 0.040, 'float32': 0.066},
1024 * 256: {'float16': 0.132, 'bfloat16': 0.132, 'float32': 0.227},
1024 * 1024: {'float16': 0.353, 'bfloat16': 0.353, 'float32': 0.488},
1024 * 4096: {'float16': 0.605, 'bfloat16': 0.605, 'float32': 0.705},
1024 * 16384: {'float16': 0.758, 'bfloat16': 0.750, 'float32': 0.819},
1024 * 65536: {'float16': 0.850, 'bfloat16': 0.850, 'float32': 0.870},
# Non pow 2
1020 * 100: {'float16': 0.051, 'bfloat16': 0.051, 'float32': 0.103},
995 * 125: {'float16': 0.063, 'bfloat16': 0.063, 'float32': 0.126},
10003 * 7007: {'float16': 0.544, 'bfloat16': 0.541, 'float32': 0.861},
}
}


@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N):
@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32'])
def test_elementwise(N, dtype_str):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100':
pytest.skip('Only test bfloat16 on a100')
dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]
ref_gpu_util = elementwise_data[DEVICE_NAME][N][dtype_str]
max_gpu_perf = get_dram_gbps()
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
z = torch.empty((N, ), dtype=dtype, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
Expand All @@ -169,29 +196,56 @@ def test_elementwise(N):

flash_attention_data = {
"a100": {
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.420,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.202,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.355,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.201,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.099,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.087,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.135,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.062,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.424,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.370,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.099,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.125,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.062,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.075,
}
}


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32'])
@pytest.mark.parametrize("mode", ['forward', 'backward'])
@pytest.mark.parametrize("dtype_str", ['float16'])
def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("seq_par", [True, False])
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
is_backward = mode == 'backward'
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]
# init data
if dtype_str == 'float32':
N_CTX = 1024
D_HEAD = 16
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
# benchmark
fn = lambda: triton.ops.attention(q, k, v, sm_scale)
fn = lambda: triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
if is_backward:
o = fn()
do = torch.randn_like(o)
Expand All @@ -207,6 +261,6 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)]
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
19 changes: 10 additions & 9 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
Expand All @@ -34,7 +35,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = triton.ops.attention(q, k, v, sm_scale)
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
Expand Down
Loading

0 comments on commit d39d78f

Please sign in to comment.