From 3aae9c18c11ce58865374e64640a474877ddee3d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 25 Jul 2024 01:28:44 -0700 Subject: [PATCH] Revert "Changes For FP8 (#1075)" This reverts commit 1899c970c8639e82e6b8a78408f4041425e9f900. --- hopper/benchmark_flash_attention.py | 281 ------------------- hopper/benchmark_flash_attention_fp8.py | 339 ----------------------- hopper/epilogue_fwd_sm90_tma.hpp | 3 +- hopper/flash_api.cpp | 45 +-- hopper/flash_attn_interface.py | 4 +- hopper/flash_fwd_hdim128_fp8_sm90.cu | 9 - hopper/flash_fwd_hdim256_fp8_sm90.cu | 9 - hopper/flash_fwd_hdim64_fp8_sm90.cu | 9 - hopper/flash_fwd_launch_template.h | 22 +- hopper/kernel_traits.h | 32 +-- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 117 ++------ hopper/setup.py | 3 - hopper/test_flash_attn.py | 120 +++----- hopper/utils.h | 82 ------ 14 files changed, 89 insertions(+), 986 deletions(-) delete mode 100644 hopper/benchmark_flash_attention.py delete mode 100644 hopper/benchmark_flash_attention_fp8.py delete mode 100644 hopper/flash_fwd_hdim128_fp8_sm90.cu delete mode 100644 hopper/flash_fwd_hdim256_fp8_sm90.cu delete mode 100644 hopper/flash_fwd_hdim64_fp8_sm90.cu diff --git a/hopper/benchmark_flash_attention.py b/hopper/benchmark_flash_attention.py deleted file mode 100644 index 9e8153059..000000000 --- a/hopper/benchmark_flash_attention.py +++ /dev/null @@ -1,281 +0,0 @@ -# Install the newest triton version with -# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" -import pickle -import math -import time -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange, repeat - -from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward -from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined - -from flash_attn import flash_attn_qkvpacked_func -from flash_attn_interface import flash_attn_func - -try: - from triton.ops.flash_attention import attention as attention_triton -except ImportError: - attention_triton = None - -try: - import xformers.ops as xops -except ImportError: - xops = None - -try: - import cudnn -except ImportError: - cudnn = None - - -def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): - assert mode in ["fwd", "bwd", "fwd_bwd"] - f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) - return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) - -def efficiency(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - - -def convert_to_cudnn_type(torch_type): - if torch_type == torch.float16: - return cudnn.data_type.HALF - elif torch_type == torch.bfloat16: - return cudnn.data_type.BFLOAT16 - elif torch_type == torch.float32: - return cudnn.data_type.FLOAT - elif torch_type == torch.int32: - return cudnn.data_type.INT32 - elif torch_type == torch.int64: - return cudnn.data_type.INT64 - else: - raise ValueError("Unsupported tensor data type.") - - -def cudnn_spda_setup(q, k, v, causal=False): - b, nheads, seqlen_q, headdim = q.shape - _, _, seqlen_k, _ = k.shape - assert v.shape == (b, nheads, seqlen_k, headdim) - assert cudnn is not None, 'CUDNN is not available' - q_gpu, k_gpu, v_gpu = q, k, v - o_gpu = torch.empty_like(q_gpu) - stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) - graph = cudnn.pygraph( - io_data_type=convert_to_cudnn_type(q.dtype), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - q = graph.tensor_like(q_gpu.detach()) - k = graph.tensor_like(k_gpu.detach()) - v = graph.tensor_like(v_gpu.detach()) - - o, stats = graph.sdpa( - name="sdpa", - q=q, - k=k, - v=v, - is_inference=False, - attn_scale=1.0 / math.sqrt(headdim), - use_causal_mask=causal, - ) - - o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) - stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.validate() - graph.build_operation_graph() - graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph.check_support() - graph.build_plans() - - variant_pack = { - q: q_gpu, - k: k_gpu, - v: v_gpu, - o: o_gpu, - stats: stats_gpu, - } - - workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) - - def run(*args, **kwargs): - graph.execute(variant_pack, workspace) - return o_gpu - - return run - - -def attention_pytorch(qkv, dropout_p=0.0, causal=True): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - dropout_p: float - Output: - output: (batch_size, seqlen, nheads, head_dim) - """ - batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, 'b t h d -> (b h) t d') - k = rearrange(k, 'b s h d -> (b h) d s') - softmax_scale = 1.0 / math.sqrt(d) - # Preallocate attn_weights for `baddbmm` - scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) - scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), - '(b h) t s -> b h t s', h=nheads) - if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) - attention = torch.softmax(scores, dim=-1) - attention_drop = F.dropout(attention, dropout_p) - output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - return output.to(dtype=qkv.dtype) - - -def time_fwd_bwd(func, *args, **kwargs): - time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark - time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) - return time_f[1].mean, time_b[1].mean - - -repeats = 30 -device = 'cuda' -dtype = torch.float16 - -# Ideally, seq-len should be divisible by 132 to avoid wave quantization. -# However, the existing Triton implementation doesn't support seq-len like 8448. -bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192)] -# bs_seqlen_vals = [(2, 8192)] -causal_vals = [False] -# headdim_vals = [64, 128] -headdim_vals = [128] -dim = 128 -dropout_p = 0.0 - -methods = (["Flash2", "Pytorch", "Flash3"] - + (["Triton"] if attention_triton is not None else []) - + (["xformers.c"] if xops is not None else []) - + (["xformers.f"] if xops is not None else []) - + (["cudnn"] if cudnn is not None else [])) - -time_f = {} -time_b = {} -time_f_b = {} -speed_f = {} -speed_b = {} -speed_f_b = {} -for causal in causal_vals: - for headdim in headdim_vals: - for batch_size, seqlen in bs_seqlen_vals: - config = (causal, headdim, batch_size, seqlen) - nheads = dim // headdim - qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) - f, b = time_fwd_bwd( - flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False - ) - time_f[config, "Flash2"] = f - time_b[config, "Flash2"] = b - - try: - qkv = qkv.detach().requires_grad_(True) - f, b = time_fwd_bwd( - attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False - ) - res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) - except: # Skip if OOM - f, b = float('nan'), float('nan') - time_f[config, "Pytorch"] = f - time_b[config, "Pytorch"] = b - - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - f, b = time_fwd_bwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False) - res = flash_attn_func(q, k, v, causal=causal) - - time_f[config, "Flash3"] = f - time_b[config, "Flash3"] = b - - if cudnn is not None: - time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark - res = benchmark_forward( - cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal), - repeats=repeats, verbose=False - ) - f = res[1].mean - time_f[config, "cudnn"] = f - time_b[config, "cudnn"] = math.inf - - if attention_triton is not None: - q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - # Try both values of sequence_parallel and pick the faster one - try: - f, b = time_fwd_bwd( - attention_triton, q, k, v, causal, headdim**(-0.5), - False, repeats=repeats, verbose=False - ) - except: - f, b = float('nan'), float('inf') - try: - _, b0 = time_fwd_bwd( - attention_triton, q, k, v, causal, headdim**(-0.5), - True, repeats=repeats, verbose=False - ) - except: - b0 = float('inf') - time_f[config, "Triton"] = f - time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') - - if xops is not None: - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - f, b = time_fwd_bwd( - xops.memory_efficient_attention, q, k, v, - attn_bias=xops.LowerTriangularMask() if causal else None, - op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) - ) - time_f[config, "xformers.c"] = f - time_b[config, "xformers.c"] = b - - if xops is not None: - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - f, b = time_fwd_bwd( - xops.memory_efficient_attention, q, k, v, - attn_bias=xops.LowerTriangularMask() if causal else None, - op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp) - ) - time_f[config, "xformers.f"] = f - time_b[config, "xformers.f"] = b - - print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") - for method in methods: - time_f_b[config, method] = time_f[config, method] + time_b[config, method] - speed_f[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), - time_f[config, method] - ) - speed_b[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), - time_b[config, method] - ) - speed_f_b[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), - time_f_b[config, method] - ) - #print (time_f[config,method]) - print( - f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " - f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " - f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" - ) - - -# with open('flash2_attn_time.plk', 'wb') as fp: -# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/benchmark_flash_attention_fp8.py b/hopper/benchmark_flash_attention_fp8.py deleted file mode 100644 index 7d9e234da..000000000 --- a/hopper/benchmark_flash_attention_fp8.py +++ /dev/null @@ -1,339 +0,0 @@ -# Install the newest triton version with -# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" -import pickle -import math -import time -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange, repeat - -from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward -from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined - -from flash_attn import flash_attn_qkvpacked_func -from flash_attn_interface import flash_attn_func - -try: - from triton_fused_attention import attention as attention_triton -except ImportError: - attention_triton = None - -try: - import xformers.ops as xops -except ImportError: - xops = None - -try: - import cudnn -except ImportError: - cudnn = None - - -def convert_to_cudnn_type(torch_type): - if torch_type == torch.float16: - return cudnn.data_type.HALF - elif torch_type == torch.bfloat16: - return cudnn.data_type.BFLOAT16 - elif torch_type == torch.float32: - return cudnn.data_type.FLOAT - elif torch_type == torch.int32: - return cudnn.data_type.INT32 - elif torch_type == torch.int64: - return cudnn.data_type.INT64 - elif torch_type == torch.float8_e4m3fn: - return cudnn.data_type.FP8_E4M3 - elif torch_type == torch.float8_e4m3fn: - return cudnn.data_type.FP8_E5M2 - else: - raise ValueError("Unsupported tensor data type.") - -def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False): - b, _, _, nheads, headdim = qkv.shape - assert cudnn is not None, 'CUDNN is not available' - o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) - o_gpu_transposed = torch.as_strided( - o_gpu, - [b, nheads, seqlen_q, headdim], - [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], - ) - stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device) - amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) - amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) - graph = cudnn.pygraph( - io_data_type=convert_to_cudnn_type(qkv.dtype), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - new_q = torch.as_strided( - qkv, - [b, nheads, seqlen_q, headdim], - [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], - storage_offset=0, - ) - q = graph.tensor( - name = "Q", - dim = list(new_q.shape), - stride = list(new_q.stride()), - data_type=convert_to_cudnn_type(qkv.dtype) - ) - new_k = torch.as_strided( - qkv, - [b, nheads, seqlen_k, headdim], - [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], - storage_offset=nheads * headdim, - ) - k = graph.tensor( - name = "K", - dim = list(new_k.shape), - stride = list(new_k.stride()), - data_type=convert_to_cudnn_type(qkv.dtype) - ) - new_v = torch.as_strided( - qkv, - [b, nheads, seqlen_k, headdim], - [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], - storage_offset=nheads * headdim * 2, - ) - v = graph.tensor( - name = "V", - dim = list(new_v.shape), - stride = list(new_v.stride()), - data_type=convert_to_cudnn_type(qkv.dtype) - ) - - def get_default_scale_tensor(): - return graph.tensor( - dim = [1, 1, 1, 1], - stride = [1, 1, 1, 1], - data_type=cudnn.data_type.FLOAT - ) - - default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") - descale_q = get_default_scale_tensor() - descale_k = get_default_scale_tensor() - descale_v = get_default_scale_tensor() - descale_s = get_default_scale_tensor() - scale_s = get_default_scale_tensor() - scale_o = get_default_scale_tensor() - - o, _, amax_s, amax_o = graph.sdpa_fp8( - q=q, - k=k, - v=v, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_s=descale_s, - scale_s=scale_s, - scale_o=scale_o, - is_inference=True, - attn_scale=1.0 / math.sqrt(headdim), - use_causal_mask=causal, - name="sdpa", - ) - - o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) - - amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) - amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) - # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.validate() - graph.build_operation_graph() - graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph.check_support() - graph.build_plans() - - variant_pack = { - q: new_q, - k: new_k, - v: new_v, - descale_q: default_scale_gpu, - descale_k: default_scale_gpu, - descale_v: default_scale_gpu, - descale_s: default_scale_gpu, - scale_s: default_scale_gpu, - scale_o: default_scale_gpu, - o: o_gpu_transposed, - amax_s: amax_s_gpu, - amax_o: amax_o_gpu, - } - - workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) - - def run(*args, **kwargs): - graph.execute(variant_pack, workspace) - return o_gpu, amax_o_gpu - - return run - - -def attention_pytorch(qkv, dropout_p=0.0, causal=True): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - dropout_p: float - Output: - output: (batch_size, seqlen, nheads, head_dim) - """ - batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, 'b t h d -> (b h) t d') - k = rearrange(k, 'b s h d -> (b h) d s') - softmax_scale = 1.0 / math.sqrt(d) - # Preallocate attn_weights for `baddbmm` - scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) - scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), - '(b h) t s -> b h t s', h=nheads) - if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) - attention = torch.softmax(scores, dim=-1) - attention_drop = F.dropout(attention, dropout_p) - output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - return output.to(dtype=qkv.dtype) - -def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): - assert mode in ["fwd", "bwd", "fwd_bwd"] - f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) - return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) - -def efficiency(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - -def time_fwd(func, *args, **kwargs): - time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark - time_f = benchmark_forward(func, *args, **kwargs) - return time_f[1].mean - - -torch.manual_seed(0) - -repeats = 30 -device = 'cuda' -# dtype = torch.float16 -dtype = torch.float8_e4m3fn - -#bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)] -bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)] -#bs_seqlen_vals = [(4, 4224), (2, 8448), (1, 8448 * 2)] -# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)] -# bs_seqlen_vals = [(4, 8448)] -causal_vals = [False, True] -#headdim_vals = [64, 128, 256] -headdim_vals = [128,256] -dim = 2048 -# dim = 128 -dropout_p = 0.0 - -methods = (["Pytorch","Flash3", "cuDNN"] - + (["Triton"] if attention_triton is not None else []) - # + (["xformers.c"] if xops is not None else []) - # + (["xformers.f"] if xops is not None else []) - ) - -time_f = {} -time_b = {} -time_f_b = {} -speed_f = {} -speed_b = {} -speed_f_b = {} -for causal in causal_vals: - for headdim in headdim_vals: - for batch_size, seqlen in bs_seqlen_vals: - torch.cuda.empty_cache() - config = (causal, headdim, batch_size, seqlen) - nheads = dim // headdim - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, - requires_grad=False) for _ in range(3)] - qkv = torch.stack([q, k, v], dim=2) - qkv = qkv.to(torch.float16) - - f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False) - time_f[config, "Pytorch"] = f - res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) - - if attention_triton is not None: - q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) - k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) - v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn) - scale = 1 / math.sqrt(headdim) - f = time_fwd( - attention_triton, q_transposed, k_transposed, v_transposed, - causal, scale, repeats=5, verbose=False, desc='Triton' - ) - f = time_fwd( - attention_triton, q_transposed, k_transposed, v_transposed, - causal, scale, repeats=repeats, verbose=False, desc='Triton' - ) - time_f[config, "Triton"] = f - res = attention_triton( - q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2), - causal, scale - ).half().transpose(1, 2) - torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5) - - out = torch.empty_like(q) - q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - - v_transposed = v.transpose(1,3).contiguous().clone() - #v_transposed = v.transpose(1,3).clone() - time.sleep(1) - f = time_fwd(flash_attn_func, q, k, v_transposed, causal=causal, repeats=repeats, verbose=False) - # res = flash_attn_func(q, k, v, causal=causal, is_fp16_acc=False) - # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05) - - time_f[config, "Flash3"] = f - - if cudnn is not None: - qkv_fp8 = qkv.to(dtype) - time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark - f = time_fwd( - cudnn_spda_setup( - qkv_fp8, seqlen, seqlen, - causal=causal - ), - repeats=repeats, verbose=False - ) - time_f[config, "cuDNN"] = f - # res, amax_o = cudnn_spda_setup( - # qkv_fp8, seqlen, seqlen, - # causal=causal - # )() - # res = res.half() - # TODO: CUDNN has numerics issues when - # num_heads=16, dim=128, seq_len=1024, batch_size=2 - # or larger sizes. - # res_cpu = res.cpu().reshape(-1) - # res_baseline_cpu = res_baseline.cpu().reshape(-1) - # print(amax_o) - # print(res) - # print(res_baseline) - # for i in range(len(res_cpu)): - # item = res_cpu[i] - # item_baseline = res_baseline_cpu[i] - # if abs(item - item_baseline) > 0.5: - # print(i) - # print(item) - # print(item_baseline) - # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05) - - print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") - for method in methods: - speed_f[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), - time_f[config, method] - ) - #print (time_f[config,method]) - print( - f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, " - ) - - -# with open('flash3_attn_time.plk', 'wb') as fp: -# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index 192e555a5..852343860 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -20,8 +20,7 @@ using namespace cute; template struct CollectiveEpilogueFwd { - using PrecType = typename Ktraits::Element; - using Element = decltype(cute::conditional_return>(cutlass::half_t{}, PrecType{})); + using Element = typename Ktraits::Element; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kHeadDim = Ktraits::kHeadDim; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index c684a1976..397ed4cc3 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -249,13 +249,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split } } } else { - if (params.d == 64) { - run_mha_fwd_(params, stream); - } else if (params.d == 128) { - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_(params, stream); - } + // run_mha_fwd_(params, stream); } } @@ -272,8 +266,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn, - "FlashAttention only support fp16, bf16 and fp8 (e4m3) data type for now"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type for now"); + // TODO: will add e4m3 later + // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, + // "FlashAttention only support fp16 and bf16 data type"); + // "FlashAttention only support fp16 and fp8 (e4m3) data type for now"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -303,50 +301,29 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - if (q_dtype == torch::kFloat8_e4m3fn) { - CHECK_SHAPE(v, batch_size, head_size_og, num_heads_k, seqlen_k); - } else { CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - } at::Tensor q_padded, k_padded, v_padded; - if (q_dtype == torch::kFloat8_e4m3fn) - { - if (head_size_og % 16 != 0) { - q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16})); - k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16})); - } else { - q_padded = q; - k_padded = k; - } - if (seqlen_k % 16 != 0) { - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 16 - seqlen_k % 16})); - } else { - v_padded = v; - } - } - else { - if (head_size_og % 8 != 0) { + if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - } else { + } else { q_padded = q; k_padded = k; v_padded = v; - } } at::Tensor out; if (out_.has_value()) { out = out_.value(); - //TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { - out = q_dtype == torch::kFloat8_e4m3fn ? torch::empty_like(q_padded, at::kHalf) : torch::empty_like(q_padded); + out = torch::empty_like(q_padded); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 07e3366bb..d88ab78ea 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -15,7 +15,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _flash_attn_forward(q, k, v, softmax_scale, causal): - # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, k, @@ -41,7 +41,7 @@ def _flash_attn_backward( causal ): # dq, dk, dv are allocated by us so they should already be contiguous - #dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd( dout, q, diff --git a/hopper/flash_fwd_hdim128_fp8_sm90.cu b/hopper/flash_fwd_hdim128_fp8_sm90.cu deleted file mode 100644 index 68dd61b82..000000000 --- a/hopper/flash_fwd_hdim128_fp8_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); -} diff --git a/hopper/flash_fwd_hdim256_fp8_sm90.cu b/hopper/flash_fwd_hdim256_fp8_sm90.cu deleted file mode 100644 index 42fe6bb1a..000000000 --- a/hopper/flash_fwd_hdim256_fp8_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); -} diff --git a/hopper/flash_fwd_hdim64_fp8_sm90.cu b/hopper/flash_fwd_hdim64_fp8_sm90.cu deleted file mode 100644 index e3312954a..000000000 --- a/hopper/flash_fwd_hdim64_fp8_sm90.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); -} diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 0ca2e4c60..cd7adb3bf 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -21,7 +21,6 @@ template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using Element = typename Kernel_traits::Element; - using ElementO = decltype(cute::conditional_return>(cutlass::half_t{}, Element{})); using TileShape_MNK = typename Kernel_traits::TileShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK; @@ -128,14 +127,10 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - if constexpr (is_same_v) { - //run_flash_fwd, Is_causal>(params, stream); - //run_flash_fwd, Is_causal>(params, stream); - run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - //run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - } + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Seqlen_traits + >(params, stream); }); }); }); @@ -148,11 +143,10 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { // Only use Cluster if number of tiles along seqlen_q is even BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - if constexpr (is_same_v) { - run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - } else { - run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); - } + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Seqlen_traits + >(params, stream); }); }); }); diff --git a/hopper/kernel_traits.h b/hopper/kernel_traits.h index 0335a9c87..90ee3ccf9 100644 --- a/hopper/kernel_traits.h +++ b/hopper/kernel_traits.h @@ -25,7 +25,6 @@ struct SharedStorageQKVO { cute::array_aligned> smem_o; }; struct { - cute::uint64_t tma_load_mbar[4]; // 4 TMA barriers pre-allocated for usage. cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; @@ -41,7 +40,6 @@ struct Flash_fwd_kernel_traits { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; - using ElementO = decltype(cute::conditional_return>(cutlass::half_t{}, Element{})); // The number of threads. static constexpr int kNWarps = kNWarps_; @@ -71,11 +69,9 @@ struct Flash_fwd_kernel_traits { decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutMNK{})); - using TiledMma1 = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(TileShape_MNK{})), - GMMA::Major::K, cute::conditional_return>( - GMMA::Major::K, GMMA::Major::MN)>(), + GMMA::Major::K, GMMA::Major::MN>(), AtomLayoutMNK{})); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutAtomVFp16 = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutVFp16 = - decltype(tile_to_shape(SmemLayoutAtomVFp16{}, + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); - using SmemLayoutVFp8 = - decltype(tile_to_shape(SmemLayoutAtomVFp8{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); - using SmemLayoutV = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVFp16{})); - // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutVtFp16 = - decltype(cute::composition(SmemLayoutVFp16{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), - make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); - - using SmemLayoutVt = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVtFp16{})); - - using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); - using SmemCopyAtomQ = Copy_Atom; + using SmemCopyAtomQ = Copy_Atom; - using SharedStorage = SharedStorageQKVO; using MainloopPipeline = typename cutlass::PipelineTmaAsync; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index f9c8e2a5a..2de15fb9c 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -43,30 +43,12 @@ struct CollectiveMainloopFwd { using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); - using SmemLayoutVFp8 = - decltype(tile_to_shape(SmemLayoutAtomVFp8{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); - - using SmemLayoutVFp16 = SmemLayoutK; + using SmemLayoutV = SmemLayoutK; // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutVtFp16 = - decltype(cute::composition(SmemLayoutVFp16{}, + using SmemLayoutVt = + decltype(cute::composition(SmemLayoutV{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), - make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); - - using SmemLayoutV = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVFp16{})); - using SmemLayoutVt = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVtFp16{})); - - // Dummy S layout for getting the shape for GEMM-II. - using SmemLayoutAtomS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); - using SmemLayoutS = - decltype(tile_to_shape(SmemLayoutAtomS{}, - make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{})))); - + make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); // using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom; // using SmemLayoutVt = // decltype(tile_to_shape(SmemLayoutAtomVt{}, @@ -103,19 +85,6 @@ struct CollectiveMainloopFwd { take<0, 2>(SmemLayoutK{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - // - using TileShapeVFP8 = decltype(make_shape(cute::get<2>(TileShape_MNK{}), cute::get<1>(TileShape_MNK{}))); - using TileShapeVFP16 = decltype(make_shape(cute::get<1>(TileShape_MNK{}), cute::get<2>(TileShape_MNK{}))); - using TileShapeV = decltype(cute::conditional_return>(TileShapeVFP8{}, TileShapeVFP16{})); - using TMA_VFP8 = decltype(make_tma_copy( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), - take<0, 2>(SmemLayoutV{}), - TileShapeV{}, - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - - using TMA_V = decltype(cute::conditional_return>(TMA_VFP8{}, TMA_KV{})); - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; @@ -128,7 +97,6 @@ struct CollectiveMainloopFwd { static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; - // Host side kernel arguments struct Arguments { Element const* ptr_Q; @@ -147,8 +115,7 @@ struct CollectiveMainloopFwd { typename Seqlen_traits::LayoutT layout_V; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; - TMA_KV tma_load_K; - TMA_V tma_load_V; + TMA_KV tma_load_K, tma_load_V; float const softmax_scale_log2; }; @@ -169,15 +136,12 @@ struct CollectiveMainloopFwd { SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - auto gmemLayoutVFp16 = args.shape_K; - auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16); - auto gmemLayoutV = cute::conditional_return>(gmemLayoutVFp8, gmemLayoutVFp16); - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), gmemLayoutV, args.layout_V.stride()); - TMA_V tma_load_V = make_tma_copy( + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); + TMA_KV tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), - cute::conditional_return>(select<2, 1>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{})), + select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any return {args.layout_Q, args.layout_K, args.layout_V, cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), @@ -234,10 +198,7 @@ struct CollectiveMainloopFwd { Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); - auto gmemLayoutVFp16 = mainloop_params.shape_K; - auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16); - auto gmemLayoutV = cute::conditional_return>(gmemLayoutVFp8, gmemLayoutVFp16); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(gmemLayoutV); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); auto [m_block, bidh, bidb] = block_coord; int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); @@ -246,34 +207,12 @@ struct CollectiveMainloopFwd { uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), TileShapeV{}, cute::conditional_return>(make_coord(_0{}, _), make_coord(_, _0{}))); // (N, K, _) - -#if 0 - if (threadIdx.x == 0 && blockIdx.x == 0) { - print ("\n"); - print (gV); - print ("\n"); - print (gK); - print ("\n"); - print ("\n"); - print (sV); - print ("\n"); - print (sK); - print ("\n"); - print (gmemLayoutVFp8); - print ("\n"); - print (gmemLayoutVFp16); - } - - // Tensor gQ = seqlen_traits_q.get_local_tile_tensor( - // mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) - // Tensor gK = seqlen_traits_k.get_local_tile_tensor( - // mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) - // Tensor gV = seqlen_traits_k.get_local_tile_tensor( - // mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + Tensor gQ = seqlen_traits_q.get_local_tile_tensor( + mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) + Tensor gK = seqlen_traits_k.get_local_tile_tensor( + mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + Tensor gV = seqlen_traits_k.get_local_tile_tensor( + mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -430,13 +369,6 @@ struct CollectiveMainloopFwd { // Note: S becomes P. Tensor tOrV = threadMma1.partition_fragment_B(sVt); - // Dummy sS to just get the shape correctly for GEMM-II. - Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutS{}); - Tensor tOrS = threadMma1.partition_fragment_A(sS); - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - ReorgCFp8toAFp8 reg2reg; - auto tOrPLayout = ReshapeTStoTP()(tSrS, tOrS); - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -450,6 +382,7 @@ struct CollectiveMainloopFwd { cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); @@ -491,11 +424,7 @@ struct CollectiveMainloopFwd { } softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - auto tSrSPrec = convert_type(tSrS); - if constexpr (is_same_v) { - reg2reg(tSrSPrec); - } - Tensor tOrP = make_tensor(tSrSPrec.data(), tOrPLayout); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); Tensor scores_scale = make_fragment_like(softmax.row_max); clear(scores_scale); @@ -527,11 +456,7 @@ struct CollectiveMainloopFwd { pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; ++smem_pipe_read_v; - auto tSrSPrec = convert_type(tSrS); - if constexpr (is_same_v) { - reg2reg(tSrSPrec); - } - cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP); + cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); } #pragma unroll 1 @@ -554,11 +479,7 @@ struct CollectiveMainloopFwd { ++smem_pipe_read_k; ++smem_pipe_read_v; // softmax.rescale_o(tOrO, scores_scale); - auto tSrSPrec = convert_type(tSrS); - if constexpr (is_same_v) { - reg2reg(tSrSPrec); - } - cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP); + cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); } // Tell warp 0 that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); diff --git a/hopper/setup.py b/hopper/setup.py index 2d3d01b94..5d01a029a 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -116,9 +116,6 @@ def append_nvcc_threads(nvcc_extra_args): "flash_fwd_hdim128_bf16_sm90.cu", "flash_fwd_hdim256_fp16_sm90.cu", "flash_fwd_hdim256_bf16_sm90.cu", - "flash_fwd_hdim64_fp8_sm90.cu", - "flash_fwd_hdim128_fp8_sm90.cu", - "flash_fwd_hdim256_fp8_sm90.cu", "flash_bwd_hdim64_fp16_sm90.cu", "flash_bwd_hdim128_fp16_sm90.cu", "flash_bwd_hdim256_fp16_sm90.cu", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index a37954e25..55ec48686 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -170,7 +170,7 @@ def test_flash_attn_output( (113, 211), (108, 256), (256, 512), - (384, 256), + (384, 256), (512, 256), (640, 128), (1024, 1024), @@ -261,87 +261,49 @@ def test_flash_attn_varlen_output( reorder_ops=True, ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - - -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["gqa"]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [64, 128, 256]) -#@pytest.mark.parametrize("d", [128]) -# @pytest.mark.parametrize("d", [256]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (64, 128), - (128, 128), - (256, 256), - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (384, 256), - (640, 128), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ], -) -def test_flash_attn_output_fp8( - seqlen_q, seqlen_k, d, causal, mha_type, dtype -): - device = "cuda" - # set seed - torch.random.manual_seed(0) - # batch_size = 40 - # nheads = 16 - batch_size = 9 - nheads = 6 - nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - # batch_size = 1 - # nheads = 1 - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.float16, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True) - out, lse = flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype).transpose(1,3).contiguous().clone(), causal=causal) - q = q.to(dtype).to(torch.float16) - k = k.to(dtype).to(torch.float16) - v = v.to(dtype).to(torch.float16) - out_ref, attn_ref = attention_ref( - q, - k, - v, - None, - None, - causal=causal, - ) - out_pt, attn_pt = attention_ref( - q, - k, - v, - None, - None, - causal=causal, - upcast=False, - reorder_ops=True, - ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # g = torch.randn_like(out) + # if d <= 128: + # ( + # dq_unpad, + # dk_unpad, + # dv_unpad, + # ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + # dk = dk_pad_fn(dk_unpad) + # dv = dk_pad_fn(dv_unpad) + # ( + # dq_ref, + # dk_ref, + # dv_ref, + # ) = torch.autograd.grad(out_ref, (q, k, v), g) + # ( + # dq_pt, + # dk_pt, + # dv_pt, + # ) = torch.autograd.grad(out_pt, (q, k, v), g) + # dq = dq_pad_fn(dq_unpad) + # print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + # print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + # print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + # print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + # print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + # print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + # print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + # print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + # print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + # print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + # print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + # print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + # if d <= 128: + # assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + # assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + # assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() diff --git a/hopper/utils.h b/hopper/utils.h index 21e7cc6b9..90116f8a7 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -228,88 +228,6 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor CUTLASS_DEVICE auto operator()(Fragment &accum) { - - using namespace cute; - - // First update `mi` to the max per-row - // - auto VT = shape<0>(accum); // number of vector elements per tile. - auto MT = shape<1>(accum); // number of tiles along M. - auto NT = shape<2>(accum); // number of tiles along N. - - auto data = accum.data(); - int n = 0; - -#pragma unroll - for (int i = 0; i < MT; ++i) { - - // Traverse 2-rows + 2-cols (2x2) simultaneously. - -#pragma unroll - for (int k = 0; k < NT * size<2>(VT) / 2; ++k) { - - auto upper = *reinterpret_cast(&data[n]); - auto lower = *reinterpret_cast(&data[n + 4]); - - auto upper0 = __byte_perm(upper, lower, selectorEx0); - auto lower0 = __byte_perm(upper, lower, selectorEx1); - upper0 = - __shfl_sync(uint32_t(-1), upper0, upper_map[threadIdx.x % 4], 4); - lower0 = - __shfl_sync(uint32_t(-1), lower0, lower_map[threadIdx.x % 4], 4); - - uint32_t *data_32bit = reinterpret_cast(&data[n]); - data_32bit[0] = __byte_perm(upper0, lower0, selectorEx4); - data_32bit[1] = __byte_perm(upper0, lower0, selectorEx5); - n += 8; - } - } - } -}; - - -// Reshape Utility for converting the layout from accumulator of GEMM-I -// to Operand A of GEMM-II. -struct ReshapeTStoTP { - template - CUTLASS_DEVICE auto operator()(FragmentC &&tC, FragmentQ &&tQ) { - - // get the layout of one row of Q. - auto layoutQRow = make_layout_like(tQ(_, 0, _).layout()); - // get the layout of M dimension of C. - auto layoutCM = get<1>(tC.layout()); - return make_layout(get<0>(layoutQRow), layoutCM, get<1>(layoutQRow)); - } -}; template