From 611dfd1b0d9e0c126792282b4ac6984d61246c2b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sun, 26 Nov 2023 14:06:01 -0800 Subject: [PATCH] Vendor punica kernels (#63) --- .github/workflows/build.yaml | 2 + .gitmodules | 3 + Dockerfile | 10 +- server/Makefile-punica | 13 - server/lorax_server/utils/layers.py | 2 +- server/lorax_server/utils/sgmv.py | 36 + server/punica_kernels/README.md | 1 + .../punica_kernels/bgmv/bgmv_all.cu | 5 + .../punica_kernels/bgmv/bgmv_config.h | 38 + .../punica_kernels/bgmv/bgmv_impl.cuh | 217 +++ .../punica_kernels/flashinfer/.clang-format | 5 + .../punica_kernels/flashinfer/cp_async.cuh | 124 ++ .../punica_kernels/flashinfer/decode.cuh | 1135 +++++++++++++ .../punica_kernels/flashinfer/layout.cuh | 108 ++ .../punica_kernels/flashinfer/math.cuh | 48 + .../punica_kernels/flashinfer/mma.cuh | 153 ++ .../punica_kernels/flashinfer/page.cuh | 366 +++++ .../flashinfer/permuted_smem.cuh | 99 ++ .../punica_kernels/flashinfer/prefill.cuh | 932 +++++++++++ .../punica_kernels/flashinfer/rope.cuh | 51 + .../punica_kernels/flashinfer/state.cuh | 122 ++ .../punica_kernels/flashinfer/utils.cuh | 120 ++ .../punica_kernels/flashinfer/vec_dtypes.cuh | 1420 +++++++++++++++++ .../flashinfer_adapter/flashinfer_all.cu | 89 ++ .../flashinfer_adapter/flashinfer_config.h | 30 + .../punica_kernels/punica_ops.cc | 403 +++++ .../punica_kernels/rms_norm/rms_norm.h | 4 + .../rms_norm/rms_norm_cutlass.cu | 189 +++ .../punica_kernels/punica_kernels/sgmv/sgmv.h | 5 + .../punica_kernels/sgmv/sgmv_cutlass.cu | 12 + .../punica_kernels/sgmv/sgmv_cutlass.cuh | 153 ++ .../sgmv_flashinfer/cp_async.cuh | 119 ++ .../punica_kernels/sgmv_flashinfer/mma.cuh | 123 ++ .../sgmv_flashinfer/permuted_smem.cuh | 83 + .../sgmv_flashinfer/sgmv_all.cu | 67 + .../sgmv_flashinfer/sgmv_config.h | 15 + .../sgmv_flashinfer/sgmv_flashinfer.cuh | 312 ++++ .../sgmv_flashinfer/vec_dtypes.cuh | 1420 +++++++++++++++++ server/punica_kernels/setup.py | 42 + server/punica_kernels/third_party/cutlass | 1 + 40 files changed, 8057 insertions(+), 20 deletions(-) create mode 100644 .gitmodules delete mode 100644 server/Makefile-punica create mode 100644 server/lorax_server/utils/sgmv.py create mode 100644 server/punica_kernels/README.md create mode 100644 server/punica_kernels/punica_kernels/bgmv/bgmv_all.cu create mode 100644 server/punica_kernels/punica_kernels/bgmv/bgmv_config.h create mode 100644 server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/.clang-format create mode 100644 server/punica_kernels/punica_kernels/flashinfer/cp_async.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/decode.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/layout.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/math.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/mma.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/page.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/permuted_smem.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/prefill.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/rope.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/state.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/utils.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer/vec_dtypes.cuh create mode 100644 server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_all.cu create mode 100644 server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_config.h create mode 100644 server/punica_kernels/punica_kernels/punica_ops.cc create mode 100644 server/punica_kernels/punica_kernels/rms_norm/rms_norm.h create mode 100644 server/punica_kernels/punica_kernels/rms_norm/rms_norm_cutlass.cu create mode 100644 server/punica_kernels/punica_kernels/sgmv/sgmv.h create mode 100644 server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu create mode 100644 server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/cp_async.cuh create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/mma.cuh create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/permuted_smem.cuh create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh create mode 100644 server/punica_kernels/punica_kernels/sgmv_flashinfer/vec_dtypes.cuh create mode 100644 server/punica_kernels/setup.py create mode 160000 server/punica_kernels/third_party/cutlass diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 06ba18e95..49030f7ea 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -30,6 +30,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 + with: + submodules: recursive - name: Docker meta id: meta diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..3b3c4bc90 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "server/punica_kernels/third_party/cutlass"] + path = server/punica_kernels/third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/Dockerfile b/Dockerfile index 4b29b8f93..5c78ac040 100644 --- a/Dockerfile +++ b/Dockerfile @@ -144,15 +144,13 @@ RUN make build-vllm # Build punica CUDA kernels FROM kernel-builder as punica-builder -RUN /opt/conda/bin/conda install packaging - WORKDIR /usr/src -COPY server/Makefile-punica Makefile +COPY server/punica_kernels/ . -ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" # Build specific version of punica -RUN make build-punica +ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" +RUN python setup.py build # Text Generation Inference base image FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base @@ -195,7 +193,7 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 / COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy builds artifacts from punica builder -COPY --from=punica-builder /usr/src/punica/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +COPY --from=punica-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/server/Makefile-punica b/server/Makefile-punica deleted file mode 100644 index 71471e975..000000000 --- a/server/Makefile-punica +++ /dev/null @@ -1,13 +0,0 @@ -punica_commit := 5ccb1d62ede179bab6c91dfb2f6f320cc1c6b76d - -punica: - # Clone punica - git clone https://github.com/predibase/punica.git --recurse - -build-punica: punica - cd punica && git fetch && git checkout $(punica_commit) - cd punica && python setup.py build - -install-punica: build-punica - pip uninstall punica -y || true - cd punica && python setup.py install \ No newline at end of file diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index e19746246..65043b39b 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -19,7 +19,7 @@ from accelerate import init_empty_weights try: - from punica.ops import add_lora_sgmv_cutlass + from lorax_server.utils.sgmv import add_lora_sgmv_cutlass HAS_SGMV = True except ImportError: warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py new file mode 100644 index 000000000..ffcdfb1cc --- /dev/null +++ b/server/lorax_server/utils/sgmv.py @@ -0,0 +1,36 @@ +import torch + + +import punica_kernels as _kernels + + +# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py +def add_lora_sgmv_cutlass( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s: torch.IntTensor, + layer_idx: int, + lora_rank: int, +): + """ + Semantics: + y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]) @ deref(wb_ptr[i]) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, H1, R]`. + wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H2]`. + s: Shape: `[S+1]`, DType: torch.int32. Indptr of the weight matrices.\ + `s[0] == 0`, `s[-1] == B`. + layer_idx: Layer index of the weight matrices. + """ + tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_cutlass(v, x, wa_ptr, s, tmp, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s, tmp, layer_idx) diff --git a/server/punica_kernels/README.md b/server/punica_kernels/README.md new file mode 100644 index 000000000..8cedbdf8a --- /dev/null +++ b/server/punica_kernels/README.md @@ -0,0 +1 @@ +These kernels are forked from the [Punica](https://github.com/punica-ai/punica) project. \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_all.cu b/server/punica_kernels/punica_kernels/bgmv/bgmv_all.cu new file mode 100644 index 000000000..d76fc97eb --- /dev/null +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_all.cu @@ -0,0 +1,5 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16) diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h new file mode 100644 index 000000000..26edcf486 --- /dev/null +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -0,0 +1,38 @@ +#pragma once + +template +void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, + const T* __restrict__ W, const int64_t* __restrict__ indicies, + int64_t batch_size, int64_t num_layers, int64_t layer_idx, + float scale); + +// clang-format off + +#define FOR_BGMV_WIDE(f, T, narrow) \ + f(T, narrow, 768) \ + f(T, narrow, 1024) \ + f(T, narrow, 2048) \ + f(T, narrow, 2560) \ + f(T, narrow, 3072) \ + f(T, narrow, 4096) \ + f(T, narrow, 5120) \ + f(T, narrow, 7168) \ + f(T, narrow, 8192) \ + f(T, narrow, 9216) \ + f(T, narrow, 10240) \ + f(T, narrow, 11008) \ + f(T, narrow, 12288) \ + f(T, narrow, 13824) \ + f(T, narrow, 16384) \ + f(T, narrow, 20480) \ + f(T, narrow, 28672) \ + f(T, narrow, 36864) \ + f(T, narrow, 49152) \ + +#define FOR_BGMV_WIDE_NARROW(f, T) \ + FOR_BGMV_WIDE(f, T, 8) \ + FOR_BGMV_WIDE(f, T, 16) \ + FOR_BGMV_WIDE(f, T, 32) \ + FOR_BGMV_WIDE(f, T, 64) + +// clang-format on diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh b/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh new file mode 100644 index 000000000..4164e94c6 --- /dev/null +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh @@ -0,0 +1,217 @@ +#pragma once + +#include +#include + +#include +#include + +#include "../flashinfer/vec_dtypes.cuh" + +namespace cg = cooperative_groups; + +// nthrs = (32, 4) +template +__global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, + const T* __restrict__ W, + const int64_t* __restrict__ indicies, + int64_t num_layers, int64_t layer_idx, + float scale) { + auto block = cg::this_thread_block(); + size_t j = blockIdx.x; + size_t batch_idx = blockIdx.y; + constexpr size_t vec_size = 16 / sizeof(T); + constexpr size_t tx = 32; + constexpr size_t ty = 4; + constexpr size_t num_pipeline_stages = 2; + constexpr size_t tile_size = tx * ty * vec_size; + __shared__ T W_shared[num_pipeline_stages * tile_size]; + __shared__ T X_shared[num_pipeline_stages * tile_size]; + __shared__ float y_warpwise[ty]; + + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + auto pipe = cuda::make_pipeline(); + + // pipeline load W/X and compute WX; + pipe.producer_acquire(); + cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t<16>(16), pipe); + cuda::memcpy_async( + X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t<16>(16), pipe); + pipe.producer_commit(); + size_t copy_idx, compute_idx; + float y = 0.f; + flashinfer::vec_t x_vec, w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; + ++tile_idx) { + copy_idx = tile_idx % num_pipeline_stages; + // pipeline stage: async copy W fragment + pipe.producer_acquire(); + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t<16>(16), pipe); + cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t<16>(16), pipe); + } + pipe.producer_commit(); + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // pipeline stage: compute WX + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = sum; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + } + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // final pipeline stage + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + + // write Y; + if (block.thread_rank() == 0) { + Y[batch_idx * feat_out + j] += y; + } +} + +// nthrs = (2, 16, 4) +template +__global__ void bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X, + const T* __restrict__ W, + const int64_t* __restrict__ indicies, + int64_t num_layers, int64_t layer_idx, + float scale) { + auto block = cg::this_thread_block(); + constexpr size_t vec_size = 16 / sizeof(T); + constexpr size_t tx = feat_in / vec_size; + static_assert(feat_in % vec_size == 0); + constexpr size_t ty = 32 / tx; + static_assert(32 % tx == 0); + constexpr size_t tz = 4; + size_t tile_idx = blockIdx.x; + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + // load X; + flashinfer::vec_t x_vec; + x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); + + // load W; + flashinfer::vec_t w_vec; + w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + block.thread_rank() * vec_size); + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } + + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += g.shfl_down(sum, offset); + } + sum = g.shfl(sum, 0); + + if (threadIdx.x == 0) { + Y[batch_idx * feat_out + tile_idx * (tz * ty) + threadIdx.z * ty + + threadIdx.y] += sum; + } +} + +template +void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X, + const T* __restrict__ W, const int64_t* __restrict__ indicies, + int64_t batch_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t vec_size = 16 / sizeof(T); + if constexpr (feat_in < feat_out) { + size_t tx = feat_in / vec_size; + size_t ty = 32 / tx; + size_t tz = 4; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, num_layers, layer_idx, scale); + } else { + assert(feat_in % (vec_size * 32) == 0); + dim3 nblks(feat_out, batch_size); + dim3 nthrs(32, 4); + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, num_layers, layer_idx, scale); + } +} + +#define INST_BGMV(feat_in, feat_out, T) \ + template void bgmv_kernel( \ + T* __restrict__ Y, const T* __restrict__ X, const T* __restrict__ W, \ + const int64_t* __restrict__ indicies, int64_t batch_size, \ + int64_t num_layers, int64_t layer_idx, float scale); + +#define INST_BGMV_TWOSIDE(T, narrow, wide) \ + INST_BGMV(narrow, wide, T) \ + INST_BGMV(wide, narrow, T) diff --git a/server/punica_kernels/punica_kernels/flashinfer/.clang-format b/server/punica_kernels/punica_kernels/flashinfer/.clang-format new file mode 100644 index 000000000..9c656d3fb --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/.clang-format @@ -0,0 +1,5 @@ +# https://github.com/yzh119/flashinfer/blob/main/.clang-format +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 100 +PointerAlignment: Left diff --git a/server/punica_kernels/punica_kernels/flashinfer/cp_async.cuh b/server/punica_kernels/punica_kernels/flashinfer/cp_async.cuh new file mode 100644 index 000000000..f3974eb92 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/cp_async.cuh @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_CP_ASYNC_CUH_ +#define FLASHINFER_CP_ASYNC_CUH_ + +#include + +namespace flashinfer { + +namespace cp_async { + +__device__ __forceinline__ void commit_group() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } +#else + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, + bool predicate) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } +#else + if (predicate) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128 || num_bits == 256, + "num_bits must be 128 or 256"); + if constexpr (num_bits == 128) { + load_128b(smem_ptr, gmem_ptr); + } else { + load_128b(smem_ptr, gmem_ptr); + load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T)); + } +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 256, + "num_bits must be 128 or 256"); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + pred_load_128b(smem_ptr + 16 / sizeof(T), + gmem_ptr + 16 / sizeof(T), predicate); + } +} + +} // namespace cp_async + +} // namespace flashinfer + +#endif // FLASHINFER_CP_ASYNC_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/decode.cuh b/server/punica_kernels/punica_kernels/flashinfer/decode.cuh new file mode 100644 index 000000000..6783e37f9 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/decode.cuh @@ -0,0 +1,1135 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_DECODE_CUH_ +#define FLASHINFER_DECODE_CUH_ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "cp_async.cuh" +#include "layout.cuh" +#include "math.cuh" +#include "page.cuh" +#include "rope.cuh" +#include "state.cuh" +#include "utils.cuh" +#include "vec_dtypes.cuh" + +namespace flashinfer { + +namespace cg = cooperative_groups; + +namespace { + +/*! + * \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim], + * return thread-local vector + * \tparam vec_size A template integer indicates the vector size used + * in the kernel + * \tparam bdx A template integer indicates the blockDim.x + * \tparam T A template type indicates the x data type + * \param x A pointer to the start of x data + * \param freq A vector of float indicates the thread-local rope frequency + * \param offset A integer indicates the offset of the position in RoPE + */ +template +__device__ __forceinline__ vec_t apply_llama_rope( + const T* x, const vec_t& freq, uint32_t offset) { + constexpr uint32_t head_dim = vec_size * bdx; + vec_t permuted_vec, vec; + vec.cast_load(x + threadIdx.x * vec_size); + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < head_dim / 2) + ? threadIdx.x * vec_size + head_dim / 2 + : threadIdx.x * vec_size - head_dim / 2)); + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = vec[i] * cos + ((threadIdx.x * vec_size < head_dim / 2) + ? -permuted_vec[i] + : permuted_vec[i]) * + sin; + } + return vec; +} + +/*! + * \brief Load k tile from smem and compute qk. + * \tparam rotary_mode The rotary mode used in the kernel + * \tparam head_dim A template integer indicates the head dimension + * \tparam vec_size A template integer indicates the vector size + * \tparam bdx A template integer indicates the block size in x dimension + * \tparam T A template type indicates the input data type + * \param smem A pointer to the start of shared memory + * \param q_vec A vector of float indicates the thread-local query vector + * \param freq A vector of float indicates the thread-local rope frequency + * \param kv_shared_offset An array of uint32_t indicates the k/v tiles offset + * in shared memory of different pipeline stages \param kv_idx A integer + * indicates the thread-local kv position in kv-cache. \param compute_stage_idx + * A integer indicates the compute stage index in the pipeline \param sm_scale A + * float indicates the scale applied to pre-softmax logits \param x A float + * indicates the thread-local result of qk + */ +template +__device__ __forceinline__ void compute_qk(const T* smem, + const vec_t& q_vec, + const vec_t& freq, + uint32_t kv_idx_base, + uint32_t compute_stage_idx, + float sm_scale, float* x) { + uint32_t tx = threadIdx.x, tz = threadIdx.z; +#pragma unroll + for (uint32_t iy = 0; iy < bdy; ++iy) { + vec_t k_vec; + if constexpr (rotary_mode == RotaryMode::kLlama) { + // apply rotary embedding for all rows in k matrix of kv-cache + k_vec = apply_llama_rope(smem + iy * bdx * vec_size, freq, + kv_idx_base + tz * bdy + iy); + } else { + // do not apply rotary embedding + k_vec.cast_load(smem + (iy * bdx + tx) * vec_size); + } + x[iy] = 0.f; +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + x[iy] += q_vec[i] * k_vec[i] * sm_scale; + } +#pragma unroll + for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { + x[iy] += math::shfl_xor_sync(x[iy], offset); + } + } +} + +/*! + * \brief Load v tile from shared memory and update partial state. + * \tparam vec_size A template integer indicates the vector size + * \tparam bdx A template integer indicates the block size in x dimension + * \tparam T A template type indicates the input data type + * \tparam norm_on_the_fly Whether to normalize on the fly or not + * \param smem A pointer to the start of shared memory + * \param x A float indicates the pre-softmax logits + * \param kv_shared_offset An array of uint32_t indicates the k/v tiles offset + * in shared memory of different pipeline stages. \param compute_stage_idx A + * integer indicates the compute stage index in the pipeline \param pred_guard A + * boolean indicates whether the current thread is in the valid range \param s + * The flashattention state to be updated + */ +template +__device__ __forceinline__ void update_partial_state( + const T* smem, const float* x, uint32_t compute_stage_idx, + uint32_t kv_idx_base, uint32_t kv_idx_bound, + state_t& s) { + uint32_t tx = threadIdx.x, tz = threadIdx.z; +#pragma unroll + for (uint32_t iy = 0; iy < bdy; ++iy) { + vec_t v_vec; + v_vec.cast_load(smem + (iy * bdx + tx) * vec_size); + if (kv_idx_base + tz * bdy + iy < kv_idx_bound) { + s.merge(v_vec, x[iy]); + } + } +} + +/*! + * \brief Synchronize the state of all warps inside a threadblock. + * \tparam vec_size A template integer indicates the vector size + * \tparam bdx A template integer indicates the block size in x dimension + * \tparam bdy A template integer indicates the block size in y dimension + * \tparam norm_on_the_fly Whether to normalize on the fly or not + * \param s The warp local state + * \param smem The pointer to shared memory buffer for o + * \param smem_md The pointer to shared memory buffer for m/d + */ +template +__device__ __forceinline__ void sync_state( + state_t& s, float* smem, float* smem_md) { + if constexpr (bdz > 1) { + constexpr uint32_t head_dim = bdx * vec_size; + auto block = cg::this_thread_block(); + uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + s.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size); + smem_md[(tz * bdy + ty) * 2] = s.m; + smem_md[(tz * bdy + ty) * 2 + 1] = s.d; + block.sync(); + s.init(); +#pragma unroll + for (uint32_t iz = 0; iz < bdz; ++iz) { + float mz = smem_md[(iz * bdy + ty) * 2], + dz = smem_md[(iz * bdy + ty) * 2 + 1]; + vec_t oz; + oz.load(smem + (iz * bdy + ty) * head_dim + tx * vec_size); + s.merge(oz, mz, dz); + } + } +} + +} // namespace + +/*! + * \brief FlashAttention decoding cuda kernel with kv-cache for a single + * sequence, fused with RoPE. + * \tparam layout The layout of k/v matrices (NHD or HND) + * \tparam cooperative Whether to use cooperative kernel or not + * \tparam norm_on_the_fly Whether to normalize on the fly or not + * \tparam rotary_mode The rotary mode + * \tparam vec_size A template integer indicates the vector size + * \tparam bdx A template integer indicates the block size in x dimension + * \tparam bdy A template integer indicates the block size in y dimension + * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeOut A template type indicates the output data type + * \param q [num_qo_heads, head_dim] The query matrix + * \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache + * \param v [seq_len, num_kv_heads, head_dim] The value matrix in kv-cache + * \param o [num_qo_heads, head_dim] The output matrix + * \param tmp Used-allocated temporary buffer + * \param info The tensor info of k/v matrices + * \param sm_scale A float indicates the scale applied to pre-softmax logits + * \param head_dim A integer indicates the head dimension + * \param rope_inv_scale A floating number indicate the multiplicative inverse + * of scaling ratio used in PI(Position Interpolation) for RoPE (Rotary + * Positional Embeddings) + * \param rope_inv_theta A floating number indicate the multiplicative inverse + * of "theta" used in RoPE (Rotary Positional Embeddings) + * \param kv_chunk_size A integer indicates the kv-chunk size + */ +template +__global__ void SingleDecodeWithKVCacheKernel( + DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + DTypeOut* __restrict__ o, float* __restrict__ tmp, + tensor_info_t info, float sm_scale, float rope_inv_scale, + float rope_inv_theta, uint32_t kv_chunk_size) { + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + sm_scale *= math::log2e; + + constexpr uint32_t head_dim = bdx * vec_size; + uint32_t kv_head_idx = blockIdx.y; + uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; + uint32_t kv_chunk_idx = blockIdx.x; + uint32_t num_kv_chunks = gridDim.x; + uint32_t num_qo_heads = info.get_num_qo_heads(); + uint32_t seq_len = info.kv_len; + + extern __shared__ uint8_t smem[]; + DTypeIn* k_smem = (DTypeIn*)smem; + DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * + sizeof(DTypeIn)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * + sizeof(DTypeIn)); + + uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + vec_t q_vec; + vec_t freq; + if constexpr (rotary_mode == RotaryMode::kLlama) { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + freq[i] = rope_inv_scale * + powf(rope_inv_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / + float(head_dim)); + } + // apply rotary embedding to q matrix + q_vec = apply_llama_rope( + q + info.get_qo_elem_offset(0, qo_head_idx, 0), freq, seq_len - 1); + } else { + // do not apply rotary embedding to q matrix + q_vec.cast_load(q + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); + } + block.sync(); + + uint32_t chunk_start = kv_chunk_idx * kv_chunk_size; + kv_chunk_size = min(kv_chunk_size, seq_len - chunk_start); + uint32_t chunk_end = chunk_start + kv_chunk_size; + + // preload k tiles and v tiles + uint32_t producer_kv_idx_base = chunk_start; + constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; +#pragma unroll + for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { + cp_async::pred_load( + k_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, + k + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, + kv_head_idx, tx * vec_size), + producer_kv_idx_base + tz * bdy + ty < chunk_end); + cp_async::commit_group(); + cp_async::pred_load( + v_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, + v + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, + kv_head_idx, tx * vec_size), + producer_kv_idx_base + tz * bdy + ty < chunk_end); + cp_async::commit_group(); + producer_kv_idx_base += bdy * bdz; + } + + // pipelining k/v tiles loading and state updating + uint32_t consumer_kv_idx_base = chunk_start, stage_idx = 0; + state_t s_partial; + float x[bdy]; + +#pragma unroll 4 + for (uint32_t iter = 0; iter < (kv_chunk_size + bdy * bdz - 1) / (bdy * bdz); + ++iter) { + // compute qk + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + compute_qk( + k_smem + (stage_idx * bdz + tz) * bdy * head_dim, q_vec, freq, + consumer_kv_idx_base, stage_idx, sm_scale, x); + block.sync(); + // load k + cp_async::pred_load( + k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, + k + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, + kv_head_idx, tx * vec_size), + producer_kv_idx_base + tz * bdy + ty < chunk_end); + cp_async::commit_group(); + + // update m/d/o state + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + update_partial_state( + v_smem + (stage_idx * bdz + tz) * bdy * head_dim, x, stage_idx, + consumer_kv_idx_base, chunk_end, s_partial); + block.sync(); + + // load v + cp_async::pred_load( + v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, + v + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, + kv_head_idx, tx * vec_size), + producer_kv_idx_base + tz * bdy + ty < chunk_end); + cp_async::commit_group(); + + stage_idx = (stage_idx + 1) % num_stages_smem; + producer_kv_idx_base += bdy * bdz; + consumer_kv_idx_base += bdy * bdz; + } + cp_async::wait_group<0>(); + block.sync(); + + // sync partial state of all warps inside a threadblock + sync_state(s_partial, reinterpret_cast(smem), + smem_md); + + if constexpr (cooperative) { + // update tmp buffer + s_partial.o.store(tmp + + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim + + tx * vec_size); + float* tmp_md = tmp + num_qo_heads * num_kv_chunks * head_dim; + *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2] = + make_float2(s_partial.m, s_partial.d); + grid.sync(); + + // sync global states + if (kv_chunk_idx == 0) { + state_t s_global; +#pragma unroll 4 + for (uint32_t iter = 0; iter < (num_kv_chunks + bdz - 1) / bdz; ++iter) { + uint32_t kv_chunk_idx = iter * bdz + tz; + if (kv_chunk_idx < num_kv_chunks) { + float2 md = *( + float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2]; + s_partial.m = md.x; + s_partial.d = md.y; + s_partial.o.load( + tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim + + tx * vec_size); + s_global.merge(s_partial); + } + } + block.sync(); + // sync partial state of all warps inside a threadblock + sync_state( + s_global, reinterpret_cast(smem), smem_md); + s_global.normalize(); + s_global.o.cast_store( + o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); + } + } else { + s_partial.normalize(); + s_partial.o.cast_store( + o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); + } +} + +template +__forceinline__ __device__ void AdvancePageIterator( + paged_kv_t paged_kv, uint32_t* kv_idx_base, + uint32_t* valid_page_size, uint32_t& producer_valid_page_size, + uint32_t& producer_entry_base, uint32_t& producer_page_iter, + uint32_t& producer_page_idx, uint32_t cur_page_indptr_begin, + uint32_t cur_page_indptr_end, uint32_t batch_idx, uint32_t stage_idx) { + if (producer_entry_base >= producer_valid_page_size) { + producer_entry_base = 0; + producer_page_iter += 1; + if (producer_page_iter < cur_page_indptr_end) { + producer_page_idx = paged_kv.indices[producer_page_iter]; + producer_valid_page_size = + paged_kv.get_valid_page_size(batch_idx, producer_page_iter); + } else { + producer_valid_page_size = 0; + } + } + kv_idx_base[stage_idx] = + producer_entry_base + + (producer_page_iter - cur_page_indptr_begin) * paged_kv.page_size; + valid_page_size[stage_idx] = producer_valid_page_size; +} + +/*! + * \brief FlashAttention decoding cuda kernel with PagedKVCcache for batch + * requests, fused with RoPE. \tparam cooperative Whether to use cooperative + * kernel or not \tparam rotary_mode The rotary mode \tparam norm_on_the_fly + * Whether to normalize on the fly or not \tparam vec_size A template integer + * indicates the vector size \tparam bdx A template integer indicates the block + * size in x dimension \tparam bdy A template integer indicates the block size + * in y dimension \tparam bdz A template integer indicates the block size in z + * dimension \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeOut A template type indicates the output data type + * \tparam IdType A template type indicates the index data type + * \param q [batch_size, num_qo_heads, head_dim] The query matrix + * \param paged_kv The PagedKVCache data structure + * \param o [num_qo_heads, head_dim] The output matrix + * \param sm_scale A float indicates the scale applied to pre-softmax logits + * \param rope_inv_scale A floating number indicate the multiplicative inverse + * of scaling ratio used in PI(Position Interpolation) for RoPE (Rotary + * Positional Embeddings) + * \param rope_inv_theta A floating number indicate the multiplicative inverse + * of "theta" used in RoPE (Rotary Positional Embeddings) + */ +template +__global__ void BatchDecodeWithPagedKVCacheKernel( + DTypeIn* __restrict__ q, paged_kv_t paged_kv, + DTypeOut* __restrict__ o, float* __restrict__ tmp, float sm_scale, + float rope_inv_scale, float rope_inv_theta) { + auto block = cg::this_thread_block(); + sm_scale *= math::log2e; + + constexpr uint32_t head_dim = bdx * vec_size; + const uint32_t batch_idx = blockIdx.x; + const uint32_t kv_head_idx = blockIdx.y; + const uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; + const uint32_t num_qo_heads = gridDim.y * bdy; + const uint32_t cur_chunk_start = + cooperative ? paged_kv.chunk_start[batch_idx] : 0U; + const uint32_t cur_page_indptr_begin = paged_kv.indptr[batch_idx], + cur_page_indptr_end = paged_kv.indptr[batch_idx + 1]; + const uint32_t cur_last_page_offset = paged_kv.last_page_offset[batch_idx]; + const uint32_t seq_len = + cooperative ? paged_kv.seq_lens_before_split[batch_idx] + : (cur_page_indptr_end - cur_page_indptr_begin - 1) * + paged_kv.page_size + + cur_last_page_offset; + + extern __shared__ uint8_t smem[]; + DTypeIn* k_smem = (DTypeIn*)smem; + DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * + sizeof(DTypeIn)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * + sizeof(DTypeIn)); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + vec_t q_vec; + vec_t freq; + if constexpr (rotary_mode == RotaryMode::kLlama) { +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + freq[i] = rope_inv_scale * + __powf(rope_inv_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / + float(head_dim)); + } + // apply rotary embedding to q matrix + if constexpr (cooperative) { + q_vec = apply_llama_rope( + q + (paged_kv.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * + head_dim, + freq, seq_len - 1); + } else { + q_vec = apply_llama_rope( + q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, + seq_len - 1); + } + } else { + // do not apply rotary embedding to q matrix + if constexpr (cooperative) { + q_vec.cast_load( + q + + (paged_kv.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * + head_dim + + tx * vec_size); + } else { + q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + + tx * vec_size); + } + } + block.sync(); + + // preload k/v tiles + uint32_t producer_entry_base = 0, stage_idx = 0; + constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + uint32_t producer_page_iter = cur_page_indptr_begin; + uint32_t producer_page_idx = paged_kv.indices[producer_page_iter]; + uint32_t producer_valid_page_size = + paged_kv.get_valid_page_size(batch_idx, producer_page_iter); + uint32_t kv_idx_base[num_stages_smem]{0}; + uint32_t valid_page_size[num_stages_smem]{0}; +#pragma unroll + for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { + AdvancePageIterator( + paged_kv, kv_idx_base, valid_page_size, producer_valid_page_size, + producer_entry_base, producer_page_iter, producer_page_idx, + cur_page_indptr_begin, cur_page_indptr_end, batch_idx, stage_idx); + bool producer_pred_guard = + (producer_entry_base + tz * bdy + ty < producer_valid_page_size) && + (producer_page_iter < cur_page_indptr_end); + cp_async::pred_load( + k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, + paged_kv.data + paged_kv.get_k_elem_offset( + producer_page_idx, kv_head_idx, + producer_entry_base + tz * bdy + ty, tx * vec_size), + producer_pred_guard); + cp_async::commit_group(); + cp_async::pred_load( + v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, + paged_kv.data + paged_kv.get_v_elem_offset( + producer_page_idx, kv_head_idx, + producer_entry_base + tz * bdy + ty, tx * vec_size), + producer_pred_guard); + cp_async::commit_group(); + stage_idx = (stage_idx + 1) % num_stages_smem; + producer_entry_base += bdy * bdz; + } + + state_t s; + float x[bdy]; + uint32_t consumer_kv_idx_base = 0; + + for (uint32_t consumer_page_iter = cur_page_indptr_begin; + consumer_page_iter < cur_page_indptr_end; ++consumer_page_iter) { + uint32_t consumer_valid_page_size = valid_page_size[stage_idx]; +#pragma unroll + for (uint32_t iter = 0; + iter < (consumer_valid_page_size + (bdy * bdz) - 1) / (bdy * bdz); + ++iter) { + consumer_kv_idx_base = kv_idx_base[stage_idx]; + AdvancePageIterator( + paged_kv, kv_idx_base, valid_page_size, producer_valid_page_size, + producer_entry_base, producer_page_iter, producer_page_idx, + cur_page_indptr_begin, cur_page_indptr_end, batch_idx, stage_idx); + bool producer_pred_guard = + (producer_entry_base + tz * bdy + ty < producer_valid_page_size) && + (producer_page_iter < cur_page_indptr_end); + // compute qk + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + compute_qk( + k_smem + (stage_idx * bdz + tz) * bdy * head_dim, q_vec, freq, + cur_chunk_start + consumer_kv_idx_base, stage_idx, sm_scale, x); + block.sync(); + + // load k tiles + cp_async::pred_load( + k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + + tx * vec_size, + paged_kv.data + + paged_kv.get_k_elem_offset(producer_page_idx, kv_head_idx, + producer_entry_base + tz * bdy + ty, + tx * vec_size), + producer_pred_guard); + cp_async::commit_group(); + + // update m/d/o states + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + update_partial_state( + v_smem + (stage_idx * bdz + tz) * bdy * head_dim, x, stage_idx, + iter * bdy * bdz, consumer_valid_page_size, s); + block.sync(); + + // load v tiles + cp_async::pred_load( + v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + + tx * vec_size, + paged_kv.data + + paged_kv.get_v_elem_offset(producer_page_idx, kv_head_idx, + producer_entry_base + tz * bdy + ty, + tx * vec_size), + producer_pred_guard); + cp_async::commit_group(); + + stage_idx = (stage_idx + 1) % num_stages_smem; + producer_entry_base += bdy * bdz; + } + } + cp_async::wait_group<0>(); + block.sync(); + + // sync partial state of all warps inside a threadblock + sync_state(s, reinterpret_cast(smem), + smem_md); + + if constexpr (cooperative) { + auto grid = cg::this_grid(); + // update tmp buffer + s.o.store(tmp + (qo_head_idx * paged_kv.batch_size + batch_idx) * head_dim + + tx * vec_size); + float* tmp_md = tmp + num_qo_heads * paged_kv.batch_size * head_dim; + *(float2*)&tmp_md[(qo_head_idx * paged_kv.batch_size + batch_idx) * 2] = + make_float2(s.m, s.d); + grid.sync(); + + // sync global states + const uint32_t cooperative_indptr_begin = + paged_kv.cooperative_indptr[batch_idx], + cooperative_indptr_end = + paged_kv.cooperative_indptr[batch_idx + 1]; + if (cooperative_indptr_begin < cooperative_indptr_end) { + state_t s_global; + const uint32_t num_pages = + cooperative_indptr_end - cooperative_indptr_begin; +#pragma unroll 4 + for (uint32_t iter = 0; iter < (num_pages + bdz - 1) / bdz; ++iter) { + uint32_t kv_chunk_idx = cooperative_indptr_begin + iter * bdz + tz; + if (kv_chunk_idx < cooperative_indptr_end) { + float2 md = *(float2*)&tmp_md[(qo_head_idx * paged_kv.batch_size + + kv_chunk_idx) * + 2]; + s.m = md.x; + s.d = md.y; + s.o.load(tmp + + (qo_head_idx * paged_kv.batch_size + kv_chunk_idx) * + head_dim + + tx * vec_size); + s_global.merge(s); + } + } + block.sync(); + // sync partial state of all warps inside a threadblock + sync_state( + s_global, reinterpret_cast(smem), smem_md); + s_global.normalize(); + s_global.o.cast_store( + o + + (paged_kv.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * + head_dim + + tx * vec_size); + } + } else { + s.normalize(); + s.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + + tx * vec_size); + } +} + +constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, + uint32_t sizeof_dtype) { + if (group_size == 8U) { + if (sizeof_dtype == 1U) { + return 256U; // not enough registers for 512 threads + } else { + return 512U; + } + } else { + return 128U; + } +} + +template +cudaError_t SingleDecodeWithKVCacheWorkEstimation( + uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, + QKVLayout layout = QKVLayout::kNHD, + RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + if (seq_len <= 128U) { + tmp_size = 0; + } else { + SWITCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {SWITCH_HEAD_DIM( + head_dim, HEAD_DIM, + {SWITCH_ROTARY_MODE( + rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(layout, QKV_LAYOUT, { + constexpr uint32_t vec_size = + std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t num_stages_smem = 2U; + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32U); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = + get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr bool norm_on_the_fly = false; + const uint32_t smem_size = 2U * num_stages_smem * bdy * bdz * + head_dim * sizeof(DTypeIn) + + 2U * bdy * bdz * sizeof(float); + + auto kernel = SingleDecodeWithKVCacheKernel< + QKV_LAYOUT, true, norm_on_the_fly, ROTARY_MODE, + num_stages_smem, vec_size, bdx, bdy, bdz, DTypeIn, + DTypeOut>; + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + max_grid_size = + uint32_t(num_blocks_per_sm) * uint32_t(num_sm); + uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; + uint32_t kv_chunk_size = max( + (seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, + min(num_threads, max(num_threads / 8, + seq_len / max(1U, (num_threads / + num_kv_heads))))); + uint32_t num_kv_chunks = + (seq_len + kv_chunk_size - 1) / kv_chunk_size; + tmp_size = num_qo_heads * num_kv_chunks * (head_dim + 2); + })})})}); + } + return cudaSuccess; +} + +/*! + * \brief FlashAttention decoding with kv-cache for a single sequence + * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeOut A template type indicates the output data type + * \param q The query matrix, shape: [num_qo_heads, head_dim] + * \param k The key matrix in kv-cache, shape: [seq_len, num_kv_heads, head_dim] + * for NHD layout, [num_kv_heads, head_dim, seq_len] for HND layout + * \param v The value matrix in kv-cache, shape: [seq_len, num_kv_heads, + * head_dim] for NHD layout, [num_kv_heads, head_dim, seq_len] for HND layout + * \param o The output matrix, shape: [num_qo_heads, head_dim] + * \param tmp Used-allocated temporary buffer + * \param num_qo_heads A integer indicates the number of heads of query and + * output \param num_kv_heads A integer indicates the number of heads of key and + * value \param seq_len A integer indicates the sequence length \param head_dim + * A integer indicates the head dimension \param layout The layout of q/k/v + * matrices. \param rotary_mode The rotary mode \param rope_scale A floating + * point number indicate the scaling ratio used in RoPE Interpolation. \param + * rope_theta A floating point number indicate the "theta" used in RoPE \param + * stream The cuda stream to launch the kernel + */ +template +cudaError_t SingleDecodeWithKVCache( + DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, + uint32_t head_dim, QKVLayout layout = QKVLayout::kNHD, + RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, + float rope_theta = 1e4, cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + const float sm_scale = 1.f / std::sqrt(float(head_dim)); + const float rope_inv_scale = 1.f / rope_scale; + const float rope_inv_theta = 1.f / rope_theta; + constexpr bool norm_on_the_fly = false; + assert(num_qo_heads % num_kv_heads == 0); + + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + + SWITCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {SWITCH_HEAD_DIM( + head_dim, HEAD_DIM, + {SWITCH_ROTARY_MODE( + rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(layout, QKV_LAYOUT, { + constexpr uint32_t vec_size = + std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t num_stages_smem = 2U; + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32U); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = + get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + tensor_info_t info( + 1, seq_len, num_kv_heads, head_dim); + const uint32_t smem_size = 2U * num_stages_smem * bdy * bdz * + head_dim * sizeof(DTypeIn) + + 2U * bdy * bdz * sizeof(float); + if (seq_len <= 128U || tmp == nullptr) { + // no need to use cooperative kernel + auto kernel = SingleDecodeWithKVCacheKernel< + QKV_LAYOUT, false, norm_on_the_fly, ROTARY_MODE, + num_stages_smem, vec_size, bdx, bdy, bdz, DTypeIn, + DTypeOut>; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + dim3 nblks = dim3(1, num_kv_heads); + dim3 nthrs = dim3(bdx, bdy, bdz); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&info, + (void*)&sm_scale, + (void*)&rope_inv_scale, + (void*)&rope_inv_theta, + (void*)&seq_len}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use cooperative kernel + auto kernel = SingleDecodeWithKVCacheKernel< + QKV_LAYOUT, true, norm_on_the_fly, ROTARY_MODE, + num_stages_smem, vec_size, bdx, bdy, bdz, DTypeIn, + DTypeOut>; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_grid_size = + uint32_t(num_blocks_per_sm) * uint32_t(num_sm); + uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; + uint32_t kv_chunk_size = max( + (seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, + min(num_threads, max(num_threads / 8, + seq_len / max(1U, (num_threads / + num_kv_heads))))); + dim3 nblks = + dim3((seq_len + kv_chunk_size - 1) / kv_chunk_size, + num_kv_heads); + assert(nblks.x > 0 && nblks.y > 0); + dim3 nthrs = dim3(bdx, bdy, bdz); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&info, + (void*)&sm_scale, + (void*)&rope_inv_scale, + (void*)&rope_inv_theta, + (void*)&kv_chunk_size}; + FLASHINFER_CUDA_CALL(cudaLaunchCooperativeKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); + } + })})})}); + return cudaSuccess; +} + +template +cudaError_t SplitPagedKVCache(uint32_t old_batch_size, + const IdType* old_page_indptr_h, + const IdType* old_last_page_offset_h, + uint32_t max_num_pages_per_batch, + paged_kv_t* new_paged_kv_d, + cudaStream_t stream = nullptr) { + std::vector new_page_indptr_h{0}, new_last_page_offset_h, + cooperative_indptr_h{0}, batch_idx_map_h, chunk_start_h, + seq_lens_before_split_h; + + for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) { + uint32_t cooperative_indptr_delta = + (old_page_indptr_h[batch_idx + 1] - old_page_indptr_h[batch_idx] + + max_num_pages_per_batch - 1) / + max_num_pages_per_batch; + uint32_t seq_len_before_split = + (old_page_indptr_h[batch_idx + 1] - old_page_indptr_h[batch_idx] - 1) * + new_paged_kv_d->page_size + + old_last_page_offset_h[batch_idx]; + for (uint32_t j = 0; j < cooperative_indptr_delta; ++j) { + bool is_last = (j + 1) == cooperative_indptr_delta; + new_page_indptr_h.push_back( + min(old_page_indptr_h[batch_idx] + (j + 1) * max_num_pages_per_batch, + old_page_indptr_h[batch_idx + 1])); + new_last_page_offset_h.push_back(is_last + ? old_last_page_offset_h[batch_idx] + : new_paged_kv_d->page_size); + batch_idx_map_h.push_back(batch_idx); + if (j == 0) { + cooperative_indptr_h.push_back(cooperative_indptr_h.back() + + cooperative_indptr_delta); + } else { + cooperative_indptr_h.push_back(cooperative_indptr_h.back()); + } + chunk_start_h.push_back(j * max_num_pages_per_batch * + new_paged_kv_d->page_size); + seq_lens_before_split_h.push_back(seq_len_before_split); + } + } + + FLASHINFER_CUDA_CALL( + cudaMemcpyAsync(new_paged_kv_d->indptr, new_page_indptr_h.data(), + sizeof(IdType) * new_page_indptr_h.size(), + cudaMemcpyHostToDevice, stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + new_paged_kv_d->last_page_offset, new_last_page_offset_h.data(), + sizeof(IdType) * new_last_page_offset_h.size(), cudaMemcpyHostToDevice, + stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + new_paged_kv_d->cooperative_indptr, cooperative_indptr_h.data(), + sizeof(IdType) * cooperative_indptr_h.size(), cudaMemcpyHostToDevice, + stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + new_paged_kv_d->batch_idx_map, batch_idx_map_h.data(), + sizeof(IdType) * batch_idx_map_h.size(), cudaMemcpyHostToDevice, stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + new_paged_kv_d->chunk_start, chunk_start_h.data(), + sizeof(IdType) * chunk_start_h.size(), cudaMemcpyHostToDevice, stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + new_paged_kv_d->seq_lens_before_split, seq_lens_before_split_h.data(), + sizeof(IdType) * seq_lens_before_split_h.size(), cudaMemcpyHostToDevice, + stream)); + return cudaSuccess; +} + +/*! + * \brief Compute the maximum number of pages per batch and the new batch size + * after we split Paged KV-Cache into multiple chunks on KV sequence length + * dimension. \param max_grid_size The maximum grid size of the kernel \param + * num_kv_heads The number of KV heads \param num_pages The number of pages per + * request in the batch \param max_num_pages_per_batch_lb The pre-set lower + * bound of maximum number of pages per batch, default to 1 \return + * (max_num_pages_per_batch, new_batch_size) The number of pages per batch and + * the new batch size after the split. + */ +template +std::pair SplitPagedKVCacheBinarySearchMinNumPagePerBatch( + const uint32_t max_grid_size, const uint32_t num_kv_heads, + const std::vector& num_pages, + const uint32_t min_num_pages_per_batch = 1) { + uint32_t low = min_num_pages_per_batch, high = 0; + for (const IdType& elem : num_pages) { + high = max(high, elem); + } + uint32_t new_batch_size; + while (low < high) { + uint32_t mid = (low + high) / 2; + new_batch_size = 0; + for (const IdType& elem : num_pages) { + new_batch_size += (elem + mid - 1) / mid; + } + if (new_batch_size * num_kv_heads > max_grid_size) { + low = mid + 1; + } else { + high = mid; + } + } + new_batch_size = 0; + for (const IdType& elem : num_pages) { + new_batch_size += (elem + low - 1) / low; + } + return {low, new_batch_size}; +} + +template +cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( + uint32_t& tmp_size, uint32_t& max_grid_size, + uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, + const paged_kv_t& paged_kv, uint32_t num_qo_heads, + RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + constexpr bool norm_on_the_fly = false; + const uint32_t head_dim = paged_kv.head_dim; + const uint32_t batch_size = paged_kv.batch_size; + const uint32_t num_kv_heads = paged_kv.num_heads; + SWITCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {SWITCH_HEAD_DIM( + head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + constexpr uint32_t vec_size = + std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t num_stages_smem = 2; + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = + get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + const uint32_t smem_size = + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) + + 2 * bdy * bdz * sizeof(float); + + auto cooperative_kernel = BatchDecodeWithPagedKVCacheKernel< + true, ROTARY_MODE, norm_on_the_fly, num_stages_smem, vec_size, + bdx, bdy, bdz, DTypeIn, DTypeOut, IdType>; + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, cooperative_kernel, num_threads, + smem_size)); + max_grid_size = num_blocks_per_sm * num_sm; + const uint32_t num_kv_heads = paged_kv.num_heads; + if (batch_size * num_kv_heads >= max_grid_size) { + // do not use cooperative kernel + tmp_size = 0; + } else { + // compute max_num_pages_per_batch and new_batch_size + std::vector page_indptr_h(batch_size + 1), + num_pages(batch_size); + FLASHINFER_CUDA_CALL( + cudaMemcpyAsync(page_indptr_h.data(), paged_kv.indptr, + sizeof(IdType) * (batch_size + 1), + cudaMemcpyDeviceToHost, stream)); + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + for (uint32_t batch_idx = 0; batch_idx < batch_size; + ++batch_idx) { + num_pages[batch_idx] = + page_indptr_h[batch_idx + 1] - page_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + SplitPagedKVCacheBinarySearchMinNumPagePerBatch( + max_grid_size, num_kv_heads, num_pages, + 128 / paged_kv.page_size); + tmp_size = num_qo_heads * new_batch_size * (head_dim + 2); + } + })})}); + return cudaSuccess; +} + +/*! + * \brief FlashAttention decoding cuda kernel with paged kv-cache for batched + * requests \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeOut A template type indicates the output data type + * \tparam IdType A template type indicates the index data type used in paged + * kv-cache \param q [batch_size, num_qo_heads, head_dim] The query matrix + * \param paged_kv The paged kv cache data structure + * \param o [batch_size, num_qo_heads, head_dim] The output matrix + * \param tmp Used-allocated temporary buffer + * \param num_qo_heads A integer indicates the number of heads of query and + * output \param rotary_mode The rotary mode \param rope_scale A floating point + * number indicate the scaling ratio used in RoPE Interpolation. \param + * rope_theta A floating point number indicate the "theta" used in RoPE \param + * stream The cuda stream to launch the kernel \param dev_id The device id + */ +template +cudaError_t BatchDecodeWithPagedKVCache( + DTypeIn* q, paged_kv_t paged_kv, DTypeOut* o, float* tmp, + uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, + float rope_scale = 1.f, float rope_theta = 1e4, + cudaStream_t stream = nullptr, uint32_t dev_id = 0) { + const float sm_scale = 1.f / std::sqrt(float(paged_kv.head_dim)); + const float rope_inv_scale = 1.f / rope_scale; + const float rope_inv_theta = 1.f / rope_theta; + constexpr bool norm_on_the_fly = false; + const uint32_t num_kv_heads = paged_kv.num_heads; + const uint32_t head_dim = paged_kv.head_dim; + const uint32_t batch_size = paged_kv.batch_size; + assert(num_qo_heads % num_kv_heads == 0); + + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + + SWITCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {SWITCH_HEAD_DIM( + head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + constexpr uint32_t vec_size = + std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t num_stages_smem = 2; + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = + get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + const uint32_t smem_size = + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) + + 2 * bdy * bdz * sizeof(float); + + auto cooperative_kernel = BatchDecodeWithPagedKVCacheKernel< + true, ROTARY_MODE, norm_on_the_fly, num_stages_smem, vec_size, + bdx, bdy, bdz, DTypeIn, DTypeOut, IdType>; + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, cooperative_kernel, num_threads, + smem_size)); + uint32_t max_grid_size = num_blocks_per_sm * num_sm; + + if (batch_size * num_kv_heads >= max_grid_size || tmp == nullptr) { + // do not use cooperative kernel + dim3 nblks(batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + auto kernel = BatchDecodeWithPagedKVCacheKernel< + false, ROTARY_MODE, norm_on_the_fly, num_stages_smem, + vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut, IdType>; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + void* args[] = {(void*)&q, + (void*)&paged_kv, + (void*)&o, + (void*)&tmp, + (void*)&sm_scale, + (void*)&rope_inv_scale, + (void*)&rope_inv_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, + args, smem_size, stream)); + } else { + // use cooperative kernel + assert(paged_kv.cooperative_indptr != nullptr); + assert(paged_kv.batch_idx_map != nullptr); + assert(paged_kv.chunk_start != nullptr); + assert(paged_kv.seq_lens_before_split != nullptr); + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + cooperative_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&q, + (void*)&paged_kv, + (void*)&o, + (void*)&tmp, + (void*)&sm_scale, + (void*)&rope_inv_scale, + (void*)&rope_inv_theta}; + dim3 nblks(batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + FLASHINFER_CUDA_CALL( + cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, + nthrs, args, smem_size, stream)); + } + })})}); + + return cudaSuccess; +} + +} // namespace flashinfer + +#endif // FLASHINFER_DECODE_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/layout.cuh b/server/punica_kernels/punica_kernels/flashinfer/layout.cuh new file mode 100644 index 000000000..6af37dd19 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/layout.cuh @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_LAYOUT_CUH_ +#define FLASHINFER_LAYOUT_CUH_ + +#include + +namespace flashinfer { + +/*! + * \brief The Layout of QKV matrices + */ +enum class QKVLayout { + // [seq_len, num_heads, head_dim] + kNHD = 0U, + // [num_heads, seq_len, head_dim] + kHND = 1U, +}; + +template +__host__ __device__ __forceinline__ size_t +get_elem_offset_impl(size_t elem_idx, size_t head_idx, size_t feat_idx, + size_t seq_len, size_t num_heads, size_t head_dim) { + if constexpr (layout == QKVLayout::kHND) { + return (head_idx * seq_len + elem_idx) * head_dim + feat_idx; + } else { + return (elem_idx * num_heads + head_idx) * head_dim + feat_idx; + } +} + +template +struct tensor_info_t { + uint32_t qo_len; + uint32_t kv_len; + uint32_t num_kv_heads; + uint32_t head_dim; + __host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, + uint32_t kv_len, + uint32_t num_kv_heads, + uint32_t head_dim) + : qo_len(qo_len), + kv_len(kv_len), + num_kv_heads(num_kv_heads), + head_dim(head_dim) {} + + __host__ __device__ __forceinline__ uint32_t get_num_kv_heads() const { + return num_kv_heads; + } + + __host__ __device__ __forceinline__ uint32_t get_num_qo_heads() const { + return num_kv_heads * group_size; + } + + __host__ __device__ __forceinline__ size_t get_qo_elem_offset( + uint32_t qo_idx, uint32_t qo_head_idx, uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, qo_len, + get_num_qo_heads(), head_dim); + } + + __host__ __device__ __forceinline__ size_t get_kv_elem_offset( + uint32_t kv_idx, uint32_t kv_head_idx, uint32_t feat_idx) const { + return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_len, + num_kv_heads, head_dim); + } + + __host__ __device__ __forceinline__ size_t get_n_stride() const { + return get_kv_n_stride(); + } + + __host__ __device__ __forceinline__ size_t get_qo_n_stride() const { + return layout == QKVLayout::kHND ? head_dim : get_num_qo_heads() * head_dim; + } + + __host__ __device__ __forceinline__ size_t get_kv_n_stride() const { + return layout == QKVLayout::kHND ? head_dim : num_kv_heads * head_dim; + } +}; + +/*! + * \brief Convert QKVLayout to string + * \param qkv_layout The QKVLayout to convert + */ +inline std::string QKVLayoutToString(const QKVLayout& qkv_layout) { + switch (qkv_layout) { + case QKVLayout::kNHD: + return "NHD"; + case QKVLayout::kHND: + return "HND"; + default: + return "Unknown"; + } +} + +} // namespace flashinfer +#endif // FLASHINFER_LAYOUT_CUH_ \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/flashinfer/math.cuh b/server/punica_kernels/punica_kernels/flashinfer/math.cuh new file mode 100644 index 000000000..83338412e --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/math.cuh @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_MATH_CUH_ +#define FLASHINFER_MATH_CUH_ + +#include + +namespace flashinfer { +namespace math { + +constexpr float log2e = 1.44269504088896340736f; + +__forceinline__ __device__ float ptx_exp2(float x) { + float y; + asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +__forceinline__ __device__ float ptx_lg2(float x) { + float y; + asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +__forceinline__ __device__ float shfl_xor_sync(float x, int delta) { + float y; + asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" + : "=f"(y) + : "f"(x), "r"(delta)); + return y; +} + +} // namespace math +} // namespace flashinfer +#endif // FLASHINFER_MATH_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/mma.cuh b/server/punica_kernels/punica_kernels/flashinfer/mma.cuh new file mode 100644 index 000000000..add7e4e12 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/mma.cuh @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_MMA_CUH_ +#define FLASHINFER_MMA_CUH_ + +#include +#include +#include + +#include + +#include "vec_dtypes.cuh" + +namespace flashinfer { + +namespace mma { + +// template +__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, uint4* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4(vec_t* v, + uint4* smem_ptr) { + static_assert(sizeof(T) == 2, "T must be half/bfloat16"); + ldmatrix_m8n8x4((uint32_t*)v->ptr(), smem_ptr); +} + +// template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, + uint4* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans(vec_t* v, + uint4* smem_ptr) { + static_assert(sizeof(T) == 2, "T must be half/bfloat16"); + ldmatrix_m8n8x4_trans((uint32_t*)v->ptr(), smem_ptr); +} + +// template +__device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, uint4* smem_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + (__CUDACC_VER_MAJOR__ >= 11) + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" + : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3])); +#else + const uint32_t tx = threadIdx.x; + uint4 word; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4); + word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1); + word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2); + word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3); + if (tx / 8 == reg_id) { + *(uint4*)smem_ptr = word; + } + } +#endif +} + +template +__device__ __forceinline__ void stmatrix_m8n8x4(vec_t* v, + uint4* smem_ptr) { + static_assert(sizeof(T) == 2, "T must be half/bfloat16"); + stmatrix_m8n8x4((uint32_t*)v->ptr(), smem_ptr); +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + vec_t* C, vec_t* A, vec_t* B) { + mma_sync_m16n16k16_row_col_f16f16f32(C->ptr(), (uint32_t*)A->ptr(), + (uint32_t*)B->ptr()); +} + +} // namespace mma + +} // namespace flashinfer + +#endif // FLASHINFER_MMA_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/page.cuh b/server/punica_kernels/punica_kernels/flashinfer/page.cuh new file mode 100644 index 000000000..b0239fefb --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/page.cuh @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PAGE_CUH_ +#define FLASHINFER_PAGE_CUH_ + +#include "layout.cuh" +#include "utils.cuh" +#include "vec_dtypes.cuh" + +namespace flashinfer { + +/*! + * \brief Paged key-value cache + * \tparam DType The data type of the key-value cache + * \tparam IdType The index data type of the kv-cache + * \note layout: [max_num_pages, num_layers, 2, num_heads, page_size, head_dim] + */ +template +struct paged_kv_t { + uint32_t num_layers; + uint32_t layer_idx; + uint32_t num_heads; + uint32_t page_size; + uint32_t head_dim; + uint32_t batch_size; + + // [max_num_pages * num_layers * 2 * num_heads * page_size * head_dim] + // The flattened key-value cache + DType* data; + // [batch_size + 1] The page indptr array, with the first element 0 + IdType* indptr; + // [nnz_pages] The page indices array + IdType* indices; + // [batch_size] The offset of the last page for each request in the batch + IdType* last_page_offset; + + /* ------------ Auxliary Information Used in Cooperative Kernels ------------ + */ + IdType* cooperative_indptr; + IdType* batch_idx_map; + IdType* chunk_start; + IdType* seq_lens_before_split; + + /*! + * \brief Construct a paged key-value cache + * \param num_layers The number of layers + * \param layer_idx The index of the layer + * \param num_heads The number of heads + * \param page_size The size of each page + * \param head_dim The dimension of each head + * \param batch_size The batch size + * \param data The flattened key-value cache + * \param indptr The page indptr array + * \param indices The page indices array + * \param last_page_offset The offset of the last page for each request in the + * batch + */ + __host__ __device__ __forceinline__ paged_kv_t( + uint32_t num_layers, uint32_t layer_idx, uint32_t num_heads, + uint32_t page_size, uint32_t head_dim, uint32_t batch_size, DType* data, + IdType* indptr, IdType* indices, IdType* last_page_offset) + : num_layers(num_layers), + layer_idx(layer_idx), + num_heads(num_heads), + page_size(page_size), + head_dim(head_dim), + batch_size(batch_size), + data(data), + indptr(indptr), + indices(indices), + last_page_offset(last_page_offset), + cooperative_indptr(nullptr), + batch_idx_map(nullptr), + chunk_start(nullptr), + seq_lens_before_split(nullptr) {} + + /*! + * \brief Construct a paged key-value cache with auxiliary information for + * cooperative kernels \param num_layers The number of layers \param layer_idx + * The index of the layer \param num_heads The number of heads \param + * page_size The size of each page \param head_dim The dimension of each head + * \param batch_size The batch size + * \param data The flattened key-value cache + * \param indptr The page indptr array + * \param indices The page indices array + * \param last_page_offset The offset of the last page for each request in the + * batch + */ + __host__ __device__ __forceinline__ paged_kv_t( + uint32_t num_layers, uint32_t layer_idx, uint32_t num_heads, + uint32_t page_size, uint32_t head_dim, uint32_t batch_size, DType* data, + IdType* indptr, IdType* indices, IdType* last_page_offset, + IdType* cooperative_indptr, IdType* batch_idx_map, IdType* chunk_start, + IdType* seq_lens_before_split) + : num_layers(num_layers), + layer_idx(layer_idx), + num_heads(num_heads), + page_size(page_size), + head_dim(head_dim), + batch_size(batch_size), + data(data), + indptr(indptr), + indices(indices), + last_page_offset(last_page_offset), + cooperative_indptr(cooperative_indptr), + batch_idx_map(batch_idx_map), + chunk_start(chunk_start), + seq_lens_before_split(seq_lens_before_split) {} + + __host__ __device__ __forceinline__ size_t get_k_elem_offset( + size_t page_idx, size_t head_idx, size_t entry_idx, size_t feat_idx) { + return (((page_idx * num_layers + layer_idx) * 2 * num_heads + head_idx) * + page_size + + entry_idx) * + head_dim + + feat_idx; + } + + __host__ __device__ __forceinline__ size_t get_v_elem_offset( + size_t page_idx, size_t head_idx, size_t entry_idx, size_t feat_idx) { + return ((((page_idx * num_layers + layer_idx) * 2 + 1) * num_heads + + head_idx) * + page_size + + entry_idx) * + head_dim + + feat_idx; + } + + __host__ __device__ __forceinline__ uint32_t + get_valid_page_size(uint32_t batch_idx, uint32_t page_iter) { + if (page_iter == indptr[batch_idx + 1] - 1) { + return last_page_offset[batch_idx]; + } else { + return page_size; + } + } +}; + +template +__global__ void AppendPagedKVCacheDecodeKernel( + paged_kv_t paged_kv, DType* __restrict__ key, + DType* __restrict__ value) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t num_heads = paged_kv.num_heads; + uint32_t batch_idx = blockIdx.x / (num_heads / bdy); + uint32_t head_idx = (blockIdx.x % (num_heads / bdy)) * bdy + ty; + + uint32_t seq_len = + (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * + paged_kv.page_size + + paged_kv.last_page_offset[batch_idx]; + + uint32_t page_idx = paged_kv.indices[paged_kv.indptr[batch_idx] + + (seq_len - 1) / paged_kv.page_size]; + uint32_t entry_idx = (seq_len - 1) % paged_kv.page_size; + + vec_t::memcpy( + paged_kv.data + paged_kv.get_k_elem_offset(page_idx, head_idx, entry_idx, + tx * vec_size), + key + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size); + + vec_t::memcpy( + paged_kv.data + paged_kv.get_v_elem_offset(page_idx, head_idx, entry_idx, + tx * vec_size), + value + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size); +} + +template +__global__ void AppendPagedKVCachePrefillKernel( + paged_kv_t paged_kv, DType* __restrict__ key, + DType* __restrict__ value, IdType* __restrict__ append_indptr) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t num_heads = paged_kv.num_heads; + uint32_t batch_idx = blockIdx.x / (num_heads / bdy); + uint32_t head_idx = (blockIdx.x % (num_heads / bdy)) * bdy + ty; + + uint32_t seq_len = + (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * + paged_kv.page_size + + paged_kv.last_page_offset[batch_idx]; + uint32_t append_seq_len = + append_indptr[batch_idx + 1] - append_indptr[batch_idx]; + uint32_t append_start = seq_len - append_seq_len; + +#pragma unroll 2 + for (uint32_t j = 0; j < append_seq_len; ++j) { + uint32_t page_seq_idx = j + append_start; + uint32_t page_idx = paged_kv.indices[paged_kv.indptr[batch_idx] + + page_seq_idx / paged_kv.page_size]; + uint32_t entry_idx = page_seq_idx % paged_kv.page_size; + + vec_t::memcpy( + paged_kv.data + paged_kv.get_k_elem_offset(page_idx, head_idx, + entry_idx, tx * vec_size), + key + + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + + tx * vec_size); + + vec_t::memcpy( + paged_kv.data + paged_kv.get_v_elem_offset(page_idx, head_idx, + entry_idx, tx * vec_size), + value + + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + + tx * vec_size); + } +} + +template +__global__ void PagedKVCacheToRaggedTensorKernel( + paged_kv_t paged_kv, DType* __restrict__ key, + DType* __restrict__ value, IdType* __restrict__ kv_indptr) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t num_heads = paged_kv.num_heads; + uint32_t batch_idx = blockIdx.x / (num_heads / bdy); + uint32_t head_idx = (blockIdx.x % (num_heads / bdy)) * bdy + ty; + +#pragma unroll 2 + for (uint32_t j = 0; j < kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + ++j) { + uint32_t page_idx = + paged_kv.indices[paged_kv.indptr[batch_idx] + j / paged_kv.page_size]; + uint32_t entry_idx = j % paged_kv.page_size; + vec_t::memcpy( + key + ((kv_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + + tx * vec_size, + paged_kv.data + paged_kv.get_k_elem_offset(page_idx, head_idx, + entry_idx, tx * vec_size)); + vec_t::memcpy( + value + ((kv_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + + tx * vec_size, + paged_kv.data + paged_kv.get_v_elem_offset(page_idx, head_idx, + entry_idx, tx * vec_size)); + } +} + +template +cudaError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, + DType* key, DType* value, + cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + uint32_t head_dim = paged_kv.head_dim; + uint32_t batch_size = paged_kv.batch_size; + uint32_t num_heads = paged_kv.num_heads; + SWITCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + constexpr uint32_t bdy = 128 / bdx; + assert(num_heads % bdy == 0); + dim3 nblks(batch_size * num_heads / bdy); + dim3 nthrs(bdx, bdy); + auto kernel = AppendPagedKVCacheDecodeKernel; + void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t AppendPagedKVCachePrefill(paged_kv_t paged_kv, + DType* key, DType* value, + IdType* append_indptr, + cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + uint32_t head_dim = paged_kv.head_dim; + uint32_t batch_size = paged_kv.batch_size; + uint32_t num_heads = paged_kv.num_heads; + SWITCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + constexpr uint32_t bdy = 128 / bdx; + assert(num_heads % bdy == 0); + dim3 nblks(batch_size * num_heads / bdy); + dim3 nthrs(bdx, bdy); + auto kernel = AppendPagedKVCachePrefillKernel; + void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, + (void*)&append_indptr}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t PagedKVCacheToRaggedTensorComputeIndptr( + paged_kv_t paged_kv, std::vector& kv_indptr_host, + cudaStream_t stream = nullptr, uint32_t dev_id = 0) { + const uint32_t batch_size = paged_kv.batch_size; + const uint32_t page_size = paged_kv.page_size; + std::vector paged_kv_indptr_host(batch_size + 1), + paged_kv_last_page_offset_host(batch_size); + kv_indptr_host.resize(batch_size + 1); + + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + paged_kv_indptr_host.data(), paged_kv.indptr, + sizeof(IdType) * (batch_size + 1), cudaMemcpyDeviceToHost, stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + paged_kv_last_page_offset_host.data(), paged_kv.last_page_offset, + sizeof(IdType) * batch_size, cudaMemcpyDeviceToHost, stream)); + + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + + kv_indptr_host[0] = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + kv_indptr_host[i + 1] = + kv_indptr_host[i] + + (paged_kv_indptr_host[i + 1] - paged_kv_indptr_host[i] - 1) * + page_size + + paged_kv_last_page_offset_host[i]; + } + + return cudaSuccess; +} + +template +cudaError_t PagedKVCacheToRaggedTensor(paged_kv_t paged_kv, + DType* key, DType* value, + IdType* kv_indptr, + cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + const uint32_t head_dim = paged_kv.head_dim; + const uint32_t batch_size = paged_kv.batch_size; + const uint32_t num_heads = paged_kv.num_heads; + const uint32_t page_size = paged_kv.page_size; + + SWITCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + constexpr uint32_t bdy = 128 / bdx; + assert(num_heads % bdy == 0); + dim3 nblks(batch_size * num_heads / bdy); + dim3 nthrs(bdx, bdy); + auto kernel = PagedKVCacheToRaggedTensorKernel; + void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, + (void*)&kv_indptr}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + return cudaSuccess; +} + +} // namespace flashinfer + +#endif // FLAHSINFER_PAGE_CUH_ \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/flashinfer/permuted_smem.cuh b/server/punica_kernels/punica_kernels/flashinfer/permuted_smem.cuh new file mode 100644 index 000000000..993fbf92d --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/permuted_smem.cuh @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PERMUTED_SMEM_CUH_ +#define FLASHINFER_PERMUTED_SMEM_CUH_ + +#include +#include +#include + +#include + +#include "cp_async.cuh" +#include "mma.cuh" + +namespace flashinfer { + +// Each cell is 4 bytes. +using cell_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t cell_capacity() { + return sizeof(cell_t) / sizeof(T); +} + +struct smem_t { + cell_t* base; + uint32_t offset; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((cell_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + return (i / 2) * stride * 2 + (j / 4) * 8 + (i % 2) * 4 + + ((j % 4) ^ ((i / 2) % 4)); + } + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R) { + cell_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(R, smem_ptr); + } + __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R) { + cell_t* smem_ptr = base + offset; + mma::stmatrix_m8n8x4(R, smem_ptr); + } + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R) { + cell_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(R, smem_ptr); + } + template + __device__ __forceinline__ void ldmatrix_m8n8x4(vec_t* v) { + cell_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(v, smem_ptr); + } + template + __device__ __forceinline__ void stmatrix_m8n8x4(vec_t* v) { + cell_t* smem_ptr = base + offset; + mma::stmatrix_m8n8x4(v, smem_ptr); + } + template + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(vec_t* v) { + cell_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(v, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(const T* gptr, + bool predicate) { + cell_t* smem_ptr = base + offset; + cp_async::pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + template + __device__ __forceinline__ void load_128b_async(const T* gptr) { + cell_t* smem_ptr = base + offset; + cp_async::load_128b(smem_ptr, reinterpret_cast(gptr)); + } + template + __device__ __forceinline__ void store_128b(T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_PERMUTED_SMEM_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/prefill.cuh b/server/punica_kernels/punica_kernels/flashinfer/prefill.cuh new file mode 100644 index 000000000..e2bb52c8e --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/prefill.cuh @@ -0,0 +1,932 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PREFILL_CUH_ +#define FLASHINFER_PREFILL_CUH_ +#include +#include +#include +#include +#include + +#include "cp_async.cuh" +#include "layout.cuh" +#include "math.cuh" +#include "mma.cuh" +#include "page.cuh" +#include "permuted_smem.cuh" +#include "rope.cuh" +#include "state.cuh" +#include "utils.cuh" + +namespace flashinfer { + +namespace cg = cooperative_groups; + +constexpr uint32_t warp_size = 32; + +namespace { + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) { + return (x > y) ? x - y : 0U; +} + +template +__device__ __forceinline__ void apply_llama_rope( + vec_t* x_first_half, vec_t* x_second_half, + const vec_t& rope_freq, uint32_t offset, float scale = 1.f) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + float cos, sin, tmp; + uint32_t i, j; + if constexpr (row_major) { + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + i = ((reg_id % 4) / 2); + j = (reg_id / 4); + } else { + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 + i = reg_id / 4; + j = (reg_id % 4) / 2; + } + __sincosf(float(offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, + &cos); + tmp = (*x_first_half)[reg_id]; + (*x_first_half)[reg_id] = + (tmp * cos - (float)(*x_second_half)[reg_id] * sin) * scale; + (*x_second_half)[reg_id] = + ((float)(*x_second_half)[reg_id] * cos + tmp * sin) * scale; + } +} + +} // namespace + +template +__device__ __forceinline__ void produce_kv( + smem_t* smem, T* gptr, const tensor_info_t& qkv_info, + const uint32_t kv_idx_base, const uint32_t kv_len, const uint32_t head_idx, + const uint32_t tx, const uint32_t ty) { + constexpr uint32_t num_cells_per_head_in = + num_frags_y * 16 / cell_capacity(); + + uint32_t kv_idx = kv_idx_base + ty * 4 + (tx % 16) / 4; + smem->offset = smem_t::get_permuted_offset( + ty * 4 + (tx % 16) / 4, (tx / 16) * 4 + tx % 4); + gptr += qkv_info.get_kv_elem_offset( + kv_idx, head_idx, ((tx / 16) * 4 + tx % 4) * cell_capacity()); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + smem->load_128b_async(gptr, kv_idx < kv_len); + smem->offset += 16; + gptr += 8 * cell_capacity(); + } + kv_idx += num_warps * 4; + smem->offset += num_warps * 4 * num_cells_per_head_in - 4 * num_frags_y; + gptr += num_warps * 4 * qkv_info.get_n_stride() - + 2 * num_frags_y * cell_capacity(); + } +} + +template +__global__ void SinglePrefillWithKVCacheKernel( + DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + DTypeOut* __restrict__ o, float* __restrict__ tmp, + const tensor_info_t qkv_info, float sm_scale, + const float log2_rope_inv_scale, const float log2_rope_inv_theta) { + sm_scale *= math::log2e; + const uint32_t qo_len = qkv_info.qo_len; + const uint32_t kv_len = qkv_info.kv_len; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, + qo_head_idx = blockIdx.z, + kv_head_idx = qo_head_idx / group_size; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_size = + cooperative ? (kv_len + num_chunks - 1) / num_chunks : kv_len; + const uint32_t chunk_start = cooperative ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + cooperative ? min((chunk_idx + 1) * chunk_size, kv_len) : kv_len; + auto block = cg::this_thread_block(); + + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_cells_per_head_in = + head_dim / cell_capacity(); + constexpr uint32_t num_cells_per_head_out = + head_dim / cell_capacity(); + + static_assert(num_frags_z * num_frags_y % num_warps == 0); + + extern __shared__ uint8_t smem[]; + + // q_frag will be used only when pin_q_in_reg is true + vec_t q_frag[num_frags_x][num_frags_y]; + vec_t x_frag[num_frags_x][num_frags_z]; + vec_t o_frag[num_frags_x][num_frags_y]; + vec_t m[num_frags_x]; + vec_t d[num_frags_x]; + vec_t rope_freq[num_frags_y / 2]; + if constexpr (rotary_mode == RotaryMode::kLlama) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y / 2; ++fy) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + rope_freq[fy][j] = math::ptx_exp2( + log2_rope_inv_scale + + log2_rope_inv_theta * + float(2 * ((fy * 16 + (j / 2) * 8 + (tx % 4) * 2 + (j % 2)) % + (head_dim / 2))) / + float(head_dim)); + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + o_frag[fx][fy].fill(0.); + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + m[fx].fill(-5e4); + d[fx].fill(0.); + } + + // cooperative fetch q fragment from gmem to reg + smem_t q_smem(smem); + uint32_t q_idx = (bx * num_warps + ty) * num_frags_x * 16 + tx / 4; + q_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx / 4, tx % 4); + DTypeIn* q_ptr = + q + qkv_info.get_qo_elem_offset(q_idx, qo_head_idx, + (tx % 4) * cell_capacity()); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 2; ++fyo) { + // load q fragment from gmem to smem + q_smem.load_128b_async(q_ptr, q_idx < qo_len); + q_smem.offset += 8; + q_ptr += 4 * cell_capacity(); + } + q_idx += 8; + q_smem.offset += 8 * num_cells_per_head_in - 4 * num_frags_y; + q_ptr += 8 * qkv_info.get_n_stride() - + 2 * num_frags_y * cell_capacity(); + } + } + cp_async::commit_group(); + cp_async::wait_group<0>(); + block.sync(); + + // preprocess q fragment, multiply sm_scale and apply rotary + if constexpr (pin_q_in_reg) { + // pin q fragment in reg + q_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx % 16, tx / 16); + q_idx = (bx * num_warps + ty) * num_frags_x * 16 + tx / 4; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + q_smem.ldmatrix_m8n8x4(&q_frag[fx][fy]); + q_smem.offset = (q_smem.offset ^ 0x2) + (fy & 0x1) * 8; + } + q_smem.offset += 16 * num_cells_per_head_in - 4 * num_frags_y; + if constexpr (rotary_mode == RotaryMode::kLlama) { +#pragma unroll + for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { + apply_llama_rope( + &q_frag[fx][fyi], &q_frag[fx][num_frags_y / 2 + fyi], + rope_freq[fyi], q_idx + kv_len - qo_len, sm_scale); + } + } else { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + q_frag[fx][fy][reg_id] *= sm_scale; + } + } + } + q_idx += 16; + } + } else { + // do not pin q fragment in reg + if constexpr (rotary_mode == RotaryMode::kLlama) { + q_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx % 16, tx / 16); + q_idx = (bx * num_warps + ty) * num_frags_x * 16 + tx / 4; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { + vec_t q_frag_local[2]; + q_smem.ldmatrix_m8n8x4(&q_frag_local[0]); + q_smem.offset += num_frags_y * 2; + q_smem.ldmatrix_m8n8x4(&q_frag_local[1]); + apply_llama_rope(&q_frag_local[0], &q_frag_local[1], + rope_freq[fyi], + q_idx + kv_len - qo_len, sm_scale); + q_smem.stmatrix_m8n8x4(&q_frag_local[1]); + q_smem.offset -= num_frags_y * 2; + q_smem.stmatrix_m8n8x4(&q_frag_local[0]); + q_smem.offset = (q_smem.offset ^ 0x2) + (fyi & 0x1) * 8; + } + q_smem.offset += 16 * num_cells_per_head_in - 2 * num_frags_y; + q_idx += 16; + } + } else { +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 256; ++i) { + vec_t tmp; + tmp.load((DTypeIn*)(q_smem.base + + ty * num_frags_x * 16 * num_cells_per_head_in) + + i * 256 + tx * 8); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp[reg_id] *= sm_scale; + } + tmp.store((DTypeIn*)(q_smem.base + + ty * num_frags_x * 16 * num_cells_per_head_in) + + i * 256 + tx * 8); + } + } + } + + smem_t k_smem[num_stages_smem]; + smem_t v_smem[num_stages_smem]; +#pragma unroll + for (uint32_t i = 0; i < num_stages_smem; ++i) { + if constexpr (pin_q_in_reg) { + k_smem[i].base = + (cell_t*)(smem + (i * num_frags_z) * 16 * head_dim * sizeof(DTypeIn)); + v_smem[i].base = (cell_t*)(smem + ((num_stages_smem + i) * num_frags_z) * + 16 * head_dim * sizeof(DTypeIn)); + } else { + k_smem[i].base = + (cell_t*)(smem + (num_warps * num_frags_x + i * num_frags_z) * 16 * + head_dim * sizeof(DTypeIn)); + v_smem[i].base = (cell_t*)(smem + (num_warps * num_frags_x + + (num_stages_smem + i) * num_frags_z) * + 16 * head_dim * sizeof(DTypeIn)); + } + } + + const uint32_t num_iterations = + ((causal ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len - qo_len + + ((bx + 1) * num_frags_x * num_warps) * 16, + chunk_start)) + : chunk_end - chunk_start) + + 16 * num_frags_z - 1) / + (16 * num_frags_z); + + const uint32_t mask_iteration = + (causal ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len + bx * num_warps * num_frags_x - qo_len, + chunk_start)) + : (chunk_end - chunk_start)) / + (16 * num_frags_z); + +#pragma unroll + for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { + const uint32_t stage_idx = iter; + produce_kv( + k_smem + stage_idx, k, qkv_info, chunk_start + iter * 16 * num_frags_z, + chunk_end, kv_head_idx, tx, ty); + cp_async::commit_group(); + produce_kv( + v_smem + stage_idx, v, qkv_info, chunk_start + iter * 16 * num_frags_z, + chunk_end, kv_head_idx, tx, ty); + cp_async::commit_group(); + } + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + const uint32_t stage_idx = iter % num_stages_smem; + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + // init x_frag with 0 +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + vec_t x_frag_local[num_frags_z]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + x_frag_local[fz].fill(0.); + } + + if constexpr (rotary_mode == RotaryMode::kLlama) { + // apply rotary on the fly + if constexpr (!pin_q_in_reg) { + q_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx % 16, tx / 16); + } + + k_smem[stage_idx].offset = + smem_t::get_permuted_offset( + 8 * (tx / 16) + tx % 8, (tx % 16) / 8); +#pragma unroll + for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { + vec_t a_frag[2]; + if constexpr (!pin_q_in_reg) { + q_smem.ldmatrix_m8n8x4(&a_frag[0]); + q_smem.offset += num_frags_y * 2; + q_smem.ldmatrix_m8n8x4(&a_frag[1]); + q_smem.offset -= num_frags_y * 2; + q_smem.offset = (q_smem.offset ^ 0x2) + (fyi & 0x1) * 8; + } + uint32_t kv_idx = chunk_start + iter * 16 * num_frags_z + tx / 4; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_t k_frag[2]; + k_smem[stage_idx].ldmatrix_m8n8x4(&k_frag[0]); + k_smem[stage_idx].offset += num_frags_y * 2; + k_smem[stage_idx].ldmatrix_m8n8x4(&k_frag[1]); + k_smem[stage_idx].offset += + 16 * num_cells_per_head_in - num_frags_y * 2; + apply_llama_rope(&k_frag[0], &k_frag[1], + rope_freq[fyi], kv_idx); + kv_idx += 16; + mma::mma_sync_m16n16k16_row_col_f16f16f32( + &x_frag_local[fz], pin_q_in_reg ? &q_frag[fx][fyi] : &a_frag[0], + &k_frag[0]); + mma::mma_sync_m16n16k16_row_col_f16f16f32( + &x_frag_local[fz], + pin_q_in_reg ? &q_frag[fx][fyi + num_frags_y / 2] : &a_frag[1], + &k_frag[1]); + } + k_smem[stage_idx].offset = (k_smem[stage_idx].offset ^ 0x2) + + (fyi & 0x1) * 8 - + num_frags_z * 16 * num_cells_per_head_in; + } + q_smem.offset += 16 * num_cells_per_head_in - 2 * num_frags_y; + } else { + if constexpr (!pin_q_in_reg) { + q_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx % 16, tx / 16); + } + + k_smem[stage_idx].offset = + smem_t::get_permuted_offset( + 8 * (tx / 16) + tx % 8, (tx % 16) / 8); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + vec_t a_frag; + if constexpr (!pin_q_in_reg) { + q_smem.ldmatrix_m8n8x4(&a_frag); + q_smem.offset = (q_smem.offset ^ 0x2) + (fy & 0x1) * 8; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_t b_frag; + k_smem[stage_idx].ldmatrix_m8n8x4(&b_frag); + k_smem[stage_idx].offset += 16 * num_cells_per_head_in; + + mma::mma_sync_m16n16k16_row_col_f16f16f32( + &x_frag_local[fz], pin_q_in_reg ? &q_frag[fx][fy] : &a_frag, + &b_frag); + } + + k_smem[stage_idx].offset = (k_smem[stage_idx].offset ^ 0x2) + + (fy & 0x1) * 8 - + num_frags_z * 16 * num_cells_per_head_in; + } + q_smem.offset += 16 * num_cells_per_head_in - 4 * num_frags_y; + } + + // apply mask + if (iter >= mask_iteration) { + uint32_t q_idx_base = + ((bx * num_warps + ty) * num_frags_x) * 16 + tx / 4, + kv_idx_base = + chunk_start + iter * 16 * num_frags_z + 2 * (tx % 4); +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = q_idx_base + 8 * ((reg_id % 4) / 2), + kv_idx = kv_idx_base + 8 * (reg_id / 4) + reg_id % 2; + const bool out_of_boundary = + (causal ? (kv_idx > kv_len + q_idx - qo_len || + (cooperative && kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + x_frag_local[fz][reg_id] = + out_of_boundary ? -5e4 : x_frag_local[fz][reg_id]; + } + kv_idx_base += 16; + } + kv_idx_base -= num_frags_z * 16; + q_idx_base += 16; + } + + // compute m,d states in online softmax +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float m_local = max( + max(x_frag_local[fz][j * 2 + 0], x_frag_local[fz][j * 2 + 1]), + max(x_frag_local[fz][j * 2 + 4], x_frag_local[fz][j * 2 + 5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x2)); + m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x1)); + float o_scale = math::ptx_exp2(m_prev - m[fx][j]); + d[fx][j] *= o_scale; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + o_frag[fx][fy][j * 2 + 0] *= o_scale; + o_frag[fx][fy][j * 2 + 1] *= o_scale; + o_frag[fx][fy][j * 2 + 4] *= o_scale; + o_frag[fx][fy][j * 2 + 5] *= o_scale; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + x_frag_local[fz][j * 2 + 0] = + math::ptx_exp2(x_frag_local[fz][j * 2 + 0] - m[fx][j]); + x_frag_local[fz][j * 2 + 1] = + math::ptx_exp2(x_frag_local[fz][j * 2 + 1] - m[fx][j]); + x_frag_local[fz][j * 2 + 4] = + math::ptx_exp2(x_frag_local[fz][j * 2 + 4] - m[fx][j]); + x_frag_local[fz][j * 2 + 5] = + math::ptx_exp2(x_frag_local[fz][j * 2 + 5] - m[fx][j]); + x_frag[fx][fz].cast_from(x_frag_local[fz]); + d[fx][j] += x_frag_local[fz][j * 2 + 0] + + x_frag_local[fz][j * 2 + 1] + + x_frag_local[fz][j * 2 + 4] + x_frag_local[fz][j * 2 + 5]; + } + } + } + + block.sync(); + produce_kv( + k_smem + stage_idx, k, qkv_info, + chunk_start + (iter + num_stages_smem) * 16 * num_frags_z, kv_len, + kv_head_idx, tx, ty); + cp_async::commit_group(); + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + + // load v tile from smem to reg + v_smem[stage_idx].offset = + smem_t::get_permuted_offset(tx % 16, tx / 16); + + // compute sfm*v +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_t v_frag; + v_smem[stage_idx].ldmatrix_m8n8x4_trans(&v_frag); + v_smem[stage_idx].offset += 16 * num_cells_per_head_in; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + &o_frag[fx][fy], &x_frag[fx][fz], &v_frag); + } + } + v_smem[stage_idx].offset = (v_smem[stage_idx].offset ^ 0x2) + + (fy & 0x1) * 8 - + num_frags_z * 16 * num_cells_per_head_in; + } + block.sync(); + produce_kv( + v_smem + stage_idx, v, qkv_info, + chunk_start + (iter + num_stages_smem) * 16 * num_frags_z, kv_len, + kv_head_idx, tx, ty); + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + block.sync(); + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d[fx][j] += math::shfl_xor_sync(d[fx][j], 0x2); + d[fx][j] += math::shfl_xor_sync(d[fx][j], 0x1); + } + } + + if constexpr (cooperative) { + // aggregate global state + auto grid = cg::this_grid(); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + o_frag[fx][fy].store( + tmp + + ((fx * num_frags_y + fy) * grid.size() + grid.thread_rank()) * 8); + o_frag[fx][fy].fill(0.f); + } + } + float* tmp_md = tmp + num_frags_x * num_frags_y * 8 * grid.size(); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + *(float2*)&tmp_md[((fx * 2 + j) * grid.size() + grid.thread_rank()) * + 2] = make_float2(m[fx][j], d[fx][j]); + m[fx][j] = -5e4; + d[fx][j] = 0.f; + } + } + + grid.sync(); + + for (uint32_t iter = 0; iter < num_chunks; ++iter) { + float other_scale[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float2 md = *( + float2*)&tmp_md[((fx * 2 + j) * grid.size() + + ((qo_head_idx * num_chunks + iter) * gridDim.x + + bx) * + block.num_threads() + + block.thread_rank()) * + 2]; + float mi = md.x, di = md.y, m_prev = m[fx][j]; + m[fx][j] = max(m_prev, mi); + float o_scale = math::ptx_exp2(m_prev - m[fx][j]); + other_scale[fx][j] = math::ptx_exp2(mi - m[fx][j]); + d[fx][j] = d[fx][j] * o_scale + di * other_scale[fx][j]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + o_frag[fx][fy][j * 2 + 0] *= o_scale; + o_frag[fx][fy][j * 2 + 1] *= o_scale; + o_frag[fx][fy][j * 2 + 4] *= o_scale; + o_frag[fx][fy][j * 2 + 5] *= o_scale; + } + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + vec_t o_frag_i; + o_frag_i.load(tmp + + ((fx * num_frags_y + fy) * grid.size() + + ((qo_head_idx * num_chunks + iter) * gridDim.x + bx) * + block.num_threads() + + block.thread_rank()) * + 8); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] += + o_frag_i[reg_id] * other_scale[fx][(reg_id % 4) / 2]; + } + } + } + } + } + + // divide d +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + __fdividef(o_frag[fx][fy][reg_id], d[fx][(reg_id % 4) / 2]); + } + } + } + + // write back + smem_t o_smem(smem); + if constexpr (std::is_same::value) { + // TODO(Zihao) + } else if constexpr (sizeof(DTypeOut) == 2) { + o_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx / 4, 0); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + vec_cast( + (DTypeOut*)(o_smem.base + o_smem.offset) + (tx % 4) * 2, + &o_frag[fx][fy][0]); + vec_cast((DTypeOut*)(o_smem.base + o_smem.offset + + 8 * num_cells_per_head_out) + + (tx % 4) * 2, + &o_frag[fx][fy][2]); + vec_cast( + (DTypeOut*)(o_smem.base + (o_smem.offset ^ 0x1)) + (tx % 4) * 2, + &o_frag[fx][fy][4]); + vec_cast( + (DTypeOut*)(o_smem.base + (o_smem.offset ^ 0x1) + + 8 * num_cells_per_head_out) + + (tx % 4) * 2, + &o_frag[fx][fy][6]); + o_smem.offset = (o_smem.offset ^ 0x2) + (fy & 0x1) * 8; + } + o_smem.offset += 16 * num_cells_per_head_out - num_frags_y * 4; + } + + o_smem.offset = smem_t::get_permuted_offset( + ty * num_frags_x * 16 + tx % 16, tx / 16); + uint32_t o_idx = (bx * num_warps + ty) * num_frags_x * 16 + tx % 16; + DTypeOut* o_ptr = + o + qkv_info.get_qo_elem_offset(o_idx, qo_head_idx, + tx / 16 * cell_capacity()); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + if (o_idx < qo_len) { + o_smem.store_128b(o_ptr); + } + o_ptr += 2 * cell_capacity(); + o_smem.offset = (o_smem.offset ^ 0x2) + (fy & 0x1) * 8; + } + o_idx += 16; + o_ptr += qkv_info.get_n_stride() * 16 - + 2 * num_frags_y * cell_capacity(); + o_smem.offset += 16 * num_cells_per_head_out - num_frags_y * 4; + } + } else { + // NOTE(Zihao): Not implemented yet. + } +} + +template +cudaError_t SinglePrefillWithKVCacheWorkEstimation( + uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, + bool causal = true, QKVLayout layout = QKVLayout::kNHD, + RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + SWITCH_NUM_FRAGS_X( + qo_len > 64, NUM_FRAGS_X, + {SWITCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {SWITCH_CAUSAL( + causal, CAUSAL, + {SWITCH_HEAD_DIM( + head_dim, HEAD_DIM, + {SWITCH_ROTARY_MODE( + rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(layout, LAYOUT, { + constexpr bool pin_q_in_reg = false; + constexpr uint32_t num_frags_x = NUM_FRAGS_X; + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_frags_z = 2; + constexpr uint32_t num_warps = 4UL; + constexpr uint32_t num_stages_smem = 1; + constexpr uint32_t num_threads = num_warps * warp_size; + constexpr uint32_t num_rows_per_cta = + num_frags_x * num_warps * 16; + auto cooperative_kernel = + SinglePrefillWithKVCacheKernel< + pin_q_in_reg, true, GROUP_SIZE, CAUSAL, LAYOUT, + ROTARY_MODE, num_frags_x, num_frags_y, + num_frags_z, num_stages_smem, num_warps, + DTypeIn, DTypeOut>; + uint32_t smem_size = + (pin_q_in_reg + ? max(num_frags_x * num_warps, + 2 * num_stages_smem * num_frags_z) + : (num_frags_x * num_warps + + 2 * num_stages_smem * num_frags_z)) * + 16 * head_dim * sizeof(DTypeIn); + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, cooperative_kernel, + num_threads, smem_size)); + max_grid_size = num_blocks_per_sm * num_sm; + uint32_t num_chunks = + min((num_blocks_per_sm * num_sm) / + (num_qo_heads * + (qo_len + (num_rows_per_cta - 1)) / + num_rows_per_cta), + kv_len / 512); + if (num_chunks > 1) { + uint32_t grid_size = + 32 * num_warps * + ((qo_len + (num_rows_per_cta - 1)) / + num_rows_per_cta) * + num_chunks * num_qo_heads; + tmp_size = sizeof(float) * + (4 * num_frags_x + + num_frags_x * num_frags_y * 8) * + grid_size; + } else { + tmp_size = 0; + } + })})})})})}); + return cudaSuccess; +} + +template +cudaError_t SinglePrefillWithKVCache( + DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, uint32_t head_dim, bool causal = true, + QKVLayout layout = QKVLayout::kNHD, + RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, + float rope_theta = 1e4, cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + const float sm_scale = 1.f / std::sqrt(float(head_dim)); + const float log2_rope_inv_scale = -std::log2f(rope_scale); + const float log2_rope_inv_theta = -std::log2f(rope_theta); + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + assert(kv_len >= qo_len); + + SWITCH_NUM_FRAGS_X( + qo_len > 64, NUM_FRAGS_X, + {SWITCH_GQA_GROUP_SIZE( + num_qo_heads / num_kv_heads, GROUP_SIZE, + {SWITCH_CAUSAL( + causal, CAUSAL, + {SWITCH_HEAD_DIM( + head_dim, HEAD_DIM, + {SWITCH_ROTARY_MODE( + rotary_mode, ROTARY_MODE, {SWITCH_LAYOUT(layout, LAYOUT, { + constexpr bool pin_q_in_reg = false; + constexpr uint32_t num_frags_x = NUM_FRAGS_X; + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_frags_z = 3; + constexpr uint32_t num_warps = 4UL; + constexpr uint32_t num_stages_smem = 1; + constexpr uint32_t num_threads = num_warps * warp_size; + constexpr uint32_t num_rows_per_cta = + num_frags_x * num_warps * 16; + auto cooperative_kernel = + SinglePrefillWithKVCacheKernel< + pin_q_in_reg, true, GROUP_SIZE, CAUSAL, LAYOUT, + ROTARY_MODE, num_frags_x, num_frags_y, + num_frags_z, num_stages_smem, num_warps, + DTypeIn, DTypeOut>; + tensor_info_t qkv_info( + qo_len, kv_len, num_kv_heads, HEAD_DIM); + uint32_t smem_size = + (pin_q_in_reg + ? max(num_frags_x * num_warps, + 2 * num_stages_smem * num_frags_z) + : (num_frags_x * num_warps + + 2 * num_stages_smem * num_frags_z)) * + 16 * head_dim * sizeof(DTypeIn); + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, cooperative_kernel, + num_threads, smem_size)); + uint32_t num_chunks = + min((num_blocks_per_sm * num_sm) / + (num_qo_heads * + (qo_len + (num_rows_per_cta - 1)) / + num_rows_per_cta), + kv_len / 512); + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not use cooperative groups + auto kernel = SinglePrefillWithKVCacheKernel< + pin_q_in_reg, false, GROUP_SIZE, CAUSAL, LAYOUT, + ROTARY_MODE, num_frags_x, num_frags_y, + num_frags_z, num_stages_smem, num_warps, DTypeIn, + DTypeOut>; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&qkv_info, + (void*)&sm_scale, + (void*)&log2_rope_inv_scale, + (void*)&log2_rope_inv_theta}; + dim3 nblks((qo_len + (num_rows_per_cta - 1)) / + num_rows_per_cta, + 1, num_qo_heads); + dim3 nthrs(32, num_warps); + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, + args, smem_size, stream)); + } else { + // Use cooperative groups to increase occupancy + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&qkv_info, + (void*)&sm_scale, + (void*)&log2_rope_inv_scale, + (void*)&log2_rope_inv_theta}; + dim3 nblks((qo_len + (num_rows_per_cta - 1)) / + num_rows_per_cta, + num_chunks, num_qo_heads); + dim3 nthrs(32, num_warps); + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + cooperative_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchCooperativeKernel( + (void*)cooperative_kernel, nblks, nthrs, args, + smem_size, stream)); + } + })})})})})}); + return cudaSuccess; +} + +template +cudaError_t BatchPrefillWithPagedKVCache( + DTypeIn* q, paged_kv_t paged_kv, IdType* q_indptr, + DTypeOut* o, float* tmp, uint32_t num_qo_heads, bool causal = true, + RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, + float rope_theta = 1e4, cudaStream_t stream = nullptr, + uint32_t dev_id = 0) { + const uint32_t num_kv_heads = paged_kv.num_heads; + const uint32_t head_dim = paged_kv.head_dim; + const uint32_t batch_size = paged_kv.batch_size; + + FLASHINFER_CUDA_CALL(cudaSetDevice(dev_id)); + std::vector q_indptr_h(paged_kv.batch_size + 1); + std::vector kv_indptr_h(paged_kv.batch_size + 1); + + FLASHINFER_CUDA_CALL(PagedKVCacheToRaggedTensorComputeIndptr( + paged_kv, kv_indptr_h, stream, dev_id)); + uint32_t nnz = kv_indptr_h.back(); + + DTypeIn *keys = nullptr, *values = nullptr; + IdType* kv_indptr = nullptr; + FLASHINFER_CUDA_CALL(cudaMallocAsync( + &keys, nnz * num_kv_heads * head_dim * sizeof(DTypeIn), stream)); + FLASHINFER_CUDA_CALL(cudaMallocAsync( + &values, nnz * num_kv_heads * head_dim * sizeof(DTypeIn), stream)); + FLASHINFER_CUDA_CALL( + cudaMallocAsync(&kv_indptr, (batch_size + 1) * sizeof(IdType), stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + q_indptr_h.data(), q_indptr, sizeof(IdType) * (paged_kv.batch_size + 1), + cudaMemcpyDeviceToHost, stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + kv_indptr, kv_indptr_h.data(), sizeof(IdType) * (paged_kv.batch_size + 1), + cudaMemcpyHostToDevice, stream)); + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + FLASHINFER_CUDA_CALL( + PagedKVCacheToRaggedTensor(paged_kv, keys, values, kv_indptr, stream)); + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + + for (uint32_t batch_idx = 0; batch_idx < paged_kv.batch_size; ++batch_idx) { + SinglePrefillWithKVCache( + q + q_indptr_h[batch_idx] * num_qo_heads * head_dim, + keys + kv_indptr_h[batch_idx] * num_kv_heads * head_dim, + values + kv_indptr_h[batch_idx] * num_kv_heads * head_dim, + o + q_indptr_h[batch_idx] * num_qo_heads * head_dim, nullptr, + num_qo_heads, num_kv_heads, + q_indptr_h[batch_idx + 1] - q_indptr_h[batch_idx], + kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx], head_dim, causal, + QKVLayout::kNHD, rotary_mode, rope_scale, rope_theta, stream, dev_id); + } + FLASHINFER_CUDA_CALL(cudaFreeAsync(keys, stream)); + FLASHINFER_CUDA_CALL(cudaFreeAsync(values, stream)); + FLASHINFER_CUDA_CALL(cudaFreeAsync(kv_indptr, stream)); + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + + return cudaSuccess; +} + +} // namespace flashinfer + +#endif // FLASHINFER_PREFILL_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/rope.cuh b/server/punica_kernels/punica_kernels/flashinfer/rope.cuh new file mode 100644 index 000000000..93e9e9d5b --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/rope.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ROPE_CUH_ +#define FLASHINFER_ROPE_CUH_ + +#include + +namespace flashinfer { + +/*! + * \brief An enumeration class that defines different modes for applying RoPE + * (Rotary Positional Embeddings). + */ +enum class RotaryMode { + // No rotary positional embeddings + kNone = 0U, + // Apply Llama-style rope. + kLlama = 1U, +}; + +/*! + * \brief Convert RotaryMode to string + * \param rotary_mode A RotaryMode value + */ +inline std::string RotaryModeToString(const RotaryMode& rotary_mode) { + switch (rotary_mode) { + case RotaryMode::kNone: + return "None"; + case RotaryMode::kLlama: + return "Llama"; + default: + return "Unknown"; + } +} + +} // namespace flashinfer + +#endif // FLASHINFER_ROPE_CUH_ \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/flashinfer/state.cuh b/server/punica_kernels/punica_kernels/flashinfer/state.cuh new file mode 100644 index 000000000..866077385 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/state.cuh @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_STATE_CUH_ +#define FLASHINFER_STATE_CUH_ + +#include "math.cuh" +#include "vec_dtypes.cuh" + +namespace flashinfer { + +/*! + * \brief The flashattention state. + * \tparam vec_size The size of the vector used in o. + * \tparam norm_on_the_fly Whether to normalize the state on the fly. If true, + * the state will be normalized when merge() is called. If false, the state will + * be normalized when normalize() is called. + */ +template +struct state_t { + /* the weighted sum of v: exp(pre-softmax logit - m) * v / d */ + vec_t o; + /* maximum value of pre-softmax logits */ + float m; + /* sum of exp(pre-softmax logits - m) */ + float d; + + __device__ __forceinline__ void init() { + o.fill(0.f); + m = -5e4; + d = 0.f; + } + + __device__ __forceinline__ state_t() { init(); } + + /*! + * \brief Merge the state with another state. + * \param other_m The maximum value of pre-softmax logits of the other state. + * \param other_d The sum of exp(pre-softmax logits - m) of the other state. + * \param other_o The weighted sum of v of the other state. + */ + __device__ __forceinline__ void merge(const vec_t& other_o, + float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = max(m_prev, other_m); + d = d_prev * math::ptx_exp2(m_prev - m) + + other_d * math::ptx_exp2(other_m - m); + if constexpr (norm_on_the_fly) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * math::ptx_exp2(m_prev - m) * (d_prev / d) + + other_o[i] * math::ptx_exp2(other_m - m) * (other_d / d); + } + } else { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * math::ptx_exp2(m_prev - m) + + other_o[i] * math::ptx_exp2(other_m - m); + } + } + } + + /*! + * \brief Merge the state with another state. + * \param other The other state. + */ + __device__ __forceinline__ void merge( + const state_t& other) { + merge(other.o, other.m, other.d); + } + + /*! + * \brief Merge the state with a single pre-softmax logit and value vector. + * \param x The pre-softmax logit. + * \param v The value vector. + */ + __device__ __forceinline__ void merge(const vec_t& other_o, + float x) { + float m_prev = m, d_prev = d; + m = max(m_prev, x); + d = d * math::ptx_exp2(m_prev - m) + math::ptx_exp2(x - m); + if constexpr (norm_on_the_fly) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * (math::ptx_exp2(m_prev - m) * d_prev / d) + + other_o[i] * (math::ptx_exp2(x - m) / d); + } + } else { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * math::ptx_exp2(m_prev - m) + + other_o[i] * math::ptx_exp2(x - m); + } + } + } + + __device__ __forceinline__ void normalize() { + if constexpr (!norm_on_the_fly) { + // only normalize by d when not normalized on the fly +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = __fdividef(o[i], d); + } + } + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_STATE_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/utils.cuh b/server/punica_kernels/punica_kernels/flashinfer/utils.cuh new file mode 100644 index 000000000..853e797d6 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/utils.cuh @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_UTILS_CUH_ +#define FLASHINFER_UTILS_CUH_ +#include + +#include "layout.cuh" +#include "rope.cuh" + +#define FLASHINFER_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + return e; \ + } \ + } + +#define SWITCH_NUM_FRAGS_X(greater_than_64, NUM_FRAGS_X, ...) \ + if (greater_than_64) { \ + constexpr size_t NUM_FRAGS_X = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr size_t NUM_FRAGS_X = 1; \ + __VA_ARGS__ \ + } + +#define SWITCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else { \ + std::cerr << "Unsupported group_size: " << group_size << std::endl; \ + } + +#define SWITCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ + } + +#define SWITCH_LAYOUT(layout, LAYOUT, ...) \ + switch (layout) { \ + case QKVLayout::kNHD: { \ + constexpr QKVLayout LAYOUT = QKVLayout::kNHD; \ + __VA_ARGS__ \ + break; \ + } \ + case QKVLayout::kHND: { \ + constexpr QKVLayout LAYOUT = QKVLayout::kHND; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::cerr << "Unsupported qkv_layout: " << int(layout) << std::endl; \ + abort(); \ + } \ + } + +#define SWITCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::cerr << "Unsupported head_dim: " << head_dim << std::endl; \ + abort(); \ + } \ + } + +#define SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \ + switch (rotary_mode) { \ + case RotaryMode::kNone: { \ + constexpr RotaryMode ROTARY_MODE = RotaryMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case RotaryMode::kLlama: { \ + constexpr RotaryMode ROTARY_MODE = RotaryMode::kLlama; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::cerr << "Unsupported rotary_mode: " << int(rotary_mode) \ + << std::endl; \ + abort(); \ + } \ + } + +#endif // FLASHINFER_UTILS_CUH_ diff --git a/server/punica_kernels/punica_kernels/flashinfer/vec_dtypes.cuh b/server/punica_kernels/punica_kernels/flashinfer/vec_dtypes.cuh new file mode 100644 index 000000000..82ad2dd8a --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer/vec_dtypes.cuh @@ -0,0 +1,1420 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#ifdef FLASHINFER_ENABLE_FP8 +#include +#endif +#include + +#include + +namespace flashinfer { + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t& operator[](size_t i); + FLASHINFER_INLINE const float_t& operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t* ptr); + FLASHINFER_INLINE void store(float_t* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src); + template + FLASHINFER_INLINE void cast_load(const T* ptr); + template + FLASHINFER_INLINE void cast_store(T* ptr) const; + FLASHINFER_INLINE static void memcpy(float_t* dst, const float_t* src); + FLASHINFER_INLINE float_t* ptr(); +}; + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(vec_t& dst, + const src_float_t* src_ptr) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl( + tgt_float_t* dst_ptr, const vec_t& src) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_ENABLE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3* ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) { + data = *((__nv_fp8x2_e4m3*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3* ptr) const { + *((__nv_fp8x2_e4m3*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) { + data = *((__nv_fp8x4_e4m3*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3* ptr) const { + *((__nv_fp8x4_e4m3*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)data)[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2* ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) { + data = *((__nv_fp8x2_e5m2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2* ptr) const { + *((__nv_fp8x2_e5m2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) { + data = *((__nv_fp8x4_e5m2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2* ptr) const { + *((__nv_fp8x4_e5m2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)data)[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { + data = *((half2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { + *((half2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((half2*)dst) = *((half2*)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)data)[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)data)[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16* ptr); + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16* ptr); + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((nv_bfloat162*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((nv_bfloat162*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { + *((nv_bfloat162*)dst) = *((nv_bfloat162*)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16* ptr); + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)data)[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16* ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float* ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float* ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float* ptr) { + data = *((float2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float* ptr) const { + *((float2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { + return ((const float*)(data))[i]; + } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +FLASHINFER_INLINE void vec_cast(float* dst, const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +FLASHINFER_INLINE void vec_cast(half* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +FLASHINFER_INLINE void vec_cast(float* dst, + const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +FLASHINFER_INLINE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)(&dst.data))[i] = __half22float2(((half2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)(&dst.data))[i] = __float22half2_rn(((float2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)(&dst.data))[i] = + __float22bfloat162_rn(((float2*)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_ENABLE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e4m3, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e4m3*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e4m3, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(float2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3*)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(half2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3*)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e5m2, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e5m2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e5m2, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(float2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2*)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(half2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2*)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_ENABLE_FP8 + +} // namespace flashinfer + +#endif // VEC_DTYPES_CUH_ \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_all.cu b/server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_all.cu new file mode 100644 index 000000000..0ccfe7e0e --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_all.cu @@ -0,0 +1,89 @@ +#include +#include +#include + +#include "../flashinfer/decode.cuh" +#include "../flashinfer/page.cuh" +#include "flashinfer_config.h" + +template +void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr, + int32_t* kv_indicies, + int32_t* last_page_offset, int head_dim, + int num_layers, int layer_idx, + int num_qo_heads, int num_kv_heads, + int page_size, int batch_size) { + flashinfer::paged_kv_t paged_kv( + num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, + kv_data, kv_indptr, kv_indicies, last_page_offset); + flashinfer::BatchDecodeWithPagedKVCache(q, paged_kv, o, nullptr, num_qo_heads, + flashinfer::RotaryMode::kLlama); +} + +template +void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr, + int32_t* kv_indicies, int32_t* last_page_offset, + T* key, T* value, int32_t* seqlen_indptr, + int num_layers, int layer_idx, int num_kv_heads, + int page_size, int batch_size) { + flashinfer::paged_kv_t paged_kv( + num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, + kv_data, kv_indptr, kv_indicies, last_page_offset); + + constexpr size_t vec_size = + std::max(16 / sizeof(T), static_cast(head_dim / 32)); + constexpr size_t bdx = head_dim / vec_size; + constexpr size_t bdy = 128 / bdx; + dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy); + dim3 nthrs(bdx, bdy); + flashinfer::AppendPagedKVCachePrefillKernel + <<>>(paged_kv, key, value, seqlen_indptr); +} + +template +void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr, + int32_t* kv_indicies, int32_t* last_page_offset, + T* key, T* value, int num_layers, int layer_idx, + int num_kv_heads, int page_size, int batch_size) { + flashinfer::paged_kv_t paged_kv( + num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, + kv_data, kv_indptr, kv_indicies, last_page_offset); + + constexpr size_t vec_size = + std::max(16 / sizeof(T), static_cast(head_dim / 32)); + constexpr size_t bdx = head_dim / vec_size; + constexpr size_t bdy = 128 / bdx; + dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy); + dim3 nthrs(bdx, bdy); + flashinfer::AppendPagedKVCacheDecodeKernel + <<>>(paged_kv, key, value); +} + +#define INST_FlashInferBatchDecodeKernel(T) \ + template void FlashInferBatchDecodeKernel( \ + T * o, T * q, T * kv_data, int32_t * kv_indptr, int32_t * kv_indicies, \ + int32_t * last_page_offset, int head_dim, int num_layers, int layer_idx, \ + int num_qo_heads, int num_kv_heads, int page_size, int batch_size); + +INST_FlashInferBatchDecodeKernel(nv_half); +INST_FlashInferBatchDecodeKernel(nv_bfloat16); + +#define INST_FlashInferInitKvKernel(head_dim, T) \ + template void FlashInferInitKvKernel( \ + T * kv_data, int32_t * kv_indptr, int32_t * kv_indicies, \ + int32_t * last_page_offset, T * key, T * value, int32_t * seqlen_indptr, \ + int num_layers, int layer_idx, int num_kv_heads, int page_size, \ + int batch_size); + +FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_half); +FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_bfloat16); + +#define INST_FlashInferAppendKvKernel(head_dim, T) \ + template void FlashInferAppendKvKernel( \ + T * kv_data, int32_t * kv_indptr, int32_t * kv_indicies, \ + int32_t * last_page_offset, T * key, T * value, int num_layers, \ + int layer_idx, int num_kv_heads, int page_size, int batch_size); +FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_half); +FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_bfloat16); diff --git a/server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_config.h b/server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_config.h new file mode 100644 index 000000000..035f37374 --- /dev/null +++ b/server/punica_kernels/punica_kernels/flashinfer_adapter/flashinfer_config.h @@ -0,0 +1,30 @@ +#pragma once + +template +void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr, + int32_t* kv_indicies, + int32_t* last_page_offset, int head_dim, + int num_layers, int layer_idx, + int num_qo_heads, int num_kv_heads, + int page_size, int batch_size); + +template +void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr, + int32_t* kv_indicies, int32_t* last_page_offset, + T* key, T* value, int32_t* seqlen_indptr, + int num_layers, int layer_idx, int num_kv_heads, + int page_size, int batch_size); + +template +void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr, + int32_t* kv_indicies, int32_t* last_page_offset, + T* key, T* value, int num_layers, int layer_idx, + int num_kv_heads, int page_size, int batch_size); + +// clang-format off + +#define FOR_FlashInferBatchDecode_D(f, ...) \ + f(64, __VA_ARGS__) \ + f(128, __VA_ARGS__) + +// clang-format on diff --git a/server/punica_kernels/punica_kernels/punica_ops.cc b/server/punica_kernels/punica_kernels/punica_ops.cc new file mode 100644 index 000000000..f86dec025 --- /dev/null +++ b/server/punica_kernels/punica_kernels/punica_ops.cc @@ -0,0 +1,403 @@ +#include +#include +#include + +#include + +#include "bgmv/bgmv_config.h" +#include "flashinfer_adapter/flashinfer_config.h" +#include "rms_norm/rms_norm.h" +#include "sgmv/sgmv.h" +#include "sgmv_flashinfer/sgmv_config.h" + +namespace { + +//====== utils ====== + +inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, + const char* a_name, const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) \ + TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) \ + TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +//====== dispatch pytorch dtype ====== + +#define _DISPATCH_SWITCH(scalar_type, ...) \ + [&]() -> bool { \ + switch (scalar_type) { \ + __VA_ARGS__ \ + default: \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(enum_type, c_type_, ...) \ + case enum_type: { \ + using c_type = c_type_; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASES(...) \ + _DISPATCH_CASE(at::ScalarType::Half, nv_half, __VA_ARGS__) \ + _DISPATCH_CASE(at::ScalarType::BFloat16, nv_bfloat16, __VA_ARGS__) + +#define DISPATCH_TORCH_DTYPE(scalar_type, ...) \ + _DISPATCH_SWITCH(scalar_type, _DISPATCH_CASES(__VA_ARGS__)) + +//====== flashinfer ====== + +void batch_decode(torch::Tensor o, torch::Tensor q, torch::Tensor kv_data, + torch::Tensor kv_indptr, torch::Tensor kv_indicies, + torch::Tensor last_page_offset, int layer_idx) { + CHECK_INPUT(o); + CHECK_INPUT(q); + CHECK_INPUT(kv_data); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(kv_indicies); + CHECK_INPUT(last_page_offset); + + CHECK_DIM(3, o); // [B, N, D] + CHECK_DIM(3, q); // [B, N, D] + CHECK_DIM(6, kv_data); // [None, L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, kv_indicies); // [None] + CHECK_DIM(1, last_page_offset); // [B] + + int num_layers = static_cast(kv_data.size(1)); + int num_kv_heads = static_cast(kv_data.size(3)); + int page_size = static_cast(kv_data.size(4)); + int head_dim = static_cast(kv_data.size(5)); + int batch_size = static_cast(o.size(0)); + int num_qo_heads = static_cast(o.size(1)); + CHECK_SHAPE(o, q); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(last_page_offset.size(0), batch_size); + + bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] { + FlashInferBatchDecodeKernel( + static_cast(o.data_ptr()), static_cast(q.data_ptr()), + static_cast(kv_data.data_ptr()), kv_indptr.data_ptr(), + kv_indicies.data_ptr(), last_page_offset.data_ptr(), + head_dim, num_layers, layer_idx, num_qo_heads, num_kv_heads, page_size, + batch_size); + return true; + }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(), + " head_dim=", head_dim); + +#undef CASE +} + +void init_kv(torch::Tensor kv_data, torch::Tensor kv_indptr, + torch::Tensor kv_indicies, torch::Tensor last_page_offset, + torch::Tensor k, torch::Tensor v, torch::Tensor seqlen_indptr, + int layer_idx) { + CHECK_INPUT(kv_data); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(kv_indicies); + CHECK_INPUT(last_page_offset); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(seqlen_indptr); + + CHECK_DIM(6, kv_data); // [None, L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, kv_indicies); // [None] + CHECK_DIM(1, last_page_offset); // [B] + CHECK_DIM(3, k); // [sum(seqlen_i), N, D] + CHECK_DIM(3, v); // [sum(seqlen_i), N, D] + CHECK_DIM(1, seqlen_indptr); // [B+1] + + int num_layers = static_cast(kv_data.size(1)); + int num_kv_heads = static_cast(kv_data.size(3)); + int page_size = static_cast(kv_data.size(4)); + int head_dim = static_cast(kv_data.size(5)); + int batch_size = static_cast(last_page_offset.size(0)); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(seqlen_indptr.size(0), batch_size + 1); + +#define CASE(dim, _) \ + case dim: \ + FlashInferInitKvKernel( \ + static_cast(kv_data.data_ptr()), \ + kv_indptr.data_ptr(), kv_indicies.data_ptr(), \ + last_page_offset.data_ptr(), \ + static_cast(k.data_ptr()), \ + static_cast(v.data_ptr()), seqlen_indptr.data_ptr(), \ + num_layers, layer_idx, num_kv_heads, page_size, batch_size); \ + return true; + + bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] { + switch (head_dim) { + FOR_FlashInferBatchDecode_D(CASE); + default: + return false; + } + }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", k.scalar_type(), + " head_dim=", head_dim); +#undef CASE +} + +void append_kv(torch::Tensor kv_data, torch::Tensor kv_indptr, + torch::Tensor kv_indicies, torch::Tensor last_page_offset, + torch::Tensor k, torch::Tensor v, int layer_idx) { + CHECK_INPUT(kv_data); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(kv_indicies); + CHECK_INPUT(last_page_offset); + CHECK_INPUT(k); + CHECK_INPUT(v); + + CHECK_DIM(6, kv_data); // [None, L, 2, N, P, D] + CHECK_DIM(1, kv_indptr); // [B+1] + CHECK_DIM(1, kv_indicies); // [None] + CHECK_DIM(1, last_page_offset); // [B] + CHECK_DIM(3, k); // [B, N, D] + CHECK_DIM(3, v); // [B, N, D] + + int num_layers = static_cast(kv_data.size(1)); + int num_kv_heads = static_cast(kv_data.size(3)); + int page_size = static_cast(kv_data.size(4)); + int head_dim = static_cast(kv_data.size(5)); + int batch_size = static_cast(k.size(0)); + CHECK_EQ(kv_indptr.size(0), batch_size + 1); + CHECK_EQ(last_page_offset.size(0), batch_size); + CHECK_SHAPE(k, v); + +#define CASE(dim, _) \ + case dim: \ + FlashInferAppendKvKernel( \ + static_cast(kv_data.data_ptr()), \ + kv_indptr.data_ptr(), kv_indicies.data_ptr(), \ + last_page_offset.data_ptr(), \ + static_cast(k.data_ptr()), \ + static_cast(v.data_ptr()), num_layers, layer_idx, \ + num_kv_heads, page_size, batch_size); \ + return true; + + bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] { + switch (head_dim) { + FOR_FlashInferBatchDecode_D(CASE); + default: + return false; + } + }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", k.scalar_type(), + " head_dim=", head_dim); +#undef CASE +} + +//====== bgmv ====== + +template +inline bool launch_bgmv_kernel(T* Y, const T* X, const T* W, + const int64_t* lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + switch (pack_u16(in_features, out_features)) { +#define CASE_ONESIDE(_T, feat_in, feat_out) \ + case pack_u16(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, batch_size, \ + num_layers, layer_idx, scale); \ + break; +#define CASE(_T, narrow, wide) \ + CASE_ONESIDE(T, narrow, wide) \ + CASE_ONESIDE(T, wide, narrow) + + FOR_BGMV_WIDE_NARROW(CASE, _) +#undef CASE +#undef CASE_ONESIDE + default: + return false; + } + + return true; +} + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + int64_t num_layers = w.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + switch (x.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, B, + num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, B, + num_layers, layer_idx, scale); + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type()); +} + +//====== sgmv ====== + +void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, + torch::Tensor w_ptr, torch::Tensor s, + torch::Tensor tmp, int layer_idx) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(s); + CHECK_INPUT(tmp); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(1, w_ptr); + CHECK_DIM(1, s); + CHECK_DIM(1, tmp); + + int num_problems = s.size(0) - 1; + int d_in = x.size(1); + int d_out = y.size(1); + CHECK_EQ(tmp.size(0), static_cast(sgmv_tmp_size(num_problems))); + bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] { + return sgmv((c_type*)y.data_ptr(), (c_type*)x.data_ptr(), + (c_type**)w_ptr.data_ptr(), s.data_ptr(), + tmp.data_ptr(), num_problems, d_in, d_out, + layer_idx); + }); + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type()); +} + +void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, + torch::Tensor s, torch::Tensor tmp, int layer_idx) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w_ptr); + CHECK_INPUT(s); + CHECK_INPUT(tmp); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(1, w_ptr); + CHECK_DIM(1, s); + CHECK_DIM(1, tmp); + + uint32_t num_problems = s.size(0) - 1; + uint32_t d_in = x.size(1); + uint32_t d_out = y.size(1); + CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte); + CHECK_EQ(tmp.size(0), 8 * 1024 * 1024); + +#define CASE(_T, D_OUT) \ + case D_OUT: \ + return sgmv_shrink( \ + (c_type*)y.data_ptr(), (c_type*)x.data_ptr(), \ + (c_type**)w_ptr.data_ptr(), s.data_ptr(), \ + tmp.data_ptr(), num_problems, d_in, layer_idx); + + bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] { + switch (d_out) { + FOR_SGMV_NARROW(CASE, c_type); + default: + return false; + } + }); + +#undef CASE + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type(), + " d_out=", d_out); +} + +//====== rms_norm ====== + +void dispatch_rms_norm(torch::Tensor output, torch::Tensor input, + torch::Tensor weight, float epsilon) { + CHECK_INPUT(output); + CHECK_INPUT(input); + CHECK_INPUT(weight); + + CHECK_DIM(2, input); + CHECK_DIM(1, weight); + CHECK_SHAPE(output, input); + CHECK_EQ(input.size(input.dim() - 1), weight.size(0)); + CHECK_EQ(input.scalar_type(), weight.scalar_type()); + CHECK_EQ(input.scalar_type(), output.scalar_type()); + + int rows = input.size(0); + int columns = input.size(1); + + bool ok = DISPATCH_TORCH_DTYPE(input.scalar_type(), [&] { + return rms_norm(static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), rows, + columns, epsilon); + }); + + TORCH_CHECK(ok, "No suitable kernel.", " dtype=", input.scalar_type(), + " columns=", columns); +} + +} // namespace + +//====== pybind ====== + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("batch_decode", &batch_decode, ""); + m.def("init_kv", &init_kv, ""); + m.def("append_kv", &append_kv, ""); + + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + + m.def("sgmv_cutlass", &dispatch_sgmv_cutlass, ""); + m.def("sgmv_cutlass_tmp_size", &sgmv_tmp_size, ""); + m.def("sgmv_shrink", &dispatch_sgmv_shrink, ""); + m.def("rms_norm", &dispatch_rms_norm, ""); +} diff --git a/server/punica_kernels/punica_kernels/rms_norm/rms_norm.h b/server/punica_kernels/punica_kernels/rms_norm/rms_norm.h new file mode 100644 index 000000000..1231c90ea --- /dev/null +++ b/server/punica_kernels/punica_kernels/rms_norm/rms_norm.h @@ -0,0 +1,4 @@ +template +bool rms_norm(T *__restrict__ output, const T *__restrict__ input, + const T *__restrict__ weight, int rows, int columns, + float epsilon); diff --git a/server/punica_kernels/punica_kernels/rms_norm/rms_norm_cutlass.cu b/server/punica_kernels/punica_kernels/rms_norm/rms_norm_cutlass.cu new file mode 100644 index 000000000..aab718fdb --- /dev/null +++ b/server/punica_kernels/punica_kernels/rms_norm/rms_norm_cutlass.cu @@ -0,0 +1,189 @@ +// Adapted from cutlass +// https://github.com/NVIDIA/cutlass/blob/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616/tools/util/include/cutlass/util/device_rmsnorm.h +/****************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include +#include +#include + +#include +#include + +template +__inline__ __device__ T warpReduceSum(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(0xffffffff, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T *val) { + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__global__ void rmsnorm_twoPassAlgo_e8(float4 *__restrict__ output, + const float4 *__restrict__ input, + const float4 *__restrict__ weight, int m, + int n, float epsilon) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + blockReduceSum(local_sums); + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half2 *h4 = (half2 *)&tmp.w; + + h1->x = static_cast(static_cast(l1->x) * s_mean * + static_cast(g1->x)); + h1->y = static_cast(static_cast(l1->y) * s_mean * + static_cast(g1->y)); + h2->x = static_cast(static_cast(l2->x) * s_mean * + static_cast(g2->x)); + h2->y = static_cast(static_cast(l2->y) * s_mean * + static_cast(g2->y)); + h3->x = static_cast(static_cast(l3->x) * s_mean * + static_cast(g3->x)); + h3->y = static_cast(static_cast(l3->y) * s_mean * + static_cast(g3->y)); + h4->x = static_cast(static_cast(l4->x) * s_mean * + static_cast(g4->x)); + h4->y = static_cast(static_cast(l4->y) * s_mean * + static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +bool rms_norm(T *__restrict__ output, const T *__restrict__ input, + const T *__restrict__ weight, int rows, int columns, + float epsilon) { + if (columns % 8 != 0) { + return false; + } + + dim3 grid(rows); + dim3 block(std::min(1024, (columns / 8 + 31) / 32 * 32)); + + if (std::is_same::value) { + rmsnorm_twoPassAlgo_e8 + <<>>((float4 *)output, (float4 *)input, (float4 *)weight, + rows, columns, epsilon); + return true; + } else if (std::is_same::value) { + rmsnorm_twoPassAlgo_e8 + <<>>((float4 *)output, (float4 *)input, (float4 *)weight, + rows, columns, epsilon); + return true; + } + return false; +} + +template bool rms_norm(nv_half *__restrict__ output, + const nv_half *__restrict__ input, + const nv_half *__restrict__ weight, int rows, + int columns, float epsilon); +template bool rms_norm(nv_bfloat16 *__restrict__ output, + const nv_bfloat16 *__restrict__ input, + const nv_bfloat16 *__restrict__ weight, int rows, + int columns, float epsilon); diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv.h b/server/punica_kernels/punica_kernels/sgmv/sgmv.h new file mode 100644 index 000000000..ea0aafea7 --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv/sgmv.h @@ -0,0 +1,5 @@ +template +bool sgmv(DType *y, DType *x, DType **w, int32_t *s, void *tmp_d, + int num_problems, int d_in, int d_out, int layer_idx); + +size_t sgmv_tmp_size(int num_problems); diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu new file mode 100644 index 000000000..f6be221d1 --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu @@ -0,0 +1,12 @@ +#include +#include + +#include "sgmv_cutlass.cuh" + +template bool sgmv(nv_half *y, nv_half *x, nv_half **w, int32_t *s, + void *tmp_d, int num_problems, int d_in, int d_out, + int layer_idx); + +template bool sgmv(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w, + int32_t *s, void *tmp_d, int num_problems, + int d_in, int d_out, int layer_idx); diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh new file mode 100644 index 000000000..521a0136e --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh @@ -0,0 +1,153 @@ +#pragma once +#include +#include +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +template +struct cutlass_dtype { + using type = T; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::half_t; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::bfloat16_t; +}; + +template +__global__ void precompute_sgmv_args(cutlass::gemm::GemmCoord *all_problems, + T **ptr_y, T **ptr_x, T **ptr_w, + int64_t *ld_y, int64_t *ld_x, + int64_t *ld_w, T *y, T *x, T **w, + int32_t *s, int d_in, int d_out, + int layer_idx) { + int i = blockIdx.x; + int m = s[i + 1] - s[i], k = d_in, n = d_out; + all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); + ptr_w[i] = w[i] + layer_idx * d_in * d_out; + ptr_x[i] = x + s[i] * d_in; + ptr_y[i] = y + s[i] * d_out; + ld_x[i] = k; + ld_w[i] = n; + ld_y[i] = n; +} + +size_t sgmv_tmp_size(int num_problems) { + constexpr auto sz = sizeof(void *) * 3 + sizeof(int64_t) * 3 + + sizeof(cutlass::gemm::GemmCoord); + return sz * num_problems; +} + +template +inline T *alloc_from_buf(void **buf, int n) { + auto *p = (T *)*buf; + *buf = (void *)(p + n); + return p; +} + +template +bool sgmv(DType *y, DType *x, DType **w, int32_t *s, void *tmp_d, + int num_problems, int d_in, int d_out, int layer_idx) { + using cutlass_t = typename cutlass_dtype::type; + + auto ptr_Y = alloc_from_buf(&tmp_d, num_problems); + auto ptr_X = alloc_from_buf(&tmp_d, num_problems); + auto ptr_W = alloc_from_buf(&tmp_d, num_problems); + auto ld_Y = alloc_from_buf(&tmp_d, num_problems); + auto ld_X = alloc_from_buf(&tmp_d, num_problems); + auto ld_W = alloc_from_buf(&tmp_d, num_problems); + auto all_problems = + alloc_from_buf(&tmp_d, num_problems); + + precompute_sgmv_args<<>>( + all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y, + (cutlass_t *)x, (cutlass_t **)w, s, d_in, d_out, layer_idx); + + using cutlass::epilogue::thread::LinearCombination; + using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; + if (d_in < d_out) { + // Expand + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + cutlass_t, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + cutlass_t, // Element B + cutlass::layout::RowMajor, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + cutlass_t, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape + cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape + LinearCombination, // Epilogue + GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator + 2 // Stages + >::GemmKernel; + + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args(all_problems, num_problems, 512, + epilogue_op, ptr_X, ptr_W, ptr_Y, + ptr_Y, ld_X, ld_W, ld_Y, ld_Y); + + GemmGrouped gemm; + if (gemm.initialize(args) != cutlass::Status::kSuccess) return false; + if (gemm.run() != cutlass::Status::kSuccess) return false; + } else { + // Shrink + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + cutlass_t, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + cutlass_t, // Element B + cutlass::layout::RowMajor, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + cutlass_t, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<16, 64, 64>, // Thread Block Shape + cutlass::gemm::GemmShape<16, 16, 64>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape + LinearCombination, // Epilogue + GemmIdentityThreadblockSwizzle<2>, // Swizzling Operator + 2 // Stages + >::GemmKernel; + + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args(all_problems, num_problems, 512, + epilogue_op, ptr_X, ptr_W, ptr_Y, + ptr_Y, ld_X, ld_W, ld_Y, ld_Y); + + GemmGrouped gemm; + if (gemm.initialize(args) != cutlass::Status::kSuccess) return false; + if (gemm.run() != cutlass::Status::kSuccess) return false; + } + return true; +} diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/cp_async.cuh b/server/punica_kernels/punica_kernels/sgmv_flashinfer/cp_async.cuh new file mode 100644 index 000000000..81c3396e6 --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/cp_async.cuh @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_CP_ASYNC_CUH_ +#define FLASHINFER_CP_ASYNC_CUH_ + +#include + +namespace flashinfer { + +namespace cp_async { + +__device__ __forceinline__ void commit_group() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } +#else + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, + bool predicate) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && \ + (__CUDACC_VER_MAJOR__ >= 11) + + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + int src_in_bytes = (predicate ? 8 * sizeof(T) : 0); + if constexpr (prefetch) { + asm volatile( + "{\n" + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" + "}\n" ::"r"(smem_int_ptr), "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" + "}\n" ::"r"(smem_int_ptr), "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } +#else + if (predicate) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128 || num_bits == 256, + "num_bits must be 128 or 256"); + if constexpr (num_bits == 128) { + load_128b(smem_ptr, gmem_ptr); + } else { + load_128b(smem_ptr, gmem_ptr); + load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T)); + } +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 256, + "num_bits must be 128 or 256"); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + pred_load_128b(smem_ptr + 16 / sizeof(T), + gmem_ptr + 16 / sizeof(T), predicate); + } +} + +} // namespace cp_async + +} // namespace flashinfer + +#endif // FLASHINFER_CP_ASYNC_CUH_ diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/mma.cuh b/server/punica_kernels/punica_kernels/sgmv_flashinfer/mma.cuh new file mode 100644 index 000000000..d926eb64c --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/mma.cuh @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_MMA_CUH_ +#define FLASHINFER_MMA_CUH_ + +#include +#include +#include + +#include + +namespace flashinfer { + +namespace mma { + +template +__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, + T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + (__CUDACC_VER_MAJOR__ >= 11) + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" + : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3])); +#else + const uint32_t tx = threadIdx.x; + uint4 word; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4); + word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1); + word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2); + word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3); + if (tx / 8 == reg_id) { + *(uint4*)smem_ptr = word; + } + } +#endif +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } +} + +} // namespace mma + +} // namespace flashinfer + +#endif // FLASHINFER_MMA_CUH_ diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/permuted_smem.cuh b/server/punica_kernels/punica_kernels/sgmv_flashinfer/permuted_smem.cuh new file mode 100644 index 000000000..ec2201d54 --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/permuted_smem.cuh @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PERMUTED_SMEM_CUH_ +#define FLASHINFER_PERMUTED_SMEM_CUH_ + +#include +#include +#include + +#include + +#include "cp_async.cuh" +#include "mma.cuh" + +namespace flashinfer { + +// Each cell is 4 bytes. +using cell_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t cell_capacity() { + return sizeof(cell_t) / sizeof(T); +} + +struct smem_t { + cell_t* base; + uint32_t offset; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((cell_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + return (i / 2) * stride * 2 + (j / 4) * 8 + (i % 2) * 4 + + ((j % 4) ^ ((i / 2) % 4)); + } + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R) { + cell_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(R, smem_ptr); + } + __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R) { + cell_t* smem_ptr = base + offset; + mma::stmatrix_m8n8x4(R, smem_ptr); + } + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R) { + cell_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(R, smem_ptr); + } + template + __device__ __forceinline__ void load_128b_async(const T* gptr, + bool predicate) { + cell_t* smem_ptr = base + offset; + cp_async::pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + template + __device__ __forceinline__ void load_128b_async(const T* gptr) { + cell_t* smem_ptr = base + offset; + cp_async::load_128b(smem_ptr, reinterpret_cast(gptr)); + } + template + __device__ __forceinline__ void store_128b(T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_PERMUTED_SMEM_CUH_ diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu new file mode 100644 index 000000000..7d45e5fcd --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_all.cu @@ -0,0 +1,67 @@ +#include +#include +#include + +#include + +#include "sgmv_config.h" +#include "sgmv_flashinfer.cuh" + +template +bool sgmv_shrink(T* y, T* x, T** w, int32_t* s, void* tmp, + uint32_t num_problems, uint32_t d_in, uint32_t layer_idx) { + static_assert(d_out % 16 == 0); + + constexpr uint32_t num_warps = 4; + constexpr uint32_t num_stages = 2; + constexpr uint32_t num_k_frags_per_stage = 8; + constexpr uint32_t num_blocks_n = d_out / 16; + uint32_t smem = num_stages * sizeof(T) * num_k_frags_per_stage * 16 * 16 * + (num_warps + num_blocks_n); + cudaStream_t stream = nullptr; + auto cooperative_kernel = + flashinfer::sgmv::sgmv_shrink; + auto kernel = flashinfer::sgmv::sgmv_shrink; + + uint32_t dev_id = 0; + int num_blocks_per_sm = 0; + int num_sm = 0; + bool use_cooperative = true; + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, cooperative_kernel, num_warps * 32, smem); + + const uint32_t max_grid_size = num_sm * num_blocks_per_sm; + + uint32_t chunk_size = 256; + uint32_t num_chunks = (d_in + chunk_size - 1) / chunk_size; + if (num_chunks * num_problems > max_grid_size) { + use_cooperative = false; + chunk_size = d_in; + num_chunks = 1; + } + + dim3 nthrs(32, num_warps); + dim3 nblks(num_chunks, num_problems); + + void* args[] = {(void*)&y, (void*)&x, (void*)&w, + (void*)&s, (void*)&tmp, (void*)&num_problems, + (void*)&d_in, (void*)&layer_idx, (void*)&chunk_size}; + + cudaError_t status; + if (use_cooperative) { + status = cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, + nthrs, args, smem, stream); + } else { + status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream); + } + return status == cudaSuccess; +} + +#define INST(T, d_out) \ + template bool sgmv_shrink(T * y, T * x, T * *w, int32_t * s, \ + void* tmp, uint32_t num_problems, \ + uint32_t d_in, uint32_t layer_idx); + +FOR_SGMV_NARROW(INST, nv_half); +FOR_SGMV_NARROW(INST, nv_bfloat16); \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h new file mode 100644 index 000000000..90652c3aa --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_config.h @@ -0,0 +1,15 @@ +#pragma once +#include + +template +bool sgmv_shrink(T* y, T* x, T** w, int32_t* s, void* tmp, + uint32_t num_problems, uint32_t d_in, uint32_t layer_idx); + +// clang-format off + +#define FOR_SGMV_NARROW(f, T) \ + f(T, 16) \ + f(T, 32) \ + f(T, 64) + +// clang-format on diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh new file mode 100644 index 000000000..43215322d --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -0,0 +1,312 @@ +#pragma once +#include + +#include "cp_async.cuh" +#include "mma.cuh" +#include "permuted_smem.cuh" +#include "vec_dtypes.cuh" + +namespace flashinfer { + +namespace sgmv { + +template +__global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t num_problems, + uint32_t d_in, uint32_t layer_idx, uint32_t chunk_size) { + auto block = cooperative_groups::this_thread_block(); + auto grid = cooperative_groups::this_grid(); + const uint32_t problem_id = blockIdx.y; + const uint32_t bx = blockIdx.x; + const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1]; + constexpr uint32_t num_stages = 2; + constexpr uint32_t num_k_frags = 8; + constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity(); + constexpr uint32_t num_blocks_n = d_out / 16; + const uint32_t num_chunks = gridDim.x; + const uint32_t chunk_start = chunk_size * bx; + const uint32_t num_iterations = (chunk_size + (num_k_frags * 16 - 1)) / (num_k_frags * 16); + constexpr uint32_t num_cells_n = (d_out < 32 ? 32 : d_out) / cell_capacity(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + extern __shared__ uint8_t smem[]; + + smem_t x_smem[2]{smem, smem + sizeof(T) * num_warps * 16 * 16 * num_k_frags}; + smem_t w_smem[2]{smem + sizeof(T) * 2 * num_warps * 16 * 16 * num_k_frags, + smem + sizeof(T) * 16 * 16 * num_k_frags * (2 * num_warps + num_blocks_n)}; + smem_t y_smem(smem); + + uint32_t x_frag[num_k_frags][4]; + uint32_t w_frag[num_k_frags][num_blocks_n][4]; + float y_frag[num_blocks_n][8]; + + for (uint32_t i = 0; i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) { + // init y_frag + if (bx == 0) { + if constexpr (num_blocks_n == 1) { + uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2; + T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity(); + y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 2, tx % 2); + y_smem.load_128b_async(y_ptr, row_idx < s_end); + } else { + uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; + T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity(); + y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) { + y_smem.load_128b_async(y_ptr, row_idx < s_end); + y_ptr += 4 * cell_capacity(); + y_smem.offset += 8; + } + row_idx += 8; + y_ptr += 8 * d_out; + y_smem.offset += 8 * num_cells_n - 4 * num_blocks_n; + } + } + cp_async::commit_group(); + cp_async::wait_group<0>(); + block.sync(); + + y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx % 16, tx / 16); +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + uint32_t tmp[4]; + y_smem.ldmatrix_m8n8x4(tmp); + vec_cast(y_frag[fn], (T*)tmp); + y_smem.offset = (y_smem.offset ^ 0x2) + (fn & 0x1) * 8; + } + } else { +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + y_frag[fn][reg_id] = 0.f; + } + } + } + + // preload x_smem, w_smem +#pragma unroll + for (uint32_t iter = 0; iter < num_stages; ++iter) { + uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; + T* x_ptr = + x + row_idx * d_in + chunk_start + (2 * num_k_frags * iter + tx % 4) * cell_capacity(); + T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size); + x_smem[iter].offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); + // pre-load x_smem, w_smem +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) { + x_smem[iter].load_128b_async(x_ptr, row_idx < s_end && x_ptr < x_ptr_max); + x_ptr += 4 * cell_capacity(); + x_smem[iter].offset += 8; + } + row_idx += 8; + x_ptr += 8 * d_in - 2 * cell_capacity() * num_k_frags; + x_ptr_max += 8 * d_in; + x_smem[iter].offset += 8 * num_cells_k - 4 * num_k_frags; + } + row_idx -= 8; + + static_assert(num_k_frags % (num_warps * 2) == 0); + constexpr uint32_t num_fko_iters_per_warp = num_k_frags / (num_warps * 2); +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + T* w_ptr = w[problem_id] + layer_idx * d_in * d_out + (fn * 16 + tx / 4) * d_in + + chunk_start + + (2 * num_k_frags * iter + ty * num_fko_iters_per_warp * 4 + tx % 4) * + cell_capacity(); + T* w_ptr_max = w[problem_id] + layer_idx * d_in * d_out + + min((fn * 16 + tx / 4 + 1) * d_in, + (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size); + w_smem[iter].offset = smem_t::get_permuted_offset( + fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) { + w_smem[iter].load_128b_async(w_ptr, w_ptr < w_ptr_max); + w_ptr += 4 * cell_capacity(); + w_smem[iter].offset += 8; + } + w_ptr += 8 * d_in - 4 * cell_capacity() * num_fko_iters_per_warp; + w_ptr_max += 8 * d_in; + w_smem[iter].offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp; + } + } + cp_async::commit_group(); + } + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + const uint32_t stage_idx = iter % 2; + cp_async::wait_group<1>(); + block.sync(); + + x_smem[stage_idx].offset = + smem_t::get_permuted_offset(ty * 16 + tx % 16, tx / 16); +#pragma unroll + for (uint32_t fk = 0; fk < num_k_frags; ++fk) { + x_smem[stage_idx].ldmatrix_m8n8x4(x_frag[fk]); + x_smem[stage_idx].offset = (x_smem[stage_idx].offset ^ 0x2) + (fk & 0x1) * 8; + } + +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + w_smem[stage_idx].offset = smem_t::get_permuted_offset( + fn * 16 + 8 * (tx / 16) + tx % 8, (tx % 16) / 8); +#pragma unroll + for (uint32_t fk = 0; fk < num_k_frags; ++fk) { + w_smem[stage_idx].ldmatrix_m8n8x4(w_frag[fk][fn]); + w_smem[stage_idx].offset = (w_smem[stage_idx].offset ^ 0x2) + (fk & 0x1) * 8; + } + w_smem[stage_idx].offset += 16 * num_cells_k - 4 * num_k_frags; + } + + // compute y_frag +#pragma unroll + for (uint32_t fk = 0; fk < num_k_frags; ++fk) { +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + mma::mma_sync_m16n16k16_row_col_f16f16f32(y_frag[fn], x_frag[fk], w_frag[fk][fn]); + } + } + block.sync(); + + // load next stage + if (iter + num_stages < num_iterations) { + uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; + T* x_ptr = x + row_idx * d_in + chunk_start + + (2 * num_k_frags * (iter + num_stages) + tx % 4) * cell_capacity(); + T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size); + x_smem[stage_idx].offset = + smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); + // pre-load x_smem, w_smem +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) { + x_smem[stage_idx].load_128b_async(x_ptr, row_idx < s_end && x_ptr < x_ptr_max); + x_ptr += 4 * cell_capacity(); + x_smem[stage_idx].offset += 8; + } + row_idx += 8; + x_ptr += 8 * d_in - 2 * cell_capacity() * num_k_frags; + x_ptr_max += 8 * d_in; + x_smem[stage_idx].offset += 8 * num_cells_k - 4 * num_k_frags; + } + row_idx -= 8; + + constexpr uint32_t num_fko_iters_per_warp = num_k_frags / (num_warps * 2); +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + T* w_ptr = + w[problem_id] + layer_idx * d_in * d_out + (fn * 16 + tx / 4) * d_in + chunk_start + + (2 * num_k_frags * (iter + num_stages) + ty * num_fko_iters_per_warp * 4 + tx % 4) * + cell_capacity(); + T* w_ptr_max = w[problem_id] + layer_idx * d_in * d_out + + min((fn * 16 + tx / 4 + 1) * d_in, + (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size); + w_smem[stage_idx].offset = smem_t::get_permuted_offset( + fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) { + w_smem[stage_idx].load_128b_async(w_ptr, w_ptr < w_ptr_max); + w_ptr += 4 * cell_capacity(); + w_smem[stage_idx].offset += 8; + } + w_ptr += 8 * d_in - 4 * cell_capacity() * num_fko_iters_per_warp; + w_ptr_max += 8 * d_in; + w_smem[stage_idx].offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp; + } + } + } + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + block.sync(); + + if constexpr (cooperative) { +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + vec_t::memcpy( + tmp + (fn * grid.size() + (problem_id * num_chunks + bx) * block.num_threads() + + block.thread_rank()) * + 8, + y_frag[fn]); + } + grid.sync(); + +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + y_frag[fn][reg_id] = 0.f; + } + for (uint32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + vec_t y_other; + y_other.load(tmp + (fn * grid.size() + + (problem_id * num_chunks + chunk_idx) * block.num_threads() + + block.thread_rank()) * + 8); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + y_frag[fn][reg_id] += y_other[reg_id]; + } + } + } + } + + if (bx == 0) { + // store y_frag + y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, 0); +#pragma unroll + for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { + vec_cast((T*)(y_smem.base + y_smem.offset) + (tx % 4) * 2, &y_frag[fn][0]); + vec_cast((T*)(y_smem.base + y_smem.offset + 8 * num_cells_n) + (tx % 4) * 2, + &y_frag[fn][2]); + vec_cast((T*)(y_smem.base + (y_smem.offset ^ 0x1)) + (tx % 4) * 2, + &y_frag[fn][4]); + vec_cast( + (T*)(y_smem.base + (y_smem.offset ^ 0x1) + 8 * num_cells_n) + (tx % 4) * 2, + &y_frag[fn][6]); + y_smem.offset = (y_smem.offset ^ 0x2) + (fn & 0x1) * 8; + } + + // store y + if constexpr (num_blocks_n == 1) { + uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2; + T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity(); + y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 2, tx % 2); + if (row_idx < s_end) { + y_smem.store_128b(y_ptr); + } + } else { + uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; + T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity(); + y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { +#pragma unroll + for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) { + if (row_idx < s_end) { + y_smem.store_128b(y_ptr); + } + y_ptr += 4 * cell_capacity(); + y_smem.offset += 8; + } + row_idx += 8; + y_ptr += 8 * d_out; + y_smem.offset += 8 * num_cells_n - 4 * num_blocks_n; + } + } + } + } +} + +} // namespace sgmv +} // namespace flashinfer \ No newline at end of file diff --git a/server/punica_kernels/punica_kernels/sgmv_flashinfer/vec_dtypes.cuh b/server/punica_kernels/punica_kernels/sgmv_flashinfer/vec_dtypes.cuh new file mode 100644 index 000000000..82ad2dd8a --- /dev/null +++ b/server/punica_kernels/punica_kernels/sgmv_flashinfer/vec_dtypes.cuh @@ -0,0 +1,1420 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#ifdef FLASHINFER_ENABLE_FP8 +#include +#endif +#include + +#include + +namespace flashinfer { + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t& operator[](size_t i); + FLASHINFER_INLINE const float_t& operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t* ptr); + FLASHINFER_INLINE void store(float_t* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src); + template + FLASHINFER_INLINE void cast_load(const T* ptr); + template + FLASHINFER_INLINE void cast_store(T* ptr) const; + FLASHINFER_INLINE static void memcpy(float_t* dst, const float_t* src); + FLASHINFER_INLINE float_t* ptr(); +}; + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(vec_t& dst, + const src_float_t* src_ptr) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl( + tgt_float_t* dst_ptr, const vec_t& src) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_ENABLE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3* ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) { + data = *((__nv_fp8x2_e4m3*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3* ptr) const { + *((__nv_fp8x2_e4m3*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) { + data = *((__nv_fp8x4_e4m3*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3* ptr) const { + *((__nv_fp8x4_e4m3*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { + return ((__nv_fp8_e4m3*)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { + return ((const __nv_fp8_e4m3*)data)[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { + return reinterpret_cast<__nv_fp8_e4m3*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2* ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) { + data = *((__nv_fp8x2_e5m2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2* ptr) const { + *((__nv_fp8x2_e5m2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) { + data = *((__nv_fp8x4_e5m2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2* ptr) const { + *((__nv_fp8x4_e5m2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { + return ((__nv_fp8_e5m2*)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { + return ((const __nv_fp8_e5m2*)data)[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { + return reinterpret_cast<__nv_fp8_e5m2*>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, + const __nv_fp8_e5m2* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { + data = *((half2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { + *((half2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((half2*)dst) = *((half2*)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)data)[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { + return ((const half*)data)[i]; + } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16* ptr); + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16* ptr); + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((nv_bfloat162*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((nv_bfloat162*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { + *((nv_bfloat162*)dst) = *((nv_bfloat162*)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16* ptr); + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((uint2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((uint2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { + *((uint2*)dst) = *((uint2*)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)data)[i]; + } + FLASHINFER_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16* ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16* ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, + const nv_bfloat16* src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float* ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float* ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float* ptr) { + data = *((float2*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float* ptr) const { + *((float2*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { + return ((const float*)(data))[i]; + } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +FLASHINFER_INLINE void vec_cast(float* dst, const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +FLASHINFER_INLINE void vec_cast(half* dst, const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +FLASHINFER_INLINE void vec_cast(float* dst, + const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +FLASHINFER_INLINE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)(&dst.data))[i] = __half22float2(((half2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)(&dst.data))[i] = __float22half2_rn(((float2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)(&dst.data))[i] = + __float22bfloat162_rn(((float2*)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_ENABLE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e4m3, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e4m3*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e4m3, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(float2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3*)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(half2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3*)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e5m2, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e5m2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl( + vec_t& dst, const vec_t<__nv_fp8_e5m2, vec_size>& src) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(float2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2*)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4*)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst, + const vec_t& src) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(half2*)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2*)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_ENABLE_FP8 + +} // namespace flashinfer + +#endif // VEC_DTYPES_CUH_ \ No newline at end of file diff --git a/server/punica_kernels/setup.py b/server/punica_kernels/setup.py new file mode 100644 index 000000000..ef557d98f --- /dev/null +++ b/server/punica_kernels/setup.py @@ -0,0 +1,42 @@ +import pathlib + +import setuptools +import torch.utils.cpp_extension as torch_cpp_ext + +root = pathlib.Path(__name__).parent + + +def remove_unwanted_pytorch_nvcc_flags(): + REMOVE_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + for flag in REMOVE_NVCC_FLAGS: + try: + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + + +remove_unwanted_pytorch_nvcc_flags() + +setuptools.setup( + name="punica_kernels", + ext_modules=[ + torch_cpp_ext.CUDAExtension( + name="punica_kernels", + sources=[ + "punica_kernels/punica_ops.cc", + "punica_kernels/bgmv/bgmv_all.cu", + "punica_kernels/flashinfer_adapter/flashinfer_all.cu", + "punica_kernels/rms_norm/rms_norm_cutlass.cu", + "punica_kernels/sgmv/sgmv_cutlass.cu", + "punica_kernels/sgmv_flashinfer/sgmv_all.cu", + ], + include_dirs=[str(root.resolve() / "third_party/cutlass/include")], + ) + ], + cmdclass={"build_ext": torch_cpp_ext.BuildExtension}, +) diff --git a/server/punica_kernels/third_party/cutlass b/server/punica_kernels/third_party/cutlass new file mode 160000 index 000000000..5cd735c48 --- /dev/null +++ b/server/punica_kernels/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit 5cd735c48ec194039edb647aa7b37f4ff470738e