From d39d78fa08769ccef8a1a04deec3c37a5f89d524 Mon Sep 17 00:00:00 2001 From: Izzy Putterman Date: Mon, 10 Jul 2023 18:52:59 -0700 Subject: [PATCH] [OPS] Add more perf-tests, new features to FA (#1849) Adding new tests across the board for float32, bfloat16, non-powers-of-2 shapes (to test masks), and tests on sequence parallel for atomics. This also adds the sequence parallel features from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py. I am not sure about the best way to grab the baseline benchmarking numbers. I have access to V100s and A100s, but I saw on the tests it mentions " # A100 in the CI server is slow-ish for some reason. # On some other servers, we are getting about 90% peak for 8kx8x8k float16". Current plan is to run CI here and use those numbers for baseline, then match against my GPUs as a sanity check. --------- Co-authored-by: Phil Tillet --- python/test/regression/test_performance.py | 134 ++++-- .../unit/operators/test_flash_attention.py | 19 +- python/triton/ops/flash_attention.py | 404 ++++++++++++------ python/tutorials/06-fused-attention.py | 4 +- 4 files changed, 379 insertions(+), 182 deletions(-) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 4d5e4f5df5c0..3aa537b3ec8e 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -55,26 +55,39 @@ def nvsmi(attrs): (1024, 64, 1024): {'float16': 0.0692}, (4096, 64, 4096): {'float16': 0.264}, (8192, 64, 8192): {'float16': 0.452}, + # Non pow 2 shapes + (1000, 200, 100): {'float16': 0.084}, + (1000, 200, 700): {'float16': 0.084}, + (994, 136, 402): {'float16': 0.084}, + (995, 135, 409): {'float16': 0.084}, + (99, 1357, 409): {'float16': 0.084}, }, # NOTE: # A100 in the CI server is slow-ish for some reason. # On some other servers, we are getting about 90% peak for 8kx8x8k float16 'a100': { - (512, 512, 512): {'float16': 0.084, 'float32': 0.13, 'int8': 0.05}, - (1024, 1024, 1024): {'float16': 0.332, 'float32': 0.35, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.641, 'float32': 0.57, 'int8': 0.34}, - (4096, 4096, 4096): {'float16': 0.785, 'float32': 0.75, 'int8': 0.46}, - (8192, 8192, 8192): {'float16': 0.805, 'float32': 0.85, 'int8': 0.51}, + # square + (512, 512, 512): {'float16': 0.084, 'float32': 0.12, 'int8': 0.05}, + (1024, 1024, 1024): {'float16': 0.332, 'float32': 0.352, 'int8': 0.169}, + (2048, 2048, 2048): {'float16': 0.635, 'float32': 0.522, 'int8': 0.34}, + (4096, 4096, 4096): {'float16': 0.750, 'float32': 0.810, 'int8': 0.46}, + (8192, 8192, 8192): {'float16': 0.760, 'float32': 0.760, 'int8': 0.51}, # tall-skinny - (16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005}, - (16, 4096, 4096): {'float16': 0.044, 'float32': 0.0457, 'int8': 0.0259}, - (16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431}, - (64, 1024, 1024): {'float16': 0.028, 'float32': 0.0509, 'int8': 0.0169}, - (64, 4096, 4096): {'float16': 0.163, 'float32': 0.162, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.285, 'float32': 0.257, 'int8': 0.174}, - (1024, 64, 1024): {'float16': 0.033, 'float32': 0.0458, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.254, 'float32': 0.230, 'int8': 0.177}, + (16, 1024, 1024): {'float16': 0.008, 'float32': 0.009, 'int8': 0.005}, + (16, 4096, 4096): {'float16': 0.036, 'float32': 0.038, 'int8': 0.026}, + (16, 8192, 8192): {'float16': 0.056, 'float32': 0.061, 'int8': 0.043}, + (64, 1024, 1024): {'float16': 0.020, 'float32': 0.030, 'int8': 0.017}, + (64, 4096, 4096): {'float16': 0.160, 'float32': 0.162, 'int8': 0.097}, + (64, 8192, 8192): {'float16': 0.280, 'float32': 0.257, 'int8': 0.174}, + (1024, 64, 1024): {'float16': 0.040, 'float32': 0.050, 'int8': 0.017}, + (4096, 64, 4096): {'float16': 0.160, 'float32': 0.200, 'int8': 0.102}, + (8192, 64, 8192): {'float16': 0.250, 'float32': 0.23, 'int8': 0.177}, + # Non pow 2 shapes + (1000, 200, 100): {'float16': 0.011, 'float32': 0.017, 'int8': 0.05}, + (1000, 200, 700): {'float16': 0.027, 'float32': 0.047, 'int8': 0.05}, + (994, 136, 402): {'float16': 0.015, 'float32': 0.024, 'int8': 0.05}, + (995, 135, 409): {'float16': 0.015, 'float32': 0.025, 'int8': 0.05}, + (99, 1357, 409): {'float16': 0.011, 'float32': 0.036, 'int8': 0.05} } } @@ -82,10 +95,12 @@ def nvsmi(attrs): @pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str) for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16']]) + for dtype_str in ['float16', 'float32']]) def test_matmul(M, N, K, dtype_str): if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100': pytest.skip('Only test float32 & int8 on a100') + if (M, N, K) in [(64, 4096, 4096), (64, 8192, 8192), (8192, 64, 8192)] and dtype_str == 'float32': + pytest.skip('Out of shared memory in float32') dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] torch.manual_seed(0) ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] @@ -126,32 +141,44 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { 'v100': { - 1024 * 16: 0.0219, - 1024 * 64: 0.0791, - 1024 * 256: 0.243, - 1024 * 1024: 0.530, - 1024 * 4096: 0.796, - 1024 * 16384: 0.905, - 1024 * 65536: 0.939, + 1024 * 16: {'float16': 0.0219, 'float32': 0.010}, + 1024 * 64: {'float16': 0.0791, 'float32': 0.010}, + 1024 * 256: {'float16': 0.243, 'float32': 0.010}, + 1024 * 1024: {'float16': 0.530, 'float32': 0.010}, + 1024 * 4096: {'float16': 0.796, 'float32': 0.010}, + 1024 * 16384: {'float16': 0.905, 'float32': 0.010}, + 1024 * 65536: {'float16': 0.939, 'float32': 0.010}, + # Non pow 2 + 1020 * 100: {'float16': 0.010, 'float32': 0.010}, + 995 * 125: {'float16': 0.010, 'float32': 0.010}, + 10003 * 7007: {'float16': 0.010, 'float32': 0.010}, }, 'a100': { - 1024 * 16: 0.010, - 1024 * 64: 0.040, - 1024 * 256: 0.132, - 1024 * 1024: 0.353, - 1024 * 4096: 0.605, - 1024 * 16384: 0.758, - 1024 * 65536: 0.850, + 1024 * 16: {'float16': 0.010, 'bfloat16': 0.010, 'float32': 0.020}, + 1024 * 64: {'float16': 0.040, 'bfloat16': 0.040, 'float32': 0.066}, + 1024 * 256: {'float16': 0.132, 'bfloat16': 0.132, 'float32': 0.227}, + 1024 * 1024: {'float16': 0.353, 'bfloat16': 0.353, 'float32': 0.488}, + 1024 * 4096: {'float16': 0.605, 'bfloat16': 0.605, 'float32': 0.705}, + 1024 * 16384: {'float16': 0.758, 'bfloat16': 0.750, 'float32': 0.819}, + 1024 * 65536: {'float16': 0.850, 'bfloat16': 0.850, 'float32': 0.870}, + # Non pow 2 + 1020 * 100: {'float16': 0.051, 'bfloat16': 0.051, 'float32': 0.103}, + 995 * 125: {'float16': 0.063, 'bfloat16': 0.063, 'float32': 0.126}, + 10003 * 7007: {'float16': 0.544, 'bfloat16': 0.541, 'float32': 0.861}, } } @pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys()) -def test_elementwise(N): +@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32']) +def test_elementwise(N, dtype_str): torch.manual_seed(0) - ref_gpu_util = elementwise_data[DEVICE_NAME][N] + if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100': + pytest.skip('Only test bfloat16 on a100') + dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str] + ref_gpu_util = elementwise_data[DEVICE_NAME][N][dtype_str] max_gpu_perf = get_dram_gbps() - z = torch.empty((N, ), dtype=torch.float16, device='cuda') + z = torch.empty((N, ), dtype=dtype, device='cuda') x = torch.randn_like(z) y = torch.randn_like(z) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) @@ -169,29 +196,56 @@ def test_elementwise(N): flash_attention_data = { "a100": { - (4, 48, 4096, 64, 'forward', 'float16'): 0.37, - (4, 48, 4096, 64, 'backward', 'float16'): 0.25, + (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.420, + (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.202, + (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.355, + (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.201, + (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.099, + (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.087, + (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.238, + (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.135, + (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.211, + (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135, + (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.062, + (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052, + (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.424, + (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262, + (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.370, + (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254, + (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.099, + (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.125, + (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.238, + (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158, + (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.211, + (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134, + (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.062, + (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.075, } } -@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]]) +@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32']) @pytest.mark.parametrize("mode", ['forward', 'backward']) -@pytest.mark.parametrize("dtype_str", ['float16']) -def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str): +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("seq_par", [True, False]) +@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]]) +def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str): is_backward = mode == 'backward' capability = torch.cuda.get_device_capability() if capability[0] < 8: pytest.skip("Flash attention only supported for compute capability < 80") torch.manual_seed(20) - dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] + dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str] # init data + if dtype_str == 'float32': + N_CTX = 1024 + D_HEAD = 16 q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() sm_scale = 0.2 # benchmark - fn = lambda: triton.ops.attention(q, k, v, sm_scale) + fn = lambda: triton.ops.attention(q, k, v, causal, sm_scale, seq_par) if is_backward: o = fn() do = torch.randn_like(o) @@ -207,6 +261,6 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str): cur_sm_clock = nvsmi(['clocks.current.sm'])[0] max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) cur_gpu_util = cur_gpu_perf / max_gpu_perf - ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)] + ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)] print_perf(ms, cur_gpu_util, ref_gpu_util) triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 3c3e8a568839..55ff30774294 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -10,22 +10,23 @@ (4, 48, 1024, 64), (4, 48, 1024, 128)]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -def test_op(Z, H, N_CTX, D_HEAD, dtype): +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('seq_par', [True, False]) +def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): capability = torch.cuda.get_device_capability() if capability[0] < 8: pytest.skip("Flash attention only supported for compute capability < 80") torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() - sm_scale = 0.2 + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - for z in range(Z): - for h in range(H): - p[:, :, M == 0] = float("-inf") + if causal: + p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).to(dtype) # p = torch.exp(p) ref_out = torch.matmul(p, v) @@ -34,7 +35,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype): ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None # # triton implementation - tri_out = triton.ops.attention(q, k, v, sm_scale) + tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) # print(ref_out) # print(tri_out) tri_out.backward(dout) diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index d17fe515e317..6d55e6e73bea 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -3,6 +3,9 @@ =============== This is a Triton implementation of the Flash Attention algorithm (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) + +Sequence Parallel implementation inspired by HazyResearch +(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py) """ import torch @@ -23,68 +26,113 @@ def _fwd_kernel( Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + MODE: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk - off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk - # Initialize pointers to Q, K, V - q_ptrs = Q + off_q - k_ptrs = K + off_k - v_ptrs = V + off_v # initialize pointer to m and l - m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # causal check on every loop iteration can be expensive + # and peeling the last iteration of the loop does not work well with ptxas + # so we have a mode to do the causal check in a separate kernel entirely + if MODE == 0: # entire non-causal attention + lo, hi = 0, N_CTX + if MODE == 1: # entire causal attention + lo, hi = 0, (start_m + 1) * BLOCK_M + if MODE == 2: # off band-diagonal + lo, hi = 0, start_m * BLOCK_M + if MODE == 3: # on band-diagonal + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + m_i = tl.load(m_ptrs) + l_i = tl.load(l_ptrs) + acc += tl.load(O_block_ptr).to(tl.float32) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(K.dtype.element_ty) # loop over k, v and update accumulator - for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs) + k = tl.load(tl.advance(K_block_ptr, (0, start_n))) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - # compute new m - m_curr = tl.maximum(tl.max(qk, 1), m_prev) - # correct old l - l_prev *= tl.exp(m_prev - m_curr) - # attention weights - p = tl.exp(qk - m_curr[:, None]) - l_curr = tl.sum(p, 1) + l_prev - # rescale operands of matmuls - l_rcp = 1. / l_curr - p *= l_rcp[:, None] - acc *= (l_prev * l_rcp)[:, None] + qk += tl.dot(q, k, allow_tf32=True) + if MODE == 1 or MODE == 3: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.math.exp2(m_i - m_i_new) + beta = tl.math.exp2(m_ij - m_i_new) + l_i *= alpha + l_i_new = l_i + beta * l_ij + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new + acc = acc * acc_scale[:, None] # update acc - p = p.to(Q.dtype.element_ty) - v = tl.load(v_ptrs) - acc += tl.dot(p, v) + v = tl.load(tl.advance(V_block_ptr, (start_n, 0))) + p = p.to(V.dtype.element_ty) + acc += tl.dot(p, v, allow_tf32=True) # update m_i and l_i - l_prev = l_curr - m_prev = m_curr - # update pointers - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + l_i = l_i_new + m_i = m_i_new # write back l and m l_ptrs = L + off_hz * N_CTX + offs_m m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(l_ptrs, l_prev) - tl.store(m_ptrs, m_prev) - # initialize pointers to output - offs_n = tl.arange(0, BLOCK_DMODEL) - off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # write back O + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) @jit @@ -107,94 +155,168 @@ def _bwd_preprocess( tl.store(Delta + off_m, delta) +@jit +def _bwd_kernel_one_col_block( + Q, K, V, sm_scale, qk_scale, + Out, DO, + DQ, DK, DV, + L, M, + D, + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + off_hz, start_n, num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + MODE: tl.constexpr, +): + if SEQUENCE_PARALLEL: + DQ += stride_dqa.to(tl.int64) * start_n + if MODE == 0: + lo = 0 + else: + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + if MODE == 1: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= qk_scale + m = tl.load(m_ptrs + offs_m_curr) + p = tl.math.exp2(qk - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp = tl.dot(do, tl.trans(v), allow_tf32=True) + # compute ds = p * (dp - delta[:, None]) + ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q, allow_tf32=True) + # compute dq + if not SEQUENCE_PARALLEL: + dq = tl.load(dq_ptrs) + dq += tl.dot(ds, k, allow_tf32=True) + tl.store(dq_ptrs, dq) + elif SEQUENCE_PARALLEL: + # dq = tl.dot(ds, k, allow_tf32=True) + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) + tl.store(dq_ptrs, dq) + + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + @jit def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, + # fmt: off + Q, K, V, sm_scale, + Out, DO, DQ, DK, DV, L, M, D, - stride_qz, stride_qh, stride_qm, stride_qk, + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, Z, H, N_CTX, - num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + MODE: tl.constexpr, + # fmt: on ): + qk_scale = sm_scale * 1.44269504 off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H # offset pointers for batch/head Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_qz + off_h * stride_qh - V += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh DO += off_z * stride_qz + off_h * stride_qh DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_qz + off_h * stride_qh - for start_n in range(0, num_block): - lo = start_n * BLOCK_M - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - m_ptrs = M + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, tl.trans(k)) - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - m = tl.load(m_ptrs + offs_m_curr) - p = tl.exp(qk * sm_scale - m[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, tl.trans(v)) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(Q.dtype.element_ty), k) - tl.store(dq_ptrs, dq) - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) + DK += off_z * stride_kz + off_h * stride_kh + DV += off_z * stride_vz + off_h * stride_vh + + num_block_n = tl.cdiv(N_CTX, BLOCK_N) + if not SEQUENCE_PARALLEL: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + Q, K, V, sm_scale, qk_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + off_hz, start_n, num_block_n, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, + MODE=MODE, + ) + else: + start_n = tl.program_id(1) + _bwd_kernel_one_col_block( + Q, K, V, sm_scale, qk_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + off_hz, start_n, num_block_n, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, + MODE=MODE, + ) class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, sm_scale): + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): # only support for Ampere now capability = torch.cuda.get_device_capability() if capability[0] < 8: @@ -209,58 +331,80 @@ def forward(ctx, q, k, v, sm_scale): L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 - - _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, num_warps=num_warps, - num_stages=2, - ) + if causal: + modes = [1] if q.shape[2] <= 2048 else [2, 3] + else: + modes = [0] + for mode in modes: + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=128, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, + MODE=mode, + num_warps=num_warps, + num_stages=2) ctx.save_for_backward(q, k, v, o, L, m) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel return o @staticmethod def backward(ctx, do): BLOCK = 128 q, k, v, o, l, m = ctx.saved_tensors + sequence_parallel = ctx.sequence_parallel + seq_len_kv = k.shape[2] do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) + if sequence_parallel: + replicas = cdiv(seq_len_kv, BLOCK) + new_dq_shape = (replicas,) + q.shape + dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) do_scaled = torch.empty_like(do) delta = torch.empty_like(l) + if ctx.causal: + mode = 1 + else: + mode = 0 _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( o, do, l, do_scaled, delta, BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - _bwd_kernel[(ctx.grid[1],)]( + _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( q, k, v, ctx.sm_scale, o, do_scaled, dq, dk, dv, l, m, delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + SEQUENCE_PARALLEL=sequence_parallel, + MODE=mode, + num_warps=8, num_stages=1, ) - return dq, dk, dv, None + + if len(dq.shape) == 5: + dq = dq.sum(dim=0) + return dq, dk, dv, None, None, None attention = _attention.apply diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 24ba46e4e90e..2c7254de04f6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -346,9 +346,7 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: - for z in range(Z): - for h in range(H): - p[:, :, M == 0] = float("-inf") + p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() # p = torch.exp(p) ref_out = torch.matmul(p, v)