From 51236c913107f2f6098ac039a4aaa4841a443c25 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 16 Dec 2024 05:08:53 -0800 Subject: [PATCH] perf: Dense and sparse customizable flashattention-3 template (#667) This PR adds flashattention-3 template for improving prefill performance on hopper. Block/Vector-sparse support in FlashInfer early version are ported to FA-3 template with CustomStride abstraction in CuTE so that we can support PageAttention with any page size. The programming interface for FA3 template is exactly the same as our previous FA2 template while we add an argument `backend` to allow user to select their own backend. Functionalities that are missing in current template include custom mask and we plan to support it using JIT instead of AOT. H100 Reference performance on variable-length dense and sparse attention kernels (exposed through [BatchPrefillWithRaggedKVCacheWrapper](https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper) and [BatchDecodeWithPagedKVCacheWrapper](https://docs.flashinfer.ai/api/decode.html#flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper) API correspondingly, for sparse attention workload, we use PageAttention with `page_size=1`. ![image](https://github.com/user-attachments/assets/7e989f8c-8b0f-4c99-ad11-6102c2dc5090) FlashInfer's vector sparse (page_size=1) attention implementation can get 90% percent of the dense equivalent, reference benchmark: https://github.com/flashinfer-ai/flashinfer/blob/04ee9bceb5ab0a66c612c1abaee8fa28de2b2349/benchmarks/bench_hopper_attention . JIT support is left to the next PR because this PR is already heavy. For fp8 support, we will incorporate SageAttention-2 algorithm for numerical stability, and it's left to v0.2.1. Currently there is some discrepancy in attention variant interface for our FA2 and FA3 template and we will gradually fix the gap. cc @merrymercy @zhyncs @youkaichao @woosukkwon @jason-huang03 --- LICENSE | 22 + .../generate_batch_paged_prefill_sm90_inst.py | 96 ++++ ...generate_batch_ragged_prefill_sm90_inst.py | 97 ++++ .../generate_single_prefill_sm90_inst.py | 85 +++ aot_build_utils/generate_sm90.py | 200 +++++++ benchmarks/bench_hopper_attention.py | 201 +++++++ csrc/aot_extension_utils.h | 18 +- csrc/batch_decode.cu | 4 +- csrc/batch_prefill.cu | 4 +- csrc/batch_prefill_sm90.cu | 281 +++++++++ csrc/dispatch_utils.h | 18 +- csrc/flashinfer_gemm_sm90_ops.cu | 27 - csrc/flashinfer_ops.cu | 10 + csrc/flashinfer_ops_sm90.cu | 65 +++ csrc/flashinfer_page_ops.cu | 10 + csrc/group_gemm.cu | 2 +- csrc/group_gemm_sm90.cu | 2 +- csrc/page.cu | 27 + csrc/rope.cu | 2 - csrc/single_decode.cu | 2 +- csrc/single_prefill.cu | 2 +- csrc/single_prefill_sm90.cu | 112 ++++ flashinfer/jit/__init__.py | 11 +- flashinfer/jit/attention.py | 102 ++++ flashinfer/jit/batch_prefill_sm90_templ.py | 19 + flashinfer/jit/single_prefill_sm90_templ.py | 19 + flashinfer/page.py | 38 ++ flashinfer/prefill.py | 534 ++++++++++++++++-- flashinfer/sparse.py | 149 ++++- flashinfer/utils.py | 92 +++ include/flashinfer/attention/heap.h | 64 +++ .../attention/hopper/attention_updater.cuh | 257 +++++++++ .../attention/hopper/block_sparse_gather.cuh | 196 +++++++ .../flashinfer/attention/hopper/epilogue.cuh | 259 +++++++++ .../attention/hopper/kernel_traits.cuh | 120 ++++ .../flashinfer/attention/hopper/mainloop.cuh | 266 +++++++++ .../attention/hopper/mainloop_mma.cuh | 265 +++++++++ .../attention/hopper/named_barrier.cuh | 112 ++++ .../flashinfer/attention/hopper/params.cuh | 154 +++++ .../attention/hopper/prefill_sm90.cuh | 524 +++++++++++++++++ .../attention/hopper/sparse_mainloop.cuh | 327 +++++++++++ .../attention/hopper/tile_scheduler.cuh | 196 +++++++ include/flashinfer/attention/hopper/utils.cuh | 165 ++++++ .../flashinfer/attention/hopper/variants.cuh | 63 +++ include/flashinfer/attention/scheduler.cuh | 193 ++++++- include/flashinfer/cutlass_utils.cuh | 18 +- include/flashinfer/page.cuh | 51 ++ licenses/LICENSE.cutlass.txt | 27 + licenses/LICENSE.flashattention3.txt | 29 + setup.py | 23 +- ...sparse_indices_to_vector_sparse_offsets.py | 84 +++ tests/test_hopper.py | 218 +++++++ 52 files changed, 5730 insertions(+), 132 deletions(-) create mode 100644 aot_build_utils/generate_batch_paged_prefill_sm90_inst.py create mode 100644 aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py create mode 100644 aot_build_utils/generate_single_prefill_sm90_inst.py create mode 100644 aot_build_utils/generate_sm90.py create mode 100644 benchmarks/bench_hopper_attention.py create mode 100644 csrc/batch_prefill_sm90.cu delete mode 100644 csrc/flashinfer_gemm_sm90_ops.cu create mode 100644 csrc/flashinfer_ops_sm90.cu create mode 100644 csrc/single_prefill_sm90.cu create mode 100644 flashinfer/jit/batch_prefill_sm90_templ.py create mode 100644 flashinfer/jit/single_prefill_sm90_templ.py create mode 100644 include/flashinfer/attention/heap.h create mode 100644 include/flashinfer/attention/hopper/attention_updater.cuh create mode 100644 include/flashinfer/attention/hopper/block_sparse_gather.cuh create mode 100644 include/flashinfer/attention/hopper/epilogue.cuh create mode 100644 include/flashinfer/attention/hopper/kernel_traits.cuh create mode 100644 include/flashinfer/attention/hopper/mainloop.cuh create mode 100644 include/flashinfer/attention/hopper/mainloop_mma.cuh create mode 100644 include/flashinfer/attention/hopper/named_barrier.cuh create mode 100644 include/flashinfer/attention/hopper/params.cuh create mode 100644 include/flashinfer/attention/hopper/prefill_sm90.cuh create mode 100644 include/flashinfer/attention/hopper/sparse_mainloop.cuh create mode 100644 include/flashinfer/attention/hopper/tile_scheduler.cuh create mode 100644 include/flashinfer/attention/hopper/utils.cuh create mode 100644 include/flashinfer/attention/hopper/variants.cuh create mode 100644 licenses/LICENSE.cutlass.txt create mode 100644 licenses/LICENSE.flashattention3.txt create mode 100644 tests/test_block_sparse_indices_to_vector_sparse_offsets.py create mode 100644 tests/test_hopper.py diff --git a/LICENSE b/LICENSE index 261eeb9e9..7c8f7e140 100644 --- a/LICENSE +++ b/LICENSE @@ -199,3 +199,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +------------------------------------------------------------------------------------------------- +Some of the code in this project are adapted from other open-source projects with different +licenses. This product also bundles some third-party components under other open source licenses. +This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +BSD 3-Clause License +-------------------- + +include/flashinfer/attention/hopper/epilogue.cuh +include/flashinfer/attention/hopper/mainloop.cuh +include/flashinfer/attention/hopper/kernel_traits.cuh +include/flashinfer/attention/hopper/named_barrier.cuh +include/flashinfer/attention/hopper/tile_scheduler.cuh +include/flashinfer/attention/hopper/utils.cuh + +BSD 3-Clause "New" License +-------------------------- + +3rdparty/cutlass +include/flashinfer/attention/hopper/block_sparse_gather.cuh diff --git a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py new file mode 100644 index 000000000..80310a745 --- /dev/null +++ b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py @@ -0,0 +1,96 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + def get_insts(attention_variant): + return "\n".join( + [ + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( + Params& params, + cudaStream_t stream); + """.format( + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + attention_variant=attention_variant, + ) + ] + ) + + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include +#include + + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{dtype_q}>; +using DTypeKV = cutlass_dtype_t<{dtype_kv}>; +using DTypeO = cutlass_dtype_t<{dtype_out}>; + +using Params = BatchPrefillPagedParams; + +{get_insts("LogitsSoftCap")} + +{get_insts("StandardAttention")} + +}}""" + return content + + +if __name__ == "__main__": + pattern = ( + r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py new file mode 100644 index 000000000..e26a7389b --- /dev/null +++ b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py @@ -0,0 +1,97 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + + def get_insts(attention_variant): + return "\n".join( + [ + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( + Params& params, + cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( + Params& params, + cudaStream_t stream); + """.format( + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + attention_variant=attention_variant, + ) + ] + ) + + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include +#include + + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{dtype_q}>; +using DTypeKV = cutlass_dtype_t<{dtype_kv}>; +using DTypeO = cutlass_dtype_t<{dtype_out}>; + +using Params = BatchPrefillRaggedParams; + +{get_insts("LogitsSoftCap")} + +{get_insts("StandardAttention")} + +}} + """ + return content + + +if __name__ == "__main__": + pattern = ( + r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_single_prefill_sm90_inst.py b/aot_build_utils/generate_single_prefill_sm90_inst.py new file mode 100644 index 000000000..13e579994 --- /dev/null +++ b/aot_build_utils/generate_single_prefill_sm90_inst.py @@ -0,0 +1,85 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, +): + content = """#include +#include +#include + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{dtype_q}>; +using DTypeKV = cutlass_dtype_t<{dtype_kv}>; +using DTypeO = cutlass_dtype_t<{dtype_out}>; + +using Params = SinglePrefillParams; + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>( + Params& params, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>( + Params& params, + cudaStream_t stream); +}} + """.format( + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], + dtype_out=dtype_literal[dtype_out], + use_custom_mask="true" if int(mask_mode) == 2 else "false", + ) + return content + + +if __name__ == "__main__": + pattern = ( + r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu" + ) + + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_sm90.py b/aot_build_utils/generate_sm90.py new file mode 100644 index 000000000..f87f34e53 --- /dev/null +++ b/aot_build_utils/generate_sm90.py @@ -0,0 +1,200 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +from itertools import product +from pathlib import Path +from typing import List + +from . import ( + generate_batch_paged_prefill_sm90_inst, + generate_batch_ragged_prefill_sm90_inst, + generate_single_prefill_sm90_inst, +) + + +def get_sm90_instantiation_cu(args: argparse.Namespace) -> List[str]: + def write_if_different(path: Path, content: str) -> None: + if path.exists() and path.read_text() == content: + return + path.write_text(content) + + path: Path = args.path + head_dims: List[int] = args.head_dims + pos_encoding_modes: List[int] = args.pos_encoding_modes + allow_fp16_qk_reductions: List[int] = args.allow_fp16_qk_reductions + mask_modes: List[int] = args.mask_modes + enable_bf16: bool = args.enable_bf16 + + path.mkdir(parents=True, exist_ok=True) + + idtypes = ["i32"] + prefill_dtypes = ["f16"] + decode_dtypes = ["f16"] + fp16_dtypes = ["f16"] + if enable_bf16: + prefill_dtypes.append("bf16") + decode_dtypes.append("bf16") + fp16_dtypes.append("bf16") + + # single prefill files + single_prefill_sm90_uris = [] + for ( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + ) in product( + head_dims, + pos_encoding_modes, + allow_fp16_qk_reductions, + mask_modes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): + fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_sm90.cu" + content = generate_single_prefill_sm90_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + single_prefill_sm90_uris.append( + f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}_sm90" + ) + write_if_different(path / fname, content) + + # batch prefill files + batch_prefill_sm90_uris = [] + for ( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + idtype, + ) in product( + head_dims, + pos_encoding_modes, + allow_fp16_qk_reductions, + mask_modes, + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): + fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + content = generate_batch_paged_prefill_sm90_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}_sm90.cu" + content = generate_batch_ragged_prefill_sm90_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window in [True, False]: + for logits_soft_cap in [True, False]: + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + batch_prefill_sm90_uris.append( + f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{sliding_window}_" + f"use_logits_cap_{logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}_sm90" + ) + + return single_prefill_sm90_uris + batch_prefill_sm90_uris + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Generate cuda files") + parser.add_argument( + "--path", type=Path, required=True, help="Path to the dispatch inc file" + ) + parser.add_argument( + "--head_dims", type=int, required=True, nargs="+", help="Head dimensions" + ) + parser.add_argument( + "--pos_encoding_modes", + type=int, + required=True, + nargs="+", + help="Position encoding modes", + ) + parser.add_argument( + "--allow_fp16_qk_reductions", + type=lambda x: x if isinstance(x, int) else int(x.lower() == "true"), + required=True, + nargs="+", + help="Allow fp16 qk reductions", + ) + parser.add_argument( + "--mask_modes", + type=int, + required=True, + nargs="+", + help="Mask modes", + ) + parser.add_argument( + "--enable_bf16", + type=lambda x: x if isinstance(x, int) else x.lower() == "true", + required=True, + nargs="+", + help="Enable bf16", + ) + parser.add_argument( + "--enable_fp8", + type=lambda x: x if isinstance(x, int) else x.lower() == "true", + default=True, + nargs="+", + help="Enable fp8", + ) + args = parser.parse_args() + get_sm90_instantiation_cu(args) diff --git a/benchmarks/bench_hopper_attention.py b/benchmarks/bench_hopper_attention.py new file mode 100644 index 000000000..f5bcc19ea --- /dev/null +++ b/benchmarks/bench_hopper_attention.py @@ -0,0 +1,201 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton + +import flashinfer + + +def bench_single_prefill(seq_len, num_heads, causal, head_dim): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: flashinfer.single_prefill_with_kv_cache_return_lse( + q, k, v, causal=causal, backend=backend + ), + warmup=100, + rep=1000, + ) + for backend in ["fa2", "fa3"] + ) + + def flops(ms): + if causal: + return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + else: + return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + + print( + f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" + ) + + +def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + sm80_wrapper, sm90_wrapper = ( + flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"), + kv_layout="NHD", + backend=backend, + ) + for backend in ["fa2", "fa3"] + ) + + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + + for wrapper in [sm80_wrapper, sm90_wrapper]: + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + ) + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: wrapper.run(q, k, v), + warmup=100, + rep=1000, + ) + for wrapper in [sm80_wrapper, sm90_wrapper] + ) + + def flops(ms): + if causal: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + ) + else: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + ) + + print( + f"bench_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" + ) + + +def bench_batch_paged_prefill( + page_size, batch_size, num_heads, seq_len, causal, head_dim +): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k = torch.randn( + batch_size * seq_len // page_size, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + v = torch.randn( + batch_size * seq_len // page_size, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + + sm80_wrapper, sm90_wrapper = ( + flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"), + kv_layout="NHD", + backend=backend, + ) + for backend in ["fa2", "fa3"] + ) + + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange( + 0, batch_size * (seq_len // page_size) + 1, (seq_len // page_size) + ).int() + kv_indices = torch.arange(0, batch_size * (seq_len // page_size)).int() + last_page_len = torch.ones(batch_size, dtype=torch.int32) * page_size + + for wrapper in [sm80_wrapper, sm90_wrapper]: + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, # page_size + causal=causal, + ) + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: wrapper.run(q, (k, v)), + warmup=100, + rep=1000, + ) + for wrapper in [sm80_wrapper, sm90_wrapper] + ) + + def flops(ms): + if causal: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + ) + else: + return ( + batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + ) + + print( + f"bench_batch_paged_prefill (page_size={page_size} batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" + ) + + +if __name__ == "__main__": + bench_batch_paged_prefill(1, 128, 32, 1024, True, 128) + bench_batch_paged_prefill(1, 64, 32, 2048, True, 128) + bench_batch_paged_prefill(1, 32, 32, 4096, True, 128) + bench_batch_paged_prefill(1, 16, 32, 8192, True, 128) + bench_batch_paged_prefill(1, 1, 32, 32768, True, 128) + bench_batch_paged_prefill(16, 128, 32, 1024, True, 128) + bench_batch_paged_prefill(16, 64, 32, 2048, True, 128) + bench_batch_paged_prefill(16, 32, 32, 4096, True, 128) + bench_batch_paged_prefill(16, 16, 32, 8192, True, 128) + bench_batch_paged_prefill(16, 1, 32, 32768, True, 128) + bench_batch_ragged_prefill(128, 32, 1024, True, 128) + bench_batch_ragged_prefill(64, 32, 2048, True, 128) + bench_batch_ragged_prefill(32, 32, 4096, True, 128) + bench_batch_ragged_prefill(16, 32, 8192, True, 128) + bench_batch_ragged_prefill(1, 32, 32768, True, 128) diff --git a/csrc/aot_extension_utils.h b/csrc/aot_extension_utils.h index 76db0168d..b701c2898 100644 --- a/csrc/aot_extension_utils.h +++ b/csrc/aot_extension_utils.h @@ -30,15 +30,15 @@ #define DISPATCH_mask_mode(expr, const_expr, ...) \ _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) -#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ - [&]() -> bool { \ - if (use_logits_soft_cap) { \ - constexpr bool USE_LOGITS_SOFT_CAP = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool USE_LOGITS_SOFT_CAP = false; \ - return __VA_ARGS__(); \ - } \ +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ }() #define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \ diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index 19cea2c8b..95b79b22d 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -56,7 +56,7 @@ std::vector BatchDecodeWithPagedKVCachePlan( using DTypeKV = kv_type; using DTypeO = DTypeQ; return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { using ParamsT = BatchDecodeParams; using AttentionVariant = ComposedAttention 0, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { using ParamsT = BatchDecodeParams; using AttentionVariant = ComposedAttention; using RaggedAttentionVariant = ComposedAttention paged_kv( num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, static_cast(paged_k_cache.data_ptr()), diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu new file mode 100644 index 000000000..0d358e908 --- /dev/null +++ b/csrc/batch_prefill_sm90.cu @@ -0,0 +1,281 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "aot_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + BatchPrefillRaggedParams& params, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + BatchPrefillPagedParams& params, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +std::vector BatchPrefillWithKVCacheSM90Plan( + unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PrefillPlanSM90Info plan_info; + + using IdType = int32_t; + + cudaStream_t stream = reinterpret_cast(cuda_stream); + + cudaError_t status = PrefillSM90Plan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), kv_len_arr.data_ptr(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); + + return plan_info.ToVector(); +} + +void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + unsigned int head_dim = q.size(2); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.f; + bool use_swa = window_left != -1; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, qkv_type, [&] { + using DTypeQ = cutlass_dtype_t; + using DTypeKV = DTypeQ; + using DTypeO = DTypeQ; + using IdType = int32_t; + + BatchPrefillRaggedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_swa, USE_SWA, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = + BatchPrefillWithRaggedKVCacheDispatched(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); +} + +void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + QKVLayout kv_layout = static_cast(layout); + unsigned int num_kv_heads, page_size; + unsigned int head_dim = q.size(2); + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); + } else { + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + auto q_scalar_type = q.scalar_type(); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.f; + bool use_swa = window_left != -1; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, qkv_type, [&] { + using DTypeQ = cutlass_dtype_t; + using DTypeKV = DTypeQ; + using DTypeO = DTypeQ; + using IdType = int32_t; + + BatchPrefillPagedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + // (num_pages, page_size, num_heads, head_dim) + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); + } else { + // (num_pages, num_heads, page_size, head_dim) + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); + } + params.nnz_qo = q.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = num_kv_heads; + params.group_size = params.num_qo_heads / num_kv_heads; + params.page_size = page_size; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); + + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_swa, USE_SWA, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = + BatchPrefillWithPagedKVCacheDispatched(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index b34253e23..7885cd728 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -59,13 +59,13 @@ using namespace flashinfer; #define DISPATCH_mask_mode(expr, const_expr, ...) \ _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) -#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ - [&]() -> bool { \ - if (use_logits_soft_cap) { \ - constexpr bool USE_LOGITS_SOFT_CAP = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool USE_LOGITS_SOFT_CAP = false; \ - return __VA_ARGS__(); \ - } \ +#define DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ + [&]() -> bool { \ + if (use_logits_soft_cap) { \ + constexpr bool USE_LOGITS_SOFT_CAP = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool USE_LOGITS_SOFT_CAP = false; \ + return __VA_ARGS__(); \ + } \ }() diff --git a/csrc/flashinfer_gemm_sm90_ops.cu b/csrc/flashinfer_gemm_sm90_ops.cu deleted file mode 100644 index b6802e424..000000000 --- a/csrc/flashinfer_gemm_sm90_ops.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pytorch_extension_utils.h" - -void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, - at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, - at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, - int64_t cuda_stream); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, - "Cutlass Segment GEMM operator for SM90"); -} diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index e6676e3a7..3d84bfc4e 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -86,6 +86,14 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, unsigned int layout, int64_t cuda_stream); +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream); + //========== prefill ========== void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, @@ -226,6 +234,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // page m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("block_sparse_indices_to_vector_sparse_offsets", + &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); // prefill m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, diff --git a/csrc/flashinfer_ops_sm90.cu b/csrc/flashinfer_ops_sm90.cu new file mode 100644 index 000000000..cc3ac8695 --- /dev/null +++ b/csrc/flashinfer_ops_sm90.cu @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "aot_extension_utils.h" + +void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, + at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, + at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, + int64_t cuda_stream); + +void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream); + +std::vector BatchPrefillWithKVCacheSM90Plan( + unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); + +void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, + at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); + +void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse, int64_t cuda_stream); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, + "Cutlass Segment GEMM operator for SM90"); + m.def("single_prefill_with_kv_cache_sm90", &single_prefill_with_kv_cache_sm90); + m.def("batch_prefill_with_kv_cache_sm90_plan", &BatchPrefillWithKVCacheSM90Plan); + m.def("batch_prefill_with_ragged_kv_cache_sm90_run", &BatchPrefillWithRaggedKVCacheSM90Run); + m.def("batch_prefill_with_paged_kv_cache_sm90_run", &BatchPrefillWithPagedKVCacheSM90Run); +} diff --git a/csrc/flashinfer_page_ops.cu b/csrc/flashinfer_page_ops.cu index d78d4ac00..e365eb629 100644 --- a/csrc/flashinfer_page_ops.cu +++ b/csrc/flashinfer_page_ops.cu @@ -20,6 +20,16 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, unsigned int layout, int64_t cuda_stream); +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("block_sparse_indices_to_vector_sparse_offsets", + &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); } diff --git a/csrc/group_gemm.cu b/csrc/group_gemm.cu index 78779fe5d..8fae8b9bd 100644 --- a/csrc/group_gemm.cu +++ b/csrc/group_gemm.cu @@ -28,7 +28,7 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] { - using cutlass_t = typename cutlass_dtype::value; + using cutlass_t = cutlass_dtype_t; auto status = CutlassSegmentGEMMRun( workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), diff --git a/csrc/group_gemm_sm90.cu b/csrc/group_gemm_sm90.cu index 3710cf2fd..5341a204d 100644 --- a/csrc/group_gemm_sm90.cu +++ b/csrc/group_gemm_sm90.cu @@ -30,7 +30,7 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo cudaStream_t stream = reinterpret_cast(cuda_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] { - using cutlass_t = typename cutlass_dtype::value; + using cutlass_t = cutlass_dtype_t; auto status = CutlassSegmentGEMMSM90Run( float_workspace_buffer.data_ptr(), float_workspace_buffer.element_size() * float_workspace_buffer.size(0), diff --git a/csrc/page.cu b/csrc/page.cu index 644a7dc62..db6841944 100644 --- a/csrc/page.cu +++ b/csrc/page.cu @@ -110,3 +110,30 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", kv_scalar_dtype); } + +void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, + at::Tensor block_sparse_indptr, + at::Tensor vector_sparse_offsets, + at::Tensor vector_sparse_indptr, + at::Tensor kv_len_arr, unsigned int stride_block, + unsigned int stride_n, unsigned int batch_size, + unsigned int block_size, int64_t cuda_stream) { + CHECK_INPUT(block_sparse_indices); + CHECK_INPUT(block_sparse_indptr); + CHECK_INPUT(vector_sparse_offsets); + CHECK_INPUT(vector_sparse_indptr); + CHECK_INPUT(kv_len_arr); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + + cudaError_t status = BlockSparseIndicesToVectorSparseOffset( + static_cast(block_sparse_indices.data_ptr()), + static_cast(block_sparse_indptr.data_ptr()), + static_cast(vector_sparse_offsets.data_ptr()), + static_cast(vector_sparse_indptr.data_ptr()), + static_cast(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size, + stream); + + TORCH_CHECK(status == cudaSuccess, "BlockSparseIndicesToVectorSparseOffset failed with error: ", + cudaGetErrorString(status)); +} diff --git a/csrc/rope.cu b/csrc/rope.cu index a1018dbca..3f10357bf 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -139,7 +139,6 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r size_t k_rope_stride_h = k_rope.stride(1); cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaStream_t torch_current_stream(nullptr); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( static_cast(q.data_ptr()), static_cast(k.data_ptr()), @@ -231,7 +230,6 @@ void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, a size_t k_rope_stride_h = k_rope.stride(1); cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaStream_t torch_current_stream(nullptr); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { cudaError_t status = BatchQKApplyLlama31RotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), diff --git a/csrc/single_decode.cu b/csrc/single_decode.cu index 60f9114bb..60a2bd765 100644 --- a/csrc/single_decode.cu +++ b/csrc/single_decode.cu @@ -76,7 +76,7 @@ void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::T using DTypeKV = kv_type; using DTypeO = DTypeQ; return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { using ParamsT = SingleDecodeParams; using AttentionVariant = ComposedAttention; using AttentionVariant = ComposedAttention + +#include +#include +#include +#include +#include +#include +#include + +#include "aot_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, + cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, at::Tensor o, + unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, std::optional maybe_lse, + int64_t cuda_stream) { + unsigned int head_dim = q.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.0f; + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + using DTypeQ = cutlass_dtype_t; + using DTypeKV = DTypeQ; + using DTypeO = DTypeQ; + SinglePrefillParams params; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? (static_cast(maybe_lse->data_ptr())) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.causal = mask_mode == MaskMode::kCausal; + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + bool use_swa = window_left != -1; + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + return DISPATCH_BOOL(use_swa, USE_SWA, [&] { + using AttentionVariant = + std::conditional_t; + cudaError_t status = + SinglePrefillWithKVCacheDispatched( + params, stream); + TORCH_CHECK(status == cudaSuccess, + "single_prefill_with_kv_cache_sm90 failed with error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); + }); + }); +} diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 49265c3b6..6c418366a 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -17,9 +17,11 @@ # Re-export from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str +from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module from .attention import gen_batch_decode_module as gen_batch_decode_module from .attention import gen_batch_prefill_module as gen_batch_prefill_module +from .attention import gen_batch_prefill_sm90_module as gen_batch_prefill_sm90_module from .attention import ( gen_customize_single_decode_module as gen_customize_single_decode_module, ) @@ -28,20 +30,21 @@ ) from .attention import gen_single_decode_module as gen_single_decode_module from .attention import gen_single_prefill_module as gen_single_prefill_module +from .attention import gen_single_prefill_sm90_module as gen_single_prefill_sm90_module from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri from .attention import get_batch_decode_uri as get_batch_decode_uri +from .attention import get_batch_prefill_sm90_uri as get_batch_prefill_sm90_uri from .attention import get_batch_prefill_uri as get_batch_prefill_uri from .attention import get_single_decode_uri as get_single_decode_uri +from .attention import get_single_prefill_sm90_uri as get_single_prefill_sm90_uri from .attention import get_single_prefill_uri as get_single_prefill_uri from .core import clear_cache_dir, load_cuda_ops from .env import * from .utils import parallel_load_modules as parallel_load_modules -from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri - try: - from .. import _kernels - from .. import _kernels_sm90 + from .. import _kernels, _kernels_sm90 + has_prebuilt_ops = True except ImportError: has_prebuilt_ops = False diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention.py index 72c2fca26..ba24f7c19 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention.py @@ -23,6 +23,10 @@ from .batch_decode_mla_templ import batch_decode_mla_suffix, batch_decode_mla_templ from .batch_decode_templ import batch_decode_suffix, batch_decode_templ +from .batch_prefill_sm90_templ import ( + batch_prefill_sm90_suffix, + batch_prefill_sm90_templ, +) from .batch_prefill_templ import batch_prefill_suffix, batch_prefill_templ from .core import load_cuda_ops from .env import FLASHINFER_GEN_SRC_DIR @@ -31,6 +35,10 @@ single_decode_suffix, single_decode_templ, ) +from .single_prefill_sm90_templ import ( + single_prefill_sm90_suffix, + single_prefill_sm90_templ, +) from .single_prefill_templ import ( customizable_single_prefill_templ, single_prefill_suffix, @@ -247,6 +255,35 @@ def get_single_prefill_sources( ) +def get_single_prefill_sm90_sources( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> List[str]: + assert not use_fp16_qk_reduction, "fp16 qk reduction is not supported on sm90" + assert ( + pos_encoding_mode == 0 + ), "Currently we only support pos_encoding_mode=0 on sm90" + return render_templates( + single_prefill_sm90_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + "use_fp16_qk_reduction": "true" if use_fp16_qk_reduction else "false", + }, + ) + + def get_single_prefill_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -269,6 +306,10 @@ def get_single_prefill_uri( ) +def get_single_prefill_sm90_uri(*args): + return get_single_prefill_uri(*args) + "_sm90" + + def gen_single_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR uri = get_single_prefill_uri(*args) @@ -282,6 +323,19 @@ def gen_single_prefill_module(*args): return load_cuda_ops(uri, source_paths) +def gen_single_prefill_sm90_module(*args): + gen_directory = FLASHINFER_GEN_SRC_DIR + uri = get_single_prefill_sm90_uri(*args) + sources = get_single_prefill_sm90_sources(*args) + source_paths = [] + for suffix, source in zip(single_prefill_sm90_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(uri, source_paths) + + def get_batch_prefill_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -309,6 +363,37 @@ def get_batch_prefill_sources( ) +def get_batch_prefill_sm90_sources( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> List[str]: + assert not use_fp16_qk_reduction, "fp16 qk reduction is not supported on sm90" + assert ( + pos_encoding_mode == 0 + ), "Currently we only support pos_encoding_mode=0 on sm90" + return render_templates( + batch_prefill_sm90_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "dtype_idx": dtype_map[dtype_idx], + "head_dim": head_dim, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": "true" if use_sliding_window else "false", + "use_logits_soft_cap": "true" if use_logits_soft_cap else "false", + "use_fp16_qk_reduction": "true" if use_fp16_qk_reduction else "false", + }, + ) + + def get_batch_prefill_uri( dtype_q: torch.dtype, dtype_kv: torch.dtype, @@ -333,6 +418,10 @@ def get_batch_prefill_uri( ) +def get_batch_prefill_sm90_uri(*args): + return get_batch_prefill_uri(*args) + "_sm90" + + def gen_batch_prefill_module(*args): gen_directory = FLASHINFER_GEN_SRC_DIR uri = get_batch_prefill_uri(*args) @@ -346,6 +435,19 @@ def gen_batch_prefill_module(*args): return load_cuda_ops(uri, source_paths) +def gen_batch_prefill_sm90_module(*args): + gen_directory = FLASHINFER_GEN_SRC_DIR + uri = get_batch_prefill_sm90_uri(*args) + sources = get_batch_prefill_sm90_sources(*args) + source_paths = [] + for suffix, source in zip(batch_prefill_sm90_suffix, sources): + path = gen_directory / f"{uri}{suffix}" + source_paths.append(path) + write_if_different(path, source) + + return load_cuda_ops(uri, source_paths) + + def get_customize_single_decode_sources( dtype_q: torch.dtype, dtype_kv: torch.dtype, diff --git a/flashinfer/jit/batch_prefill_sm90_templ.py b/flashinfer/jit/batch_prefill_sm90_templ.py new file mode 100644 index 000000000..c06c4aac0 --- /dev/null +++ b/flashinfer/jit/batch_prefill_sm90_templ.py @@ -0,0 +1,19 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +batch_prefill_sm90_suffix = [".cu", "_pybind.cc"] + +batch_prefill_sm90_templ = [r"""""", r""""""] diff --git a/flashinfer/jit/single_prefill_sm90_templ.py b/flashinfer/jit/single_prefill_sm90_templ.py new file mode 100644 index 000000000..917cf3684 --- /dev/null +++ b/flashinfer/jit/single_prefill_sm90_templ.py @@ -0,0 +1,19 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +single_prefill_sm90_suffix = [".cu", "_pybind.cc"] + +single_prefill_sm90_templ = [r"""""", r""""""] diff --git a/flashinfer/page.py b/flashinfer/page.py index a008fb085..b0f80903d 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -51,6 +51,44 @@ def get_page_module(): return _page_module +def block_sparse_indices_to_vector_sparse_offsets( + block_sparse_indices: torch.Tensor, + block_sparse_indptr: torch.Tensor, + vector_sparse_offsets: torch.Tensor, + vector_sparse_indptr: torch.Tensor, + kv_lens: torch.Tensor, + stride_block: int, + stride_n: int, + block_size: int, +) -> torch.Tensor: + if block_size == 1: + if stride_block == 1: + return block_sparse_indices + else: + return block_sparse_indices * stride_block + + with block_sparse_indices.device as device: + assert block_sparse_indices.dtype == torch.int32 + assert block_sparse_indptr.dtype == torch.int32 + assert vector_sparse_offsets.dtype == torch.int32 + assert vector_sparse_indptr.dtype == torch.int32 + assert kv_lens.dtype == torch.int32 + batch_size = block_sparse_indptr.size(0) - 1 + get_page_module().block_sparse_indices_to_vector_sparse_offsets( + block_sparse_indices, + block_sparse_indptr, + vector_sparse_offsets, + vector_sparse_indptr, + kv_lens, + stride_block, + stride_n, + batch_size, + block_size, + get_cuda_stream(device), + ) + return vector_sparse_offsets + + @register_custom_op( "flashinfer::append_paged_kv_cache", mutates_args=("paged_k_cache", "paged_v_cache"), diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 829447613..d5e8f5a3f 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -24,13 +24,18 @@ from .jit import ( gen_batch_prefill_module, + gen_batch_prefill_sm90_module, gen_single_prefill_module, + gen_single_prefill_sm90_module, + get_batch_prefill_sm90_uri, get_batch_prefill_uri, + get_single_prefill_sm90_uri, get_single_prefill_uri, has_prebuilt_ops, load_cuda_ops, prebuilt_ops_uri, ) +from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens from .quantization import packbits, segment_packbits from .utils import ( MaskMode, @@ -43,14 +48,95 @@ _get_cache_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, + determine_attention_backend, get_cuda_stream, is_float8, + log2e, register_custom_op, register_fake_op, ) _single_prefill_modules = {} +_single_prefill_sm90_modules = {} _batch_prefill_modules = {} +_batch_prefill_sm90_modules = {} + + +def get_single_prefill_sm90_module(*args): + global _single_prefill_sm90_modules + if args not in _single_prefill_sm90_modules: + uri = get_single_prefill_sm90_uri(*args) + # if has_prebuilt_ops and uri in prebuilt_ops_uri: + from . import _kernels_sm90 + + run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90 + # else: + # run_func = gen_single_prefill_sm90_module(*args).run + + @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "maybe_lse")) + def run_single_prefill_sm90( + mask_mode: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_packed_custom_mask: Optional[torch.Tensor], + tmp: torch.Tensor, + maybe_alibi_slopes: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + run_func( + mask_mode, + q, + k, + v, + maybe_packed_custom_mask, + # tmp, + maybe_alibi_slopes, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{uri}_run") + def _fake_run_single_prefill_sm90( + mask_mode: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_packed_custom_mask: Optional[torch.Tensor], + tmp: torch.Tensor, + maybe_alibi_slopes: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module + _single_prefill_sm90_modules[args] = SimpleNamespace( + run=run_single_prefill_sm90 + ) + + return _single_prefill_sm90_modules[args] def get_single_prefill_module(*args): @@ -130,6 +216,207 @@ def _fake_run_single_prefill( return _single_prefill_modules[args] +def get_batch_prefill_sm90_module(*args): + global _batch_prefill_sm90_modules + if args not in _batch_prefill_sm90_modules: + uri = get_batch_prefill_sm90_uri(*args) + + from . import _kernels_sm90 + + head_dim = args[4] + plan_func = ( + lambda *plan_args: _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan( + head_dim, + *plan_args, + ) + ) + ragged_run_func = _kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run + paged_run_func = _kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run + + # torch library for ragged_run + + @register_custom_op( + f"flashinfer::{uri}_ragged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "maybe_lse", + ), + ) + def ragged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + ragged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + kv_indptr, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{uri}_ragged_run") + def _fake_ragged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # torch library for paged_run + + @register_custom_op( + f"flashinfer::{uri}_paged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "maybe_lse", + ), + ) + def paged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + paged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{uri}_paged_run") + def _fake_paged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module. + # + # Note that plan is not part of model logic. It should not be included in + # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. + _batch_prefill_sm90_modules[args] = SimpleNamespace( + plan=plan_func, + ragged_run=ragged_run, + paged_run=paged_run, + ) + return _batch_prefill_sm90_modules[args] + + def get_batch_prefill_module(*args): global _batch_prefill_modules if args not in _batch_prefill_modules: @@ -236,7 +523,7 @@ def _fake_ragged_run( # torch library for paged_run @register_custom_op( - f"flashinfer::{get_batch_prefill_uri(*args)}_paged_run", + f"flashinfer::{uri}_paged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", @@ -428,6 +715,7 @@ def single_prefill_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: bool = False, + backend: str = "auto", ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Prefill/Append attention with KV cache for single request, return the attention output. @@ -485,6 +773,10 @@ def single_prefill_with_kv_cache( The theta used in RoPE, if not provided, will be set to 1e4. return_lse : bool Whether to return the log sum exp value of the attention logits. + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. Returns ------- @@ -563,7 +855,21 @@ def single_prefill_with_kv_cache( if return_lse: lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) - out = get_single_prefill_module( + if backend == "auto": + backend = determine_attention_backend( + q.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + packed_custom_mask is not None, # use_custom_mask + q.dtype, + k.dtype, + ) + if backend == "fa2": + module_getter = get_single_prefill_module + elif backend == "fa3": + module_getter = get_single_prefill_sm90_module + + out = module_getter( q.dtype, k.dtype, q.dtype, @@ -733,6 +1039,7 @@ def __init__( paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, custom_mask_buf: Optional[torch.Tensor] = None, qk_indptr_buf: Optional[torch.Tensor] = None, + backend: str = "auto", ) -> None: r"""Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`. @@ -783,14 +1090,35 @@ def __init__( should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True`` and the custom mask will be used in attention computation. + + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + if backend in ["fa3", "auto"]: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + # NOTE(Zihao): assume maximum accumulate kv length is 16M + self._vector_sparse_indices_buffer = torch.empty( + (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + ) + # NOTE(Zihao): assume maximum batch size is 32768 + self._vector_sparse_indptr_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + self._kv_lens_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -834,6 +1162,7 @@ def __init__( self._custom_mask_buf = custom_mask_buf self._qk_indptr_buf = qk_indptr_buf self._max_total_num_rows = None + self._backend = backend @property def is_cuda_graph_enabled(self) -> bool: @@ -1068,7 +1397,18 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - self._cached_module = get_batch_prefill_module( + + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + self._custom_mask_buf is not None, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( q_data_type, kv_data_type, q_data_type, @@ -1079,21 +1419,63 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - paged_kv_indptr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - get_cuda_stream(device), + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + paged_kv_indptr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + kv_lens_arr_host = get_seq_lens( + paged_kv_indptr_host, paged_kv_last_page_len_host, page_size + ) + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( + kv_lens_arr_host, non_blocking=non_blocking ) + if page_size != 1: + vector_sparse_indptr_host = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + ], + dim=0, + ) + self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) + ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) + else: + vector_sparse_indptr_host = paged_kv_indptr_host + + with self.device as device: + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + vector_sparse_indptr_host, + kv_lens_arr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction @@ -1199,6 +1581,13 @@ def run( _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) + stride_block = k_cache.stride(0) + if self._kv_layout == "NHD": + page_size = k_cache.shape[1] + stride_n = k_cache.stride(1) + else: + page_size = k_cache.shape[2] + stride_n = k_cache.stride(2) window_left = self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -1228,6 +1617,22 @@ def run( else: mask_mode = MaskMode.NON_CAUSAL.value + if self._backend == "fa3": + # NOTE(Zihao): we divide both stride_block and stride_n by stride_n + # because we will multiply stride_n back in the kernel + sparse_indices = block_sparse_indices_to_vector_sparse_offsets( + self._paged_kv_indices_buf, + self._paged_kv_indptr_buf, + self._vector_sparse_indices_buffer, # output + self._vector_sparse_indptr_buffer, + self._kv_lens_buffer, + stride_block // stride_n, + 1, # stride_n // stride_n + page_size, + ) + else: + sparse_indices = self._paged_kv_indices_buf + out = self._cached_module.paged_run( mask_mode, self._float_workspace_buffer, @@ -1240,7 +1645,7 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], q.device), self._qo_indptr_buf, self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, + sparse_indices, # self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, self._qk_indptr_buf, TensorLayout[self._kv_layout].value, @@ -1404,6 +1809,7 @@ def __init__( kv_indptr_buf: Optional[torch.Tensor] = None, custom_mask_buf: Optional[torch.Tensor] = None, qk_indptr_buf: Optional[torch.Tensor] = None, + backend: str = "auto", ) -> None: r"""Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`. @@ -1442,16 +1848,26 @@ def __init__( should be ``[batch_size]``. This argument is only effective when ``use_cuda_graph`` is ``True`` and custom mask will be used in attention computation. + + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + if backend in ["fa3", "auto"]: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True + self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: @@ -1478,6 +1894,7 @@ def __init__( self._custom_mask_buf = custom_mask_buf self._qk_indptr_buf = qk_indptr_buf self._max_total_num_rows = None + self._backend = backend @property def is_cuda_graph_enabled(self) -> bool: @@ -1671,7 +2088,18 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - self._cached_module = get_batch_prefill_module( + + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + self._custom_mask_buf is not None, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( q_data_type, kv_data_type, q_data_type, @@ -1682,21 +2110,45 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - 1, # page_size - self.is_cuda_graph_enabled, - get_cuda_stream(device), - ) + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] + with self.device as device: + # NOTE(Zihao): there are some interface differences between fa2 and fa3 + # we should align the interface in the future + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + kv_len_arr, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 4bc92d0d2..7732ce057 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -21,7 +21,12 @@ import torch from .decode import get_batch_decode_module -from .prefill import _compute_page_qk_indptr, get_batch_prefill_module +from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens +from .prefill import ( + _compute_page_qk_indptr, + get_batch_prefill_module, + get_batch_prefill_sm90_module, +) from .quantization import segment_packbits from .utils import ( MaskMode, @@ -30,6 +35,7 @@ _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, canonicalize_torch_dtype, + determine_attention_backend, get_cuda_stream, ) @@ -107,6 +113,7 @@ class BlockSparseAttentionWrapper: def __init__( self, float_workspace_buffer: torch.Tensor, + backend: str = "auto", ) -> None: r"""Constructs of :class:`BlockSparseAttentionWrapper`. @@ -116,12 +123,34 @@ def __init__( The user reserved float workspace buffer used to store intermediate attention results in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. + backend : str + The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + If set to ``auto``, the function will automatically choose the backend based on the + device architecture and kernel availability. """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device - ) + if backend in ["fa3", "auto"]: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + # NOTE(Zihao): assume maximum accumulate kv length is 16M + self._vector_sparse_indices_buffer = torch.empty( + (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + ) + # NOTE(Zihao): assume maximum batch size is 32768 + self._vector_sparse_indptr_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + self._kv_lens_buffer = torch.empty( + (32768,), dtype=torch.int32, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), + dtype=torch.uint8, + device=float_workspace_buffer.device, + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -139,6 +168,7 @@ def __init__( self.C: Optional[int] = None self.M: Optional[int] = None self.N: Optional[int] = None + self._backend = backend def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor @@ -176,6 +206,7 @@ def plan( head_dim: int, mask: Optional[torch.Tensor] = None, packed_mask: Optional[torch.Tensor] = None, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, @@ -217,6 +248,10 @@ def plan( packed_mask : torch.Tensor, optional The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + causal : bool + Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided in + :meth:`plan`. pos_encoding_mode : str, optional The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -303,7 +338,7 @@ def plan( else: self._packed_mask_buf = None self._qk_indptr_buf = None - mask_mode = MaskMode.NON_CAUSAL.value + mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value self._mask_mode = mask_mode self.M = M @@ -318,7 +353,7 @@ def plan( # at this moment, when mask is provided, we use the tensor-core implementation if ( R * (num_qo_heads // num_kv_heads) < 4 - and mask_mode == MaskMode.NON_CAUSAL.value + and mask_mode != MaskMode.CUSTOM.value ): # If the operation is not compute-bound, we use the cuda-core implementation self._use_tensor_cores = False @@ -349,7 +384,18 @@ def plan( else: # if the operation is compute-bound, we use the tensor-core implementation self._use_tensor_cores = True - self._cached_module = get_batch_prefill_module( + + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + mask_mode == MaskMode.CUSTOM.value, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( q_data_type, kv_data_type, q_data_type, @@ -360,21 +406,60 @@ def plan( logits_soft_cap > 0, # use_logits_soft_cap allow_fp16_qk_reduction, ) - - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - num_blocks_row, - num_qo_heads, - num_kv_heads, - C, - False, # is_cuda_graph_enabled - get_cuda_stream(device), + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + M, # total_num_rows + num_blocks_row, + num_qo_heads, + num_kv_heads, + C, + False, # is_cuda_graph_enabled + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + kv_lens_arr_host = (kv_indptr_host[1:] - kv_indptr_host[:-1]) * self.C + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( + kv_lens_arr_host, non_blocking=non_blocking ) + if self.C != 1: + vector_sparse_indptr_host = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + ], + dim=0, + ) + self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) + ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) + else: + vector_sparse_indptr_host = kv_indptr_host + + with self.device as device: + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + vector_sparse_indptr_host, + kv_lens_arr_host, + num_blocks_row, # batch_size + num_qo_heads, + num_kv_heads, + self.C, # page_size + False, # is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction @@ -426,7 +511,6 @@ def run( return_lse : bool Whether to return the logsumexp of attention output - Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] @@ -454,6 +538,10 @@ def run( k = k.reshape(-1, self.C, *k.shape[-2:]) v = v.reshape(-1, self.C, *v.shape[-2:]) + stride_block = k.stride(0) + stride_n = k.stride(1) + print(k.shape, stride_block, stride_n) + lse = None if return_lse: lse = torch.empty( @@ -461,6 +549,21 @@ def run( ) if self._use_tensor_cores: + if self._backend == "fa3": + sparse_indices = block_sparse_indices_to_vector_sparse_offsets( + self._paged_kv_indices_buf, + self._paged_kv_indptr_buf, + self._vector_sparse_indices_buffer, # output + self._vector_sparse_indptr_buffer, + self._kv_lens_buffer, + stride_block // stride_n, + 1, # stride_n // stride_n + self.C, # block_size + ) + print(self.C, sparse_indices, self._vector_sparse_indices_buffer) + else: + sparse_indices = self._paged_kv_indices_buf + out = self._cached_module.paged_run( self._mask_mode, self._float_workspace_buffer, @@ -473,7 +576,7 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], self.device), self._qo_indptr, self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, + sparse_indices, # self._paged_kv_indices_buf, self._paged_kv_last_page_len, self._qk_indptr_buf, TensorLayout[self._kv_layout].value, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 4abce3741..d38af8271 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -260,3 +260,95 @@ def determine_gemm_backend(device: torch.device) -> str: return "sm90" else: return "sm80" + + +def is_fa3_backend_supported( + pos_encoding_mode: int, + allow_fp16_qk_reductions: bool, + use_custom_mask: bool, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, +) -> bool: + """ + Check if the FA3 backend is supported based on the given parameters. + NOTE(Zihao): this function is a workaround for the lack of support for certain features in + our FA3 backend, and will be removed once the backend is fully supported. + + Parameters + ---------- + pos_encoding_mode : int + The positional encoding mode. + allow_fp16_qk_reductions : bool + Whether FP16 QK reductions are allowed. + use_custom_mask : bool + Whether a custom mask is used. + dtype_q : torch.dtype + The data type of the query tensor. + dtype_kv : torch.dtype + The data type of the key-value tensor. + + Returns + ------- + bool + True if the FA3 backend is supported, False otherwise. + """ + if use_custom_mask: + return False + if pos_encoding_mode != PosEncodingMode.NONE.value: + return False + if allow_fp16_qk_reductions: + return False + # NOTE: currently fp8 is not supported in our FA3 backend + # will add support soon + if dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]: + return False + if dtype_kv in [torch.float8_e4m3fn, torch.float8_e5m2]: + return False + return True + + +def determine_attention_backend( + device: torch.device, + pos_encoding_mode: int, + allow_fp16_qk_reductions: bool, + use_custom_mask: bool, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, +) -> str: + """ + Determine the appropriate attention backend based on the device and parameters. + + Parameters + ---------- + device : torch.device + The device to be used. + mask_mode : int + The mask mode. + pos_encoding_mode : int + The positional encoding mode. + allow_fp16_qk_reductions : bool + Whether FP16 QK reductions are allowed. + use_custom_mask : bool + Whether a custom mask is used. + dtype_q : torch.dtype + The data type of the query tensor. + dtype_kv : torch.dtype + The data type of the key-value tensor. + + Returns + ------- + str + The name of the attention backend to be used. + """ + major, _ = get_compute_capability(device) + + if major >= 9 and is_fa3_backend_supported( + pos_encoding_mode, + allow_fp16_qk_reductions, + use_custom_mask, + dtype_q, + dtype_kv, + ): + return "fa3" + else: + return "fa2" diff --git a/include/flashinfer/attention/heap.h b/include/flashinfer/attention/heap.h new file mode 100644 index 000000000..e742c9b6a --- /dev/null +++ b/include/flashinfer/attention/heap.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HEAP_H +#define FLASHINFER_ATTENTION_HEAP_H + +#include +#include +#include +#include + +namespace flashinfer { + +/*! + * \brief Heap data structure for (index, value) pairs + * \note minimal element on top + */ +class CTACostHeap { + public: + // first: index, second: cost + using Element = std::pair; + + CTACostHeap(int capacity) : heap_(capacity) { + for (int i = 0; i < capacity; ++i) { + heap_[i] = std::make_pair(i, 0.f); + } + } + + void insert(const Element& element) { + heap_.push_back(element); + std::push_heap(heap_.begin(), heap_.end(), compare); + } + + Element pop() { + std::pop_heap(heap_.begin(), heap_.end(), compare); + Element minElement = heap_.back(); + heap_.pop_back(); + return minElement; + } + + private: + // Custom comparator for the min-heap: compare based on 'val' in the pair + static bool compare(const Element& a, const Element& b) { + return a.second > b.second; // create a min-heap based on val + } + + std::vector heap_; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HEAP_H diff --git a/include/flashinfer/attention/hopper/attention_updater.cuh b/include/flashinfer/attention/hopper/attention_updater.cuh new file mode 100644 index 000000000..f9fc1abb7 --- /dev/null +++ b/include/flashinfer/attention/hopper/attention_updater.cuh @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_ATTENTION_UPDATER_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_ATTENTION_UPDATER_CUH_ + +#include +#include + +namespace flashinfer { + +using namespace cute; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } +}; + +template +struct SumOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, + Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, + Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, + Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, + Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, + Tensor& sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { + quad_allreduce_(sum, sum, sum_op); + } +} + +template +__forceinline__ __device__ void apply_exp2(Tensor& tensor, + Tensor const& max) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + auto row_max = max(mi); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) - row_max); + } + } +} + +template +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, + Tensor const& max, + const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + auto row_max = max(mi); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // row_max * scale is a constant for each row, so we can use fma here + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - row_max * scale); + } + } +} + +template +struct DefaultUpdater { + using TensorT = decltype(make_tensor(Shape>{})); + CUTLASS_DEVICE DefaultUpdater(float scale_ = 1.f) {}; + + __forceinline__ __device__ TensorT get_lse() { return TensorT(); } + + template + __forceinline__ __device__ void update(Tensor0& acc_s) { + // NOTE(Zihao): nothing to do here + }; + + template + __forceinline__ __device__ void finalize(Tensor1& acc_s) { + // NOTE(Zihao): nothing to do here + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { + // NOTE(Zihao): nothing to do here + }; +}; + +template +struct OnlineSoftmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum, scores_scale; + const float sm_scale_log2; + + CUTLASS_DEVICE OnlineSoftmax(float scale_ = 1.f) : sm_scale_log2(scale_) { clear(scores_scale); }; + + __forceinline__ __device__ TensorT get_lse() const { return row_sum; } + + template + __forceinline__ __device__ void update(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + if constexpr (init) { + reduce_max(scores, row_max); + if constexpr (WITH_SCALE) { + scale_apply_exp2(scores, row_max, sm_scale_log2); + } else { + apply_exp2(scores, row_max); + } + reduce_sum(scores, row_sum); + } else { + // update row_max + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + reduce_max(scores, row_max); + // update scores_scale and scale row_sum +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = row_max(mi); + if constexpr (WITH_SCALE) { + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2); + } else { + scores_scale(mi) = exp2f(scores_max_prev(mi) - scores_max_cur); + } + row_sum(mi) *= scores_scale(mi); + } + // perform exp2 on scores + if constexpr (WITH_SCALE) { + scale_apply_exp2(scores, row_max, sm_scale_log2); + } else { + apply_exp2(scores, row_max); + } + // update row_sum + reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ void finalize(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float sum = row_sum(mi); + float inv_sum = 1.f / sum; + scores_scale(mi) = inv_sum; + if constexpr (WITH_SCALE) { + row_sum(mi) = row_max(mi) * sm_scale_log2 + math::ptx_log2(sum); + } else { + row_sum(mi) = row_max(mi) + math::ptx_log2(sum); + } + } + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale(mi); + } + } + }; +}; + +template +using OnlineSoftmaxWithScale = OnlineSoftmax; + +template +using OnlineSoftmaxWithoutScale = OnlineSoftmax; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_ATTENTION_UPDATER_CUH_ diff --git a/include/flashinfer/attention/hopper/block_sparse_gather.cuh b/include/flashinfer/attention/hopper/block_sparse_gather.cuh new file mode 100644 index 000000000..29988a5bf --- /dev/null +++ b/include/flashinfer/attention/hopper/block_sparse_gather.cuh @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_BLOCK_SPARSE_GATHER_CUH +#define FLASHINFER_ATTENTION_HOPPER_BLOCK_SPARSE_GATHER_CUH + +#include + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" +#include "cutlass/fast_math.h" + +namespace flashinfer { + +using namespace cute; + +template +struct BlockSparseIndexedGather { + CUTE_HOST_DEVICE constexpr BlockSparseIndexedGather(IdType const* indices) : indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr IdType operator()(I i) const { + // NOTE(Zihao): there is a risk of out-of-bound access, adding boundary check here + // would degrade performance significantly. It is the user's responsibility to ensure + // that (indptr[-2] + TILE_KV) is less than the size of the indices tensor. + return indices_[i]; + } + + CUTE_HOST_DEVICE friend void print(BlockSparseIndexedGather const& s) { + cute::print("BlockSparseIndexedGather"); + } + + IdType const* indices_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride { + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, int stride_n) + : func_(func), stride_n_(stride_n) {} + + template + CUTE_HOST_DEVICE friend auto operator*(I i, CustomStride const& s) { + // uint64_t ret; + // #if defined(__CUDA_ARCH__) + // asm("{\n\t" + // "mul.wide.u32 %0, %1, %2;\n\t" + // "}" : "=l"(ret) : "r"(s.func_(i)), "r"(s.stride_n_)); + // #else + // ret = uint64_t(s.func_(i)) * uint64_t(s.stride_n_); + // #endif + // return ret; + + // NOTE(Zihao): if the tensor is larger than 64GB ((2 ** 32) * 16byte), we use + // 64-bit multiplication to avoid overflow. Otherwise, 32-bit multiplication is + // sufficient. + // There is a 20+ TFLOPs/s gap between 32-bit and 64-bit multiplication on H100. + return uint32_t(s.func_(i)) * s.stride_n_; + } + + template + CUTE_HOST_DEVICE friend auto operator*(CustomStride const& s, I i) { + // uint64_t ret; + // #if defined(__CUDA_ARCH__) + // asm("{\n\t" + // "mul.wide.u32 %0, %1, %2;\n\t" + // "}" : "=l"(ret) : "r"(s.func_(i)), "r"(s.stride_n_)); + // #else + // ret = uint64_t(s.func_(i)) * uint64_t(s.stride_n_); + // #endif + // return ret; + + // NOTE(Zihao): if the tensor is larger than 64GB = (2 ** 32) * 16byte (16byte is the + // element size after upcasting), we use 64-bit multiplication to avoid overflow. Otherwise, + // 32-bit multiplication is sufficient. + // There is a 20+ TFLOPs/s gap between 32-bit and 64-bit multiplication on H100. + return uint32_t(s.func_(i)) * s.stride_n_; + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) { + cute::print("BlockSparseStride{"); + print(s.func_); + cute::print(","); + print(s.stride_n_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { + return CustomStride(s.func_, safe_div(s.stride_n_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, + CustomStride const& stride) { + return Layout(shape, stride); + } + + Func func_; + uint32_t stride_n_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(int stride_n, Func&& func) { + return make_layout(make_shape(_1{}, _1{}), + make_stride(CustomStride(static_cast(func), stride_n), _1{})); +} + +/// Helper function to optionally create a block sparse gather tensor +template +CUTLASS_HOST_DEVICE auto make_block_sparse_tensor(Iterator iter, Shape const& shape, int stride_n, + Func&& func) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride_n, static_cast(func)); + + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +} + +} // namespace flashinfer + +namespace cute { + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) { + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, + [](auto const& s, auto const& d) { return upcast(s, d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, Offset, Layout> const& layout) { + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and + // offset + auto idx = + find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = + as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace cute + +#endif // FLASHINFER_ATTENTION_HOPPER_BLOCK_SPARSE_GATHER_CUH diff --git a/include/flashinfer/attention/hopper/epilogue.cuh b/include/flashinfer/attention/hopper/epilogue.cuh new file mode 100644 index 000000000..7f8b5a32c --- /dev/null +++ b/include/flashinfer/attention/hopper/epilogue.cuh @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_EPILOGUE_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_EPILOGUE_CUH_ + +#include + +#include "../../math.cuh" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "named_barrier.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +__forceinline__ __device__ void write_tiled(DTypeO* O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int thread_idx, int qo_tile_idx, + int qo_head_idx, int qo_indptr, int64_t qo_len) { + Tensor mO = make_tensor(make_gmem_ptr(O + qo_indptr * stride<0>(layout_O)), layout_O); + Tensor gO = + get_local_tile_tensor(mO, tile_shape_O, qo_head_idx, 0, qo_len)(_, _, qo_tile_idx); // (O, D) + Tensor cO = cute::make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx) + + ThrCopy thr_copy_O = tiled_copy_O.get_slice(thread_idx); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D) + Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D) + Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D) + Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D)) + Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D)) + Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D)) + + const int qo_tile_size = get<0>(tile_shape_O); + int valid_qo_tile_size = std::min(qo_len - qo_tile_idx * qo_tile_size, qo_tile_size); + if (valid_qo_tile_size == qo_tile_size) { + copy(tiled_copy_O, tOsOGroup, tOgOGroup); + } else { + // copy if not out of bound + auto predicate_fn = [&](auto coords) { + auto s_coords = tOcOGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_qo_tile_size); + }; + copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup); + } +} + +template +__forceinline__ __device__ void write_O(ElemO* O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int thread_idx, int qo_tile_idx, + int qo_head_idx, int qo_indptr, int qo_len, + int write_warp_idx) { + write_tiled(O, tiled_copy_O, layout_O, tile_shape_O, sO, thread_idx, + qo_tile_idx, qo_head_idx, qo_indptr, qo_len); +} + +template +struct CollectiveEpilogue { + using DTypeO = typename Ktraits::DTypeO; + static constexpr int CTA_Q = Ktraits::CTA_Q; + static constexpr int CTA_KV = Ktraits::CTA_KV; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + using TileShape_QKD = Shape, Int, Int>; + + static constexpr int NUM_WARPS = Ktraits::NUM_WARPS; + static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; + + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int NUM_MMA_THREADS = NUM_THREADS - NUM_COPY_THREADS; + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); + + using SmemCopyAtomO = Copy_Atom; + using SharedStorage = cute::array_aligned>; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + using TMA_O = decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{}, + select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for O + + static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v); + static_assert(HEAD_DIM % VEC_SIZE == 0); + static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM / VEC_SIZE; + static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0); + static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW; + using TiledCopyOAtom = cute::Copy_Atom, DTypeO>; + using TiledCopyOThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), LayoutRight{})); + using TiledCopyOValLayout = + decltype(cute::make_layout(cute::make_shape(_1{}, Int{}), LayoutRight{})); + using TiledCopyO = + decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout + TiledCopyOValLayout{} // Val layout + )); + + // used for rmem -> smem O copy in fp8 kernel to undo column permutation + using ThreadLayoutrO = Layout, _4, _1>, Stride<_4, _32, _1, _0>>; + using ValueLayoutrO = + Layout, Int>, Stride<_0, _2, Stride<_4, _1>, _8>>; + using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom, DTypeO>{}, + ThreadLayoutrO{}, ValueLayoutrO{})); + using TiledCopyShaperO = Shape<_8, Int, _16, Int>; + using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout{})); + + // Host side kernel arguments + struct Arguments { + DTypeO* O_ptr; + LayoutT const layout_O; + float* lse_ptr; + LayoutLseT const layout_LSE; + }; + + // Device side kernel params + struct Params { + DTypeO* O_ptr; + LayoutT const layout_O; + float* lse_ptr; + LayoutLseT const layout_LSE; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.O_ptr), args.layout_O); + return {args.O_ptr, args.layout_O, args.lse_ptr, args.layout_LSE}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) {} + + template + CUTLASS_DEVICE void store(Params const& epilogue_params, FrgTensorO const& tOrO, + FrgTensorLSE const& lse, SharedStorage& shared_storage, + TiledMma tiled_mma, int thread_idx, BlockCoord const& block_coord) { + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor tOrO_out = convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, + /*id=*/static_cast(NamedBarriers::kValueEmpty)); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); + Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, + qo_len)(_, qo_tile_idx); + Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_QKD{})); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. + Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (epilogue_params.lse_ptr) { // don't write to LSE if it's nullptr + if (get<1>(taccOcO_row(_0{})) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < qo_len - qo_tile_idx * CTA_Q) { + gLSE(row) = lse(mi); + } + } + } + } + + int write_warp_idx = NUM_WARPS - 1; + if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + TiledCopyO gmem_tiled_copy_O; + write_O(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, + select<0, 2>(TileShape_QKD{}), sO, thread_idx, qo_tile_idx, + qo_head_idx, qo_indptr, qo_len, write_warp_idx); + } + + CUTLASS_DEVICE void store_tail() { + // tma_store_wait<0>(); + } + + // Write 0 to output and -inf to LSE + template + CUTLASS_DEVICE void store_zero(Params const& epilogue_params, SharedStorage& shared_storage, + int thread_idx, BlockCoord const& block_coord) { + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.O_ptr), epilogue_params.layout_O); + Tensor gO = get_local_tile_tensor(mO, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, qo_tile_idx); // (O, D) + Tensor cO = cute::make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx) + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); + Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, + qo_len)(_, qo_tile_idx); + + TiledCopyO tiled_copy_O; + auto thr_copy_O = tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D) + Tensor tOrO = make_fragment_like(tOgO); // (CPY, CPY_O, CPY_D) + clear(tOrO); + Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D) + Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D)) + Tensor tOrOGroup = flatten_1(tOrO); // (CPY, (CPY_O, CPY_D)) + Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D)) + + const int qo_tile_size = get<0>(TileShape_QKD{}); + int valid_qo_tile_size = std::min(qo_len - qo_tile_idx * qo_tile_size, qo_tile_size); + if (valid_qo_tile_size == qo_tile_size) { + copy(tiled_copy_O, tOrOGroup, tOgOGroup); + } else { + auto predicate_fn = [&](auto coords) { + auto s_coords = tOcOGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_qo_tile_size); + }; + copy_if(tiled_copy_O, predicate_fn, tOrOGroup, tOgOGroup); + } + + static_assert(CTA_Q <= NUM_MMA_THREADS); + if (epilogue_params.lse_ptr) { // don't write to LSE if it's nullptr + if (thread_idx < qo_len - qo_tile_idx * CTA_Q) { + gLSE(thread_idx) = -math::inf; + } + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_EPILOGUE_CUH_ diff --git a/include/flashinfer/attention/hopper/kernel_traits.cuh b/include/flashinfer/attention/hopper/kernel_traits.cuh new file mode 100644 index 000000000..a144b708f --- /dev/null +++ b/include/flashinfer/attention/hopper/kernel_traits.cuh @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ + +#include + +#include "../../cutlass_utils.cuh" +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +namespace flashinfer { + +using namespace cute; + +template +struct SharedStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename MainloopPipeline::SharedStorage pipeline_k; + typename MainloopPipeline::SharedStorage pipeline_v; + }; +}; + +template +struct AttentionKernelTraits { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + using DTypeQKAccum = float; + + static constexpr int CTA_Q = CTA_Q_; + static_assert(CTA_Q % 64 == 0); + static constexpr int CTA_KV = CTA_KV_; + static constexpr int HEAD_DIM = HEAD_DIM_; + static_assert(HEAD_DIM % 32 == 0); + + static constexpr int NUM_WARPS = ((CTA_Q / 64) + 1) * 4; + static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; + // NOTE(Zihao): the following constant should only be used when TMA is enabled, + // where only one warp inside a warp group is used for TMA. + static constexpr int NUM_PRODUCER_THREADS = cutlass::NumThreadsPerWarp; + + using AttentionVariant = AttentionVariant_; + using TileShape_QKD = Shape, Int, Int>; + + static constexpr int NUM_STAGES = NUM_STAGES_; + + using AtomLayoutQKD = Layout, _1, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), AtomLayoutQKD{})); + using TiledMmaPV = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(TileShape_QKD{})), GMMA::Major::K, + GMMA::Major::MN>(), + AtomLayoutQKD{})); + + static constexpr int NUM_MMA_THREADS = size(TiledMmaQK{}); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutV = decltype(tile_to_shape( + SmemLayoutAtomV{}, + make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{}), Int{}))); + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVt = decltype(composition( + SmemLayoutV{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}), + get<1>(TileShape_QKD{}), Int{}), + Step<_2, _1, _3>{}))); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); + using MainloopPipeline = + std::conditional_t, + typename cutlass::PipelineAsync>; + using PipelineState = typename cutlass::PipelineState; + + using SharedStorage = SharedStorageQKVO; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ diff --git a/include/flashinfer/attention/hopper/mainloop.cuh b/include/flashinfer/attention/hopper/mainloop.cuh new file mode 100644 index 000000000..a6b561e5f --- /dev/null +++ b/include/flashinfer/attention/hopper/mainloop.cuh @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_MAINLOOP_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_MAINLOOP_CUH_ + +#include +#include +#include +#include + +#include "../../math.cuh" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "mainloop_mma.cuh" +#include "named_barrier.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +struct CollectiveMainloop { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; // (N, D, H) + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q + + using TMA_K = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + take<0, 2>(SmemLayoutK{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + + using TMA_V = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + take<0, 2>(SmemLayoutV{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + + static constexpr bool USE_TMA_LOAD_KV = true; + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = + static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = + static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + + // Whether use scheduler barrier or hardware warp scheduler, using heuristic based on data type + // and head dim + static constexpr bool UseSchedulerBarrier = + cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + using WarpScheduler = WarpScheduler; + + // Host side kernel arguments + struct Arguments { + DTypeQ const* Q_ptr; + LayoutT layout_Q; + DTypeKV const* K_ptr; + LayoutT layout_K; + DTypeKV const* V_ptr; + LayoutT layout_V; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + // Device side kernel params + struct Params { + LayoutT layout_Q; + LayoutT layout_K; + LayoutT layout_V; + TMA_Q tma_load_Q; + TMA_K tma_load_K; + TMA_V tma_load_V; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); + TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, + select<0, 2>(TileShape_QKD{}), _1{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.K_ptr), args.layout_K); + TMA_K tma_load_K = make_tma_copy(GmemTiledCopyKV{}, mK, SmemLayoutK{}(_, _, _0{}), + select<1, 2>(TileShape_QKD{}), _1{}); // no mcast + Tensor mV = make_tensor(make_gmem_ptr(args.V_ptr), args.layout_V); + TMA_V tma_load_V = make_tma_copy(GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), + select<1, 2>(TileShape_QKD{}), _1{}); // no mcast + return {args.layout_Q, args.layout_K, args.layout_V, tma_load_Q, tma_load_K, + tma_load_V, args.window_left, args.logits_soft_cap, args.sm_scale_log2}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, + const int kv_len) { + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); + if constexpr (CAUSAL) { + num_kv_tiles = std::min(num_kv_tiles, + cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); + } + + return num_kv_tiles; + } + + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, + Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + BlockCoord const& block_coord, int work_idx) { + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); + + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + + // Prepare the TMA loads + Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, q_tile_idx); // (Q, D) + Tensor gK = get_local_tile_tensor(mK, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, + kv_len); // (K, D, _) + Tensor gV = get_local_tile_tensor(mV, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, + kv_len); // (K, D, _) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = + tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), + group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + auto [tKgK, tKsK] = + tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{}, group_modes<0, 2>(sK), + group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE) + auto [tVgV, tVsV] = + tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 2>(sV), + group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) + + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int kv_tile_idx = num_kv_tiles - 1; + int swa_begin_kv_tile_idx = 0; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), + /*mcast_mask=*/0), + tKgK(_, kv_tile_idx), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + } + + // Wait for the MMA warpgroups to say that smem_q is ready + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + static_cast(NamedBarriers::kQueryEmpty)); + + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast( + shared_storage.barrier_Q), + /*mcast_mask=*/0), + tQgQ, tQsQ); + } + + // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on + // O. + shared_storage.barrier_O.wait((work_idx + 1) % 2); + + if (lane_predicate) { +#pragma unroll 2 + for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), + /*mcast_mask=*/0), + tKgK(_, kv_tile_idx - 1), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), + /*mcast_mask=*/0), + tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + if (lane_predicate) { + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), + /*mcast_mask=*/0), + tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + scheduler.broadcast_next_work(work_tile_info); + } + + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v) { + int lane_predicate = cute::elect_one_sync(); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0 && lane_predicate) { + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_MAINLOOP_CUH_ diff --git a/include/flashinfer/attention/hopper/mainloop_mma.cuh b/include/flashinfer/attention/hopper/mainloop_mma.cuh new file mode 100644 index 000000000..b98df9e0c --- /dev/null +++ b/include/flashinfer/attention/hopper/mainloop_mma.cuh @@ -0,0 +1,265 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ + +#include +#include +#include +#include + +namespace flashinfer { + +template +CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& variant, + MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_read_k, PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO, AttentionUpdater& attention_updater, + int kv_tile_idx_count, int swa_begin_kv_tile_idx, + int swa_end_kv_tile_idx, int thread_idx, int work_idx, int q_tile_idx, + SharedStorage& shared_storage, const int32_t qo_len, + const int32_t kv_len, const int32_t qo_head_idx, + const int32_t kv_head_idx) { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); + + typename Ktraits::TiledMmaQK tiled_mma_qk; + typename Ktraits::TiledMmaPV tiled_mma_pv; + auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx); + auto threadMmaPV = tiled_mma_pv.get_thread_slice(thread_idx); + + Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ); + Tensor tSrK = threadMmaQK.partition_fragment_B(sK); + Tensor tOrV = threadMmaPV.partition_fragment_B(sVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + int kv_tile_idx = kv_tile_idx_count - 1; + + cutlass::ConsumerToken barrier_token = + static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { + shared_storage.barrier_Q.wait(work_idx % 2); + } + + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS); + WarpScheduler::barrier_arrive(); + + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::NUM_WARPS - 1 && lane_predicate) { +#pragma unroll + for (uint32_t cta_id = 0; cta_id < 1; ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } + } + } + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + + auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; }; + auto col_limit_left = [&](int qo_idx) { + return qo_idx + kv_len - qo_len - mainloop_params.window_left; + }; + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + if constexpr (!CAUSAL) { // Just masking based on col + if (kv_idx >= kv_len) { + tSrS(i) = -math::inf; + } + } else { + if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { + tSrS(i) = -math::inf; + } + } + if constexpr (LEFT_SLIDING_WINDOW) { + if (kv_idx < col_limit_left(qo_idx)) { + tSrS(i) = -math::inf; + } + } + } + } + + attention_updater.update(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())); + + constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0; + // masking loops +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; + ++masking_step, --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS); + if (masking_step > 0) { + attention_updater.rescale_o(tOrO); + } + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + if (kv_idx >= col_limit_right(qo_idx)) { + tSrS(i) = -math::inf; + } + if constexpr (LEFT_SLIDING_WINDOW) { + if (kv_idx < col_limit_left(qo_idx)) { + tSrS(i) = -math::inf; + } + } + } + attention_updater.update(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())), + tOrP); + } + +#pragma unroll 1 + for (; kv_tile_idx > swa_end_kv_tile_idx + 1; --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS); + attention_updater.rescale_o(tOrO); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + // #pragma unroll + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + } + attention_updater.update(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())), + tOrP); + } + + if constexpr (LEFT_SLIDING_WINDOW) { + constexpr int n_swa_masking_steps = cute::ceil_div(CTA_Q, CTA_KV) + 1; +#pragma unroll + for (int masking_step = 0; + masking_step < n_swa_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; + ++masking_step, --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, + tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + attention_updater.rescale_o(tOrO); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, + qo_head_idx, kv_head_idx); + if (kv_idx < col_limit_left(qo_idx)) { + tSrS(i) = -math::inf; + } + } + attention_updater.update(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())), + tOrP); + } + } + + // Tell warp 0 that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + /*id=*/static_cast(NamedBarriers::kQueryEmpty)); + attention_updater.rescale_o(tOrO); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), + tOrO); + attention_updater.finalize(tSrS); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang + ++smem_pipe_read_v; + + attention_updater.rescale_o(tOrO); + return; +} + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ diff --git a/include/flashinfer/attention/hopper/named_barrier.cuh b/include/flashinfer/attention/hopper/named_barrier.cuh new file mode 100644 index 000000000..8ba3b3a08 --- /dev/null +++ b/include/flashinfer/attention/hopper/named_barrier.cuh @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ + +#include + +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" + +namespace flashinfer { + +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + kQueryEmpty = 0, + kValueEmpty = 1, + kWarpSchedulerWG1 = 2, + kWarpSchedulerWG2 = 3, + kWarpSchedulerWG3 = 4, + kPrefetchIndices = 5, +}; + +__device__ __forceinline__ int get_warp_group_barrier_idx(int warp_group_idx) { + return static_cast(NamedBarriers::kWarpSchedulerWG1) + warp_group_idx - 1; +} + +template +__device__ __forceinline__ int get_next_consumer_warp_group_idx() { + static_assert(num_consumer_warp_groups == 2 || num_consumer_warp_groups == 3); + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if constexpr (num_consumer_warp_groups == 2) { + // 1 -> 2, 2 -> 1 + return 3 - warp_group_idx; + } else { + // num_consumer_warp_groups == 3 + // 1 -> 2, 2 -> 3, 3 -> 1 + return (warp_group_idx % 3) + 1; + } +} + +template +__device__ __forceinline__ int get_prev_consumer_warp_group_idx() { + static_assert(num_consumer_warp_groups == 2 || num_consumer_warp_groups == 3); + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if constexpr (num_consumer_warp_groups == 2) { + // 1 -> 2, 2 -> 1 + return 3 - warp_group_idx; + } else { + // num_consumer_warp_groups == 3 + // 1 -> 3, 2 -> 1, 3 -> 2 + return ((warp_group_idx + 1) % 3) + 1; + } +} + +template +struct WarpScheduler { + constexpr static int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static CUTLASS_DEVICE void barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync( + NUM_MMA_THREADS, get_warp_group_barrier_idx(cutlass::canonical_warp_group_idx())); + } + } + + static CUTLASS_DEVICE void barrier_arrive() { + if constexpr (!UseSchedulerBarrier) { + return; + } + static_assert(NUM_MMA_THREADS == 2 * cutlass::NumThreadsPerWarpGroup || + NUM_MMA_THREADS == 3 * cutlass::NumThreadsPerWarpGroup); + if constexpr (NUM_MMA_THREADS == 2 * cutlass::NumThreadsPerWarpGroup) { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, get_warp_group_barrier_idx(get_next_consumer_warp_group_idx<2>())); + } else { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, get_warp_group_barrier_idx(get_next_consumer_warp_group_idx<3>())); + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, get_warp_group_barrier_idx(get_prev_consumer_warp_group_idx<3>())); + } + } + + static CUTLASS_DEVICE void mma_init() { + // Tell producer (warp 0) that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + /*id=*/static_cast(NamedBarriers::kQueryEmpty)); + if constexpr (!UseSchedulerBarrier) { + return; + } + static_assert(NUM_MMA_THREADS == 2 * cutlass::NumThreadsPerWarpGroup || + NUM_MMA_THREADS == 3 * cutlass::NumThreadsPerWarpGroup); + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, /*id=*/static_cast(NamedBarriers::kWarpSchedulerWG1)); + } + if constexpr (NUM_MMA_THREADS == 3 * cutlass::NumThreadsPerWarpGroup) { + if (cutlass::canonical_warp_group_idx() > 2) { + cutlass::arch::NamedBarrier::arrive( + NUM_MMA_THREADS, /*id=*/static_cast(NamedBarriers::kWarpSchedulerWG2)); + } + } + } + +}; // struct WarpScheduler + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ diff --git a/include/flashinfer/attention/hopper/params.cuh b/include/flashinfer/attention/hopper/params.cuh new file mode 100644 index 000000000..fcd80a956 --- /dev/null +++ b/include/flashinfer/attention/hopper/params.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_PARAMS_CUH +#define FLASHINFER_ATTENTION_HOPPER_PARAMS_CUH + +#include + +#include + +namespace flashinfer { + +template +struct SinglePrefillParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + + int qo_len; + int kv_len; + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; + + struct AdditionalParams {}; +}; + +template +struct BatchPrefillRaggedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + int64_t nnz_kv; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; + + struct AdditionalParams {}; +}; + +template +struct BatchPrefillPagedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* kv_indices; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int page_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; + + struct AdditionalParams {}; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_PARAMS_CUH diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh new file mode 100644 index 000000000..708f80f35 --- /dev/null +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../cutlass_utils.cuh" +#include "../../exception.h" +#include "../mask.cuh" +#include "cute/tensor.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "epilogue.cuh" +#include "kernel_traits.cuh" +#include "mainloop.cuh" +#include "mainloop_mma.cuh" +#include "params.cuh" +#include "sparse_mainloop.cuh" +#include "tile_scheduler.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +__global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp, 1) + PrefillWithKVCacheKernel(CUTE_GRID_CONSTANT + typename CollectiveMainloop::Params const mainloop_params, + CUTE_GRID_CONSTANT + typename CollectiveEpilogue::Params const epilogue_params, + CUTE_GRID_CONSTANT + typename TileScheduler::Params const scheduler_params) { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using DTypeO = typename Ktraits::DTypeO; + using DTypeQKAccum = typename Ktraits::DTypeQKAccum; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + using AttentionVariant = typename Ktraits::AttentionVariant; + AttentionVariant variant(mainloop_params); + + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int CTA_Q = Ktraits::CTA_Q; + static constexpr int CTA_KV = Ktraits::CTA_KV; + + static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV; + + using AttentionUpdater = + typename AttentionVariant::template Updater<2 * (2 * CTA_Q / NUM_MMA_THREADS)>; + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto& shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + if constexpr (use_tma_load_kv) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NUM_MMA_THREADS; + } else { + pipeline_params.producer_arv_count = NUM_COPY_THREADS; + pipeline_params.consumer_arv_count = NUM_MMA_THREADS; + } + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(/*num_threads=*/1); + shared_storage.barrier_O.init(/*num_threads=*/1); + } + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_k = [&] { + if constexpr (use_tma_load_kv) { + return MainloopPipeline(shared_storage.pipeline_k, pipeline_params, + /*cluster_shape=*/Shape<_1, _1, _1>{}); + } else { + return MainloopPipeline(shared_storage.pipeline_k, pipeline_params); + } + }(); + + MainloopPipeline pipeline_v = [&] { + if constexpr (use_tma_load_kv) { + return MainloopPipeline(shared_storage.pipeline_v, pipeline_params, + /*cluster_shape=*/Shape<_1, _1, _1>{}); + } else { + return MainloopPipeline(shared_storage.pipeline_v, pipeline_params); + } + }(); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer + // blocks in the Cluster + __syncthreads(); + + if (warp_group_idx == 0) { // Producer + if constexpr (use_tma_load_kv) { + cutlass::arch::warpgroup_reg_dealloc(); + } else { + cutlass::arch::warpgroup_reg_dealloc<72>(); + } + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (!use_tma_load_kv || warp_idx_in_warpgroup == 0) { // Load Q, K, V + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); + + int work_idx = 0; + + TileScheduler scheduler; + for (auto work_tile_info = scheduler.get_initial_work(scheduler_params); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work( + scheduler_params, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + + if (q_tile_idx * CTA_Q >= qo_len) { + continue; + } + int num_kv_tiles = + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + if (num_kv_tiles <= 0) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + ++work_idx; + } + collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); + } + } else { // Consumer + if constexpr (use_tma_load_kv) { + cutlass::arch::warpgroup_reg_alloc(); + } else { + cutlass::arch::warpgroup_reg_alloc(); + } + + TileScheduler scheduler; + // Initialize matmul objects. + typename Ktraits::TiledMmaPV tiled_mma_pv; + + PipelineState smem_pipe_read_k, smem_pipe_read_v; + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v + // (like in Cutlass's gemm) because the read and release pipeline states are always the same. + + CollectiveMainloop::WarpScheduler::mma_init(); + scheduler.init_consumer(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(scheduler_params); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, + work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 2>(TileShape_QKD{})); + AttentionUpdater attention_updater(mainloop_params.sm_scale_log2); + + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + + if (q_tile_idx * CTA_Q >= qo_len) { + continue; + } + int num_kv_tiles = + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + if (num_kv_tiles <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + collective_epilogue.store_zero(epilogue_params, shared_storage, + threadIdx.x - NUM_COPY_THREADS, block_coord); + continue; + } + + int swa_begin_kv_tile_idx = 0; + int swa_end_kv_tile_idx = -1; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx( + mainloop_params.window_left, q_tile_idx, qo_len, kv_len); + swa_end_kv_tile_idx = get_swa_end_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + mma_f16( + mainloop_params, variant, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, + tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx, + threadIdx.x - NUM_COPY_THREADS, work_idx, q_tile_idx, shared_storage, qo_len, kv_len, + qo_head_idx, kv_head_idx); + collective_epilogue.store(epilogue_params, tOrO, attention_updater.get_lse(), shared_storage, + tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, block_coord); + + ++work_idx; + } + collective_epilogue.store_tail(); + } +} + +template +cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched( + SinglePrefillParams& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using TileShape_QKD = typename KernelTraits::TileShape_QKD; + + using CollectiveMainloop = CollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + using Scheduler = SingleTileScheduler; + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.qo_len, params.num_qo_heads, params.head_dim, params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + get_gmem_layout(params.kv_len, params.num_kv_heads, params.head_dim, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(params.kv_len, params.num_kv_heads, params.head_dim, params.v_stride_n, + params.v_stride_h), // layout_V + params.window_left, params.logits_soft_cap, params.sm_scale_log2}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.o_ptr), + get_gmem_layout(params.qo_len, params.num_qo_heads, params.head_dim, params.o_stride_n, + params.o_stride_h), // layout_O + static_cast(params.lse_ptr), + get_lse_gmem_layout(params.qo_len, params.num_qo_heads), // layout_LSE + }); + + int num_tiles_q = cutlass::ceil_div(params.qo_len, KernelTraits::CTA_Q); + // TODO(Zihao): also support kv-head major + typename Scheduler::Arguments scheduler_args = { + num_tiles_q, params.num_qo_heads, params.qo_len, params.kv_len, + cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + auto kernel = + (void*)PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int num_ctas = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(num_ctas); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( + BatchPrefillPagedParams& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using TileShape_QKD = typename KernelTraits::TileShape_QKD; + + using CollectiveMainloop = SparseCollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + using Scheduler = BatchPrefillTileScheduler; + + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + // NOTE(Zihao): nnz was useless here, we can just pass 0 + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, params.head_dim, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, params.head_dim, params.v_stride_n, + params.v_stride_h), // layout_V + params.kv_indices, params.window_left, params.logits_soft_cap, params.sm_scale_log2}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + params.o_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.o_stride_n, + params.o_stride_h), // layout_O + params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE + }); + + typename Scheduler::Arguments scheduler_args = { + params.work_indptr, params.head_indices, + params.qo_tile_indices, params.qo_indptr, + params.kv_indptr, params.qo_lens, + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + auto kernel = + (void*)PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( + BatchPrefillRaggedParams& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using TileShape_QKD = typename KernelTraits::TileShape_QKD; + + using CollectiveMainloop = CollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + using Scheduler = BatchPrefillTileScheduler; + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + // NOTE(Zihao): nnz was useless here, we can just pass 0 + get_gmem_layout(params.nnz_kv, params.num_kv_heads, params.head_dim, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(params.nnz_kv, params.num_kv_heads, params.head_dim, params.v_stride_n, + params.v_stride_h), // layout_V + params.window_left, params.logits_soft_cap, params.sm_scale_log2}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + params.o_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, params.head_dim, params.o_stride_n, + params.o_stride_h), // layout_O + params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE + }); + + // NOTE(Zihao): add support for kv head-major later + typename Scheduler::Arguments scheduler_args = { + params.work_indptr, params.head_indices, + params.qo_tile_indices, params.qo_indptr, + params.kv_indptr, params.qo_lens, + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + auto kernel = + (void*)PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, + cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else { + // HEAD_DIM == 256; + SinglePrefillWithKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +} + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched( + BatchPrefillRaggedParams& params, cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else { + // HEAD_DIM == 256; + BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +} + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched( + BatchPrefillPagedParams& params, cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } else { + // HEAD_DIM == 256; + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later + BatchPrefillWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ diff --git a/include/flashinfer/attention/hopper/sparse_mainloop.cuh b/include/flashinfer/attention/hopper/sparse_mainloop.cuh new file mode 100644 index 000000000..263a8c742 --- /dev/null +++ b/include/flashinfer/attention/hopper/sparse_mainloop.cuh @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ + +#include +#include +#include +#include + +#include "../../math.cuh" +#include "block_sparse_gather.cuh" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "named_barrier.cuh" +#include "utils.cuh" + +namespace flashinfer { + +using namespace cute; + +template +struct SparseCollectiveMainloop { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + static constexpr auto AlignmentKV = 128 / cutlass::sizeof_bits::value; + using AlignmentTypeKV = cute::uint_byte_t(sizeof(DTypeKV)) * AlignmentKV>; + // NOTE(Zihao): use SM80_CP_ASYNC for sparse loading of KV-cache + using GmemCopyAtomKV = cute::Copy_Atom, DTypeKV>; + using GmemTiledCopyKV = + decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< + GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, + cutlass::detail::TagToStrideB_t, + decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; // (N, D, H) + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q + + static constexpr bool USE_TMA_LOAD_KV = false; + static constexpr int NUM_MMA_THREADS = size(typename Ktraits::TiledMmaQK{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + static constexpr uint32_t TmaTransactionBytesQ = + static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + + static constexpr bool UseSchedulerBarrier = + cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + using WarpScheduler = WarpScheduler; + + // Host side kernel arguments + struct Arguments { + DTypeQ const* Q_ptr; + LayoutT layout_Q; + DTypeKV const* K_ptr; + LayoutT layout_K; + DTypeKV const* V_ptr; + LayoutT layout_V; + IdType const* kv_indices; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + // Device side kernel params + struct Params { + LayoutT layout_Q; + LayoutT layout_K; + LayoutT layout_V; + TMA_Q tma_load_Q; + DTypeKV* K_ptr; + DTypeKV* V_ptr; + IdType* kv_indices; + int window_left; + float const logits_soft_cap; + float const sm_scale_log2; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); + TMA_Q tma_load_Q = + make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{}); + return {args.layout_Q, + args.layout_K, + args.layout_V, + tma_load_Q, + const_cast(args.K_ptr), + const_cast(args.V_ptr), + const_cast(args.kv_indices), + args.window_left, + args.logits_soft_cap, + args.sm_scale_log2}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, + const int kv_len) { + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); + if constexpr (CAUSAL) { + num_kv_tiles = std::min(num_kv_tiles, + cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); + } + + return num_kv_tiles; + } + + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, + Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + BlockCoord const& block_coord, int work_idx) { + int thread_idx = threadIdx.x; + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + + // Prepare the TMA loads + Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, q_tile_idx); // (Q, D) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = + tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), + group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int kv_tile_idx = num_kv_tiles - 1; + int swa_begin_kv_tile_idx = 0; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + constexpr int HEAD_DIM = get<2>(TileShape_QKD{}); + constexpr int CTA_KV = get<1>(TileShape_QKD{}); + auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); + + Tensor mK = make_block_sparse_tensor( // (kv_len, D) + make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), + make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_K), indexed_gather); + Tensor mV = make_block_sparse_tensor( // (kv_len, D) + make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), + make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_V), indexed_gather); + + Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) + Tensor gV = local_tile(mV, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) + Tensor cKV = cute::make_identity_tensor(gK.shape()); + + GmemTiledCopyKV gmem_tiled_copy_kv; + auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); + + Tensor tKgK = gmem_thr_copy_kv.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tVgV = gmem_thr_copy_kv.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tKVcKV = gmem_thr_copy_kv.partition_D(cKV); // (CPY, CPY_KV, CPY_D) + Tensor tKVcKVGroup = flatten_1(tKVcKV); // (CPY, (CPY_KV, CPY_D)) + + int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); + auto predicate_fn = [&](auto coords) { + auto s_coords = tKVcKVGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_last_kv_tile_size); + }; + + // load last k-tile + { + pipeline_k.producer_acquire(smem_pipe_write_k); + Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tKsKiGroup = + flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + } + + // load Q tile + if (warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, + static_cast(NamedBarriers::kQueryEmpty)); + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast( + shared_storage.barrier_Q), + /*mcast_mask=*/0), + tQgQ, tQsQ); + } + } + + shared_storage.barrier_O.wait((work_idx + 1) % 2); + + if (kv_tile_idx == swa_begin_kv_tile_idx) { + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsViGroup = + flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } else { + // load second last k-tile and last v-tile + pipeline_k.producer_acquire(smem_pipe_write_k); + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsViGroup = + flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + --kv_tile_idx; + ++smem_pipe_write_v; + + // load remaining k/v tiles +#pragma unroll 2 + for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + pipeline_k.producer_acquire(smem_pipe_write_k); + + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + + // load first v tile + { + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) + copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } + } + + scheduler.broadcast_next_work(work_tile_info); + } + + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v) { + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ diff --git a/include/flashinfer/attention/hopper/tile_scheduler.cuh b/include/flashinfer/attention/hopper/tile_scheduler.cuh new file mode 100644 index 000000000..396102713 --- /dev/null +++ b/include/flashinfer/attention/hopper/tile_scheduler.cuh @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ + +#include "cutlass/arch/barrier.h" +#include "cutlass/fast_math.h" +#include "named_barrier.cuh" + +namespace flashinfer { + +struct SingleTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + int const num_qo_tiles, num_qo_heads, qo_len, kv_len; + cutlass::FastDivmod group_size_fastdiv; + }; + + // Device side kernel params + struct Params { + int const qo_len, kv_len; + cutlass::FastDivmod group_size_fastdiv; + }; + + static Params to_underlying_arguments(Arguments const& args) { + return {args.qo_len, args.kv_len, args.group_size_fastdiv}; + } + + static dim3 get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(args.num_qo_tiles), uint32_t(args.num_qo_heads)}; + } + + struct WorkTileInfo { + int q_tile_idx = 0; + int qo_head_idx = 0; + int kv_head_idx = 0; + bool is_valid_tile = false; + + CUTLASS_DEVICE + bool is_valid(Params const& params) const { return is_valid_tile; } + + CUTLASS_DEVICE + auto get_block_coord(Params const& params) const { + return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, /*qo_indptr=*/0, + /*kv_indptr=*/0, params.qo_len, params.kv_len}; + } + }; + + CUTLASS_DEVICE + SingleTileScheduler() {} + + CUTLASS_DEVICE + WorkTileInfo get_initial_work(Params const& params) const { + int qo_head_idx = blockIdx.y; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {/*q_tile_idx=*/int(blockIdx.x), qo_head_idx, kv_head_idx, /*is_valid_tile*/ true}; + } + + CUTLASS_DEVICE + void init_consumer() const {} + + CUTLASS_DEVICE + void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, + WorkTileInfo const& current_work) const { + return {-1, -1, false}; + } +}; + +template +struct BatchPrefillTileScheduler { + public: + // Host side kernel arguments + struct Arguments { + IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, + *kv_lens; + cutlass::FastDivmod group_size_fastdiv; + }; + + // Device side kernel params + struct Params { + IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, + *kv_lens; + cutlass::FastDivmod group_size_fastdiv; + }; + + static Params to_underlying_arguments(Arguments const& args) { + return {args.work_indptr, args.head_indices, args.qo_tile_indices, args.qo_indptr, + args.kv_indptr, args.qo_lens, args.kv_lens, args.group_size_fastdiv}; + } + + static dim3 get_grid_dim(Arguments const& args, int num_sm) { + return {132U}; // 132 + } + + struct WorkTileInfo { + int q_tile_idx = 0; + int qo_head_idx = 0; + int kv_head_idx = 0; + int qo_indptr = 0; + int kv_indptr = 0; + int qo_len = 0; + int kv_len = 0; + int counter = 0; + int ptr_begin = 0; + int ptr_end = 0; + + CUTLASS_DEVICE + bool is_valid(Params const& params) const { return counter + ptr_begin < ptr_end; } + + CUTLASS_DEVICE + auto get_block_coord(Params const& params) const { + return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, + kv_indptr, qo_len, kv_len}; + } + }; + + CUTLASS_DEVICE + BatchPrefillTileScheduler() {} + + CUTLASS_DEVICE + WorkTileInfo get_initial_work(Params const& params) const { + int ptr_begin = params.work_indptr[blockIdx.x]; + int ptr_end = params.work_indptr[blockIdx.x + 1]; + if (ptr_begin < ptr_end) { + int work_idx = ptr_begin; + int qo_head_idx = params.head_indices[work_idx]; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {params.qo_tile_indices[work_idx], + qo_head_idx, + kv_head_idx, + params.qo_indptr[work_idx], + params.kv_indptr[work_idx], + params.qo_lens[work_idx], + params.kv_lens[work_idx], + 0, + ptr_begin, + ptr_end}; + } else { + return {-1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; + } + } + + CUTLASS_DEVICE + void init_consumer() const {} + + CUTLASS_DEVICE + void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, + WorkTileInfo const& current_work) const { + int work_idx = current_work.ptr_begin + current_work.counter + 1; + if (work_idx < current_work.ptr_end) { + int qo_head_idx = params.head_indices[work_idx]; + int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); + return {params.qo_tile_indices[work_idx], + qo_head_idx, + kv_head_idx, + params.qo_indptr[work_idx], + params.kv_indptr[work_idx], + params.qo_lens[work_idx], + params.kv_lens[work_idx], + current_work.counter + 1, + current_work.ptr_begin, + current_work.ptr_end}; + } else { + return {-1, + -1, + -1, + -1, + -1, + -1, + current_work.counter + 1, + current_work.ptr_begin, + current_work.ptr_end}; + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_ diff --git a/include/flashinfer/attention/hopper/utils.cuh b/include/flashinfer/attention/hopper/utils.cuh new file mode 100644 index 000000000..0441cbd1f --- /dev/null +++ b/include/flashinfer/attention/hopper/utils.cuh @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_UTILS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_UTILS_CUH_ + +#include +#include +#include +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../math.cuh" +#include "../../utils.cuh" +#include "cutlass/fast_math.h" + +namespace flashinfer { + +using namespace cute; + +template +CUTLASS_DEVICE int get_swa_begin_kv_tile_idx(int window_left, int q_tile_idx, const int qo_len, + const int kv_len) { + return std::max((q_tile_idx * CTA_Q + kv_len - qo_len - window_left) / CTA_KV - 1, 0); +} + +template +CUTLASS_DEVICE int get_swa_end_kv_tile_idx(int window_left, int q_tile_idx, const int qo_len, + const int kv_len) { + return std::max(((q_tile_idx + 1) * CTA_Q + kv_len - qo_len - window_left) / CTA_KV, -1); +} + +template +CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) { + Tensor tensor_flatten = cute::flatten(tensor); + return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten); +} + +CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride, + int64_t h_stride) { + return make_layout(make_shape(nnz, head_dim, num_heads), + make_stride(n_stride, cute::_1{}, h_stride)); +} + +CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) { + return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads))); +} + +template +CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)), + make_coord(offset, _0{})); + auto g_sequence = + make_tensor(g_offset.data(), + make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride())); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +template +CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset)); + + auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len), + cute::make_shape(shape<0>(m_tensor)))); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_)); + return g_tensor; +} + +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, +// MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); +}; + +// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, +// MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) + return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), + make_layout(get<2, 1>(l), get<2>(acc_layout))); +}; + +template +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB, + TensorC& tCrC) { + constexpr bool Is_RS = + !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + if constexpr (init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } +} + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_UTILS_CUH_ diff --git a/include/flashinfer/attention/hopper/variants.cuh b/include/flashinfer/attention/hopper/variants.cuh new file mode 100644 index 000000000..75d7c7bc9 --- /dev/null +++ b/include/flashinfer/attention/hopper/variants.cuh @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// NOTE(Zihao): we should merge this with include/flashinfer/attention/variants.cuh in the future +#ifndef FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ +#include + +#include "../../math.cuh" +#include "attention_updater.cuh" + +namespace flashinfer { + +struct StandardAttention { + template + using Updater = OnlineSoftmaxWithScale; + + template + __device__ StandardAttention(const ParamsT& params) {} + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return logits; + } +}; + +struct LogitsSoftCap { + float pre_tanh_scale; + float post_tanh_scale; + template + using Updater = OnlineSoftmaxWithoutScale; + + template + __device__ LogitsSoftCap(const ParamsT& params) { + pre_tanh_scale = (params.sm_scale_log2 * math::loge2) * math::ptx_rcp(params.logits_soft_cap); + post_tanh_scale = math::log2e * params.logits_soft_cap; + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return math::tanh(logits * pre_tanh_scale) * post_tanh_scale; + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 64cf106fb..f80231714 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -22,13 +22,13 @@ #include #include #include -#include #include #include "../allocator.h" #include "../exception.h" #include "../pos_enc.cuh" #include "../utils.cuh" +#include "heap.h" namespace flashinfer { @@ -720,5 +720,196 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i return cudaSuccess; } +inline float cost_function(int qo_len, int kv_len, int group_size) { + return 2 * float(qo_len) * float(group_size) + kv_len; +} + +template +std::vector flatten(const std::vector>& vec, int size_after_flatten) { + std::vector result; + result.reserve(size_after_flatten); + for (const auto& inner_vec : vec) { + result.insert(result.end(), inner_vec.begin(), inner_vec.end()); + } + return std::move(result); +} + +struct PrefillPlanSM90Info { + int64_t qo_tile_indices_offset; + int64_t qo_indptr_offset; + int64_t kv_indptr_offset; + int64_t qo_len_offset; + int64_t kv_len_offset; + int64_t head_indices_offset; + int64_t work_indptr_offset; + + PrefillPlanSM90Info() + : qo_tile_indices_offset(0), + qo_indptr_offset(0), + kv_indptr_offset(0), + qo_len_offset(0), + kv_len_offset(0), + head_indices_offset(0), + work_indptr_offset(0) {} + + // convert PrefillPlanSM90Info to std::vector + std::vector ToVector() const { + return {qo_tile_indices_offset, qo_indptr_offset, kv_indptr_offset, qo_len_offset, + kv_len_offset, head_indices_offset, work_indptr_offset}; + } + + // From std::vector to PrefillPlanSM90Info + void FromVector(const std::vector& vec) { + if (vec.size() != 7) { + std::ostringstream err_msg; + err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be 8, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + qo_tile_indices_offset = vec[0]; + qo_indptr_offset = vec[1]; + kv_indptr_offset = vec[2]; + qo_len_offset = vec[3]; + kv_len_offset = vec[4]; + head_indices_offset = vec[5]; + work_indptr_offset = vec[6]; + } +}; + +template +cudaError_t PrefillSM90Plan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PrefillPlanSM90Info& plan_info, + IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, uint32_t page_size, bool causal, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + std::vector> idx_qo_kv_len_vec; + for (uint32_t i = 0; i < batch_size; ++i) { + int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; + int kv_len = kv_len_arr_h[i]; + if (kv_len < 0) { + std::ostringstream err_msg; + err_msg << "kv_len[" << i << "]" << kv_len << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + if (qo_len < 0) { + std::ostringstream err_msg; + err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" + << qo_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + idx_qo_kv_len_vec.push_back({i, qo_len, kv_len}); + } + + std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(), + [](const auto& a, const auto& b) { return std::get<2>(a) > std::get<2>(b); }); + int cta_tile_q = 128; + if (head_dim == 64) { + cta_tile_q = 192; + } + + const int num_sm90_ctas = 132; // for sm90, the num_ctas is fixed + + CTACostHeap cta_cost_heap(num_sm90_ctas); + std::vector> cta_qo_tile_indices(num_sm90_ctas, std::vector()), + cta_qo_indptr(num_sm90_ctas, std::vector()), + cta_kv_indptr(num_sm90_ctas, std::vector()), + cta_qo_len(num_sm90_ctas, std::vector()), + cta_kv_len(num_sm90_ctas, std::vector()), + cta_head_indices(num_sm90_ctas, std::vector()); + + for (int qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) { + int num_qo_tiles = ceil_div(qo_len, cta_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + auto [cta_idx, accum_cost] = cta_cost_heap.pop(); + // NOTE(Zihao): our current FA3 implementation do not fuse query and group heads + // so the group_size in cost_function is always 1 + cta_cost_heap.insert( + {cta_idx, + accum_cost + cost_function(cta_tile_q, + causal + ? kv_len - (num_qo_tiles - qo_tile_idx - 1) * cta_tile_q + : kv_len, + /*group_size=*/1)}); + cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx); + cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]); + cta_qo_len[cta_idx].push_back(qo_len); + cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); + cta_kv_len[cta_idx].push_back(kv_len); + cta_head_indices[cta_idx].push_back(qo_head_idx); + } + } + } + + std::vector work_indptr_vec(num_sm90_ctas + 1, 0); + for (uint32_t i = 0; i < num_sm90_ctas; ++i) { + work_indptr_vec[i + 1] = work_indptr_vec[i] + cta_qo_tile_indices[i].size(); + } + IdType total_num_works = work_indptr_vec[num_sm90_ctas]; + auto qo_tile_indices_vec = flatten(cta_qo_tile_indices, total_num_works); + auto qo_indptr_vec = flatten(cta_qo_indptr, total_num_works); + auto kv_indptr_vec = flatten(cta_kv_indptr, total_num_works); + auto qo_len_vec = flatten(cta_qo_len, total_num_works); + auto kv_len_vec = flatten(cta_kv_len, total_num_works); + auto head_indices_vec = flatten(cta_head_indices, total_num_works); + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + const int max_total_num_works = 1048576; + if (total_num_works > max_total_num_works) { + std::ostringstream err_msg; + err_msg << "total_num_works " << total_num_works << " should be less than " + << max_total_num_works; + FLASHINFER_ERROR(err_msg.str()); + } + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_tile_indices"); + plan_info.qo_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_offset"); + plan_info.kv_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_kv_offset"); + plan_info.qo_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, + 16, "batch_prefill_sm90_qo_len"); + plan_info.kv_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, + 16, "batch_prefill_sm90_kv_len"); + plan_info.head_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_head_indices"); + plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (num_sm90_ctas + 1), 16, "batch_prefill_sm90_work_indptr"); + + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* qo_offset_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_indptr_offset); + IdType* kv_offset_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_indptr_offset); + IdType* qo_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_len_offset); + IdType* kv_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_len_offset); + IdType* head_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.head_indices_offset); + IdType* work_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.work_indptr_offset); + + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(qo_indptr_vec.begin(), qo_indptr_vec.end(), qo_offset_h); + std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), kv_offset_h); + std::copy(qo_len_vec.begin(), qo_len_vec.end(), qo_len_h); + std::copy(kv_len_vec.begin(), kv_len_vec.end(), kv_len_h); + std::copy(head_indices_vec.begin(), head_indices_vec.end(), head_indices_h); + std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), work_indptr_h); + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + return cudaSuccess; +} + } // namespace flashinfer #endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_ diff --git a/include/flashinfer/cutlass_utils.cuh b/include/flashinfer/cutlass_utils.cuh index f6d3ef03a..5102756af 100644 --- a/include/flashinfer/cutlass_utils.cuh +++ b/include/flashinfer/cutlass_utils.cuh @@ -44,29 +44,37 @@ namespace flashinfer { template struct cutlass_dtype { - using value = T; + using type = T; }; template <> struct cutlass_dtype { - using value = cutlass::half_t; + using type = cutlass::half_t; }; template <> struct cutlass_dtype { - using value = cutlass::bfloat16_t; + using type = cutlass::bfloat16_t; }; template <> struct cutlass_dtype<__nv_fp8_e4m3> { - using value = cutlass::float_e4m3_t; + using type = cutlass::float_e4m3_t; }; template <> struct cutlass_dtype<__nv_fp8_e5m2> { - using value = cutlass::float_e5m2_t; + using type = cutlass::float_e5m2_t; }; +template +using cutlass_dtype_t = typename cutlass_dtype::type; + +template +void compileTimeDebug(T&&) { + static_assert(sizeof(T) == 0, "Compile time debug"); +} + } // namespace flashinfer #endif // FLASHINFER_CUTLASS_UTILS_CUH_ diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index fa256be3f..04034f194 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -16,6 +16,8 @@ #ifndef FLASHINFER_PAGE_CUH_ #define FLASHINFER_PAGE_CUH_ +#include + #include #include "fastdiv.cuh" @@ -280,6 +282,55 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, } } +template +__global__ void BlockSparseIndicesToVectorSparseOffsetsKernel( + IdType* __restrict__ block_sparse_indices, IdType* __restrict__ block_sparse_indptr, + IdType* __restrict__ vector_sparse_offsets, IdType* __restrict__ vector_sparse_indptr, + IdType* __restrict__ kv_lens, const uint32_t stride_block, const uint32_t stride_n, + const uint32_t batch_size, const uint_fastdiv block_size) { +#pragma unroll 1 + for (int b = blockIdx.x; b < batch_size; ++b) { +#pragma unroll 2 + for (int pos = threadIdx.x; pos < kv_lens[b]; pos += blockDim.x) { + uint32_t q, r; + block_size.divmod(pos, q, r); + vector_sparse_offsets[vector_sparse_indptr[b] + pos] = + block_sparse_indices[block_sparse_indptr[b] + q] * stride_block + r * stride_n; + } + } +} + +template +cudaError_t BlockSparseIndicesToVectorSparseOffset( + IdType* block_sparse_indices, IdType* block_sparse_indptr, IdType* vector_sparse_offsets, + IdType* vector_sparse_indptr, IdType* kv_lens, const int64_t stride_block, + const int64_t stride_n, const int64_t batch_size, const uint32_t block_size, + cudaStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + + uint32_t num_threads = 512; + + uint_fastdiv block_size_fastdiv(block_size); + + auto kernel = BlockSparseIndicesToVectorSparseOffsetsKernel; + void* args[] = {(void*)&block_sparse_indices, + (void*)&block_sparse_indptr, + (void*)&vector_sparse_offsets, + (void*)&vector_sparse_indptr, + (void*)&kv_lens, + (void*)&stride_block, + (void*)&stride_n, + (void*)&batch_size, + (void*)&block_size_fastdiv}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, num_sms, num_threads, args, 0, stream)); + + return cudaSuccess; +} + /*! * \brief Append new keys/values to the paged key-value cache in the decode phase * \tparam DType The data type of the key-value cache diff --git a/licenses/LICENSE.cutlass.txt b/licenses/LICENSE.cutlass.txt new file mode 100644 index 000000000..525500841 --- /dev/null +++ b/licenses/LICENSE.cutlass.txt @@ -0,0 +1,27 @@ +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.flashattention3.txt b/licenses/LICENSE.flashattention3.txt new file mode 100644 index 000000000..5860e4b33 --- /dev/null +++ b/licenses/LICENSE.flashattention3.txt @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/setup.py b/setup.py index 9bbf9895a..7cb292f2a 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,9 @@ head_dims = list(map(int, head_dims)) pos_encoding_modes = list(map(int, pos_encoding_modes)) +pos_encoding_modes_sm90 = [mode for mode in pos_encoding_modes if mode != 2] allow_fp16_qk_reductions = list(map(int, allow_fp16_qk_reductions)) +allow_fp16_qk_reductions_sm90 = [mode for mode in allow_fp16_qk_reductions if mode != 1] mask_modes = list(map(int, mask_modes)) enable_aot = os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1" @@ -66,6 +68,7 @@ def generate_cuda() -> None: try: # no aot_build_utils in sdist sys.path.append(str(root)) from aot_build_utils.generate import get_instantiation_cu + from aot_build_utils.generate_sm90 import get_sm90_instantiation_cu except ImportError: return @@ -79,6 +82,15 @@ def generate_cuda() -> None: enable_bf16=enable_bf16, enable_fp8=enable_fp8, ) + ) + get_sm90_instantiation_cu( + argparse.Namespace( + path=gen_dir, + head_dims=head_dims, + pos_encoding_modes=pos_encoding_modes_sm90, + allow_fp16_qk_reductions=allow_fp16_qk_reductions_sm90, + mask_modes=mask_modes, + enable_bf16=enable_bf16, + ) ) aot_config_str = f"""prebuilt_ops_uri = set({aot_kernel_uris})""" (root / "flashinfer" / "jit" / "aot_config.py").write_text(aot_config_str) @@ -185,10 +197,15 @@ def __init__(self, *args, **kwargs) -> None: ] kernel_sm90_sources = [ "csrc/group_gemm_sm90.cu", - "csrc/flashinfer_gemm_sm90_ops.cu", + "csrc/single_prefill_sm90.cu", + "csrc/batch_prefill_sm90.cu", + "csrc/flashinfer_ops_sm90.cu", ] decode_sources = list(gen_dir.glob("*decode_head*.cu")) - prefill_sources = list(gen_dir.glob("*prefill_head*.cu")) + prefill_sources = [ + f for f in gen_dir.glob("*prefill_head*.cu") if "_sm90" not in f.name + ] + prefill_sm90_sources = list(gen_dir.glob("*prefill_head*_sm90.cu")) ext_modules = [ torch_cpp_ext.CUDAExtension( name="flashinfer._kernels", @@ -202,7 +219,7 @@ def __init__(self, *args, **kwargs) -> None: ), torch_cpp_ext.CUDAExtension( name="flashinfer._kernels_sm90", - sources=kernel_sm90_sources, + sources=kernel_sm90_sources + prefill_sm90_sources, include_dirs=include_dirs, extra_compile_args={ "cxx": cxx_flags, diff --git a/tests/test_block_sparse_indices_to_vector_sparse_offsets.py b/tests/test_block_sparse_indices_to_vector_sparse_offsets.py new file mode 100644 index 000000000..cf2ef003c --- /dev/null +++ b/tests/test_block_sparse_indices_to_vector_sparse_offsets.py @@ -0,0 +1,84 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +import flashinfer.page + + +@pytest.mark.parametrize("batch_size", [1, 7, 19, 128, 517]) +@pytest.mark.parametrize("kv_len", [97, 199, 2049, 31791]) +@pytest.mark.parametrize("block_size", [1, 3, 7, 16, 64, 79, 128]) +@pytest.mark.parametrize("stride_block", [128]) +@pytest.mark.parametrize("stride_n", [1]) +def test_block_sparse_indices_to_vector_sparse_offsets( + batch_size, kv_len, block_size, stride_block, stride_n +): + if batch_size * kv_len > 1048576: + pytest.skip("skip large test") + num_blocks_per_row = (kv_len + block_size - 1) // block_size + + block_sparse_indices = torch.arange( + batch_size * num_blocks_per_row, device="cuda", dtype=torch.int32 + ) + block_sparse_indptr = torch.arange( + 0, + batch_size * num_blocks_per_row + 1, + num_blocks_per_row, + device="cuda", + dtype=torch.int32, + ) + vector_sparse_offsets_buf = torch.zeros( + batch_size * kv_len, device="cuda", dtype=torch.int32 + ) + vector_sparse_indptr = torch.arange( + 0, batch_size * kv_len + 1, kv_len, device="cuda", dtype=torch.int32 + ) + kv_lens = torch.full((batch_size,), kv_len, device="cuda", dtype=torch.int32) + + vector_sparse_offsets = ( + flashinfer.page.block_sparse_indices_to_vector_sparse_offsets( + block_sparse_indices, + block_sparse_indptr, + vector_sparse_offsets_buf, + vector_sparse_indptr, + kv_lens, + stride_block, + stride_n, + block_size, + ) + ) + + # Check that the output is correct + for i in range(batch_size): + indices_i = block_sparse_indices[ + i * num_blocks_per_row : (i + 1) * num_blocks_per_row + ].cpu() + output_i = vector_sparse_offsets[ + vector_sparse_indptr[i] : vector_sparse_indptr[i + 1] + ].cpu() + + output_ref_i = ( + indices_i[torch.arange(0, kv_len, dtype=torch.int32) // block_size] + * stride_block + + (torch.arange(0, kv_len, dtype=torch.int32) % block_size) * stride_n + ) + torch.testing.assert_close(output_i, output_ref_i) + + +if __name__ == "__main__": + pass diff --git a/tests/test_hopper.py b/tests/test_hopper.py new file mode 100644 index 000000000..1fbad5ffd --- /dev/null +++ b/tests/test_hopper.py @@ -0,0 +1,218 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +import flashinfer + + +@pytest.mark.parametrize("seq_len", [11, 99, 1763, 9999, 32767]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_single_prefill( + seq_len, num_qo_heads, num_kv_heads, causal, head_dim, logits_soft_cap +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + torch.random.manual_seed(123) + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + + o_sm80, lse_sm80 = flashinfer.single_prefill_with_kv_cache_return_lse( + q, + k, + v, + causal=causal, + logits_soft_cap=logits_soft_cap, + backend="fa2", + ) + + o_sm90, lse_sm90 = flashinfer.single_prefill_with_kv_cache_return_lse( + q, k, v, causal=causal, logits_soft_cap=logits_soft_cap, backend="fa3" + ) + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [11, 99, 1763, 9999, 32767]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("head_dim", [128]) # [64, 128, 256]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_batch_ragged_prefill( + batch_size, seq_len, num_qo_heads, num_kv_heads, causal, head_dim, logits_soft_cap +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + torch.random.manual_seed(42) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + workspace_buffer = torch.empty( + 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" + ) + + wrapper_sm80 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, backend="fa2" + ) + + wrapper_sm90 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, backend="fa3" + ) + + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + + wrapper_sm80.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, k, v) + + wrapper_sm90.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, k, v) + + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [11, 12, 99, 1763, 9999, 32767]) +@pytest.mark.parametrize("page_size", [1]) # [1, 16]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +def test_batch_paged_prefill( + batch_size, + seq_len, + page_size, + num_qo_heads, + num_kv_heads, + causal, + head_dim, + logits_soft_cap, +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + torch.random.manual_seed(42) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + num_pages_per_request = (seq_len + page_size - 1) // page_size + k = torch.randn( + batch_size * num_pages_per_request, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + v = torch.randn( + batch_size * num_pages_per_request, + page_size, + num_kv_heads, + head_dim, + dtype=torch.half, + device="cuda", + ) + + workspace_buffer = torch.empty( + 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" + ) + + wrapper_sm80 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, backend="fa2" + ) + + wrapper_sm90 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, backend="fa3" + ) + + last_page_len = seq_len - (num_pages_per_request - 1) * page_size + qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() + kv_indptr = torch.arange( + 0, batch_size * num_pages_per_request + 1, num_pages_per_request + ).int() + kv_indices = torch.arange(0, batch_size * num_pages_per_request).int() + last_page_len = torch.full((batch_size,), last_page_len, dtype=torch.int32) + + wrapper_sm80.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, (k, v)) + + wrapper_sm90.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + ) + o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, (k, v)) + + torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + # test_batch_prefill(14, 64, 32, 32, False, 128) + # test_batch_prefill(1, 32767, 8, 8, True, 128) + # test_single_prefill(64, 1, 1, False, 256) + # test_batch_paged_prefill(2, 32768, 1, 1, 1, False, 128) + test_batch_paged_prefill(16, 32767, 1, 8, 8, True, 128)