From 82809c900a907a9e48abc96790e8caf8f5f81f2f Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 8 Sep 2023 02:50:44 +0000
Subject: [PATCH 01/56] refresh decoder attention kernel

---
 src/turbomind/kernels/decoder_mha/array_ops.h | 317 +++++++++
 .../decoder_multihead_attention.cu            |  47 ++
 .../decoder_mha/decoder_multihead_attention.h |   9 +
 .../decoder_multihead_attention_params.h      |  49 ++
 .../decoder_multihead_attention_template.h    | 671 ++++++++++++++++++
 src/turbomind/kernels/decoder_mha/iterator.h  | 340 +++++++++
 src/turbomind/kernels/decoder_mha/kv_cache.cu | 149 ++++
 src/turbomind/kernels/decoder_mha/kv_cache.h  |  25 +
 .../test_decoder_multihead_attention.cu       | 199 ++++++
 .../kernels/decoder_mha/test_utils.cu         | 236 ++++++
 .../kernels/decoder_mha/test_utils.h          |  36 +
 .../kernels/decoder_mha/thread_map.h          |  96 +++
 12 files changed, 2174 insertions(+)
 create mode 100644 src/turbomind/kernels/decoder_mha/array_ops.h
 create mode 100644 src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
 create mode 100644 src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h
 create mode 100644 src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
 create mode 100644 src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
 create mode 100644 src/turbomind/kernels/decoder_mha/iterator.h
 create mode 100644 src/turbomind/kernels/decoder_mha/kv_cache.cu
 create mode 100644 src/turbomind/kernels/decoder_mha/kv_cache.h
 create mode 100644 src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
 create mode 100644 src/turbomind/kernels/decoder_mha/test_utils.cu
 create mode 100644 src/turbomind/kernels/decoder_mha/test_utils.h
 create mode 100644 src/turbomind/kernels/decoder_mha/thread_map.h

diff --git a/src/turbomind/kernels/decoder_mha/array_ops.h b/src/turbomind/kernels/decoder_mha/array_ops.h
new file mode 100644
index 0000000000..a157d15cac
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/array_ops.h
@@ -0,0 +1,317 @@
+#pragma once
+
+#include "../gemm_s_f16/common.h"
+#include <cfloat>
+#include <limits>
+
+namespace turbomind {
+
+namespace ops {
+
+template<typename T>
+struct plus {
+    __device__ T operator()(T a, T b)
+    {
+        return a + b;
+    }
+};
+
+template<typename T>
+struct minus {
+    __device__ T operator()(T a, T b)
+    {
+        return a + b;
+    }
+};
+
+template<typename T>
+struct multiplies {
+    __device__ T operator()(T a, T b)
+    {
+        return a * b;
+    }
+};
+
+template<typename T, int N, typename Op>
+inline __device__ Array<T, N> binary_op_vv(const Array<T, N>& a, const Array<T, N>& b, Op op)
+{
+    Array<T, N> c;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        c[i] = op(a[i], b[i]);
+    }
+    return c;
+}
+
+template<typename T, int N, typename Op>
+inline __device__ Array<T, N> binary_op_sv(const T& a, const Array<T, N>& b, Op op)
+{
+    Array<T, N> c;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        c[i] = op(a, b[i]);
+    }
+    return c;
+}
+
+template<typename T, int N, typename Op>
+inline __device__ Array<T, N> binary_op_vs(const Array<T, N>& a, const T& b, Op op)
+{
+    Array<T, N> c;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        c[i] = op(a[i], b);
+    }
+    return c;
+}
+
+template<typename T, int N>
+inline __device__ Array<T, N> operator+(const Array<T, N>& a, const Array<T, N>& b)
+{
+    return binary_op_vv(a, b, plus<T>{});
+}
+
+template<typename T, int N>
+inline __device__ Array<T, N> operator*(const Array<T, N>& a, const Array<T, N>& b)
+{
+    return binary_op_vv(a, b, multiplies<T>{});
+}
+
+template<typename T, int N>
+inline __device__ Array<T, N> operator*(const Array<T, N>& a, const T& b)
+{
+    return binary_op_vs(a, b, multiplies<T>{});
+}
+
+}  // namespace ops
+
+template<int N>
+struct RotaryEmbedding {
+
+    static_assert(N % 2 == 0);
+
+    Array<float, N> inv_freqs_;
+
+    __device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 2) {
+            const float2 tmp  = rotary_embedding_coefficient(offset.x + i, dims, base, timestep);
+            inv_freqs_[i]     = tmp.x;
+            inv_freqs_[i + 1] = tmp.y;
+        }
+    }
+
+    inline __device__ float2 rotary_embedding_coefficient(int idx, int dims, float base, int timestep)
+    {
+        const float inv_freq = timestep / powf(base, idx / (float)dims);
+        return {cos(inv_freq), sin(inv_freq)};
+    }
+
+    template<typename T>
+    __device__ void apply(Array<T, N>& x)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 2) {
+            float tmp0 = inv_freqs_[i] * (float)x[i] - inv_freqs_[i + 1] * (float)x[i + 1];
+            float tmp1 = inv_freqs_[i] * (float)x[i + 1] + inv_freqs_[i + 1] * (float)x[i];
+            x[i]       = (T)tmp0;
+            x[i + 1]   = (T)tmp1;
+        }
+    }
+};
+
+template<typename VecQk, typename ThreadMap>
+struct LogNScaling {
+    __device__ void apply(VecQk& x)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < VecQk::kSize; ++i) {
+            // TODO:
+        }
+    }
+};
+
+template<typename To, typename From, int N>
+inline __device__ Array<To, N> cast(const Array<From, N>& src)
+{
+    Array<To, N> dst;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        dst[i] = (To)src[i];
+    }
+    return dst;
+}
+
+template<typename T, int N>
+inline __device__ void Store(T* dst, const Array<T, N>& src)
+{
+    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
+
+    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
+        *(uint4*)dst = (const uint4&)src;
+    }
+    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
+        *(uint2*)dst = (const uint2&)src;
+    }
+    else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {
+        *(uint1*)dst = (const uint1&)src;
+    }
+    else {
+        static_assert(!std::is_same_v<T, T>);
+    }
+}
+
+template<typename T, int N>
+inline __device__ void Ldg(Array<T, N>& dst, const T* src)
+{
+    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
+
+    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
+        (uint4&)dst = __ldg((const uint4*)src);
+    }
+    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
+        (uint2&)dst = __ldg((const uint2*)src);
+    }
+    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {
+        (uint&)dst = __ldg((const uint*)src);
+    }
+    else {
+        static_assert(!std::is_same_v<T, T>);
+    }
+}
+
+template<typename T, int N>
+inline __device__ void Lds(Array<T, N>& dst, const T* src)
+{
+    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
+
+    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
+        (uint4&)dst = *(const uint4*)src;
+    }
+    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
+        (uint2&)dst = *(const uint2*)src;
+    }
+    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {
+        (uint1&)dst = *(const uint1*)src;
+    }
+    else {
+        static_assert(!std::is_same_v<T, T>);
+    }
+}
+
+template<typename Accum, typename Compute, int kThreadGroupSize, typename T, int N, int V>
+inline __device__ Accum qk_dot(const Array<T, N> (&q)[V], const Array<T, N> (&k)[V])
+{
+    Accum accum{};
+
+    PRAGMA_UNROLL
+    for (int vi = 0; vi < V; ++vi) {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            accum += Accum(Compute(q[vi][i]) * Compute(k[vi][i]));
+        }
+    }
+
+    PRAGMA_UNROLL
+    for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
+        accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
+    }
+
+    return accum;
+}
+
+template<typename Accum, typename Compute, int kThreadGroupSize, typename T, int N>
+inline __device__ Accum qk_dot(const Array<T, N>& q, const Array<T, N>& k)
+{
+    Accum accum{};
+
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        accum += Accum(Compute(q[i]) * Compute(k[i]));
+    }
+
+    PRAGMA_UNROLL
+    for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
+        accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
+    }
+
+    return accum;
+}
+
+template<typename ComputeType, typename Tp, typename Tv, typename To, int N, int M>
+inline __device__ void fma_pv(Tp pr, const Array<Tv, N> (&v)[M], Array<To, N> (&o)[M])
+{
+    PRAGMA_UNROLL
+    for (int m = 0; m < M; ++m) {
+        PRAGMA_UNROLL
+        for (int n = 0; n < N; ++n) {
+            o[m][n] += To(ComputeType(v[m][n]) * ComputeType(pr));
+        }
+    }
+}
+
+template<typename ThreadMap, typename T, int N>
+inline __device__ Array<T, N> qk_max(Array<T, N> val, T* smem_red, int warp_id, int lane_id)
+{
+    constexpr int kWarpCount = ThreadMap::kWarpCount;
+
+    // warp maximum
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        PRAGMA_UNROLL
+        for (int mask = WARP_SIZE / 2; mask >= ThreadMap::kWarpThreadC; mask /= 2) {
+            val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
+        }
+        if (lane_id == 0) {
+            smem_red[i * kWarpCount + warp_id] = val[i];
+        }
+    }
+
+    __syncthreads();
+
+    // block maximum
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : -FLT_MAX;
+        PRAGMA_UNROLL
+        for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
+            val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
+        }
+        // braodcast to all threads
+        val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
+    }
+
+    return val;
+}
+
+template<int kWarpCount, typename T, int N>
+inline __device__ Array<T, N> blockSum(Array<T, N> val, T* smem_red, int warp_id, int lane_id)
+{
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        PRAGMA_UNROLL
+        for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
+            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
+        }
+        if (lane_id == 0) {
+            smem_red[i * kWarpCount + warp_id] = val[i];
+        }
+    }
+
+    __syncthreads();
+
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : T{};
+        PRAGMA_UNROLL
+        for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
+            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
+        }
+        val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
+    }
+
+    return val;
+}
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
new file mode 100644
index 0000000000..f3a246d66f
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
@@ -0,0 +1,47 @@
+#include "decoder_multihead_attention_template.h"
+
+#include <iostream>
+
+namespace turbomind {
+
+template<typename MHAType>
+bool Dump()
+{
+    using MapKv = typename MHAType::MapKv;
+
+    std::cout << "     warps: " << MapKv::kWarpCount << "\n";
+    std::cout << "     shape: (" << MapKv::kC << ", " << MapKv::kS << ")\n";
+    std::cout << "    access: (" << MapKv::kAccessC << ", " << 1 << ")\n";
+    std::cout << "warpThread: (" << MapKv::kWarpThreadC << ", " << MapKv::kWarpThreadS << ")\n";
+    std::cout << "warpAccess: (" << MapKv::kWarpAccessC << ", " << MapKv::kWarpAccessS << ")\n";
+    std::cout << "  warpIter: (" << MapKv::kWarpIterC << ", " << MapKv::kWarpIterS << ")\n";
+    std::cout << "      warp: (" << MapKv::kWarpC << ", " << MapKv::kWarpS << ")\n";
+    std::cout << "      iter: (" << MapKv::kIterC << ", " << MapKv::kIterS << ")\n";
+    std::cout << " footprint: (" << MapKv::kFootprintC << ", " << MapKv::kFootprintS << ")\n";
+    std::cout << "     delta: (" << MapKv::kDeltaC << ", " << MapKv::kDeltaS << ")\n";
+
+    return true;
+}
+
+template<typename T, int HeadDim>
+void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
+{
+    using MHAType = DecoderMultiHeadAttentionKernel<T, 1, HeadDim, 16, HeadDim, 1024, 5>;
+
+    [[maybe_unused]] static const bool init = Dump<MHAType>();
+
+    dim3 block(MHAType::kWarpCount * WARP_SIZE);
+    dim3 grid(params.num_kv_heads, params.batch_size);
+
+    const size_t kDynamicSmemSize = MHAType::GetDynamicSmemSize(params.max_timestep);
+    std::cout << "dynamic shared memory size: " << kDynamicSmemSize << "\n";
+
+    cudaFuncSetAttribute(
+        decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynamicSmemSize);
+
+    decoder_multihead_attention<MHAType><<<grid, block, kDynamicSmemSize>>>(params);
+}
+
+template void LaunchDecoderMultiheadAttention<half, 128>(const DecoderMultiHeadAttentionParams<half>& params);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h
new file mode 100644
index 0000000000..cdee4af1cd
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h
@@ -0,0 +1,9 @@
+
+#include "decoder_multihead_attention_params.h"
+
+namespace turbomind {
+
+template<typename T, int HeadDim>
+void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params);
+
+}
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
new file mode 100644
index 0000000000..4a6c6e8541
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
@@ -0,0 +1,49 @@
+#pragma once
+
+namespace turbomind {
+
+template<typename T>
+struct DecoderMultiHeadAttentionParams {
+    // token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]
+    T*  out;
+    T*  q;
+    T*  k;
+    T*  v;
+    int stride;
+
+    // bias, [qH, D] or [kvH, D]
+    T* q_bias;
+    T* k_bias;
+    T* v_bias;
+
+    // sequence-level buffers
+    int*  per_sample_length;
+    bool* finished;
+
+    // kv cache
+    void** per_sample_k_cache;  // [H, S, D]
+    void** per_sample_v_cache;  // [H, S, D]
+    size_t per_sample_kv_cache_offset;
+
+    // batch-level params
+    int batch_size;
+    int max_seq_len;
+    int max_timestep;  // max_timestep in the batch, used to compute smem sizes
+
+    // instance-level params
+    int num_heads;
+    int num_kv_heads;
+    int size_per_head;
+    float inv_sqrt_dh;
+
+    // rotary embedding
+    int   rotary_embedding_dim;
+    float rotary_embedding_base;
+    int   max_position_embeddings;
+    bool  use_dynamic_ntk;
+
+    // log(n) attention
+    bool use_logn_attn;
+};
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
new file mode 100644
index 0000000000..68f188eca2
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
@@ -0,0 +1,671 @@
+#pragma once
+
+#include "array_ops.h"
+#include "iterator.h"
+#include "src/turbomind/kernels/gemm_s_f16/common.h"
+#include "thread_map.h"
+#include <cuda_pipeline_primitives.h>
+
+#include "decoder_multihead_attention_params.h"
+
+namespace turbomind {
+
+template<typename T, int HeadPerCta, int MaxHeadDim, int KeyPerIter, int HeadDim, int SliceLen, int Stages>
+struct DecoderMultiHeadAttentionKernel {
+    using Dtype     = T;
+    using ParamType = DecoderMultiHeadAttentionParams<T>;
+
+    static constexpr int kWarpCount  = 4;
+    static constexpr int kHeadPerCta = HeadPerCta;
+    static constexpr int kMaxHeadDim = MaxHeadDim;
+    static constexpr int kKeyPerIter = KeyPerIter;
+    static constexpr int kHeadDim    = HeadDim;
+    static constexpr int kStages     = Stages;
+
+    static constexpr int kSliceLen     = SliceLen;
+    static constexpr int kIterPerSlice = kSliceLen / kKeyPerIter;
+
+    static constexpr int kVecKvSize    = sizeof(uint4) / sizeof(T);
+    static constexpr int kThreadPerKey = 8;
+
+    using VecKv      = Array<Dtype, kVecKvSize>;
+    using VecKvFloat = Array<float, kVecKvSize>;
+
+    static constexpr bool kChainedKv = false;
+
+    using MapKv  = ThreadMapKv<kMaxHeadDim, kKeyPerIter, kVecKvSize, kThreadPerKey, kWarpCount>;
+    using IterKv = turbomind::Iterator<T, MapKv, SliceLen, kStages, kChainedKv>;
+
+    static size_t GetDynamicSmemSize(int max_timestep)
+    {
+        size_t smem_kv_cache = IterKv::kSmemByteSize;
+        size_t smem_kv_align = 128;
+        size_t smem_qk       = sizeof(float) * kHeadPerCta * kSliceLen;
+        size_t smem_pr       = sizeof(float) * kHeadPerCta * kSliceLen;
+        // size_t smem_ctrl_lut = ((max_timestep + KeyPerIter - 1) / KeyPerIter * 2 + 31) / 32 * sizeof(uint32_t);
+        return smem_kv_align + smem_kv_cache + std::max(smem_qk, smem_pr);
+    }
+
+    // using AccumType   = float;
+    // using ComputeType = float;
+
+    using QkAccumType   = float;
+    using QkComputeType = float;
+
+    using PvAccumType   = float;
+    using PvComputeType = float;
+
+    struct SharedStorage {
+        __align__(16) Dtype Q[kHeadPerCta * kMaxHeadDim];
+        __align__(16) float O[kHeadPerCta * kMaxHeadDim];
+        float M[kHeadPerCta];  // max{dot(Q,  K^T  )}
+        float L[kHeadPerCta];  // sum{exp(s - S_max)}
+        float red_max[kHeadPerCta * kWarpCount];
+        float red_sum[kHeadPerCta * kWarpCount];
+    };
+
+    const ParamType& params_;
+
+    int head_idx_;
+    int batch_idx_;
+    int warp_id_;
+    int lane_id_;
+
+    int timestep_;
+    T*  k_cache_;  // [S, D]
+    T*  v_cache_;  // [S, D]
+
+    Dtype*    smem_Kv_;
+    float*    smem_S_;
+    float*    smem_P_;
+    Dtype*    smem_Q_;
+    float*    smem_M_;
+    float*    smem_L_;
+    float*    smem_O_;
+    float*    smem_red_max_;
+    float*    smem_red_sum_;
+    unsigned* smem_ctrl_;
+
+    __device__ bool thread0()
+    {
+        return blockIdx.x == 0 && threadIdx.x == 0;
+    }
+
+    __device__ DecoderMultiHeadAttentionKernel(const ParamType& params, SharedStorage& smem, uint8_t* dsmem):
+        params_(params)
+    {
+        smem_Kv_ = (Dtype*)dsmem;
+        smem_S_  = (float*)(smem_Kv_ + IterKv::kSizePerTile * kStages);  // [HeadPerCta * kSliceLen]
+        // smem_P_       = (float*)(smem_S_ + kHeadPerCta * kSliceLen);          // [HeadPerCta * kSliceLen]
+        smem_P_       = smem_S_;
+        smem_ctrl_    = (unsigned*)(smem_P_ + kHeadPerCta * kSliceLen);  // [max_timestep / kKeyPerStep / 32]
+        smem_Q_       = smem.Q;
+        smem_M_       = smem.M;
+        smem_L_       = smem.L;
+        smem_O_       = smem.O;
+        smem_red_max_ = smem.red_max;
+        smem_red_sum_ = smem.red_sum;
+
+        head_idx_  = blockIdx.x;
+        batch_idx_ = blockIdx.y;
+        warp_id_   = threadIdx.x / WARP_SIZE;
+        lane_id_   = threadIdx.x % WARP_SIZE;
+
+        timestep_ = params_.per_sample_length[batch_idx_];
+
+        /// TODO: block level kv cache
+        k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.per_sample_kv_cache_offset
+                   + head_idx_ * params_.max_seq_len * params_.size_per_head;
+        v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.per_sample_kv_cache_offset
+                   + head_idx_ * params_.max_seq_len * params_.size_per_head;
+    }
+
+    // [kkkk][vvvv][kkkk][vvvv][kkkk][vvvv][k][v]
+    // __device__ int is_last_iter_of_slice(int iter, int full, int partial)
+    // {
+    //     if (iter < full) {
+    //         return (iter + 1) % kIterPerSlice == 0;
+    //     }
+    //     else {
+    //         return (iter - full + 1) % partial == 0;
+    //     }
+    // }
+
+    __device__ void Prolugue()
+    {
+        // - Each warp is handling a row of Q
+        // - K/V are loaded redundantly only for the current step
+        static_assert(kMaxHeadDim % WARP_SIZE == 0);
+        static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
+
+        using VecQ = Array<T, kVecQSize>;
+
+        using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
+
+        static constexpr int kQVecPerThread  = MapQ::kIterC;
+        static constexpr int kQHeadPerThread = MapQ::kIterS;  // > 1 when #warp < #head
+
+        static_assert(kQVecPerThread == 1);
+
+        int2 offset   = MapQ::get_offset(warp_id_, lane_id_);
+        bool is_valid = offset.x < kMaxHeadDim && offset.y < kHeadPerCta;
+
+        if (!is_valid) {
+            return;
+        }
+
+        VecQ frag_Q[kQHeadPerThread];
+        VecQ frag_K;
+        VecQ frag_V;
+
+        // load qkv
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQHeadPerThread; ++s) {
+            int di = offset.x;
+            int qi = offset.y + s;
+            Ldg(frag_Q[s], &params_.q[batch_idx_ * params_.stride + (head_idx_ + qi) * kHeadDim + di]);
+        }
+        Ldg(frag_K, &params_.k[batch_idx_ * params_.stride + head_idx_ * kHeadDim + offset.x]);
+        Ldg(frag_V, &params_.v[batch_idx_ * params_.stride + head_idx_ * kHeadDim + offset.x]);
+
+        if (params_.q_bias) {
+            // load biases
+            VecQ bias_Q[kQHeadPerThread];
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                int di = offset.x;
+                int qi = offset.y + s;
+                Ldg(bias_Q[s], &params_.q_bias[(head_idx_ + qi) * kHeadDim + di]);
+            }
+            VecQ bias_K;
+            VecQ bias_V;
+            Ldg(bias_K, &params_.k_bias[head_idx_ * kHeadDim + offset.x]);
+            Ldg(bias_V, &params_.v_bias[head_idx_ * kHeadDim + offset.x]);
+
+            using namespace ops;
+            // apply biases
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                frag_Q[s] = frag_Q[s] + bias_Q[s];
+            }
+            frag_K = frag_K + bias_K;
+            frag_V = frag_V + bias_V;
+        }
+
+        // Apply rotary embedding
+        RotaryEmbedding<kVecQSize> rotary_emb(
+            params_.rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);
+
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQHeadPerThread; ++s) {
+            rotary_emb.apply(frag_Q[s]);
+        }
+        rotary_emb.apply(frag_K);
+
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQHeadPerThread; ++s) {
+            int         qi = offset.y + s;
+            QkAccumType qk = qk_dot<QkAccumType, QkComputeType, WARP_SIZE>(frag_Q[s], frag_K);
+            if (lane_id_ == 0) {
+                qk *= params_.inv_sqrt_dh;
+                // printf("qk_last[%d]=%f\n", head_idx_, qk);
+                smem_M_[qi] = qk;
+                smem_L_[qi] = 1.f;
+            }
+            // write Q and O
+            Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
+            Store(&smem_O_[qi * kMaxHeadDim + offset.x], cast<float>(frag_V));
+        }
+
+        // store
+        if (warp_id_ == 0) {
+            Store(&k_cache_[timestep_ * kMaxHeadDim + offset.x], frag_K);
+            Store(&v_cache_[timestep_ * kMaxHeadDim + offset.x], frag_V);
+        }
+    }
+
+    __device__ void PrefetchKvCache(IterKv& iter)
+    {
+        PRAGMA_UNROLL
+        for (int stage = 0; stage < kStages - 1; ++stage) {
+            iter.PrefetchStage();
+            CpAsyncCommit();
+        }
+    }
+
+    __device__ void CpAsyncWait()
+    {
+        __pipeline_wait_prior(kStages - 2);
+        // __syncwarp();
+        // __syncthreads();
+    }
+
+    __device__ void CpAsyncCommit()
+    {
+        __pipeline_commit();
+    }
+
+    __device__ void CpAsyncFlush()
+    {
+        __pipeline_commit();
+        __pipeline_wait_prior(0);
+    }
+
+    static constexpr int kKvVecPerThread = MapKv::kIterC;
+    static constexpr int kKvKeyPerThread = MapKv::kIterS;
+
+    struct FragmentQ {
+        VecKv data[kHeadPerCta][kKvVecPerThread];
+    };
+
+    struct State {
+        // Double buffering to hide smem/dequant latency
+        VecKv frag_Kv_buf[2][kKvVecPerThread];
+    };
+
+    static constexpr int kPrefetchCount = (IterKv::kIterCount + MapKv::kIterS - 1) / MapKv::kIterS;
+
+    __device__ void ComputeSlice(FragmentQ& frag_Q, State& state, const int2& offset, int step, int iter_length)
+    {
+
+        Array<float, kHeadPerCta> frag_M;
+        PRAGMA_UNROLL
+        for (int i = 0; i < kHeadPerCta; ++i) {
+            frag_M[i] = smem_M_[i];
+        }
+
+        IterKv iter_K(k_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_);
+        PrefetchKvCache(iter_K);
+        CpAsyncWait();
+
+        iter_K.Load(state.frag_Kv_buf[0]);
+        iter_K.PrefetchBatch(0, kPrefetchCount);
+        if (kKvKeyPerThread == 1) {
+            CpAsyncCommit();
+            CpAsyncWait();
+            iter_K.AdvancePrefetchStage();
+            iter_K.AdvanceComputeStage();
+        }
+
+        ///////////////////////////////////////////////////////////////////////////////////////////
+        /// Compute QK(Q, S) = Q(Q, D) * K^T(D, S)
+
+        PRAGMA_NO_UNROLL
+        for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
+            PRAGMA_UNROLL
+            for (int si = 0; si < kKvKeyPerThread; ++si) {
+                // smem -> rmem for next iter
+                iter_K.Load(state.frag_Kv_buf[(si + 1) % 2]);
+
+                // current iter's K fragment
+                auto& frag_K = state.frag_Kv_buf[si % 2];
+
+                const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
+
+                PRAGMA_UNROLL
+                for (int qi = 0; qi < kHeadPerCta; ++qi) {
+
+                    auto qk = qk_dot<QkAccumType, QkComputeType, kThreadPerKey>(frag_Q.data[qi], frag_K);
+
+                    // if (ti == 16) {
+                    //     for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    //         for (int i = 0; i < kVecKvSize; ++i) {
+                    //             printf("frag_Q = %f, frag_K[%d] = %f\n",
+                    //                    (float)frag_Q.data[qi][vi][i],
+                    //                    offset.x + vi * kVecKvSize + i,
+                    //                    (float)frag_K[vi][i]);
+                    //         }
+                    //     }
+                    // }
+
+                    qk *= params_.inv_sqrt_dh;
+
+                    if (step + local_offset < timestep_) {
+
+                        // group leader writes to smem
+                        if (threadIdx.x % kThreadPerKey == 0) {
+                            // printf("qk_%d = %f\n", step + local_offset, (float)qk);
+
+                            smem_S_[kSliceLen * qi + local_offset] = qk;
+
+                            // local max
+                            frag_M[qi] = fmaxf(frag_M[qi], qk);
+                        }
+                    }
+                }
+
+                iter_K.PrefetchBatch((si + 1) % kKvKeyPerThread, kPrefetchCount);
+
+                if (kKvKeyPerThread == 1 || si == kKvKeyPerThread - 2) {
+                    CpAsyncCommit();
+                    CpAsyncWait();
+                    iter_K.AdvancePrefetchStage();
+                    iter_K.AdvanceComputeStage();
+                }
+            }
+
+            // handle special case
+            if (kKvKeyPerThread == 1) {
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_Kv_buf[0][vi] = state.frag_Kv_buf[1][vi];
+                }
+            }
+        }
+
+        CpAsyncFlush();
+
+        __syncthreads();
+
+        Array<float, kHeadPerCta> exp_M_diff;
+        PRAGMA_UNROLL
+        for (int i = 0; i < kHeadPerCta; ++i) {
+            exp_M_diff[i] = smem_M_[i];
+        }
+
+        /// block synchronization
+        frag_M = qk_max<MapKv>(frag_M, smem_red_max_, warp_id_, lane_id_);
+
+        if (threadIdx.x == 0 && step == timestep_ - kSliceLen) {
+            // printf("frag_M[%d] = %f\n", head_idx_, (float)frag_M[0]);
+        }
+
+        // wait while smem_red_ is being used.
+        // __syncthreads();
+
+        PRAGMA_UNROLL
+        for (int i = 0; i < kHeadPerCta; ++i) {
+            // if (thread0()) {
+            //     printf("%f %f %f\n", (float)exp_M_diff[i], (float)frag_M[i], (float)__expf(exp_M_diff[i] -
+            //     frag_M[i]));
+            // }
+            // exp(m1 - m2)
+            exp_M_diff[i] = __expf(exp_M_diff[i] - frag_M[i]);
+
+            if (threadIdx.x == 0) {
+                smem_M_[i] = frag_M[i];
+            }
+        }
+
+        // __syncthreads();  // DEBUG
+
+        /////////////////////////////////////////////////////////////////////////////////////////
+        // / Compute softmax P(Q, S)
+        Array<float, kHeadPerCta> frag_L{};
+
+        for (int ti = threadIdx.x; ti < iter_length; ti += kWarpCount * WARP_SIZE) {
+            PRAGMA_UNROLL
+            for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                int   idx = qi * kSliceLen + ti;
+                float qk  = smem_S_[idx];
+                float pr  = expf(qk - frag_M[qi]);
+                // printf("smem_P[%d] = %f\n", ti, pr);
+                smem_P_[idx] = pr;
+                frag_L[qi] += pr;
+            }
+        }
+
+        // if (thread0()) {
+        // printf("frag_L0 = %f\n", (float)frag_L[0]);
+        // }
+
+        /// block synchronization
+        frag_L = blockSum<kWarpCount>(frag_L, smem_red_sum_, warp_id_, lane_id_);
+
+        if (thread0()) {
+            // printf("frag_L = %f\n", (float)frag_L[0]);
+        }
+
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            // exp(m1 - m2) * l1
+            frag_L[qi] += exp_M_diff[qi] * smem_L_[qi];
+        }
+
+        __syncthreads();
+
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            if (threadIdx.x == 0) {
+                smem_L_[qi] = frag_L[qi];
+            }
+        }
+
+        if (threadIdx.x == 0 && step == timestep_ - kSliceLen) {
+            // printf("frag_L'[%d] = %f\n", head_idx_, (float)frag_L[0]);
+        }
+
+        /////////////////////////////////////////////////////////////////////////////////////////
+        // / Compute O[H,D] = P[H,S] * V[S,D]
+        VecKvFloat frag_O[kHeadPerCta][kKvVecPerThread]{};  // value initialize
+                                                            // float      frag_Pr_buf[2][kHeadPerCta];
+
+        // ti = step + offset.y;
+
+        // int ti = step + offset.y;
+
+        // PRAGMA_UNROLL
+        // for (int qi = 0; qi < kHeadPerCta; ++qi) {
+        //     // prefetch Pr for first warp iter
+        //     frag_Pr_buf[0][qi] = smem_P_[qi * kSliceLen + ti];
+        // }
+        IterKv iter_V(v_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_);
+        PrefetchKvCache(iter_V);
+        CpAsyncWait();
+
+        iter_V.Load(state.frag_Kv_buf[0]);
+        iter_V.PrefetchBatch(0, kPrefetchCount);
+        if (kKvKeyPerThread == 1) {
+            CpAsyncCommit();
+            CpAsyncWait();
+            iter_V.AdvancePrefetchStage();
+            iter_V.AdvanceComputeStage();
+        }
+
+        PRAGMA_NO_UNROLL
+        for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
+            PRAGMA_UNROLL
+            for (int si = 0; si < kKvKeyPerThread; ++si) {
+                // Load value cache for next warp iter
+                iter_V.Load(state.frag_Kv_buf[(si + 1) % 2]);
+
+                // Load Pr for next warp iter
+                // PRAGMA_UNROLL
+                // for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                //     frag_Pr_buf[(si + 1) % 2][qi] = smem_P_[qi * kSliceLen + (ti + MapKv::kWarpAccessS)];
+                // }
+
+                auto& frag_V = state.frag_Kv_buf[si % 2];
+                // auto& frag_P = frag_Pr_buf[si % 2];
+
+                const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
+
+                float frag_P[kHeadPerCta];
+                PRAGMA_UNROLL
+                for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                    frag_P[qi] = smem_P_[qi * kSliceLen + local_offset];
+                }
+
+                if (step + local_offset < timestep_) {
+                    PRAGMA_UNROLL
+                    for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                        fma_pv<PvComputeType>(frag_P[qi], frag_V, frag_O[qi]);
+                    }
+                    // for (int i = 0; i < kKvVecPerThread; ++i) {
+                    //     for (int j = 0; j < kVecKvSize; ++j) {
+                    //         printf("frag_V %f\n", (float)frag_V[i][j]);
+                    //     }
+                    // }
+                    // if (threadIdx.x % MapKv::kWarpThreadC == 0) {
+                    //     printf("frag_P[%d] %f\n", ti, frag_P[0]);
+                    // }
+                }
+
+                iter_V.PrefetchBatch((si + 1) % kKvKeyPerThread, kPrefetchCount);
+
+                if (kKvKeyPerThread == 1 || si == kKvKeyPerThread - 2) {
+                    CpAsyncCommit();
+                    CpAsyncWait();
+                    iter_V.AdvancePrefetchStage();
+                    iter_V.AdvanceComputeStage();
+                }
+            }
+
+            // handle special case
+            if (kKvKeyPerThread == 1) {
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_Kv_buf[0][vi] = state.frag_Kv_buf[1][vi];
+                }
+                // PRAGMA_UNROLL
+                // for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                //     frag_Pr_buf[0][qi] = frag_Pr_buf[1][qi];
+                // }
+            }
+        }
+
+        /// warp reduce over S dim
+        PRAGMA_UNROLL
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            PRAGMA_UNROLL
+            for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                PRAGMA_UNROLL
+                for (int i = 0; i < kVecKvSize; ++i) {
+                    // reduce over warp thread S
+                    PRAGMA_UNROLL
+                    for (int mask = WARP_SIZE / 2; mask >= MapKv::kWarpThreadC; mask /= 2) {
+                        frag_O[qi][vi][i] += __shfl_xor_sync(uint32_t(-1), frag_O[qi][vi][i], mask);
+                    }
+                }
+            }
+        }
+
+        // __syncthreads();
+
+        PRAGMA_UNROLL
+        for (int gi = 0; gi < MapKv::kS; gi += MapKv::kFootprintS) {
+            PRAGMA_UNROLL
+            for (int qi = 0; qi < kHeadPerCta; ++qi) {
+                PRAGMA_UNROLL
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    if (offset.y == gi) {
+                        // ! 2-way bank conflict
+                        auto& smem_O = (VecKvFloat&)smem_O_[qi * kMaxHeadDim + offset.x + vi * MapKv::kDeltaC];
+                        using namespace ops;
+                        auto tmp_O = smem_O;
+                        if (offset.y == 0) {
+                            tmp_O = tmp_O * exp_M_diff[qi];
+                        }
+                        // ! 2-way bank conflict
+                        smem_O = tmp_O + frag_O[qi][vi];
+                    }
+                }
+            }
+            __syncthreads();
+        }
+
+        CpAsyncFlush();
+    }
+
+    __device__ void LoopKv()
+    {
+        const int2 offset = MapKv::get_offset(warp_id_, lane_id_);
+
+        ///////////////////////////////////////////////////////////////////////////////////////////
+        /// Load Q from shared memory.
+        /// NOTE: There will be bank-conflict when sizeof(VecKv) > 16 (e.g. KV is quantized)
+        FragmentQ frag_Q;
+
+        PRAGMA_UNROLL
+        for (int qi = 0; qi < kHeadPerCta; ++qi) {
+            PRAGMA_UNROLL
+            for (int c = 0; c < kKvVecPerThread; ++c) {
+                const int di       = offset.x + MapKv::kDeltaC * c;
+                frag_Q.data[qi][c] = (VecKv&)smem_Q_[qi * kMaxHeadDim + di];
+            }
+        }
+
+        State state;
+
+        PRAGMA_NO_UNROLL
+        for (int step = 0; step < params_.max_timestep; step += kSliceLen) {
+            int iter_length = min(params_.max_timestep - step, kSliceLen);
+            ComputeSlice(frag_Q, state, offset, step, iter_length);
+        }
+    }
+
+    __device__ void Run()
+    {
+        if constexpr (0) {
+            for (int i = threadIdx.x; i < kStages * IterKv::kSizePerTile; i += blockDim.x) {
+                smem_Kv_[i] = Dtype(0);
+            }
+            __syncthreads();
+        }
+
+        // Compute attention for current step
+        Prolugue();
+
+        __syncthreads();
+
+        // Iterate over K/V
+        LoopKv();
+
+        __syncthreads();
+
+        // Normalize outputs & write to device memory
+        Epilogue();
+    }
+
+    __device__ void Epilogue()
+    {
+        static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
+
+        using VecQ      = Array<T, kVecQSize>;
+        using VecQFloat = Array<float, kVecQSize>;
+
+        using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
+
+        static constexpr int kQkvHeadPerThread = MapQ::kIterS;
+        static_assert(kQkvHeadPerThread == 1);
+
+        int2 offset = MapQ::get_offset(warp_id_, lane_id_);
+
+        bool is_valid = offset.x < kMaxHeadDim && offset.y < kHeadPerCta;
+        if (!is_valid) {
+            return;
+        }
+
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQkvHeadPerThread; ++s) {
+            int   di    = offset.x;
+            int   qi    = offset.y + s;
+            float scale = __fdividef(1.f, smem_L_[qi] + 1e-6f);
+            // float scale = 1.f;
+            using namespace ops;
+            VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
+            /// FIXME: `(head_idx_ + qi)` doesn't look right
+            Store(&params_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
+                  cast<Dtype>(frag_O));
+        }
+    }
+};
+
+extern __shared__ uint8_t dynamic_smem[];
+
+template<typename MHAType, typename ParamType = typename MHAType::ParamType>
+__global__ void decoder_multihead_attention(ParamType params)
+{
+    __shared__ typename MHAType::SharedStorage shared_storage;
+
+    uint8_t* smem_ptr = dynamic_smem;
+
+    // Align dynamic smem ptr to 128 byte boundary, this eliminates excessive wavefronts from smem to L1
+    // but it does not improve performance
+    if constexpr (0) {
+        int misalign = (uintptr_t)smem_ptr % 128;
+        if (misalign) {
+            smem_ptr += 128 - misalign;
+        }
+    }
+
+    MHAType{params, shared_storage, smem_ptr}.Run();
+}
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/iterator.h b/src/turbomind/kernels/decoder_mha/iterator.h
new file mode 100644
index 0000000000..190ce4820f
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/iterator.h
@@ -0,0 +1,340 @@
+#pragma once
+
+#include "../gemm_s_f16/common.h"
+#include "array_ops.h"
+
+namespace turbomind {
+
+// k0,k1,k2,v0,v1,v2,k3,k4,k5,v3,v4,v5
+
+template<int HeadDim, int ElemSize, int SliceLen>
+struct BlockIterator {
+    const void* kv_cache_[2];
+
+    static constexpr int kStride = HeadDim * ElemSize * SliceLen;
+
+    BlockIterator() = default;
+
+    __device__ BlockIterator(const void* k_cache, void* v_cache)
+    {
+        kv_cache_[0] = k_cache;
+        kv_cache_[1] = v_cache;
+    }
+
+    __device__ const void* Next()
+    {
+        // if (blockIdx.x == 0 && threadIdx.x == 0) {
+        //     printf("Next()\n");
+        // }
+        const void* ret = kv_cache_[0];
+        const void* tmp = (const uint8_t*)kv_cache_[0] + kStride;
+        kv_cache_[0]    = kv_cache_[1];
+        kv_cache_[1]    = tmp;
+        return ret;
+    }
+};
+
+struct BlockIterator2 {
+    const void*  prefetch_data_;
+    const void** block_ptrs_;
+
+    __device__ BlockIterator2(const void** block_ptrs): block_ptrs_{block_ptrs}
+    {
+        // prefetch first ptr
+        prefetch_data_ = *block_ptrs_++;
+    }
+
+    __device__ const void* Next()
+    {
+        // return prefetched ptr
+        const void* ret = prefetch_data_;
+        // prefetch next ptr
+        prefetch_data_ = *block_ptrs_++;
+
+        return ret;
+    }
+};
+
+template<typename T, typename ThreadMap, int BlockLen, int Stages, bool Chained>
+struct Iterator {
+
+    using ElementType = T;
+    using AccessType  = Array<T, ThreadMap::kAccessC>;
+
+    static constexpr int kElementSize = sizeof(ElementType);
+    static constexpr int kAccessSize  = sizeof(AccessType);
+
+    static constexpr int kSizePerTile  = ThreadMap::kS * ThreadMap::kC;
+    static constexpr int kSmemByteSize = kElementSize * Stages * kSizePerTile;
+
+    BlockIterator<ThreadMap::kC, sizeof(T), BlockLen> block_iterator_;
+    // SignalIterator                                    signal_iterator_;
+
+    static constexpr int kIterCount = ThreadMap::kIterS * ThreadMap::kIterC;
+
+    static constexpr int kStepC = ThreadMap::kDeltaC;
+    static constexpr int kStepS = ThreadMap::kDeltaS * ThreadMap::kC - ThreadMap::kIterC * kStepC;
+    static constexpr int kStepK =
+        ThreadMap::kS * ThreadMap::kC - ThreadMap::kIterS * ThreadMap::kDeltaS * ThreadMap::kC;
+
+    // (C, S, K) = (64, 384, 1536)
+
+    // initial offset, used to reset src_offset when switching to a new block
+    int init_offset_;
+
+    int src_offset_;
+    int dst_offset_;
+
+    int iter_c_;
+    int iter_b_;
+
+    int  seq_len_;
+    int  offset_s_;
+    bool is_valid_s_;
+
+    const T* src_;
+    T*       smem_;
+
+    int smem_read_offset_;
+
+    struct __align__(sizeof(AccessType)) SharedStorage
+    {
+        T smem_[Stages][kSizePerTile];
+    };
+
+    __device__
+    Iterator(void* k_cache, void* v_cache, T* smem, uint32_t* smem_signal, int seq_len, int warp_id, int lane_id):
+        block_iterator_(k_cache, v_cache)  //, signal_iterator_(smem_signal)
+    {
+        src_  = (const T*)block_iterator_.Next();
+        smem_ = smem;
+
+        int2 init_offset = ThreadMap::get_offset(warp_id, lane_id);
+
+        init_offset_ = init_offset.x + init_offset.y * ThreadMap::kC;
+
+        // printf("%d\n", init_offset.x);
+
+        src_offset_       = init_offset_;
+        dst_offset_       = init_offset_;
+        smem_read_offset_ = init_offset_;
+
+        iter_c_ = 0;
+        iter_b_ = 0;
+
+        seq_len_    = seq_len;
+        offset_s_   = init_offset.y;
+        is_valid_s_ = offset_s_ < seq_len;
+    }
+
+    Iterator() = default;
+
+    __device__ Iterator(T* src, T* smem, int step, int seq_len, int warp_id, int lane_id)
+    {
+        src_  = src;
+        smem_ = smem;
+
+        int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
+
+        init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
+
+        src_offset_       = init_offset_ + step * ThreadMap::kC;
+        dst_offset_       = init_offset_;
+        smem_read_offset_ = init_offset_;
+
+        iter_c_ = 0;
+        iter_b_ = 0;
+
+        seq_len_    = seq_len;
+        offset_s_   = init_offset_cs.y + step;
+        is_valid_s_ = offset_s_ < seq_len;
+    }
+
+    __device__ void PrefetchStage()
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < kIterCount; ++i) {
+            Prefetch(is_valid_s_);
+            ++(*this);
+        }
+        AdvancePrefetchStage();
+    }
+
+    __device__ void PrefetchBatch(int batch_idx, int batch_size)
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < batch_size; ++i) {
+            if (batch_idx * batch_size + i < kIterCount) {
+                Prefetch(is_valid_s_);
+                ++(*this);
+            }
+        }
+    }
+
+    __device__ Iterator& operator++()
+    {
+        src_offset_ += kStepC;
+        dst_offset_ += kStepC;
+        ++iter_c_;
+        if (iter_c_ < ThreadMap::kIterC) {
+            return *this;
+        }
+
+        iter_c_ = 0;
+        src_offset_ += kStepS;
+        dst_offset_ += kStepS;
+
+        offset_s_ += ThreadMap::kDeltaS;
+        is_valid_s_ = offset_s_ < seq_len_;
+
+        return *this;
+    }
+
+    __device__ void AdvancePrefetchStage()
+    {
+        src_offset_ += kStepK;
+        dst_offset_ += kStepK;
+
+        offset_s_ += ThreadMap::kS - ThreadMap::kIterS * ThreadMap::kDeltaS;
+
+        is_valid_s_ = offset_s_ < seq_len_;
+
+        // if (init_offset_ / ThreadMap::kC == 0) {
+        //     int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
+        //     int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
+        //     int c = dst_offset_ % ThreadMap::kC;
+        //     printf("tid=%d, k=%d, s=%d, c=%d, offset_s=%d, valid_s=%d, init_s=%d\n",
+        //            threadIdx.x,
+        //            k,
+        //            s,
+        //            c,
+        //            offset_s_,
+        //            (int)is_valid_s_,
+        //            init_offset_ / ThreadMap::kC);
+        // }
+
+        // if (threadIdx.x == 0 && blockIdx.x == 0) {
+        //     printf("next stage %d\n", offset_s_);
+        // }
+
+        if (dst_offset_ >= Stages * kSizePerTile) {
+            dst_offset_ -= Stages * kSizePerTile;
+        }
+
+        // if constexpr (Chained) {
+        //     bool is_last_stage = *signal_iterator_;
+
+        //     ++signal_iterator_;
+
+        //     if (is_last_stage) {
+        //         AdvancePrefetchSlice();
+        //     }
+        // }
+    }
+
+#if 0
+    __device__ void AdvancePrefetchSlice()
+    {
+        src_        = (const T*)block_iterator_.Next();
+        src_offset_ = init_offset_;
+
+        ++iter_b_;
+        offset_s_   = iter_b_ / 2 * BlockLen + init_offset_ / ThreadMap::kC;
+        is_valid_s_ = offset_s_ < seq_len_;
+    }
+#endif
+
+    static __device__ void CpAsync(T* dst, const T* src, bool mask)
+    {
+        const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);
+        constexpr int cp_size      = sizeof(AccessType);
+        static_assert(cp_size == 16);
+        // cp.async.cg.shared.global.L2::256B
+        asm volatile("{\n"
+                     "  .reg .pred p;\n"
+                     "  setp.ne.b32 p, %0, 0;\n"
+                     "  @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
+                     "}\n" ::"r"((int)mask),
+                     "r"(smem_int_ptr),
+                     "l"(src),
+                     "n"(cp_size));
+    }
+
+    static __device__ void Copy(T* dst, const T* src, bool mask)
+    {
+        if (mask) {
+            Ldg(*(AccessType*)dst, src);
+        }
+    }
+
+    __device__ void Prefetch(bool mask)
+    {
+        // if (blockIdx.x == 0 && threadIdx.x == 0 && mask) {
+        //     int  c    = src_offset_ % ThreadMap::kC;
+        //     int  s    = src_offset_ / ThreadMap::kC;
+        //     bool fuck = src_offset_ >= 128 * 4096;
+        //     printf("%d %d %d %d %s\n", (int)threadIdx.x, c, s, offset_s_, fuck ? "FUCK" : "");
+        // }
+
+        // if (blockIdx.x == 0 && threadIdx.x == 0) {
+        //     int  c    = dst_offset_ % ThreadMap::kC;
+        //     int  s    = dst_offset_ / ThreadMap::kC;
+        //     bool fuck = (dst_offset_ >= Stages * kSizePerTile);
+        //     printf("%d %d %d %s\n", c, s, dst_offset_, fuck ? "FUCK" : "");
+        // }
+
+        // if (init_offset_ / ThreadMap::kC == 0) {
+        //     int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
+        //     int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
+        //     int c = dst_offset_ % ThreadMap::kC;
+        //     printf("tid=%d, k=%d, s=%d, c=%d, offset_s=%d, valid_s=%d, init_s=%d, mask=%d\n",
+        //            threadIdx.x,
+        //            k,
+        //            s,
+        //            c,
+        //            offset_s_,
+        //            (int)is_valid_s_,
+        //            init_offset_ / ThreadMap::kC,
+        //            (int)mask);
+        // }
+
+        // CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
+        Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
+    }
+
+    __device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
+    {
+
+        // if (init_offset_ / ThreadMap::kC == 0) {
+        //     int k = smem_read_offset_ / (ThreadMap::kS * ThreadMap::kC);
+        //     int s = smem_read_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
+        //     int c = smem_read_offset_ % ThreadMap::kC;
+        //     printf("tid=%d, k=%d, s=%d, c=%d, init_s=%d\n", threadIdx.x, k, s, c, init_offset_ / ThreadMap::kC);
+        // }
+
+        for (int vi = 0; vi < ThreadMap::kIterC; ++vi) {
+
+            // int offset = smem_read_offset_ + vi * ThreadMap::kDeltaC;
+            // if (offset >= Stages * kSizePerTile || offset % sizeof(AccessType)) {
+            //     int c = offset % ThreadMap::kC;
+            //     int s = offset / ThreadMap::kC;
+            //     printf("%d %d %d\n", c, s, offset);
+            // }
+
+            Lds(frag[vi], smem_ + smem_read_offset_ + vi * ThreadMap::kDeltaC);
+        }
+
+        smem_read_offset_ += ThreadMap::kDeltaS * ThreadMap::kC;
+    }
+
+    __device__ void AdvanceComputeStage()
+    {
+        smem_read_offset_ += kStepK;
+
+        if (smem_read_offset_ >= Stages * kSizePerTile) {
+            smem_read_offset_ -= Stages * kSizePerTile;
+        }
+    }
+};
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.cu b/src/turbomind/kernels/decoder_mha/kv_cache.cu
new file mode 100644
index 0000000000..ad106db5f8
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.cu
@@ -0,0 +1,149 @@
+#include "../gemm_s_f16/common.h"
+// #include "cute/tensor.hpp"
+#include <cuda_fp16.h>
+
+namespace turbomind {
+
+// [S/x, H, x, D] <-> [S/y, H, y, D]
+
+template<typename T>
+__device__ void ConvertBlockSize(const T** src_block_ptrs,
+                                 T**       dst_block_ptrs,
+                                 int       src_block_size,
+                                 int       dst_block_size,
+                                 int       heads,
+                                 int       dims,
+                                 int       seq_len)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    size_t count = (size_t)heads * seq_len * dims;
+
+    for (size_t i = (threadIdx.x + blockIdx.x * blockDim.x) * kVecSize; i < count;
+         i += blockDim.x * gridDim.x * kVecSize) {
+        // get coords from [H, S, D]
+        int di = i % dims;
+        int ii = i / dims;
+
+        int si = ii % seq_len;
+        int hi = ii / seq_len;
+
+        // compute indices into src
+        int src_block_index  = si / src_block_size;
+        int src_block_offset = hi * src_block_size * dims + si % src_block_size * dims + di;
+
+        // compute indices into dst
+        int dst_block_index  = si / dst_block_size;
+        int dst_block_offset = hi * dst_block_size * dims + si % dst_block_size * dims + di;
+
+        const T* src_block = src_block_ptrs[src_block_index];
+        T*       dst_block = dst_block_ptrs[dst_block_index];
+
+        uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
+
+        // __stcg(reinterpret_cast<uint4*>(dst_block + dst_block_offset), data);
+        *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
+    }
+}
+
+template<typename T>
+__global__ void
+LinearToBlocksKernel(const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len)
+{
+    __shared__ const T* src_block_ptr[1];
+
+    if (threadIdx.x == 0) {
+        src_block_ptr[0] = src;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_block_ptr, dst_block_ptrs, seq_len, dst_block_size, heads, dims, seq_len);
+}
+
+template<typename T>
+__global__ void
+BlocksToLinearKernel(const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len)
+{
+    __shared__ T* dst_block_ptr[1];
+
+    if (threadIdx.x == 0) {
+        dst_block_ptr[0] = dst;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_block_ptrs, dst_block_ptr, src_block_size, seq_len, heads, dims, seq_len);
+}
+
+template<typename T>
+__global__ void BlocksToBlocksKernel(const T** src_block_ptrs,
+                                     T**       dst_block_ptrs,
+                                     int       src_block_size,
+                                     int       dst_block_size,
+                                     int       heads,
+                                     int       dims,
+                                     int       seq_len)
+{
+    ConvertBlockSize(src_block_ptrs, dst_block_ptrs, src_block_size, dst_block_size, heads, dims, seq_len);
+}
+
+template<typename T>
+void ConvertLinearToBlocks(
+    const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    int threads = 512;
+    int blocks  = std::min(512, (heads * seq_len * dims / kVecSize + threads - 1) / threads);
+
+    LinearToBlocksKernel<<<blocks, threads, 0, st>>>(src, dst_block_ptrs, dst_block_size, heads, dims, seq_len);
+}
+
+template<typename T>
+void ConvertBlocksToLinear(
+    const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    int threads = 256;
+    int blocks  = (heads * seq_len * dims / kVecSize + threads - 1) / threads;
+
+    BlocksToLinearKernel<<<blocks, threads, 0, st>>>(src_block_ptrs, dst, src_block_size, heads, dims, seq_len);
+}
+
+template<typename T>
+void ConvertBlocksToBlocks(const T**    src_block_ptrs,
+                           T**          dst_block_ptrs,
+                           int          src_block_size,
+                           int          dst_block_size,
+                           int          heads,
+                           int          dims,
+                           int          seq_len,
+                           cudaStream_t st)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    int threads = 512;
+    int blocks  = std::min(512, (heads * seq_len * dims / kVecSize + threads - 1) / threads);
+
+    BlocksToBlocksKernel<<<blocks, threads, 0, st>>>(
+        src_block_ptrs, dst_block_ptrs, src_block_size, dst_block_size, heads, dims, seq_len);
+}
+
+template void ConvertLinearToBlocks(
+    const half* src, half** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+
+template void ConvertBlocksToLinear(
+    const half** src_block_ptrs, half* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+
+template void ConvertBlocksToBlocks(const half** src_block_ptrs,
+                                    half**       dst_block_ptrs,
+                                    int          src_block_size,
+                                    int          dst_block_size,
+                                    int          heads,
+                                    int          dims,
+                                    int          seq_len,
+                                    cudaStream_t st);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.h b/src/turbomind/kernels/decoder_mha/kv_cache.h
new file mode 100644
index 0000000000..72758e4b08
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include <cuda_runtime.h>
+
+namespace turbomind {
+
+template<typename T>
+void ConvertLinearToBlocks(
+    const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+
+template<typename T>
+void ConvertBlocksToLinear(
+    const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+
+template<typename T>
+void ConvertBlocksToBlocks(const T**    src_block_ptrs,
+                           T**          dst_block_ptrs,
+                           int          src_block_size,
+                           int          dst_block_size,
+                           int          heads,
+                           int          dims,
+                           int          seq_len,
+                           cudaStream_t st);
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
new file mode 100644
index 0000000000..055c66ce48
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
@@ -0,0 +1,199 @@
+
+
+#include "decoder_multihead_attention.h"
+#include "kv_cache.h"
+#include "test_utils.h"
+#include <cmath>
+#include <iostream>
+#include <thrust/universal_vector.h>
+
+#include <numeric>
+
+using namespace turbomind;
+
+template<typename T>
+T* align(T* ptr, size_t alignment)
+{
+    size_t misalign = (uintptr_t)ptr % alignment;
+    std::cout << "misalignment: " << misalign << "\n";
+    if (misalign) {
+        return (T*)((uint8_t*)ptr + alignment - misalign);
+    }
+    return ptr;
+}
+
+// [S/S, H, S, D] <-> [S/b, H, b, D]
+
+void TestBlocks(thrust::universal_vector<half>& linear,
+                thrust::universal_vector<half>& _blocks,
+                thrust::universal_vector<half*> _ptrs,
+                int                             head_num,
+                int                             head_dim,
+                int                             block_size)
+{
+    int seq_len  = linear.size() / head_num / head_dim;
+    int n_blocks = (seq_len + block_size - 1) / block_size;
+
+    std::cout << "seq_len = " << seq_len << ", block_num = " << n_blocks << ", block_size = " << block_size << "\n";
+
+    thrust::universal_vector<half>  blocks(n_blocks * head_num * block_size * head_dim);
+    thrust::universal_vector<half*> ptrs(n_blocks);
+
+    std::vector<size_t> idxs(n_blocks);
+    std::iota(idxs.begin(), idxs.end(), 0);
+
+    std::random_shuffle(idxs.begin(), idxs.end());
+
+    for (int i = 0; i < n_blocks; ++i) {
+        ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
+    }
+
+    for (int i = 0; i < 10; ++i) {
+        ConvertLinearToBlocks(
+            (const half*)linear.data().get(), ptrs.data().get(), block_size, head_num, head_dim, seq_len, 0);
+    }
+    thrust::universal_vector<half> _linear(linear.size());
+
+    for (int i = 0; i < 10; ++i) {
+        ConvertBlocksToLinear(
+            (const half**)ptrs.data().get(), _linear.data().get(), block_size, head_num, head_dim, seq_len, 0);
+    }
+    cudaDeviceSynchronize();
+
+    Compare(_linear.data().get(), linear.data().get(), head_dim, head_num * seq_len);
+    exit(0);
+}
+
+int main(int argc, char* argv[])
+{
+    DecoderMultiHeadAttentionParams<half> params{};
+
+    // constexpr int kHeadNum = 108 * 4;
+    constexpr int kHeadNum    = 32;
+    constexpr int kHeadDim    = 128;
+    constexpr int kBatchSize  = 1;
+    constexpr int kContextLen = 8192;
+    constexpr int kTestIter   = 1;
+
+    RNG rng{};
+
+    thrust::universal_vector<half>  output(kBatchSize * kHeadNum * kHeadDim);
+    thrust::universal_vector<half>  qkv(kBatchSize * kHeadNum * 3 * kHeadDim);
+    thrust::universal_vector<bool>  finished(kBatchSize);
+    thrust::universal_vector<half>  k_cache(kBatchSize * (kContextLen + 1) * kHeadNum * kHeadDim);
+    thrust::universal_vector<half>  v_cache(kBatchSize * (kContextLen + 1) * kHeadNum * kHeadDim);
+    thrust::universal_vector<int>   sequence_lengths(kBatchSize);
+    thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
+    thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
+
+    rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
+
+    if (kContextLen) {
+        rng.GenerateNormal(k_cache.data().get(), kContextLen * kHeadNum * kHeadDim);
+        rng.GenerateNormal(v_cache.data().get(), kContextLen * kHeadNum * kHeadDim);
+    }
+
+    thrust::universal_vector<half>  k_blocks;
+    thrust::universal_vector<half*> k_ptrs;
+
+    TestBlocks(k_cache, k_blocks, k_ptrs, kHeadNum, kHeadDim, 128);
+
+    thrust::universal_vector<half>  k_cache_ref = k_cache;
+    thrust::universal_vector<half>  v_cache_ref = v_cache;
+    thrust::universal_vector<half>  output_ref  = output;
+    thrust::universal_vector<void*> k_cache_ref_ptrs(kBatchSize);
+    thrust::universal_vector<void*> v_cache_ref_ptrs(kBatchSize);
+
+    cudaDeviceSynchronize();
+
+    for (int i = 0; i < kBatchSize; ++i) {
+        sequence_lengths[i] = kContextLen;
+        k_cache_ptrs[i]     = k_cache.data().get() + i * k_cache.size() / kBatchSize;
+        v_cache_ptrs[i]     = v_cache.data().get() + i * v_cache.size() / kBatchSize;
+        k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
+        v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
+
+        align(k_cache_ptrs[i], 256);
+        align(v_cache_ptrs[i], 256);
+    }
+
+    // getchar();
+
+    params.out    = output_ref.data().get();
+    params.q      = qkv.data().get();
+    params.k      = params.q + kHeadNum * kHeadDim;
+    params.v      = params.k + kHeadNum * kHeadDim;
+    params.stride = 3 * kHeadNum * kHeadDim;
+
+    params.batch_size   = kBatchSize;
+    params.max_seq_len  = kContextLen + 1;
+    params.max_timestep = kContextLen;
+
+    params.finished           = finished.data().get();
+    params.per_sample_length  = sequence_lengths.data().get();
+    params.per_sample_k_cache = k_cache_ref_ptrs.data().get();
+    params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
+
+    params.per_sample_kv_cache_offset = 0;
+
+    params.num_heads     = kHeadNum;
+    params.num_kv_heads  = kHeadNum;
+    params.size_per_head = kHeadDim;
+    params.inv_sqrt_dh   = 1.f / std::sqrt((float)params.size_per_head);
+
+    params.rotary_embedding_dim  = kHeadDim;
+    params.rotary_embedding_base = 10000.f;
+
+    for (int i = 0; i < kTestIter; ++i) {
+        mmha_ft_reference(params, cudaStream_t{});
+    }
+
+    cudaDeviceSynchronize();
+    // if (auto err = cudaGetLastError(); err != cudaSuccess) {
+    //     std::cout << cudaGetErrorString(err) << "\n";
+    //     return -1;
+    // }
+    std::cout << "---------------------------------------------------\n";
+
+    params.out                = output.data().get();
+    params.per_sample_k_cache = k_cache_ptrs.data().get();
+    params.per_sample_v_cache = v_cache_ptrs.data().get();
+
+    std::vector<thrust::universal_vector<half>> outputs;
+
+    for (int i = 0; i < std::max(kTestIter, 10); ++i) {
+        LaunchDecoderMultiheadAttention<half, 128>(params);
+        if (auto err = cudaGetLastError(); err != cudaSuccess) {
+            std::cout << cudaGetErrorString(err) << "\n";
+            return -1;
+        }
+        if (1) {
+            outputs.push_back(output);
+        }
+    }
+
+    cudaDeviceSynchronize();
+
+    if (outputs.size() > 1) {
+        std::cout << "Evaluating consistency..." << std::endl;
+        for (size_t i = 1; i < outputs.size(); ++i) {
+            Compare(outputs[i].data().get(), outputs[0].data().get(), kHeadDim, kHeadNum);
+        }
+    }
+
+    std::cout << "---------------------------------------------------\n";
+
+    Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadNum, 0);
+
+    Compare(v_cache.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
+            v_cache_ref.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
+            kHeadDim,
+            kHeadNum);
+
+    Compare(k_cache.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
+            k_cache_ref.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
+            kHeadDim,
+            kHeadNum);
+
+    return 0;
+}
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.cu b/src/turbomind/kernels/decoder_mha/test_utils.cu
new file mode 100644
index 0000000000..3cf4262179
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/test_utils.cu
@@ -0,0 +1,236 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "test_utils.h"
+#include <cublas_v2.h>
+#include <curand.h>
+#include <curand_kernel.h>
+#include <fstream>
+#include <iostream>
+
+#define _CG_ABI_EXPERIMENTAL
+#include <cooperative_groups.h>
+#include <cooperative_groups/memcpy_async.h>
+#include <cooperative_groups/reduce.h>
+
+#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
+
+namespace turbomind {
+
+cublasHandle_t cublas_handle{};
+cudaStream_t   cublas_stream{};
+
+template<typename T>
+void Compare(const T* c, const T* c_ref, int m, int n, bool show, float rtol, float atol)
+{
+    float asums{};
+    float rsums{};
+    int   outliers{};
+    for (int nn = 0; nn < n; ++nn) {
+        float abs_diff_sum{};
+        float rel_diff_sum{};
+        for (int mm = 0; mm < m; ++mm) {
+            auto x = float(c[nn * m + mm]);
+            auto y = float(c_ref[nn * m + mm]);
+            // if (show) {
+            //     std::cout << x << "\t" << y << std::endl;
+            // }
+            auto abs_diff = std::abs(x - y);
+            auto rel_diff = abs_diff / std::abs(y + 1e-6f);
+            if (abs_diff > atol + rtol * std::abs(y)) {
+                ++outliers;
+                if (show) {
+                    std::cout << nn << "," << mm << "\t" << x << "\t" << y << std::endl;
+                }
+            }
+            abs_diff_sum += abs_diff;
+            rel_diff_sum += rel_diff;
+        }
+        asums += abs_diff_sum / m;
+        rsums += rel_diff_sum / m;
+    }
+    std::cout << "abs_diff = " << asums / n << " rel_diff = " << rsums / n << " outliers = " << outliers / (float)n
+              << std::endl;
+}
+
+template void Compare(const half* c, const half* c_ref, int m, int n, bool show, float rtol, float atol);
+template void Compare(const float* c, const float* c_ref, int m, int n, bool show, float rtol, float atol);
+
+void LoadBinary(const std::string& path, size_t size, void* dst)
+{
+    std::ifstream ifs(path, std::ios::binary | std::ios::in);
+    if (!ifs.is_open()) {
+        std::cerr << "failed to open " << path << "\n";
+        std::abort();
+    }
+    ifs.seekg(0, ifs.end);
+    auto actual_size_in_bytes = ifs.tellg();
+    ifs.seekg(0, ifs.beg);
+    if (size != actual_size_in_bytes) {
+        std::cerr << "[warning] file " << path << " has " << actual_size_in_bytes << " bytes, while " << size
+                  << " bytes is requested\n";
+    }
+    ifs.read((char*)dst, size);
+    std::cerr << "[info] " << path << " " << size << "\n";
+}
+
+namespace cg = cooperative_groups;
+
+__global__ void curand_init(curandState* state)
+{
+    auto tid = cg::this_grid().thread_rank();
+    curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);
+}
+
+template<typename T>
+__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)
+{
+    auto grid = cg::this_grid();
+    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
+        float tmp = curand_uniform(state + grid.thread_rank());
+        result[i] = T(scale * tmp + shift);
+    }
+}
+
+template<typename T>
+__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)
+{
+    auto grid = cg::this_grid();
+    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
+        float tmp = curand_normal(state + grid.thread_rank());
+        result[i] = T(scale * tmp + shift);
+    }
+}
+
+__global__ void curand_bytes(curandState* state, size_t count, uint* result)
+{
+    auto grid = cg::this_grid();
+    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
+        result[i] = curand(state + grid.thread_rank());
+    }
+}
+
+struct RNG::Impl {
+
+    curandState* states{};
+
+    Impl()
+    {
+        cudaMalloc(&states, sizeof(curandState) * 64 * 64);
+        curand_init<<<64, 64>>>(states);
+    }
+
+    ~Impl()
+    {
+        cudaFree(states);
+    }
+
+    void GenerateUInt(uint* out, size_t count)
+    {
+        curand_bytes<<<64, 64>>>(states, count, out);
+    }
+
+    template<typename T>
+    void GenerateUniform(T* out, size_t count, float scale, float shift)
+    {
+        curand_uniform<<<64, 64>>>(states, count, out, scale, shift);
+    }
+
+    template<typename T>
+    void GenerateNormal(T* out, size_t count, float scale, float shift)
+    {
+        curand_normal<<<64, 64>>>(states, count, out, scale, shift);
+    }
+};
+
+RNG::RNG(): impl_(std::make_unique<Impl>()) {}
+
+RNG::~RNG() = default;
+
+void RNG::GenerateUInt(uint* out, size_t count)
+{
+    impl_->GenerateUInt(out, count);
+}
+
+template<typename T>
+void RNG::GenerateUniform(T* out, size_t count, float scale, float shift)
+{
+    std::cout << count << std::endl;
+    impl_->GenerateUniform(out, count, scale, shift);
+}
+
+template<typename T>
+void RNG::GenerateNormal(T* out, size_t count, float scale, float shift)
+{
+    impl_->GenerateNormal(out, count, scale, shift);
+}
+
+template void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);
+template void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);
+
+template void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);
+template void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);
+
+template<typename T>
+struct SATypeConverter {
+    using Type = T;
+};
+
+template<>
+struct SATypeConverter<half> {
+    using Type = uint16_t;
+};
+
+template<typename T>
+void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t st)
+{
+    using DataType = typename SATypeConverter<T>::Type;
+
+    // Prepare the parameters.
+    Masked_multihead_attention_params<DataType> params{};
+    params.q_bias = reinterpret_cast<const DataType*>(p.q_bias);
+    params.k_bias = reinterpret_cast<const DataType*>(p.k_bias);
+    params.v_bias = reinterpret_cast<const DataType*>(p.v_bias);
+
+    // Set the output buffer.
+    params.out = reinterpret_cast<DataType*>(p.out);
+
+    // Set the input buffers.
+    // [B, nH + kvH, D]
+    params.q = reinterpret_cast<const DataType*>(p.q);
+    params.k = reinterpret_cast<const DataType*>(p.k);
+    params.v = reinterpret_cast<const DataType*>(p.v);
+
+    params.stride   = p.stride;
+    params.finished = p.finished;
+
+    params.k_cache_per_sample         = reinterpret_cast<DataType**>(p.per_sample_k_cache);
+    params.v_cache_per_sample         = reinterpret_cast<DataType**>(p.per_sample_v_cache);
+    params.kv_cache_per_sample_offset = p.per_sample_kv_cache_offset;
+    params.batch_size                 = p.batch_size;
+    params.beam_width                 = 1;
+    params.memory_max_len             = p.max_seq_len;
+    params.prefix_prompt_lengths      = 0;
+    params.max_prefix_prompt_length   = 0;
+    params.length_per_sample          = p.per_sample_length;  // max_input_length + current output length
+    // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
+    params.timestep     = p.max_timestep;  // was step - 1
+    params.num_heads    = p.num_heads;
+    params.num_kv_heads = p.num_kv_heads;
+
+    params.hidden_size_per_head    = p.size_per_head;
+    params.rotary_embedding_dim    = p.rotary_embedding_dim;
+    params.max_position_embeddings = p.max_position_embeddings;
+    params.use_dynamic_ntk         = p.use_dynamic_ntk;
+    params.use_logn_attn           = p.use_logn_attn;
+
+    // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
+    params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * 1.f);
+
+    params.int8_mode = 0;
+
+    masked_multihead_attention(params, st);
+}
+
+template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params, cudaStream_t st);
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.h b/src/turbomind/kernels/decoder_mha/test_utils.h
new file mode 100644
index 0000000000..16cd1fd69e
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/test_utils.h
@@ -0,0 +1,36 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "decoder_multihead_attention.h"
+#include <cuda_fp16.h>
+#include <memory>
+
+namespace turbomind {
+
+template<typename T>
+void Compare(const T* c, const T* c_ref, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);
+
+void LoadBinary(const std::string& path, size_t size, void* dst);
+
+class RNG {
+public:
+    RNG();
+    ~RNG();
+    void GenerateUInt(uint* out, size_t count);
+
+    template<typename T>
+    void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);
+
+    template<typename T>
+    void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);
+
+private:
+    struct Impl;
+    std::unique_ptr<Impl> impl_;
+};
+
+template<typename T>
+void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params, cudaStream_t st);
+
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/thread_map.h b/src/turbomind/kernels/decoder_mha/thread_map.h
new file mode 100644
index 0000000000..0968681c77
--- /dev/null
+++ b/src/turbomind/kernels/decoder_mha/thread_map.h
@@ -0,0 +1,96 @@
+#pragma once
+#include "../gemm_s_f16/common.h"
+#include "src/turbomind/kernels/custom_ar_kernels.h"
+
+namespace turbomind {
+
+template<int C, int S, int AccessC, int WarpCount>
+struct ThreadMapQ {
+    static constexpr int kWarpCount = WarpCount;
+    static constexpr int kAccessC   = AccessC;
+
+    static constexpr int kWarpThreadC = C / kAccessC;
+    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
+
+    static_assert(kWarpThreadC <= WARP_SIZE);
+
+    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;  // C
+    static constexpr int kWarpAccessS = kWarpThreadS;
+
+    static constexpr int kWarpIterC = C / kWarpAccessC;  // 1
+    static constexpr int kWarpIterS = S / kWarpAccessS;
+
+    static constexpr int kWarpC = 1;
+    static constexpr int kWarpS = kWarpCount;
+
+    static constexpr int kIterC = kWarpIterC / kWarpC;  // 1
+    static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
+
+    static constexpr int kFootprintC = kWarpAccessC * kIterC;  // C
+    static constexpr int kFootprintS = kWarpAccessS * kIterS;
+
+    static constexpr int kDeltaC = kWarpAccessC;
+    static constexpr int kDeltaS = kWarpAccessS;
+
+    __device__ static int2 get_offset(int warp_id, int lane_id)
+    {
+        int warp_offset_c = warp_id % kWarpC;
+        int warp_offset_s = warp_id / kWarpC;
+
+        int warp_thread_offset_c = lane_id % kWarpThreadC;
+        int warp_thread_offset_s = lane_id / kWarpThreadC;
+
+        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
+        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
+
+        return {cta_thread_offset_c, cta_thread_offset_s};
+    }
+};
+
+template<int C, int S, int AccessC, int WarpThreadC, int WarpCount>
+struct ThreadMapKv {
+    static constexpr int kC = C;
+    static constexpr int kS = S;
+
+    static constexpr int kWarpCount = WarpCount;
+    static constexpr int kAccessC   = AccessC;
+
+    static constexpr int kWarpThreadC = WarpThreadC;
+    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
+
+    static_assert(kWarpThreadC <= WARP_SIZE);
+
+    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
+    static constexpr int kWarpAccessS = kWarpThreadS;
+
+    static constexpr int kWarpIterC = C / kWarpAccessC;
+    static constexpr int kWarpIterS = S / kWarpAccessS;
+
+    static constexpr int kWarpC = 1;
+    static constexpr int kWarpS = kWarpCount;
+
+    static constexpr int kIterC = kWarpIterC / kWarpC;
+    static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
+
+    static constexpr int kFootprintC = kWarpAccessC * kIterC;
+    static constexpr int kFootprintS = kWarpAccessS * kIterS;
+
+    static constexpr int kDeltaC = kWarpAccessC;
+    static constexpr int kDeltaS = kWarpAccessS;
+
+    __device__ static int2 get_offset(int warp_id, int lane_id)
+    {
+        int warp_offset_c = warp_id % kWarpC;
+        int warp_offset_s = warp_id / kWarpC;
+
+        int warp_thread_offset_c = lane_id % kWarpThreadC;
+        int warp_thread_offset_s = lane_id / kWarpThreadC;
+
+        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
+        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
+
+        return {cta_thread_offset_c, cta_thread_offset_s};
+    }
+};
+
+}  // namespace turbomind

From d1e1c486a67c40c93e7e873c3edbaad24135720f Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 11 Sep 2023 09:19:24 +0000
Subject: [PATCH 02/56] block-level kv cache

---
 .../decoder_multihead_attention.cu            |   2 +-
 .../decoder_multihead_attention_params.h      |  15 ++-
 .../decoder_multihead_attention_template.h    | 114 +++++++++++++-----
 src/turbomind/kernels/decoder_mha/iterator.h  | 107 ++++++++--------
 src/turbomind/kernels/decoder_mha/kv_cache.cu |   1 -
 .../test_decoder_multihead_attention.cu       | 102 +++++++++++-----
 .../kernels/decoder_mha/test_utils.cu         |  20 +--
 .../kernels/decoder_mha/test_utils.h          |   3 +-
 8 files changed, 233 insertions(+), 131 deletions(-)

diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
index f3a246d66f..ff81f88ce6 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
@@ -33,7 +33,7 @@ void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& p
     dim3 block(MHAType::kWarpCount * WARP_SIZE);
     dim3 grid(params.num_kv_heads, params.batch_size);
 
-    const size_t kDynamicSmemSize = MHAType::GetDynamicSmemSize(params.max_timestep);
+    const size_t kDynamicSmemSize = MHAType::GetDynamicSmemSize(0);
     std::cout << "dynamic shared memory size: " << kDynamicSmemSize << "\n";
 
     cudaFuncSetAttribute(
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
index 4a6c6e8541..2cb2184643 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
@@ -25,15 +25,22 @@ struct DecoderMultiHeadAttentionParams {
     void** per_sample_v_cache;  // [H, S, D]
     size_t per_sample_kv_cache_offset;
 
+    /// cache layout M,[N,H,x,D]
+    /// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
+    /// 1. [L,sum(S),H,x,D]
+    void** k_cache_block_ptrs;  // X,[H,x,D]
+    void** v_cache_block_ptrs;  // X,[H,x,D]
+    int*   cu_ctxlens;          // [B+1]
+    int    kv_cache_block_size;
+
     // batch-level params
     int batch_size;
     int max_seq_len;
-    int max_timestep;  // max_timestep in the batch, used to compute smem sizes
 
     // instance-level params
-    int num_heads;
-    int num_kv_heads;
-    int size_per_head;
+    int   num_heads;
+    int   num_kv_heads;
+    int   size_per_head;
     float inv_sqrt_dh;
 
     // rotary embedding
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
index 68f188eca2..3dabd9042a 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
@@ -31,18 +31,17 @@ struct DecoderMultiHeadAttentionKernel {
     using VecKv      = Array<Dtype, kVecKvSize>;
     using VecKvFloat = Array<float, kVecKvSize>;
 
-    static constexpr bool kChainedKv = false;
+    static constexpr bool kUseBlockIter = true;
 
     using MapKv  = ThreadMapKv<kMaxHeadDim, kKeyPerIter, kVecKvSize, kThreadPerKey, kWarpCount>;
-    using IterKv = turbomind::Iterator<T, MapKv, SliceLen, kStages, kChainedKv>;
+    using IterKv = turbomind::Iterator<T, MapKv, SliceLen, kStages, kUseBlockIter>;
 
-    static size_t GetDynamicSmemSize(int max_timestep)
+    static size_t GetDynamicSmemSize(int)
     {
         size_t smem_kv_cache = IterKv::kSmemByteSize;
         size_t smem_kv_align = 128;
         size_t smem_qk       = sizeof(float) * kHeadPerCta * kSliceLen;
         size_t smem_pr       = sizeof(float) * kHeadPerCta * kSliceLen;
-        // size_t smem_ctrl_lut = ((max_timestep + KeyPerIter - 1) / KeyPerIter * 2 + 31) / 32 * sizeof(uint32_t);
         return smem_kv_align + smem_kv_cache + std::max(smem_qk, smem_pr);
     }
 
@@ -75,16 +74,18 @@ struct DecoderMultiHeadAttentionKernel {
     T*  k_cache_;  // [S, D]
     T*  v_cache_;  // [S, D]
 
-    Dtype*    smem_Kv_;
-    float*    smem_S_;
-    float*    smem_P_;
-    Dtype*    smem_Q_;
-    float*    smem_M_;
-    float*    smem_L_;
-    float*    smem_O_;
-    float*    smem_red_max_;
-    float*    smem_red_sum_;
-    unsigned* smem_ctrl_;
+    const void** k_cache_ptrs_;
+    const void** v_cache_ptrs_;
+
+    Dtype* smem_Kv_;
+    float* smem_S_;
+    float* smem_P_;
+    Dtype* smem_Q_;
+    float* smem_M_;
+    float* smem_L_;
+    float* smem_O_;
+    float* smem_red_max_;
+    float* smem_red_sum_;
 
     __device__ bool thread0()
     {
@@ -94,11 +95,9 @@ struct DecoderMultiHeadAttentionKernel {
     __device__ DecoderMultiHeadAttentionKernel(const ParamType& params, SharedStorage& smem, uint8_t* dsmem):
         params_(params)
     {
-        smem_Kv_ = (Dtype*)dsmem;
-        smem_S_  = (float*)(smem_Kv_ + IterKv::kSizePerTile * kStages);  // [HeadPerCta * kSliceLen]
-        // smem_P_       = (float*)(smem_S_ + kHeadPerCta * kSliceLen);          // [HeadPerCta * kSliceLen]
-        smem_P_       = smem_S_;
-        smem_ctrl_    = (unsigned*)(smem_P_ + kHeadPerCta * kSliceLen);  // [max_timestep / kKeyPerStep / 32]
+        smem_Kv_      = (Dtype*)dsmem;
+        smem_S_       = (float*)(smem_Kv_ + IterKv::kSizePerTile * kStages);  // [HeadPerCta * kSliceLen]
+        smem_P_       = smem_S_;  // ! reusing only works when S and P has same dtype
         smem_Q_       = smem.Q;
         smem_M_       = smem.M;
         smem_L_       = smem.L;
@@ -113,11 +112,22 @@ struct DecoderMultiHeadAttentionKernel {
 
         timestep_ = params_.per_sample_length[batch_idx_];
 
-        /// TODO: block level kv cache
-        k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.per_sample_kv_cache_offset
-                   + head_idx_ * params_.max_seq_len * params_.size_per_head;
-        v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.per_sample_kv_cache_offset
-                   + head_idx_ * params_.max_seq_len * params_.size_per_head;
+        if constexpr (kUseBlockIter) {
+            k_cache_ptrs_ = params_.k_cache_block_ptrs + params_.cu_ctxlens[batch_idx_];
+            v_cache_ptrs_ = params_.v_cache_block_ptrs + params_.cu_ctxlens[batch_idx_];
+            // if (thread0()) {
+            //     printf("%d %p %p\n",
+            //            params_.cu_ctxlens[batch_idx_],
+            //            params_.k_cache_block_ptrs,
+            //            params_.v_cache_block_ptrs);
+            // }
+        }
+        else {
+            k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.per_sample_kv_cache_offset
+                       + head_idx_ * params_.max_seq_len * params_.size_per_head;
+            v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.per_sample_kv_cache_offset
+                       + head_idx_ * params_.max_seq_len * params_.size_per_head;
+        }
     }
 
     // [kkkk][vvvv][kkkk][vvvv][kkkk][vvvv][k][v]
@@ -219,8 +229,21 @@ struct DecoderMultiHeadAttentionKernel {
 
         // store
         if (warp_id_ == 0) {
-            Store(&k_cache_[timestep_ * kMaxHeadDim + offset.x], frag_K);
-            Store(&v_cache_[timestep_ * kMaxHeadDim + offset.x], frag_V);
+            if constexpr (kUseBlockIter) {
+                int block_index  = timestep_ / params_.kv_cache_block_size;
+                int block_offset = timestep_ % params_.kv_cache_block_size;
+                // if (thread0()) {
+                //     printf("%d %d %p %p\n", block_index, block_offset, k_cache_ptrs_, v_cache_ptrs_);
+                // }
+                k_cache_ = (T*)k_cache_ptrs_[block_index] + head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                v_cache_ = (T*)v_cache_ptrs_[block_index] + head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K);
+                Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V);
+            }
+            else {
+                Store(&k_cache_[timestep_ * kHeadDim + offset.x], frag_K);
+                Store(&v_cache_[timestep_ * kHeadDim + offset.x], frag_V);
+            }
         }
     }
 
@@ -274,7 +297,22 @@ struct DecoderMultiHeadAttentionKernel {
             frag_M[i] = smem_M_[i];
         }
 
-        IterKv iter_K(k_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_);
+        IterKv iter_K;
+
+        if constexpr (kUseBlockIter) {
+            iter_K = {k_cache_ptrs_,
+                      params_.kv_cache_block_size,
+                      head_idx_,
+                      smem_Kv_,
+                      step,
+                      step + iter_length,
+                      warp_id_,
+                      lane_id_};
+        }
+        else {
+            iter_K = {k_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_};
+        }
+
         PrefetchKvCache(iter_K);
         CpAsyncWait();
 
@@ -446,7 +484,23 @@ struct DecoderMultiHeadAttentionKernel {
         //     // prefetch Pr for first warp iter
         //     frag_Pr_buf[0][qi] = smem_P_[qi * kSliceLen + ti];
         // }
-        IterKv iter_V(v_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_);
+
+        IterKv iter_V;
+
+        if constexpr (kUseBlockIter) {
+            iter_V = {v_cache_ptrs_,
+                      params_.kv_cache_block_size,
+                      head_idx_,
+                      smem_Kv_,
+                      step,
+                      step + iter_length,
+                      warp_id_,
+                      lane_id_};
+        }
+        else {
+            iter_V = {v_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_};
+        }
+
         PrefetchKvCache(iter_V);
         CpAsyncWait();
 
@@ -584,8 +638,8 @@ struct DecoderMultiHeadAttentionKernel {
         State state;
 
         PRAGMA_NO_UNROLL
-        for (int step = 0; step < params_.max_timestep; step += kSliceLen) {
-            int iter_length = min(params_.max_timestep - step, kSliceLen);
+        for (int step = 0; step < timestep_; step += kSliceLen) {
+            int iter_length = min(timestep_ - step, kSliceLen);
             ComputeSlice(frag_Q, state, offset, step, iter_length);
         }
     }
diff --git a/src/turbomind/kernels/decoder_mha/iterator.h b/src/turbomind/kernels/decoder_mha/iterator.h
index 190ce4820f..532cc23a22 100644
--- a/src/turbomind/kernels/decoder_mha/iterator.h
+++ b/src/turbomind/kernels/decoder_mha/iterator.h
@@ -5,57 +5,30 @@
 
 namespace turbomind {
 
-// k0,k1,k2,v0,v1,v2,k3,k4,k5,v3,v4,v5
-
-template<int HeadDim, int ElemSize, int SliceLen>
 struct BlockIterator {
-    const void* kv_cache_[2];
-
-    static constexpr int kStride = HeadDim * ElemSize * SliceLen;
+    const void** ptrs_;
+    const void*  prefetch_;
 
     BlockIterator() = default;
 
-    __device__ BlockIterator(const void* k_cache, void* v_cache)
-    {
-        kv_cache_[0] = k_cache;
-        kv_cache_[1] = v_cache;
-    }
-
-    __device__ const void* Next()
-    {
-        // if (blockIdx.x == 0 && threadIdx.x == 0) {
-        //     printf("Next()\n");
-        // }
-        const void* ret = kv_cache_[0];
-        const void* tmp = (const uint8_t*)kv_cache_[0] + kStride;
-        kv_cache_[0]    = kv_cache_[1];
-        kv_cache_[1]    = tmp;
-        return ret;
-    }
-};
-
-struct BlockIterator2 {
-    const void*  prefetch_data_;
-    const void** block_ptrs_;
-
-    __device__ BlockIterator2(const void** block_ptrs): block_ptrs_{block_ptrs}
+    __device__ BlockIterator(const void** block_ptrs): ptrs_{block_ptrs}
     {
         // prefetch first ptr
-        prefetch_data_ = *block_ptrs_++;
+        prefetch_ = *ptrs_++;
     }
 
     __device__ const void* Next()
     {
         // return prefetched ptr
-        const void* ret = prefetch_data_;
+        const void* ret = prefetch_;
         // prefetch next ptr
-        prefetch_data_ = *block_ptrs_++;
+        prefetch_ = *ptrs_++;
 
         return ret;
     }
 };
 
-template<typename T, typename ThreadMap, int BlockLen, int Stages, bool Chained>
+template<typename T, typename ThreadMap, int BlockLen, int Stages, bool kUseBlockIter>
 struct Iterator {
 
     using ElementType = T;
@@ -67,8 +40,7 @@ struct Iterator {
     static constexpr int kSizePerTile  = ThreadMap::kS * ThreadMap::kC;
     static constexpr int kSmemByteSize = kElementSize * Stages * kSizePerTile;
 
-    BlockIterator<ThreadMap::kC, sizeof(T), BlockLen> block_iterator_;
-    // SignalIterator                                    signal_iterator_;
+    BlockIterator block_iterator_;
 
     static constexpr int kIterCount = ThreadMap::kIterS * ThreadMap::kIterC;
 
@@ -92,6 +64,11 @@ struct Iterator {
     int  offset_s_;
     bool is_valid_s_;
 
+    int block_size_;
+    int block_k_;
+
+    int head_idx_;
+
     const T* src_;
     T*       smem_;
 
@@ -102,20 +79,18 @@ struct Iterator {
         T smem_[Stages][kSizePerTile];
     };
 
-    __device__
-    Iterator(void* k_cache, void* v_cache, T* smem, uint32_t* smem_signal, int seq_len, int warp_id, int lane_id):
-        block_iterator_(k_cache, v_cache)  //, signal_iterator_(smem_signal)
+    Iterator() = default;
+
+    __device__ Iterator(T* src, T* smem, int step, int seq_len, int warp_id, int lane_id)
     {
-        src_  = (const T*)block_iterator_.Next();
+        src_  = src;
         smem_ = smem;
 
-        int2 init_offset = ThreadMap::get_offset(warp_id, lane_id);
-
-        init_offset_ = init_offset.x + init_offset.y * ThreadMap::kC;
+        int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
 
-        // printf("%d\n", init_offset.x);
+        init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
 
-        src_offset_       = init_offset_;
+        src_offset_       = init_offset_ + step * ThreadMap::kC;
         dst_offset_       = init_offset_;
         smem_read_offset_ = init_offset_;
 
@@ -123,31 +98,39 @@ struct Iterator {
         iter_b_ = 0;
 
         seq_len_    = seq_len;
-        offset_s_   = init_offset.y;
+        offset_s_   = init_offset_cs.y + step;
         is_valid_s_ = offset_s_ < seq_len;
     }
 
-    Iterator() = default;
-
-    __device__ Iterator(T* src, T* smem, int step, int seq_len, int warp_id, int lane_id)
+    __device__ Iterator(
+        const void** block_ptrs, int block_size, int head_idx, T* smem, int step, int seqlen, int warp_id, int lane_id)
     {
-        src_  = src;
+        // src_  = src;
+        int block_index = step / block_size;
+        block_size_     = block_size;
+        block_k_        = (block_index + 1) * block_size - step;  // offset to next block
+        head_idx_       = head_idx;
+
+        block_iterator_ = BlockIterator(block_ptrs + block_index);
+
+        src_ = (const T*)block_iterator_.Next() + head_idx_ * block_size_ * ThreadMap::kC;
+
         smem_ = smem;
 
         int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
 
         init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
 
-        src_offset_       = init_offset_ + step * ThreadMap::kC;
+        src_offset_       = init_offset_ + (step - block_index * block_size) * ThreadMap::kC;
         dst_offset_       = init_offset_;
         smem_read_offset_ = init_offset_;
 
         iter_c_ = 0;
         iter_b_ = 0;
 
-        seq_len_    = seq_len;
+        seq_len_    = seqlen;
         offset_s_   = init_offset_cs.y + step;
-        is_valid_s_ = offset_s_ < seq_len;
+        is_valid_s_ = offset_s_ < seqlen;
     }
 
     __device__ void PrefetchStage()
@@ -199,6 +182,20 @@ struct Iterator {
 
         is_valid_s_ = offset_s_ < seq_len_;
 
+        if constexpr (kUseBlockIter) {
+            if (is_valid_s_) {
+                block_k_ -= ThreadMap::kS;
+                if (block_k_ == 0) {
+                    src_        = (const T*)block_iterator_.Next() + head_idx_ * block_size_ * ThreadMap::kC;
+                    block_k_    = block_size_;
+                    src_offset_ = init_offset_;
+                }
+            }
+            // if (blockIdx.x == 0 && threadIdx.x == 0) {
+            //     printf("%d %d %d\n", offset_s_, src_offset_ / ThreadMap::kC, block_k_);
+            // }
+        }
+
         // if (init_offset_ / ThreadMap::kC == 0) {
         //     int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
         //     int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
@@ -298,8 +295,8 @@ struct Iterator {
         //            (int)mask);
         // }
 
-        // CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
-        Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
+        CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
+        // Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
     }
 
     __device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.cu b/src/turbomind/kernels/decoder_mha/kv_cache.cu
index ad106db5f8..4e73d04b26 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.cu
@@ -41,7 +41,6 @@ __device__ void ConvertBlockSize(const T** src_block_ptrs,
 
         uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
 
-        // __stcg(reinterpret_cast<uint4*>(dst_block + dst_block_offset), data);
         *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
     }
 }
diff --git a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
index 055c66ce48..400df2547b 100644
--- a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
@@ -24,12 +24,12 @@ T* align(T* ptr, size_t alignment)
 
 // [S/S, H, S, D] <-> [S/b, H, b, D]
 
-void TestBlocks(thrust::universal_vector<half>& linear,
-                thrust::universal_vector<half>& _blocks,
-                thrust::universal_vector<half*> _ptrs,
-                int                             head_num,
-                int                             head_dim,
-                int                             block_size)
+void TestBlocks(thrust::universal_vector<half>&  linear,
+                thrust::universal_vector<half>&  _blocks,
+                thrust::universal_vector<half*>& _ptrs,
+                int                              head_num,
+                int                              head_dim,
+                int                              block_size)
 {
     int seq_len  = linear.size() / head_num / head_dim;
     int n_blocks = (seq_len + block_size - 1) / block_size;
@@ -60,20 +60,24 @@ void TestBlocks(thrust::universal_vector<half>& linear,
     }
     cudaDeviceSynchronize();
 
-    Compare(_linear.data().get(), linear.data().get(), head_dim, head_num * seq_len);
-    exit(0);
+    // Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, head_num * seq_len);
+
+    _blocks.swap(blocks);
+    _ptrs.swap(ptrs);
 }
 
 int main(int argc, char* argv[])
 {
     DecoderMultiHeadAttentionParams<half> params{};
 
-    // constexpr int kHeadNum = 108 * 4;
-    constexpr int kHeadNum    = 32;
-    constexpr int kHeadDim    = 128;
-    constexpr int kBatchSize  = 1;
-    constexpr int kContextLen = 8192;
-    constexpr int kTestIter   = 1;
+    constexpr int kHeadNum = 108 * 6;
+    // constexpr int kHeadNum     = 32 * 4;
+    constexpr int kHeadDim     = 128;
+    constexpr int kBatchSize   = 1;
+    constexpr int kContextLen  = 8192;
+    constexpr int kSequenceLen = kContextLen + 1;
+    constexpr int kBlockSz     = 128;
+    constexpr int kTestIter    = 1;
 
     RNG rng{};
 
@@ -85,18 +89,36 @@ int main(int argc, char* argv[])
     thrust::universal_vector<int>   sequence_lengths(kBatchSize);
     thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
     thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
+    thrust::universal_vector<int>   cu_ctxlens(kBatchSize + 1);
 
     rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
 
     if (kContextLen) {
-        rng.GenerateNormal(k_cache.data().get(), kContextLen * kHeadNum * kHeadDim);
-        rng.GenerateNormal(v_cache.data().get(), kContextLen * kHeadNum * kHeadDim);
+        rng.GenerateNormal(k_cache.data().get(), kHeadNum * kSequenceLen * kHeadDim);
+        rng.GenerateNormal(v_cache.data().get(), kHeadNum * kSequenceLen * kHeadDim);
+
+        cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim,
+                          sizeof(half) * kSequenceLen * kHeadDim,
+                          0,
+                          sizeof(half) * kHeadDim,
+                          kHeadNum);
+
+        cudaMemset2DAsync(v_cache.data().get() + kContextLen * kHeadDim,
+                          sizeof(half) * kSequenceLen * kHeadDim,
+                          0,
+                          sizeof(half) * kHeadDim,
+                          kHeadNum);
     }
 
     thrust::universal_vector<half>  k_blocks;
     thrust::universal_vector<half*> k_ptrs;
 
-    TestBlocks(k_cache, k_blocks, k_ptrs, kHeadNum, kHeadDim, 128);
+    TestBlocks(k_cache, k_blocks, k_ptrs, kHeadNum, kHeadDim, kBlockSz);
+
+    thrust::universal_vector<half>  v_blocks;
+    thrust::universal_vector<half*> v_ptrs;
+
+    TestBlocks(v_cache, v_blocks, v_ptrs, kHeadNum, kHeadDim, kBlockSz);
 
     thrust::universal_vector<half>  k_cache_ref = k_cache;
     thrust::universal_vector<half>  v_cache_ref = v_cache;
@@ -112,9 +134,10 @@ int main(int argc, char* argv[])
         v_cache_ptrs[i]     = v_cache.data().get() + i * v_cache.size() / kBatchSize;
         k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
         v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
+        cu_ctxlens[i + 1]   = cu_ctxlens[i] + kContextLen;
 
-        align(k_cache_ptrs[i], 256);
-        align(v_cache_ptrs[i], 256);
+        // align(k_cache_ptrs[i], 256);
+        // align(v_cache_ptrs[i], 256);
     }
 
     // getchar();
@@ -125,15 +148,19 @@ int main(int argc, char* argv[])
     params.v      = params.k + kHeadNum * kHeadDim;
     params.stride = 3 * kHeadNum * kHeadDim;
 
-    params.batch_size   = kBatchSize;
-    params.max_seq_len  = kContextLen + 1;
-    params.max_timestep = kContextLen;
+    params.batch_size  = kBatchSize;
+    params.max_seq_len = kContextLen + 1;
+    params.cu_ctxlens  = cu_ctxlens.data().get();
 
-    params.finished           = finished.data().get();
-    params.per_sample_length  = sequence_lengths.data().get();
-    params.per_sample_k_cache = k_cache_ref_ptrs.data().get();
-    params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
+    printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
+    params.k_cache_block_ptrs  = (void**)k_ptrs.data().get();
+    params.v_cache_block_ptrs  = (void**)v_ptrs.data().get();
+    params.kv_cache_block_size = kBlockSz;
 
+    params.finished                   = finished.data().get();
+    params.per_sample_length          = sequence_lengths.data().get();
+    params.per_sample_k_cache         = k_cache_ref_ptrs.data().get();
+    params.per_sample_v_cache         = v_cache_ref_ptrs.data().get();
     params.per_sample_kv_cache_offset = 0;
 
     params.num_heads     = kHeadNum;
@@ -172,26 +199,37 @@ int main(int argc, char* argv[])
         }
     }
 
+    if (1) {
+        ConvertBlocksToLinear(
+            (const half**)k_ptrs.data().get(), k_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
+        ConvertBlocksToLinear(
+            (const half**)v_ptrs.data().get(), v_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
+    }
+
     cudaDeviceSynchronize();
 
     if (outputs.size() > 1) {
         std::cout << "Evaluating consistency..." << std::endl;
         for (size_t i = 1; i < outputs.size(); ++i) {
-            Compare(outputs[i].data().get(), outputs[0].data().get(), kHeadDim, kHeadNum);
+            Compare(outputs[i].data().get(), outputs[0].data().get(), kHeadDim, kHeadDim, kHeadNum);
         }
     }
 
     std::cout << "---------------------------------------------------\n";
 
-    Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadNum, 0);
+    Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadDim, kHeadNum, 0);
+
+    // [H, S, D]
 
-    Compare(v_cache.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
-            v_cache_ref.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
+    Compare(k_cache.data().get() + kContextLen * kHeadDim,
+            k_cache_ref.data().get() + kContextLen * kHeadDim,
+            kSequenceLen * kHeadDim,
             kHeadDim,
             kHeadNum);
 
-    Compare(k_cache.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
-            k_cache_ref.data().get() + (kContextLen - 0) * kHeadNum * kHeadDim,
+    Compare(v_cache.data().get() + kContextLen * kHeadDim,
+            v_cache_ref.data().get() + kContextLen * kHeadDim,
+            kSequenceLen * kHeadDim,
             kHeadDim,
             kHeadNum);
 
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.cu b/src/turbomind/kernels/decoder_mha/test_utils.cu
index 3cf4262179..46f582bc0a 100644
--- a/src/turbomind/kernels/decoder_mha/test_utils.cu
+++ b/src/turbomind/kernels/decoder_mha/test_utils.cu
@@ -20,7 +20,7 @@ cublasHandle_t cublas_handle{};
 cudaStream_t   cublas_stream{};
 
 template<typename T>
-void Compare(const T* c, const T* c_ref, int m, int n, bool show, float rtol, float atol)
+void Compare(const T* src, const T* ref, size_t stride, int m, int n, bool show, float rtol, float atol)
 {
     float asums{};
     float rsums{};
@@ -29,8 +29,8 @@ void Compare(const T* c, const T* c_ref, int m, int n, bool show, float rtol, fl
         float abs_diff_sum{};
         float rel_diff_sum{};
         for (int mm = 0; mm < m; ++mm) {
-            auto x = float(c[nn * m + mm]);
-            auto y = float(c_ref[nn * m + mm]);
+            auto x = float(src[nn * stride + mm]);
+            auto y = float(ref[nn * stride + mm]);
             // if (show) {
             //     std::cout << x << "\t" << y << std::endl;
             // }
@@ -52,8 +52,9 @@ void Compare(const T* c, const T* c_ref, int m, int n, bool show, float rtol, fl
               << std::endl;
 }
 
-template void Compare(const half* c, const half* c_ref, int m, int n, bool show, float rtol, float atol);
-template void Compare(const float* c, const float* c_ref, int m, int n, bool show, float rtol, float atol);
+template void Compare(const half* src, const half* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
+template void
+Compare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
 
 void LoadBinary(const std::string& path, size_t size, void* dst)
 {
@@ -212,8 +213,13 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
     params.prefix_prompt_lengths      = 0;
     params.max_prefix_prompt_length   = 0;
     params.length_per_sample          = p.per_sample_length;  // max_input_length + current output length
-    // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
-    params.timestep     = p.max_timestep;  // was step - 1
+
+    for (int i = 0; i < p.batch_size; ++i) {
+        params.timestep = std::max(params.timestep, p.cu_ctxlens[i + 1] - p.cu_ctxlens[i]);
+    }
+
+    std::cout << "timestep = " << params.timestep << "\n";
+
     params.num_heads    = p.num_heads;
     params.num_kv_heads = p.num_kv_heads;
 
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.h b/src/turbomind/kernels/decoder_mha/test_utils.h
index 16cd1fd69e..ecfedcb53f 100644
--- a/src/turbomind/kernels/decoder_mha/test_utils.h
+++ b/src/turbomind/kernels/decoder_mha/test_utils.h
@@ -9,7 +9,8 @@
 namespace turbomind {
 
 template<typename T>
-void Compare(const T* c, const T* c_ref, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);
+void Compare(
+    const T* src, const T* ref, size_t stride, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);
 
 void LoadBinary(const std::string& path, size_t size, void* dst);
 

From a9ff3ce16c05ab6167358faee092f5826645052e Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 21 Sep 2023 13:49:42 +0000
Subject: [PATCH 03/56] `BlockManager` & `SequenceManager`

---
 src/turbomind/kernels/CMakeLists.txt          |    1 +
 .../decoder_masked_multihead_attention_128.cu |   26 +-
 ...er_masked_multihead_attention_template.cuh |   29 +-
 .../decoder_multihead_attention.cu            |    2 +-
 .../test_decoder_multihead_attention.cu       |    4 +-
 src/turbomind/models/llama/BlockManager.cc    |  187 +++
 src/turbomind/models/llama/BlockManager.h     |  122 ++
 src/turbomind/models/llama/CMakeLists.txt     |    2 +
 src/turbomind/models/llama/LlamaBatch.cc      | 1183 +++++++++--------
 src/turbomind/models/llama/LlamaBatch.h       |  156 ++-
 .../llama/LlamaContextAttentionLayer.cc       |   10 +-
 .../models/llama/LlamaContextDecoder.cc       |    1 +
 src/turbomind/models/llama/LlamaDecoder.cc    |    1 +
 src/turbomind/models/llama/LlamaV2.cc         |  137 +-
 src/turbomind/models/llama/LlamaV2.h          |   47 +-
 src/turbomind/models/llama/Request.h          |   27 +-
 src/turbomind/models/llama/SequenceManager.cc |  368 +++++
 src/turbomind/models/llama/SequenceManager.h  |   96 ++
 18 files changed, 1623 insertions(+), 776 deletions(-)
 create mode 100644 src/turbomind/models/llama/BlockManager.cc
 create mode 100644 src/turbomind/models/llama/BlockManager.h
 create mode 100644 src/turbomind/models/llama/SequenceManager.cc
 create mode 100644 src/turbomind/models/llama/SequenceManager.h

diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt
index 7c014845dd..473e579c45 100644
--- a/src/turbomind/kernels/CMakeLists.txt
+++ b/src/turbomind/kernels/CMakeLists.txt
@@ -71,3 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
 set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
 
 add_subdirectory(gemm_s_f16)
+add_subdirectory(decoder_mha)
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
index 329bcd6484..39b88a016b 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
@@ -28,19 +28,27 @@
 
 #define MMHA_LAUNCH_KERNEL(                                                                                            \
     T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream)                      \
+    auto   func    = &mmha::masked_multihead_attention_kernel<T,                                                       \
+                                                         Dh,                                                      \
+                                                         Dh_MAX,                                                  \
+                                                         THDS_PER_KEY,                                            \
+                                                         THDS_PER_VALUE,                                          \
+                                                         THDS_PER_BLOCK,                                          \
+                                                         HAS_BEAMS,                                               \
+                                                         QUANT_POLICY>;                                           \
     size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK);                              \
-    dim3   grid(params.num_heads, params.batch_size);                                                                  \
-    mmha::masked_multihead_attention_kernel<T,                                                                         \
-                                            Dh,                                                                        \
-                                            Dh_MAX,                                                                    \
-                                            THDS_PER_KEY,                                                              \
-                                            THDS_PER_VALUE,                                                            \
-                                            THDS_PER_BLOCK,                                                            \
-                                            HAS_BEAMS,                                                                 \
-                                            QUANT_POLICY><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
+    cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz);                                  \
+    dim3 grid(params.num_heads, params.batch_size);                                                                    \
+    func<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+// cudaFuncAttributes attr{};                                                                                         \
+// cudaFuncGetAttributes(&attr, func);                                                                                \
+// std::cout << "static_smem_sz: " << attr.sharedSizeBytes << std::endl;                                              \
+// std::cout << "max_dynamic_smem: " << attr.maxDynamicSharedSizeBytes << std::endl;                                  \
+// std::cout << "dynamic_smem_sz: " << smem_sz << std::endl;                                                          \
+
 template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
 void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
 {
diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
index 6b9101abb0..8a247e4a55 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
@@ -79,8 +79,7 @@ namespace mmha {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int Dh>
-struct Qk_vec_m_ {
-};
+struct Qk_vec_m_ {};
 
 template<>
 struct Qk_vec_m_<float, 32> {
@@ -180,8 +179,7 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int THREADS_PER_KEY>
-struct K_vec_m_ {
-};
+struct K_vec_m_ {};
 
 template<>
 struct K_vec_m_<float, 4> {
@@ -262,8 +260,7 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int V_VEC_SIZE>
-struct V_vec_m_ {
-};
+struct V_vec_m_ {};
 
 template<>
 struct V_vec_m_<float, 1> {
@@ -343,8 +340,7 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> {
 
 #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
 template<typename T>
-struct Qk_vec_acum_fp32_ {
-};
+struct Qk_vec_acum_fp32_ {};
 
 template<>
 struct Qk_vec_acum_fp32_<float> {
@@ -426,8 +422,7 @@ struct Qk_vec_acum_fp32_<fp8_4_t> {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T>
-struct K_vec_acum_fp32_ {
-};
+struct K_vec_acum_fp32_ {};
 
 template<>
 struct K_vec_acum_fp32_<float> {
@@ -489,8 +484,7 @@ struct K_vec_acum_fp32_<fp8_4_t> {
 
 #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
 template<typename T>
-struct V_vec_acum_fp32_ {
-};
+struct V_vec_acum_fp32_ {};
 
 template<>
 struct V_vec_acum_fp32_<float> {
@@ -1471,6 +1465,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
         }
         // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
 
+        printf("QK_last[%d] = %f\n", hi, qk);
+
         qk_max                        = qk;
         qk_smem[tlength - first_step] = qk;
         // qk_smem[params.timestep] = qk;
@@ -1595,6 +1591,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
 
                 qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
             }
+            // printf("QK_%d = %f\n", (int)ti, qk);
             qk_max                   = is_mask ? qk_max : fmaxf(qk_max, qk);
             qk_smem[ti - first_step] = qk;
         }
@@ -1631,6 +1628,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
     // Broadcast to all the threads in the warp.
     qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
 
+    if (threadIdx.x == 0) {
+        printf("QK_MAX[%d] = %f\n", hi, (float)qk_max);
+    }
+
     // Compute the logits and start the sum.
     float sum = 0.f;
     // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
@@ -1656,6 +1657,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
     // Compute the sum.
     sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
 
+    if (threadIdx.x == 0) {
+        printf("SUM[%d] = %f\n", hi, (float)sum);
+    }
+
     // Normalize the logits.
     float inv_sum = __fdividef(1.f, sum + 1.e-6f);
 
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
index ff81f88ce6..cc6edaf230 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
@@ -26,7 +26,7 @@ bool Dump()
 template<typename T, int HeadDim>
 void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
 {
-    using MHAType = DecoderMultiHeadAttentionKernel<T, 1, HeadDim, 16, HeadDim, 1024, 5>;
+    using MHAType = DecoderMultiHeadAttentionKernel<T, 1, HeadDim, 16, HeadDim, 2048, 6>;
 
     [[maybe_unused]] static const bool init = Dump<MHAType>();
 
diff --git a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
index 400df2547b..8a5bc46d5f 100644
--- a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
@@ -70,13 +70,13 @@ int main(int argc, char* argv[])
 {
     DecoderMultiHeadAttentionParams<half> params{};
 
-    constexpr int kHeadNum = 108 * 6;
+    constexpr int kHeadNum = 108 * 4;
     // constexpr int kHeadNum     = 32 * 4;
     constexpr int kHeadDim     = 128;
     constexpr int kBatchSize   = 1;
     constexpr int kContextLen  = 8192;
     constexpr int kSequenceLen = kContextLen + 1;
-    constexpr int kBlockSz     = 128;
+    constexpr int kBlockSz     = 256;
     constexpr int kTestIter    = 1;
 
     RNG rng{};
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
new file mode 100644
index 0000000000..e46c0fc073
--- /dev/null
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -0,0 +1,187 @@
+#include "src/turbomind/models/llama/BlockManager.h"
+// #include "src/turbomind/models/llama/utility.h"
+#include <algorithm>
+#include <iterator>
+#include <stdexcept>
+
+namespace turbomind {
+
+BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator):
+    block_size_(block_size), allocator_(allocator)
+{
+    if (block_count < 1.) {
+        max_block_count_ = GetBlockCount(block_size, block_count);
+    }
+    else {
+        max_block_count_ = block_count;
+    }
+
+    if (chunk_size == 0) {
+        chunk_size_ = static_cast<int>(std::sqrt(max_block_count_));
+    }
+    else if (chunk_size < 0) {
+        chunk_size_ = max_block_count_;
+    }
+
+    blocks_.reserve(max_block_count_);
+
+    active_ids_.reserve(max_block_count_);
+    cached_ids_.reserve(max_block_count_);
+    free_ids_.reserve(max_block_count_);
+
+    // pre-allocate first chunk
+    Malloc();
+}
+
+BlockManager::~BlockManager()
+{
+    for (auto& chunk : chunks_) {
+        allocator_->free(&chunk);
+    }
+}
+
+bool BlockManager::Malloc()
+{
+    auto chunk_size = std::min<int>(chunk_size_, max_block_count_ - blocks_.size());
+
+    if (!chunk_size) {
+        return false;
+    }
+
+    auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
+    if (!ptr) {
+        return false;
+    }
+
+    chunks_.push_back(ptr);
+
+    for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
+        auto& block     = blocks_.emplace_back();
+        block.ref_count = 0;
+        block.id        = (int)blocks_.size() - 1;
+        block.timestamp = 0;
+        block.data      = ptr;
+
+        free_ids_.push_back(block.id);
+    }
+
+    return true;
+}
+
+size_t BlockManager::GetBlockCount(size_t block_size, double ratio)
+{
+    size_t free{};
+    size_t total{};
+    check_cuda_error(cudaMemGetInfo(&free, &total));
+    return static_cast<size_t>(free * ratio / block_size);
+}
+
+void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
+{
+    std::vector<int> src1(src.size() - delta.size());
+    std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
+    src.swap(src1);
+
+    std::vector<int> dst1(dst.size() + delta.size());
+    std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
+    dst.swap(dst1);
+}
+
+std::vector<const Block*> BlockManager::Allocate(int count)
+{
+    while (free_ids_.size() < count) {
+        if (!Malloc()) {
+            throw std::runtime_error("out of memory");
+        }
+    }
+
+    std::vector<const Block*> ret;
+
+    std::vector<int> idxs;
+    idxs.reserve(count);
+
+    for (int i = 0; i < count; ++i) {
+        int idx     = free_ids_[i];
+        idxs[i]     = idx;
+        auto& block = blocks_[idx];
+        FT_CHECK(block.ref_count == 0);
+        FT_CHECK(block.timestamp == 0);
+        block.ref_count = 1;
+        block.unique_id = unique_id_++;
+        ret.push_back(&block);
+    }
+
+    Move(free_ids_, idxs, active_ids_);
+
+    return ret;
+}
+
+void BlockManager::Evict(int count)
+{
+    std::vector<int> idxs(cached_ids_);
+    // get first `count` cached ids according to timestamp
+    std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
+        return blocks_[i].timestamp < blocks_[j].timestamp;
+    });
+    idxs.resize(count);
+
+    // sort the retrieved ids
+    std::sort(idxs.begin(), idxs.end());
+
+    // set as free
+    for (const auto& idx : idxs) {
+        blocks_[idx].timestamp = 0;
+    }
+
+    Move(cached_ids_, idxs, free_ids_);
+}
+
+void BlockManager::Release(const std::vector<const Block*>& bs)
+{
+    std::vector<int> cached;
+
+    for (const auto& p : bs) {
+        auto& block = blocks_[p->id];
+        if (--block.ref_count == 0) {
+            cached.push_back(block.id);
+        }
+    }
+
+    std::sort(cached.begin(), cached.end());
+
+    Move(active_ids_, cached, cached_ids_);
+}
+
+void BlockManager::Retain(const std::vector<const Block*>& bs)
+{
+    for (const auto& p : bs) {
+        FT_CHECK(is_active(*p));
+        ++const_cast<Block*>(p)->ref_count;
+    }
+}
+
+void BlockManager::Touch(const std::vector<const Block*>& bs)
+{
+    std::for_each(bs.crbegin(), bs.crend(), [this](const Block* p) {
+        FT_CHECK(is_active(*p));
+        const_cast<Block*>(p)->timestamp = timestamp_++;
+    });
+}
+
+Snapshot BlockManager::TakeSnapshot()
+{
+    std::vector<int> ref_count(blocks_.size());
+    for (const auto& idx : active_ids_) {
+        ref_count[idx] = blocks_[idx].ref_count;
+    }
+    return {(int)active_ids_.size(), (int)cached_ids_.size(), (int)free_ids_.size(), std::move(ref_count)};
+}
+
+std::ostream& operator<<(std::ostream& os, const Block& block)
+{
+    os << "Block[id=" << block.id << ",ref_count=" << block.ref_count << ",unique_id=" << block.unique_id
+       << ",timestamp=" << block.timestamp << ",data=" << block.data << "]";
+    return os;
+}
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h
new file mode 100644
index 0000000000..362b4cfb16
--- /dev/null
+++ b/src/turbomind/models/llama/BlockManager.h
@@ -0,0 +1,122 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include "src/turbomind/utils/allocator.h"
+#include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/logger.h"
+#include <algorithm>
+#include <cstdint>
+#include <cuda_runtime.h>
+#include <iterator>
+#include <numeric>
+#include <queue>
+#include <unordered_map>
+#include <vector>
+
+namespace turbomind {
+
+// [L, H, S, D]
+
+// [L, S/x, H, x, D]
+
+struct Block {
+    int      id;  // fixed linear id in the pool
+    int      ref_count;
+    uint64_t unique_id;  // unique for every block allocation
+    uint64_t timestamp;
+    void*    data;
+
+    friend std::ostream& operator<<(std::ostream& os, const Block& block);
+};
+
+inline bool is_active(const Block& block)
+{
+    return block.ref_count > 0;
+}
+
+inline bool is_cached(const Block& block)
+{
+    return block.ref_count == 0 && block.timestamp > 0;
+}
+
+inline bool is_free(const Block& block)
+{
+    return block.ref_count == 0 && block.timestamp == 0;
+}
+
+struct Snapshot {
+    int              active;
+    int              cached;
+    int              free;
+    std::vector<int> ref_count;
+};
+
+class BlockManager {
+public:
+    explicit BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator);
+
+    ~BlockManager();
+
+    // free -> active
+    std::vector<const Block*> Allocate(int count);
+
+    // active -> cached
+    void Release(const std::vector<const Block*>& bs);
+
+    // cached -> free
+    void Evict(int count);
+
+    // active -> active
+    void Retain(const std::vector<const Block*>& bs);
+
+    // increase timestamp in reversed order
+    void Touch(const std::vector<const Block*>& bs);
+
+    Snapshot TakeSnapshot();
+
+    int max_block_count() const noexcept
+    {
+        return max_block_count_;
+    }
+
+    int active_count() const noexcept {
+        return active_ids_.size();
+    }
+
+    int cached_count() const noexcept {
+        return cached_ids_.size();
+    }
+
+    int free_count() const noexcept {
+        return free_ids_.size();
+    }
+
+private:
+    static size_t GetBlockCount(size_t block_size, double ratio);
+
+    // move indices between sets
+    static void Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst);
+
+    // allocate a chunk of blocks
+    bool Malloc();
+
+private:
+    size_t      block_size_;
+    int         max_block_count_{};
+    int         chunk_size_{};
+    IAllocator* allocator_;
+
+    std::vector<void*> chunks_;
+
+    std::vector<int> active_ids_;
+    std::vector<int> cached_ids_;
+    std::vector<int> free_ids_;
+
+    std::vector<Block> blocks_;  // < 100k
+
+    uint64_t unique_id_{1UL << 63};
+    uint64_t timestamp_{1};
+};
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt
index d7751e7d63..1da6fcec16 100644
--- a/src/turbomind/models/llama/CMakeLists.txt
+++ b/src/turbomind/models/llama/CMakeLists.txt
@@ -10,6 +10,8 @@ add_library(Llama STATIC
         LlamaV2.cc
         LlamaBatch.cc
         LlamaCacheManager.cc
+        BlockManager.cc
+        SequenceManager.cc
         LlamaContextDecoder.cc
         LlamaContextAttentionLayer.cc
         LlamaDecoderSelfAttentionLayer.cc
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 83db7ad65d..b105be6d56 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -6,36 +6,39 @@
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/LlamaV2.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
+#include <algorithm>
 #include <cstdint>
 #include <iomanip>
+#include <math.h>
 #include <sstream>
 #include <unordered_map>
 
 namespace turbomind {
 
 template<typename T>
-void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs,
-                                   std::vector<std::shared_ptr<Request>>& infer_reqs)
+void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
 {
     std::unordered_map<uint64_t, int> occurrence;
 
-    auto count_occurrence = [&occurrence](const std::vector<std::shared_ptr<Request>>& rs) {
+    auto count_occurrence = [&occurrence](const Requests& rs) {
         for (const auto& r : rs) {
             ++occurrence[r->id];
         }
     };
 
-    auto invalidate = [](const char* type, std::shared_ptr<Request>& req, int ec) {
-        TM_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
+    auto reject = [](const char* type, std::shared_ptr<Request>& req, int ec) {
+        TM_LOG_WARNING(
+            "[RejectInvalidRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
         req->signal.set_value(ec);
         req.reset();
     };
 
-    auto handle_conflict_or_invalid = [this, &occurrence, &invalidate](std::vector<std::shared_ptr<Request>>& rs,
-                                                                       const char*                            type) {
+    auto handle_conflict_or_invalid = [this, &occurrence, &reject](Requests& rs, const char* type) {
         for (auto& r : rs) {
             if (r) {
                 int ec = 0;
@@ -46,18 +49,18 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
                 else if (r->start_flag && r->stop_flag) {
                     ec = Request::kInvalid;
                 }
-                else if (!r->start_flag && !llama_->kv_cache_mgr_->contains(r->id)) {
+                else if (!r->start_flag && !sequence_manager_->Contains(r->id)) {
                     ec = Request::kInvalid;
                 }
 
                 if (ec) {
-                    invalidate(type, r, ec);
+                    reject(type, r, ec);
                 }
             }
         }
     };
 
-    auto drop_invalid = [](std::vector<std::shared_ptr<Request>>& rs) {
+    auto drop_invalid = [](Requests& rs) {
         int count = 0;
         for (int i = 0; i < rs.size(); ++i) {
             if (rs[i]) {
@@ -77,14 +80,14 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
         for (auto& r : stop_reqs) {
             if (r && r->end_flag == false) {
                 int ec = Request::kInactive;
-                for (int i = 0; i < batch_size_; ++i) {
-                    if (requests_[i] && requests_[i]->id == r->id) {
+                for (int i = 0; i < state_->size; ++i) {
+                    if (state_->requests[i] && state_->requests[i]->id == r->id) {
                         ec = 0;
                         break;
                     }
                 }
                 if (ec) {
-                    invalidate("stop", r, ec);
+                    reject("stop", r, ec);
                 }
             }
         }
@@ -98,9 +101,9 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
         // invalidate requests for busy sequences
         for (auto& r : infer_reqs) {
             if (r) {
-                for (int i = 0; i < batch_size_; ++i) {
-                    if (requests_[i] && requests_[i]->id == r->id) {
-                        invalidate("infer", r, Request::kBusy);
+                for (int i = 0; i < state_->size; ++i) {
+                    if (state_->requests[i] && state_->requests[i]->id == r->id) {
+                        reject("infer", r, Request::kBusy);
                         break;
                     }
                 }
@@ -112,31 +115,30 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
 }
 
 template<typename T>
-void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request>>& requests)
+void LlamaBatch<T>::ProcessStopRequests(const Requests& requests)
 {
     for (const auto& r : requests) {
         int ec = Request::kFail;
         // find matching active sequence
-        for (int i = 0; i < batch_size_; ++i) {
+        for (int i = 0; i < state_->size; ++i) {
             // stop & optionally erase active sequence
-            if (requests_[i] && requests_[i]->id == r->id) {
+            if (state_->requests[i] && state_->requests[i]->id == r->id) {
                 ec = 0;
-                finishRequest(i, r->end_flag);
+                FinishRequest(i, r->end_flag);
                 break;
             }
         }
-        // mismatch, try erase inactive sequence
+        // mismatch, try erase inactive sequence, in this case there is no active request to finish
         if (ec && r->end_flag) {
             ec = 0;
-            llama_->kv_cache_mgr_->erase(r->id);
+            sequence_manager_->Erase(r->id);
         }
         // clear output buffers (prevent leaking conversations) if request is successful
         if (ec == 0) {
             auto& output_ids      = r->outputs[rank_].at("output_ids");
             auto& sequence_length = r->outputs[rank_].at("sequence_length");
-            check_cuda_error(
-                cudaMemsetAsync(output_ids.getPtr<int>(), 0, sizeof(int) * output_ids.shape.at(2), stream_));
-            check_cuda_error(cudaMemsetAsync(sequence_length.getPtr<int>(), 0, sizeof(int), stream_));
+            Clear(output_ids.getPtr<int>(), output_ids.shape.at(2));
+            Clear(sequence_length.getPtr<int>(), 1);
             check_cuda_error(cudaStreamSynchronize(stream_));
         }
         if (rank_ == 0) {
@@ -146,13 +148,258 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request
 }
 
 template<typename T>
-void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
+void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
+{
+    auto& state = *incoming_;
+
+    state.size = state.active_size = 0;
+
+    int i = 0;
+    for (const auto& r : requests) {
+
+        // sanity check, incoming request in previous iter should have been moved to `state_`
+        FT_CHECK(state.sequences[i] == nullptr);
+
+        state.requests[i] = r;
+
+        // get sequence for the request
+        state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Fetch(r->id);
+
+        auto& seq = *state.sequences[i];
+
+        if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
+            /// TODO: revise step setting
+            if (step <= seq.tokens.size()) {
+                seq.tokens.resize(step);
+                seq.cache_len = std::min(seq.cache_len, step);
+            }
+            else if (rank_ == 0) {
+                TM_LOG_WARNING(
+                    "[ProcessInferRequests] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
+            }
+        }
+
+        const int  input_length = r->inputs[rank_].getVal<int>("input_lengths");
+        const int* input_ids    = r->inputs[rank_].getPtr<int>("input_ids");
+
+        // `output_ids` contains all token ids of the sequences
+        const auto output_ids_base = state.output_ids + session_len_ * i;
+        auto       output_ids      = output_ids_base;
+
+        // copy history tokens
+        if (!seq.tokens.empty()) {
+            output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);
+        }
+
+        // copy input tokens
+        if (input_length) {
+            output_ids = Copy(input_ids, input_length, output_ids);
+        }
+
+        // total context length (history + input)
+        state.h_context_length[i] = output_ids - output_ids_base;
+        state.h_finished[i]       = false;
+
+        const int request_output_len = state.requests[i]->inputs[rank_].getVal<int>("request_output_len");
+        state.seq_len_limit[i]       = state.h_context_length[i] + request_output_len;
+        // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
+        // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
+        if (state.seq_len_limit[i] >= session_len_) {
+            state.seq_len_limit[i] = session_len_ - 1;
+            if (rank_ == 0) {
+                const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i];
+                TM_LOG_WARNING(
+                    "[initialize] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
+                    (long)seq.id,
+                    state.h_context_length[i],
+                    request_output_len,
+                    (int)session_len_,
+                    trunc_output_len);
+            }
+        }
+
+        // recover random state HtoD if not a new sequence
+        if (!r->start_flag) {
+            Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state);
+            Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state);
+        }
+
+        // assign priority based on arrival time
+        r->priority = request_count_++;
+
+        // increment pointer
+        i++;
+    }
+
+    incoming_->size = i;
+}
+
+template<typename T>
+bool LlamaBatch<T>::Initialize()
+{
+    std::vector<const Sequence*>             sequences;
+    std::vector<Sequence::Status>            status;
+    std::vector<uint64_t>                    priorities;
+    std::vector<int>                         context_lengths;
+    std::vector<std::pair<BatchState*, int>> coords;
+
+    // count the holes introduced by finished requests in from previous iteration or stop requests from
+    // current iteration
+    int holes{};
+    int active_holes{};
+    for (int i = 0; i < state_->size; ++i) {
+        if (!state_->requests[i]) {
+            ++holes;
+            if (i < state_->active_size) {
+                ++active_holes;
+            }
+        }
+    }
+
+    auto add = [&](BatchState* state) {
+        for (int i = 0; i < state->size; ++i) {
+            if (auto& r = state->requests[i]) {
+                sequences.push_back(state->sequences[i]);
+                status.push_back(state->sequences[i]->status);
+                priorities.push_back(r->priority);
+                coords.emplace_back(state, i);
+            }
+        }
+    };
+
+    add(state_);
+    add(incoming_);
+
+    bool modified = sequence_manager_->Materialize(sequences, context_lengths, priorities, llama_->step_length_);
+
+    // no swap-in/swap-out & no holes in the buffers & no new requests -> nothing changed
+    if (!modified && !holes && !incoming_->size) {
+        return false;
+    }
+
+    std::vector<int> idxs(sequences.size());
+    std::iota(idxs.begin(), idxs.end(), 0);
+
+    if (modified) {
+        // put active ones first
+        auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
+            return sequences[idx]->status == Sequence::kActive;  // present status
+        });
+
+        // move swap-ins to the back
+        auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) {
+            return status[idx] == Sequence::kActive;  // past status
+        });
+
+        // sort swap-ins according to missing length
+        if (swapin_beg != active_end) {
+            std::vector<int> missing_len(sequences.size());
+            for (int i = 0; i < sequences.size(); ++i) {
+                missing_len[i] = (int)sequences[i]->tokens.size() - sequences[i]->cache_len;
+            }
+            std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; });
+        }
+    }
+
+    // Copy sequence states to the back state buffer
+    back_->size = back_->active_size = 0;
+    for (const auto& i : idxs) {
+        auto& s = *sequences[i];
+        if (modified) {
+            // backup random states from dynamic decode layers for swap-outs
+            if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
+                SaveRandomState(*coords[i].first, coords[i].second);
+            }
+            // restore random states to dynamic decode layers for swap-ins
+            if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
+                LoadRandomState(*coords[i].first, coords[i].second);
+            }
+        }
+        if (s.status == Sequence::kActive) {
+            ++back_->active_size;
+        }
+        CopyState(coords[i], {back_, back_->size++});
+    }
+    // Swap the buffers
+    std::swap(state_, back_);
+
+    const int batch_size = state_->active_size;
+
+    // Prepare intermediate buffers
+    h_cu_block_counts_[0] = 0;
+
+    auto k_ptrs = h_k_block_ptrs_;
+    auto v_ptrs = h_v_block_ptrs_;
+
+    for (int i = 0; i < batch_size; ++i) {
+        const auto& seq = *state_->sequences[i];
+
+        // cumulative num of blocks
+        h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
+
+        k_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), k_ptrs, [&](const Block* p) {
+            return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
+        });
+        v_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), v_ptrs, [&](auto p) {
+            return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
+        });
+    }
+
+    Copy(state_->h_context_length, batch_size, context_length_buf_);
+
+    Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
+    Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
+    Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
+
+    // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
+    // generation & sampling need to be re-initialized for correctness
+    return modified || active_holes;
+}
+
+template<typename T>
+void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst)
+{
+    const auto& [src, i] = _src;
+    const auto& [dst, j] = _dst;
+
+    FT_CHECK((bool)src->requests[i]);
+    FT_CHECK(!(bool)dst->requests[j]);
+
+    dst->h_context_length[j] = src->h_context_length[i];
+    dst->h_finished[j]       = src->h_finished[i];
+    dst->seq_len_limit[j]    = src->seq_len_limit[i];
+    dst->sequences[j]        = src->sequences[i];
+    dst->requests[j]         = std::move(src->requests[i]);
+
+    Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_);
+
+    Copy((curandState_t*)src->top_k_curand_state + i, 1, (curandState_t*)dst->top_k_curand_state + j);
+    Copy((curandState_t*)src->top_p_curand_state + i, 1, (curandState_t*)dst->top_p_curand_state + j);
+}
+
+template<typename T>
+void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx)
+{
+    Copy(llama_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
+    Copy(llama_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
+}
+
+template<typename T>
+void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
+{
+    Copy((curandState_t*)state.top_k_curand_state + idx, 1, llama_->GetTopKState(idx));
+    Copy((curandState_t*)state.top_p_curand_state + idx, 1, llama_->GetTopPState(idx));
+}
+
+template<typename T>
+void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     const size_t batchxbeam = batch_size;
 
-    const size_t hidden_units = llama_->hidden_units_;
-    const size_t vocab_size   = llama_->vocab_size_padded_;
+    const size_t hidden_units    = llama_->hidden_units_;
+    const size_t vocab_size      = llama_->vocab_size_padded_;
+    const size_t max_block_count = sequence_manager_->max_block_count();
 
     context_decoder_input_buf_ =
         (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
@@ -169,11 +416,11 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
     history_length_buf_ = (int*)allocator_->reMalloc(history_length_buf_, sizeof(int) * batchxbeam);
     context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
 
-    total_padding_count_ = (int*)allocator_->reMalloc(total_padding_count_, sizeof(int) * batchxbeam, false);
-    sequence_lengths_    = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
+    sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
 
-    k_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(k_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam);
-    v_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(v_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam);
+    cu_block_counts_ = (int*)allocator_->reMalloc(cu_block_counts_, sizeof(int) * (batch_size + 1));
+    k_block_ptrs_    = (uintptr_t*)allocator_->reMalloc(k_block_ptrs_, sizeof(uintptr_t) * max_block_count);
+    v_block_ptrs_    = (uintptr_t*)allocator_->reMalloc(v_block_ptrs_, sizeof(uintptr_t) * max_block_count);
 
     logits_buf_       = (float*)allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
     local_logits_buf_ = (float*)allocator_->reMalloc(local_logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
@@ -188,10 +435,8 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
 }
 
 template<typename T>
-void LlamaBatch<T>::allocatePersistantBuffer(size_t max_batch_size)
+void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
 {
-    output_ids_buf_ = (int*)allocator_->reMalloc(output_ids_buf_, sizeof(int) * max_batch_size * session_len_, true);
-
     stop_words_buf_ =
         (int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
     bad_words_buf_ =
@@ -212,8 +457,13 @@ void LlamaBatch<T>::allocatePersistantBuffer(size_t max_batch_size)
                         {"repetition_penalty", h_repetition_penalty_},
                         {"random_seed", h_random_seed_}};
 
-    topk_curandstate_buf_ = allocator_->reMalloc(topk_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true);
-    topp_curandstate_buf_ = allocator_->reMalloc(topp_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true);
+    for (auto& s : states_) {
+        s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
+        s.top_k_curand_state = allocator_->reMalloc(s.top_k_curand_state, sizeof(curandState_t) * max_batch_size, true);
+        s.top_p_curand_state = allocator_->reMalloc(s.top_p_curand_state, sizeof(curandState_t) * max_batch_size, true);
+    }
+
+    const size_t max_block_count = sequence_manager_->max_block_count();
 
     {
         NcclGuard barrier(llama_->tensor_para_, stream_, true);
@@ -223,15 +473,20 @@ void LlamaBatch<T>::allocatePersistantBuffer(size_t max_batch_size)
             (int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
         h_history_length_buf_ =
             (int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_context_length_buf_ =
-            (int*)allocator_->reMalloc(h_context_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_sequence_lengths_ =
-            (int*)allocator_->reMalloc(h_sequence_lengths_, sizeof(int) * max_batch_size, false, true);
-        h_k_cache_ptr_buf_ =
-            (uintptr_t*)allocator_->reMalloc(h_k_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true);
-        h_v_cache_ptr_buf_ =
-            (uintptr_t*)allocator_->reMalloc(h_v_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true);
-        h_finished_buf_ = (bool*)allocator_->reMalloc(h_finished_buf_, sizeof(bool) * max_batch_size, false, true);
+
+        h_cu_block_counts_ =
+            (int*)allocator_->reMalloc(h_cu_block_counts_, sizeof(int) * (max_batch_size + 1), false, true);
+        h_k_block_ptrs_ =
+            (uintptr_t*)allocator_->reMalloc(h_k_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
+        h_v_block_ptrs_ =
+            (uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
+
+        for (auto& s : states_) {
+            s.h_context_length =
+                (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
+            s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
+        }
+
         h_seq_limit_len_ =
             (uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
     }
@@ -240,7 +495,7 @@ void LlamaBatch<T>::allocatePersistantBuffer(size_t max_batch_size)
 }
 
 template<typename T>
-void LlamaBatch<T>::freeBuffer()
+void LlamaBatch<T>::FreeBuffer()
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     if (is_allocate_buffer_) {
@@ -256,11 +511,11 @@ void LlamaBatch<T>::freeBuffer()
         allocator_->free((void**)&history_length_buf_);
         allocator_->free((void**)&context_length_buf_);
 
-        allocator_->free((void**)&total_padding_count_);
         allocator_->free((void**)&sequence_lengths_);
 
-        allocator_->free((void**)&k_cache_ptr_buf_);
-        allocator_->free((void**)&v_cache_ptr_buf_);
+        allocator_->free((void**)&cu_block_counts_);
+        allocator_->free((void**)&k_block_ptrs_);
+        allocator_->free((void**)&v_block_ptrs_);
 
         allocator_->free((void**)&logits_buf_);
         allocator_->free((void**)&local_logits_buf_);
@@ -282,29 +537,35 @@ void LlamaBatch<T>::freeBuffer()
     }
 
     if (is_allocate_persistant_buffer_) {
+        for (auto& s : states_) {
+            allocator_->free((void**)&s.h_context_length, true);
+            allocator_->free((void**)&s.h_finished, true);
+            allocator_->free((void**)&s.output_ids);
+        }
+        allocator_->free((void**)&h_cu_block_counts_, true);
+        allocator_->free((void**)&h_k_block_ptrs_, true);
+        allocator_->free((void**)&h_v_block_ptrs_, true);
         allocator_->free((void**)&h_input_ids_buf_, true);
         allocator_->free((void**)&h_input_length_buf_, true);
         allocator_->free((void**)&h_history_length_buf_, true);
-        allocator_->free((void**)&h_context_length_buf_, true);
-        allocator_->free((void**)&h_sequence_lengths_, true);
-        allocator_->free((void**)&h_k_cache_ptr_buf_, true);
-        allocator_->free((void**)&h_v_cache_ptr_buf_, true);
         allocator_->free((void**)&h_seq_limit_len_, true);
-        allocator_->free((void**)&h_finished_buf_, true);
-
-        allocator_->free((void**)&output_ids_buf_);
-
         is_allocate_persistant_buffer_ = false;
     }
 }
 
 template<typename T>
-LlamaBatch<T>::LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama):
+LlamaBatch<T>::LlamaBatch(int                              max_batch_size,
+                          int                              max_context_token_num,
+                          int                              session_len,
+                          std::unique_ptr<SequenceManager> sequence_manager,
+                          LlamaV2<T>*                      llama):
     max_batch_size_(max_batch_size),
     max_context_token_num_(max_context_token_num),
     session_len_(session_len),
     rank_(llama->tensor_para_.rank_),
     debug_(llama->debug_),
+    step_length_(llama->step_length_),
+    sequence_manager_(std::move(sequence_manager)),
     llama_(llama),
     data_type_(getTensorType<T>())
 {
@@ -312,41 +573,47 @@ LlamaBatch<T>::LlamaBatch(int max_batch_size, int max_context_token_num, int ses
     allocator_      = llama_->allocator_;
     cublas_wrapper_ = llama_->cublas_wrapper_;
 
-    requests_.resize(max_batch_size);
-    request_seq_len_limit_.resize(max_batch_size);
-    cached_seq_.resize(max_batch_size);
+    for (auto& s : states_) {
+        s.requests.resize(max_batch_size);
+        s.sequences.resize(max_batch_size);
+        s.seq_len_limit.resize(max_batch_size);
+    }
 
-    allocatePersistantBuffer(max_batch_size);
+    state_    = &states_[0];
+    back_     = &states_[1];
+    incoming_ = &states_[2];
+
+    AllocateBuffer(max_batch_size, session_len_);
+    AllocatePersistantBuffer(max_batch_size);
 }
 
 template<typename T>
-void LlamaBatch<T>::initializeSampling(int infer_request_count)
+void LlamaBatch<T>::InitializeSampling()
 {
+    const int batch_size = state_->size;
     TensorMap inputs;
     for (const auto& param : sampling_params_) {
+        // find an exemplar that matches the param name
         const Tensor* ptr{};
-        for (int i = 0; i < batch_size_; ++i) {
-            if (requests_[i]->inputs[rank_].isExist(param.first)) {
-                ptr = &requests_[i]->inputs[rank_].at(param.first);
+        for (int i = 0; i < batch_size; ++i) {
+            if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
+                ptr = &state_->requests[i]->inputs[rank_].at(param.first);
                 break;
             }
         }
+        // fill the batch of the param
         if (ptr) {
             const auto& ref   = *ptr;
             auto        shape = ref.shape;
             FT_CHECK(shape[0] == 1);
-            shape[0]                = batch_size_;
+            shape[0]                = batch_size;
             const int size_in_bytes = ref.sizeBytes();
-            check_cuda_error(cudaMemsetAsync(param.second, 0, size_in_bytes * batch_size_, stream_));
-            for (int i = 0; i < batch_size_; ++i) {
-                if (requests_[i]->inputs[rank_].isExist(param.first)) {
-                    auto& src = requests_[i]->inputs[rank_].at(param.first);
+            Clear((std::byte*)param.second, size_in_bytes * batch_size);
+            for (int i = 0; i < batch_size; ++i) {
+                if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
+                    auto& src = state_->requests[i]->inputs[rank_].at(param.first);
                     FT_CHECK(ref.shape == src.shape);
-                    check_cuda_error(cudaMemcpyAsync((uint8_t*)param.second + size_in_bytes * i,
-                                                     src.getPtr<void>(),
-                                                     size_in_bytes,
-                                                     cudaMemcpyDefault,
-                                                     stream_));
+                    Copy(src.getPtr<std::byte>(), size_in_bytes, (std::byte*)param.second + size_in_bytes * i);
                 }
             }
             inputs.insert({param.first, {ref.where, ref.type, shape, param.second}});
@@ -358,35 +625,28 @@ void LlamaBatch<T>::initializeSampling(int infer_request_count)
 
     inputs_ = std::move(inputs);
 
-    llama_->dynamic_decode_layer_->setup(batch_size_, 1, &inputs_);
-
-    for (int i = 0; i < batch_size_; ++i) {
-        // recover random states if not a new request or new request w/o "random_seed"
-        if (i < batch_size_ - infer_request_count || !requests_[i]->inputs[rank_].isExist("random_seed")) {
-            check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topk_curandstate_buf() + i,
-                                             (curandState_t*)topk_curandstate_buf_ + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topp_curandstate_buf() + i,
-                                             (curandState_t*)topp_curandstate_buf_ + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
+    llama_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
+
+    // recover random states if not a new request
+    for (int i = 0; i < batch_size; ++i) {
+        if (!state_->requests[i]->start_flag) {
+            LoadRandomState(*state_, i);
         }
     }
 
-    handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
+    handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size);
     cudaStreamSynchronize(0);
 }
 
 template<typename T>
-void LlamaBatch<T>::initializeGeneration()
+void LlamaBatch<T>::InitializeGeneration()
 {
-    max_context_len_ = *std::max_element(h_context_length_buf_, h_context_length_buf_ + batch_size_);
+    const int batch_size = state_->size;
+
+    max_context_len_ = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
 
-    check_cuda_error(cudaMemsetAsync(token_ids_buf_, 0, sizeof(int) * batch_size_ * session_len_ * 2, stream_));
-    invokeTransposeAxis01(token_ids_buf_, output_ids_buf_, batch_size_, session_len_, 1, stream_);
+    Clear(token_ids_buf_, batch_size * session_len_);
+    invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
     sync_check_cuda_error();
 
     // token_ids_buf_[s, b]
@@ -395,56 +655,31 @@ void LlamaBatch<T>::initializeGeneration()
     // ABCDEFGHi    ->  ABCDEFGHi i
     // ABCDEFGh         ABCDEFGh  h
     // ABCd             ABCd      d
-    for (int i = 0; i < batch_size_; ++i) {
+    for (int i = 0; i < batch_size; ++i) {
         auto token_ids = token_ids_buf_ + i;
-        auto p_src     = h_context_length_buf_[i] - 1;
+        auto p_src     = state_->h_context_length[i] - 1;
         auto p_dst     = max_context_len_ - 1;
         if (p_src != p_dst) {  // dst and src of `cudaMemcpyAsync` must not overlap
-            check_cuda_error(cudaMemcpyAsync(token_ids + p_dst * batch_size_,
-                                             token_ids + p_src * batch_size_,
-                                             sizeof(int),
-                                             cudaMemcpyDefault,
-                                             stream_));
+            Copy(token_ids + p_src * batch_size, 1, token_ids + p_dst * batch_size);
         }
     }
 
-    check_cuda_error(cudaMemcpyAsync(
-        context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-
-    check_cuda_error(
-        cudaMemcpyAsync(sequence_lengths_, context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
+    Copy(context_length_buf_, batch_size, sequence_lengths_);
     // `sequence_lengths_` will be increased by dynamic decode
     // note that in decoder and in output "sequence length" has different semantic
     // - in decoder it means length of sequence that has kv cache already computed
     // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
-    invokePlusScalar(sequence_lengths_, -1, batch_size_, stream_);
-    sync_check_cuda_error();
-
-    // total_padding_count_
-    // decoding starts at max_context_len
-    check_cuda_error(cudaMemsetAsync(total_padding_count_, 0, sizeof(int) * batch_size_, stream_));
-    invokeUpdatePaddingCount(total_padding_count_,  //
-                             context_length_buf_,
-                             max_context_len_,
-                             batch_size_,
-                             1,
-                             stream_);
+    invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
     sync_check_cuda_error();
 
     // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for
-    for (int i = 0; i < batch_size_; ++i) {
-        h_seq_limit_len_[i] = request_seq_len_limit_[i] + (max_context_len_ - h_context_length_buf_[i]);
+    for (int i = 0; i < batch_size; ++i) {
+        h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len_ - state_->h_context_length[i]);
         // mask finished sequences
-        h_finished_buf_[i] = max_context_len_ >= h_seq_limit_len_[i];
+        state_->h_finished[i] = max_context_len_ >= h_seq_limit_len_[i];
     }
-    check_cuda_error(
-        cudaMemcpyAsync(seq_limit_len_, h_seq_limit_len_, sizeof(uint32_t) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(
-        cudaMemcpyAsync(finished_buf_, h_finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_));
+    Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
+    Copy(state_->h_finished, batch_size, finished_buf_);
 
     // ! range of step_ [1, 2 * session_len]
     // consider a sequence with context_len == session_len and another sequence with context_len == 1 and
@@ -452,24 +687,26 @@ void LlamaBatch<T>::initializeGeneration()
     step_ = max_context_len_;
 
     if (rank_ == 0) {
-        TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size_);
+        TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
         TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len_);
 
         TM_LOG_INFO("[initGen] slot  sequence_id  context_len  seq_limit_len  finished");
-        for (int i = 0; i < batch_size_; ++i) {
+        for (int i = 0; i < batch_size; ++i) {
             TM_LOG_INFO("[initGen] %4d  %11ld  %11d  %13d  %8d",
                         i,
-                        (long)cached_seq_[i].id,
-                        h_context_length_buf_[i],
+                        (long)state_->sequences[i]->id,
+                        state_->h_context_length[i],
                         (int)h_seq_limit_len_[i],
-                        (int)h_finished_buf_[i]);
+                        (int)state_->h_finished[i]);
         }
     }
 }
 
 template<typename T>
-bool LlamaBatch<T>::generate()
+bool LlamaBatch<T>::Generate()
 {
+    const int batch_size = state_->active_size;
+
     constexpr int kLogInterval = 10;
     if (rank_ == 0 && (step_ - 1) % kLogInterval == 0) {
         TM_LOG_INFO("------------------------- step = %d -------------------------", step_ - 1);
@@ -479,36 +716,32 @@ bool LlamaBatch<T>::generate()
 
     std::vector<int> prev;
     if (debug_ && rank_ == 0 && is_first_step) {
-        prev.resize(batch_size_);
-        cudaMemcpyAsync(prev.data(),
-                        token_ids_buf_ + (step_ - 1) * batch_size_,
-                        sizeof(int) * batch_size_,
-                        cudaMemcpyDefault,
-                        stream_);
+        prev.resize(batch_size);
+        Copy(token_ids_buf_ + (step_ - 1) * batch_size, batch_size, prev.data());
     }
 
     // embeddingLookup(step_ - 1);
     llama_->embeddingLookup(decoder_input_buf_,  //
                             token_ids_buf_,
-                            batch_size_,
+                            batch_size,
                             step_ - 1);
 
     llama_->decoderForward(decoder_output_buf_,
-                           k_cache_ptr_buf_,
-                           v_cache_ptr_buf_,
+                           k_block_ptrs_,
+                           v_block_ptrs_,
                            decoder_input_buf_,
                            sequence_lengths_,
-                           total_padding_count_,
                            finished_buf_,
+                           cu_block_counts_,
                            step_,
                            0,
                            session_len_,
-                           batch_size_);
+                           batch_size);
 
     llama_->postDecodeEmbedding(logits_buf_,  //
                                 local_logits_buf_,
                                 decoder_output_buf_,
-                                batch_size_);
+                                batch_size);
 
     // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
     // not supported yet.
@@ -527,13 +760,12 @@ bool LlamaBatch<T>::generate()
                           0,
                           max_context_len_,
                           session_len_ * 2,
-                          batch_size_);
+                          batch_size);
 
     if (debug_ && rank_ == 0) {
-        std::vector<int> curr(batch_size_);
+        std::vector<int> curr(batch_size);
 
-        cudaMemcpyAsync(
-            curr.data(), token_ids_buf_ + step_ * batch_size_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_);
+        Copy(token_ids_buf_ + step_ * batch_size, batch_size, curr.data());
         cudaStreamSynchronize(stream_);
 
         if (is_first_step) {
@@ -559,324 +791,131 @@ bool LlamaBatch<T>::generate()
 }
 
 template<typename T>
-void LlamaBatch<T>::initialize(const std::vector<std::shared_ptr<Request>>& infer_requests)
+void LlamaBatch<T>::ContextDecode()
 {
-    FT_CHECK(batch_size_ + infer_requests.size() <= max_batch_size_);
-
-    const int infer_request_count = infer_requests.size();
-
-    allocateBuffer(batch_size_ + infer_request_count, session_len_);
-
-    // handle infer requests
-    std::vector<int>       tmp_input_length(infer_request_count);
-    std::vector<CachedSeq> tmp_cached_seq;
-    tmp_cached_seq.reserve(infer_request_count);
-
-    int tmp_max_input_length = 0;
-    for (int i = 0; i < infer_request_count; ++i) {
-        auto& r = *infer_requests[i];
-
-        LlamaCacheManager::Sequence seq{};
-        if (r.start_flag) {
-            seq = llama_->kv_cache_mgr_->create(r.id, stream_);
-        }
-        else {
-            seq = llama_->kv_cache_mgr_->fetch(r.id, stream_);
-        }
-
-        const int step = r.inputs[rank_].getVal<int>("step", -1);
-        if (step >= 0) {
-            if (step <= seq.token_ids.size()) {
-                seq.token_ids.resize(step);
-                seq.cache_len = std::min(seq.cache_len, (size_t)step);
-            }
-            else if (rank_ == 0) {
-                TM_LOG_WARNING("[initialize] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
-            }
-        }
+    const auto batch_size = state_->active_size;
 
-        // input length with missing cache accounted for
-        int actual_input_len = r.inputs[rank_].getVal<int>("input_lengths") + (seq.token_ids.size() - seq.cache_len);
-
-        // insert `start_id` for empty sequences
-        if (seq.token_ids.empty() && actual_input_len == 0) {
-            seq.token_ids.push_back(llama_->start_id_);
-            seq.cache_len    = 0;
-            actual_input_len = seq.token_ids.size() - seq.cache_len;
+    int base = -1;
+    for (int i = 0; i < batch_size; ++i) {
+        if (h_input_length_buf_[i] > 1) {
+            base = i;
+            break;
         }
-
-        tmp_input_length[i] = actual_input_len;
-
-        tmp_max_input_length = std::max((int)tmp_max_input_length, actual_input_len);
-        tmp_cached_seq.push_back(std::move(seq));
     }
-
-    FT_CHECK(tmp_max_input_length > 0);
-    const int max_input_length = tmp_max_input_length;
-
-    // arrange requests in ascending order w.r.t actual input lengths, so that requests need context decoding will
-    // be together
-    {
-        std::vector<int> idxs(tmp_input_length.size());
-        std::iota(idxs.begin(), idxs.end(), 0);
-        std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return tmp_input_length[i] < tmp_input_length[j]; });
-        for (int i = 0; i < idxs.size(); ++i) {
-            requests_[batch_size_ + i]   = infer_requests[idxs[i]];
-            cached_seq_[batch_size_ + i] = tmp_cached_seq[idxs[i]];
-        }
+    if (base == -1) {
+        TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
+        return;
     }
 
-    const int count = batch_size_ + infer_requests.size();
-
-    std::vector<int> tmp_input_len(count);
-
-    for (int i = batch_size_; i < count; ++i) {
-        const auto& seq = cached_seq_[i];
-
-        h_input_length_buf_[i] = requests_[i]->inputs[rank_].getVal<int>("input_lengths");
-        tmp_input_len[i]       = h_input_length_buf_[i];
-        // prepare output ids
-        // <--------> max_context_len
-        // aaaAAAA
-        // bbbbBBBBBB
-        // ccCCC
-        auto output_ids_ptr = output_ids_buf_ + i * session_len_;
-
-        // clear the persistent buffer to prevent leaking previous conversation
-        check_cuda_error(cudaMemsetAsync(output_ids_ptr, 0, sizeof(int) * session_len_, stream_));
+    for (int i = base; i < batch_size; ++i) {
+        const auto& seq     = *state_->sequences[i];
+        const int   missing = state_->h_context_length[i] - seq.cache_len;
+        FT_CHECK(missing > 1);
+        Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
+        h_input_length_buf_[i]   = missing;
+        h_history_length_buf_[i] = seq.cache_len;
+    }
 
-        if (!seq.token_ids.empty()) {
-            check_cuda_error(cudaMemcpyAsync(output_ids_ptr,  //
-                                             seq.token_ids.data(),
-                                             sizeof(int) * seq.token_ids.size(),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            output_ids_ptr += seq.token_ids.size();
-        }
+    Copy(h_input_length_buf_, batch_size, input_length_buf_);
+    Copy(h_history_length_buf_, batch_size, history_length_buf_);
 
-        if (h_input_length_buf_[i]) {
-            auto input_ids_ptr = requests_[i]->inputs[rank_].getPtr<int>("input_ids");
-            check_cuda_error(cudaMemcpyAsync(output_ids_ptr,  //
-                                             input_ids_ptr,
-                                             sizeof(int) * h_input_length_buf_[i],
-                                             cudaMemcpyDefault,
-                                             stream_));
-        }
+    check_cuda_error(cudaStreamSynchronize(stream_));
+    const auto tick = std::chrono::high_resolution_clock::now();
 
-        if (!requests_[i]->start_flag && !seq.random_state_.empty()) {
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + i,
-                                             seq.random_state_.data(),
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + i,
-                                             seq.random_state_.data() + sizeof(curandState_t),
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-        }
+    const int context_decode_count = batch_size - base;
+    if (rank_ == 0) {
+        TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
     }
-
-    for (int i = batch_size_; i < count; ++i) {
-        const auto& seq           = cached_seq_[i];
-        const int   missed        = (int)seq.token_ids.size() - seq.cache_len;
-        auto        input_ids_buf = input_ids_buf_ + i * session_len_;
-        FT_CHECK(missed >= 0);
-        if (missed > 0) {
-            check_cuda_error(cudaMemcpyAsync(input_ids_buf,  //
-                                             seq.token_ids.data() + seq.cache_len,
-                                             sizeof(int) * missed,
-                                             cudaMemcpyDefault,
-                                             stream_));
-            input_ids_buf += missed;
-        }
-        auto& input_ids = requests_[i]->inputs[rank_].at("input_ids");
-        check_cuda_error(cudaMemcpyAsync(input_ids_buf,  //
-                                         input_ids.getPtr<int>(),
-                                         sizeof(int) * h_input_length_buf_[i],
-                                         cudaMemcpyDefault,
-                                         stream_));
-        h_input_length_buf_[i] += missed;
-        h_history_length_buf_[i] = seq.cache_len;
-        h_context_length_buf_[i] = h_input_length_buf_[i] + h_history_length_buf_[i];
-
-        const int request_output_len = requests_[i]->inputs[rank_].getVal<int>("request_output_len");
-        request_seq_len_limit_[i]    = h_context_length_buf_[i] + request_output_len;
-        // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
-        // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
-        if (request_seq_len_limit_[i] >= session_len_) {
-            request_seq_len_limit_[i] = session_len_ - 1;
+    invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_);
+    invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
+
+    auto get_input_len   = [this](int index) { return h_input_length_buf_[index] - 1; };
+    auto get_context_len = [this](int index) { return state_->h_context_length[index] - 1; };
+
+    std::vector<int> decode_indices{base};
+    std::vector<int> decode_lengths{get_input_len(base)};
+
+    auto token_num       = get_input_len(base);
+    auto max_input_len   = get_input_len(base);
+    auto max_context_len = get_context_len(base);
+    auto offset          = base;
+    for (int i = offset + 1; i <= batch_size; ++i) {
+        if (i == batch_size || token_num + state_->h_context_length[i] > max_context_token_num_) {
+            const int context_decode_batch_size = i - offset;
             if (rank_ == 0) {
-                const int trunc_output_len = request_seq_len_limit_[i] - h_context_length_buf_[i];
-                TM_LOG_WARNING(
-                    "[initialize] [%ld] total sequence length (%d + %d) exceeds session_len (%d), request_output_len is truncated to %d",
-                    (long)seq.id,
-                    h_context_length_buf_[i],
-                    request_output_len,
-                    (int)session_len_,
-                    trunc_output_len);
+                TM_LOG_INFO(
+                    "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
+                    base,
+                    context_decode_batch_size,
+                    token_num,
+                    max_input_len,
+                    max_context_len);
             }
-        }
-
-        h_k_cache_ptr_buf_[i] = (uint64_t)seq.k_cache;
-        h_v_cache_ptr_buf_[i] = (uint64_t)seq.v_cache;
-    }
-
-    const int max_context_len = *std::max_element(h_context_length_buf_ + batch_size_, h_context_length_buf_ + count);
-
-    batch_size_      = count;
-    max_context_len_ = max_context_len;
-    step_            = max_context_len;
-
-    check_cuda_error(
-        cudaMemcpyAsync(input_length_buf_, h_input_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        history_length_buf_, h_history_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(cudaMemcpyAsync(
-        v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
-
-    if (llama_->tensor_para_.rank_ == 0) {
-        TM_LOG_INFO("[init] infer_request_count = %d", (int)infer_request_count);
-        TM_LOG_INFO("[init] batch_size = %d", (int)batch_size_);
-        TM_LOG_INFO("[init] session_len = %d", (int)session_len_);
-        TM_LOG_INFO("[init] max_input_length = %d", (int)max_input_length);
-        TM_LOG_INFO("[init] max_context_len = %d", (int)max_context_len);
-        TM_LOG_INFO(
-            "[init] slot  sequence_id  history_len  input_len  context_len  tmp_input_len  token_ids.size  cache_len");
-        for (int i = batch_size_ - infer_request_count; i < batch_size_; ++i) {
-            TM_LOG_INFO("[init] %4d  %11ld  %11d  %9d  %11d  %13d  %14d  %9d",
-                        i,
-                        (int)cached_seq_[i].id,
-                        h_history_length_buf_[i],
-                        h_input_length_buf_[i],
-                        h_context_length_buf_[i],
-                        tmp_input_len[i],
-                        (int)cached_seq_[i].token_ids.size(),
-                        (int)cached_seq_[i].cache_len);
-        }
-    }
-}
-
-template<typename T>
-void LlamaBatch<T>::contextDecode()
-{
-    int base = -1;
-    for (int i = 0; i < batch_size_; ++i) {
-        if (h_input_length_buf_[i] > 1) {
-            base = i;
-            break;
-        }
-    }
-    if (base >= 0) {
-        check_cuda_error(cudaStreamSynchronize(stream_));
-        const auto tick = std::chrono::high_resolution_clock::now();
-
-        const int context_decode_count = batch_size_ - base;
-        if (rank_ == 0) {
-            TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
-        }
-        invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_);
-        invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
-
-        auto get_input_len   = [this](int index) { return h_input_length_buf_[index] - 1; };
-        auto get_context_len = [this](int index) { return h_context_length_buf_[index] - 1; };
-
-        std::vector<int> decode_indices{base};
-        std::vector<int> decode_lengths{get_input_len(base)};
-
-        auto token_num       = get_input_len(base);
-        auto max_input_len   = get_input_len(base);
-        auto max_context_len = get_context_len(base);
-        auto offset          = base;
-        for (int i = offset + 1; i <= batch_size_; ++i) {
-            if (i == batch_size_ || token_num + h_context_length_buf_[i] > max_context_token_num_) {
-                const int context_decode_batch_size = i - offset;
-                if (rank_ == 0) {
-                    TM_LOG_INFO(
-                        "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
-                        base,
-                        context_decode_batch_size,
-                        token_num,
-                        max_input_len,
-                        max_context_len);
-                }
-                // construct context_decoder_ids w/o padding
-                // aaaa____
-                // bb______ -> aaaabbcccccccc
-                // cccccccc
-                auto context_decoder_ids = context_decoder_ids_buf_;
-                for (int j = offset; j < i; ++j) {
-                    check_cuda_error(cudaMemcpyAsync(context_decoder_ids,
-                                                     input_ids_buf_ + j * session_len_,
-                                                     sizeof(int) * get_input_len(j),
-                                                     cudaMemcpyDefault,
-                                                     stream_));
-                    context_decoder_ids += get_input_len(j);
-                }
-                llama_->contextDecode(nullptr,
-                                      k_cache_ptr_buf_ + offset,
-                                      v_cache_ptr_buf_ + offset,
-                                      context_decoder_input_buf_,
-                                      context_decoder_output_buf_,
-                                      context_decoder_ids_buf_,
-                                      input_length_buf_ + offset,
-                                      history_length_buf_ + offset,
-                                      context_length_buf_ + offset,
-                                      token_num,
-                                      max_input_len,
-                                      max_context_len,
-                                      session_len_,
-                                      context_decode_batch_size);
-
-                // compute logits of inputs if requested
-                outputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
-
-                if (i < batch_size_) {
-                    // initialize next sub-batch
-                    token_num       = get_input_len(i);
-                    max_input_len   = get_input_len(i);
-                    max_context_len = get_context_len(i);
-                    offset          = i;
-
-                    decode_indices = {i};
-                    decode_lengths = {get_input_len(i)};
-                }
+            // construct context_decoder_ids w/o padding
+            // aaaa____
+            // bb______ -> aaaabbcccccccc
+            // cccccccc
+            auto context_decoder_ids = context_decoder_ids_buf_;
+            for (int j = offset; j < i; ++j) {
+                context_decoder_ids = Copy(input_ids_buf_ + j * session_len_, get_input_len(j), context_decoder_ids);
             }
-            else {
-                // add to current sub-batch
-                token_num += get_input_len(i);
-                max_input_len   = std::max(max_input_len, get_input_len(i));
-                max_context_len = std::max(max_context_len, get_context_len(i));
-
-                decode_indices.push_back(i);
-                decode_lengths.push_back(get_input_len(i));
+            llama_->contextDecode(nullptr,
+                                  k_block_ptrs_,
+                                  v_block_ptrs_,
+                                  context_decoder_input_buf_,
+                                  context_decoder_output_buf_,
+                                  context_decoder_ids_buf_,
+                                  input_length_buf_ + offset,
+                                  history_length_buf_ + offset,
+                                  context_length_buf_ + offset,
+                                  cu_block_counts_ + offset,
+                                  token_num,
+                                  max_input_len,
+                                  max_context_len,
+                                  session_len_,
+                                  context_decode_batch_size);
+
+            // compute logits of inputs if requested
+            OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
+
+            if (i < batch_size) {
+                // initialize next sub-batch
+                token_num       = get_input_len(i);
+                max_input_len   = get_input_len(i);
+                max_context_len = get_context_len(i);
+                offset          = i;
+
+                decode_indices = {i};
+                decode_lengths = {get_input_len(i)};
             }
         }
+        else {
+            // add to current sub-batch
+            token_num += get_input_len(i);
+            max_input_len   = std::max(max_input_len, get_input_len(i));
+            max_context_len = std::max(max_context_len, get_context_len(i));
 
-        invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
-        invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_);
-
-        for (int i = offset; i < batch_size_; ++i) {
-            h_input_length_buf_[i] = 0;
+            decode_indices.push_back(i);
+            decode_lengths.push_back(get_input_len(i));
         }
+    }
 
-        check_cuda_error(cudaStreamSynchronize(stream_));
-        const auto tock = std::chrono::high_resolution_clock::now();
-        if (rank_ == 0) {
-            TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
-        }
+    invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
+    invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_);
+
+    for (int i = offset; i < batch_size; ++i) {
+        h_input_length_buf_[i] = 0;
     }
-    else if (rank_ == 0) {
-        TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
+
+    check_cuda_error(cudaStreamSynchronize(stream_));
+    const auto tock = std::chrono::high_resolution_clock::now();
+    if (rank_ == 0) {
+        TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
     }
 }
 
 template<typename T>
-void LlamaBatch<T>::outputContextLogits(T*                      context_decoder_output,
+void LlamaBatch<T>::OutputContextLogits(T*                      context_decoder_output,
                                         const std::vector<int>& indices,
                                         const std::vector<int>& lengths)
 {
@@ -885,7 +924,7 @@ void LlamaBatch<T>::outputContextLogits(T*                      context_decoder_
     {
         bool is_return_logits = false;
         for (int k = 0; k < indices.size(); ++k) {
-            auto& request = requests_[indices[k]];
+            auto& request = state_->requests[indices[k]];
             output_logits.push_back(request->outputs[rank_].getPtr<float>("logits", nullptr));
             num_token += lengths[k];
             if (output_logits.back()) {
@@ -899,8 +938,9 @@ void LlamaBatch<T>::outputContextLogits(T*                      context_decoder_
 
     if (context_logits_buf_ == nullptr) {
         NcclGuard guard(llama_->tensor_para_, stream_, true);
-        context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
-        const auto tp       = llama_->tensor_para_.world_size_;
+        context_logits_buf_ =
+            (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
+        const auto tp = llama_->tensor_para_.world_size_;
         if (tp > 1) {
             FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
             const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
@@ -915,129 +955,72 @@ void LlamaBatch<T>::outputContextLogits(T*                      context_decoder_
 
     for (int k = 0; k < indices.size(); ++k) {
         if (output_logits[k]) {
-            check_cuda_error(cudaMemcpyAsync(output_logits[k],
-                                             logits,
-                                             sizeof(float) * llama_->vocab_size_ * lengths[k],
-                                             cudaMemcpyDefault,
-                                             stream_));
+            Copy(logits, llama_->vocab_size_ * lengths[k], output_logits[k]);
         }
         logits += llama_->vocab_size_padded_ * lengths[k];
     }
 }
 
 template<typename T>
-void LlamaBatch<T>::finish()
+int LlamaBatch<T>::Finish()
 {
+    const int batch_size = state_->active_size;
+
     // secure info needed by `synchronize()`
-    check_cuda_error(
-        cudaMemcpyAsync(h_finished_buf_, finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_));
-    check_cuda_error(
-        cudaMemcpyAsync(h_sequence_lengths_, sequence_lengths_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
+    Copy(finished_buf_, batch_size, state_->h_finished);
+    Copy(sequence_lengths_, batch_size, h_sequence_lengths_);
 
-    setOutputTensors(step_);
+    SetOutputTensors(step_);
 
     check_cuda_error(cudaStreamSynchronize(stream_));
 
-    for (int i = 0; i < batch_size_; ++i) {
-        FT_CHECK(requests_[i] != nullptr);
-        if (requests_[i]->stream_cb && rank_ == 0) {
-            requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
+    for (int i = 0; i < batch_size; ++i) {
+        FT_CHECK(state_->requests[i] != nullptr);
+        if (state_->requests[i]->stream_cb && rank_ == 0) {
+            state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
         }
     }
 
     if (debug_ && rank_ == 0) {
         std::stringstream ss;
-        for (int i = 0; i < batch_size_; ++i) {
-            ss << (i ? ", " : "") << "(" << h_sequence_lengths_[i] << "," << h_finished_buf_[i] << ")";
+        for (int i = 0; i < batch_size; ++i) {
+            ss << (i ? ", " : "") << "(" << h_sequence_lengths_[i] << "," << state_->h_finished[i] << ")";
         }
         TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
     }
 
-    for (int i = 0; i < batch_size_; ++i) {
-        if (h_finished_buf_[i]) {
-            finishRequest(i, false);
-            ++finished_count_;
-        }
-    }
-}
-
-template<typename T>
-void LlamaBatch<T>::synchronize()
-{
-    // compact
-    int idx = 0;
-    for (int i = 0; i < batch_size_; ++i) {
-        if (requests_[i]) {
-            h_input_length_buf_[idx]   = 0;
-            h_history_length_buf_[idx] = 0;
-
-            h_context_length_buf_[idx] = h_sequence_lengths_[i] + 1;
-            h_sequence_lengths_[idx]   = h_context_length_buf_[idx];
-
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + idx,
-                                             llama_->dynamic_decode_layer_->topk_curandstate_buf() + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + idx,
-                                             llama_->dynamic_decode_layer_->topp_curandstate_buf() + i,
-                                             sizeof(curandState_t),
-                                             cudaMemcpyDefault,
-                                             stream_));
-
-            if (i != idx) {
-                h_finished_buf_[idx]        = h_finished_buf_[i];
-                request_seq_len_limit_[idx] = request_seq_len_limit_[i];
-
-                h_k_cache_ptr_buf_[idx] = h_k_cache_ptr_buf_[i];
-                h_v_cache_ptr_buf_[idx] = h_v_cache_ptr_buf_[i];
-
-                requests_[idx]   = std::move(requests_[i]);
-                cached_seq_[idx] = std::move(cached_seq_[i]);
-                check_cuda_error(cudaMemcpyAsync(output_ids_buf_ + idx * session_len_,
-                                                 output_ids_buf_ + i * session_len_,
-                                                 sizeof(int) * h_context_length_buf_[idx],
-                                                 cudaMemcpyDefault,
-                                                 stream_));
-            }
-            ++idx;
+    int finished_count{};
+    for (int i = 0; i < batch_size; ++i) {
+        if (state_->requests[i] && state_->h_finished[i]) {
+            FinishRequest(i, false);
+            ++finished_count;
         }
     }
-    batch_size_ = idx;
-
-    if (rank_ == 0) {
-        TM_LOG_INFO("[synchronize] batch_size = %d", (int)batch_size_);
-    }
-
-    finished_count_ = 0;
+    return finished_count;
 }
 
 template<typename T>
-void LlamaBatch<T>::setOutputTensors(int max_gen_step)
+void LlamaBatch<T>::SetOutputTensors(int max_gen_step)
 {
+    const auto batch_size = state_->active_size;
     // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
-    invokeGatherOutput(output_ids_buf_,
+    invokeGatherOutput(state_->output_ids,
                        token_ids_buf_,
                        context_length_buf_,
                        max_context_len_,
                        max_gen_step,
                        session_len_,
-                       batch_size_,
+                       batch_size,
                        stream_);
     sync_check_cuda_error();
 
     /// TODO: fuse the loop into a single kernel
-    for (int i = 0; i < batch_size_; ++i) {
-        if (requests_[i]) {
-            auto& output_ids      = requests_[i]->outputs[rank_].at("output_ids");
-            auto& sequence_length = requests_[i]->outputs[rank_].at("sequence_length");
-            check_cuda_error(cudaMemcpyAsync(output_ids.getPtr<int>(),
-                                             output_ids_buf_ + i * session_len_,
-                                             sizeof(int) * output_ids.shape.at(2),
-                                             cudaMemcpyDefault,
-                                             stream_));
-            check_cuda_error(cudaMemcpyAsync(
-                sequence_length.getPtr<int>(), sequence_lengths_ + i, sizeof(int), cudaMemcpyDefault, stream_));
+    for (int i = 0; i < batch_size; ++i) {
+        if (state_->requests[i]) {
+            auto& output_ids      = state_->requests[i]->outputs[rank_].at("output_ids");
+            auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
+            Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
+            Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
             if (max_gen_step > max_context_len_) {  // +1 for newly generated token
                 invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
             }
@@ -1046,19 +1029,15 @@ void LlamaBatch<T>::setOutputTensors(int max_gen_step)
 }
 
 template<typename T>
-void LlamaBatch<T>::finishRequest(int index, bool force_end)
+void LlamaBatch<T>::FinishRequest(int index, bool force_end)
 {
     if (rank_ == 0) {
-        TM_LOG_INFO("[finishRequest] slot = %d, id = %lu", index, (long)requests_[index]->id);
+        TM_LOG_INFO("[finishRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
     }
 
     if (debug_ && rank_ == 0) {
         std::vector<int> tokens(h_sequence_lengths_[index] + 1);
-        cudaMemcpyAsync(tokens.data(),
-                        output_ids_buf_ + index * session_len_,
-                        sizeof(int) * tokens.size(),
-                        cudaMemcpyDefault,
-                        stream_);
+        Copy(state_->output_ids + index * session_len_, tokens.size(), tokens.data());
         cudaStreamSynchronize(stream_);
         std::stringstream ss;
         for (const auto& t : tokens) {
@@ -1067,10 +1046,8 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
         TM_LOG_INFO("[finishRequest] slot %d, tokens [%s]", index, ss.str().c_str());
     }
 
-    auto&      output_ids_tensor = requests_[index]->outputs[rank_].at("output_ids");
-    const auto output_ids_data   = output_ids_tensor.getPtr<int>();
-    if (requests_[index]->end_flag || force_end) {
-        llama_->kv_cache_mgr_->erase(requests_[index]->id);
+    if (state_->requests[index]->end_flag || force_end) {
+        sequence_manager_->Erase(state_->requests[index]->id);
     }
     else {
         // the last generated token is not processed by decoder thus dont have k/v cache
@@ -1078,38 +1055,110 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
         const int cache_len  = h_sequence_lengths_[index];
         const int output_len = n_steps > 0 ? cache_len + 1 : cache_len;
 
-        auto& seq = cached_seq_[index];
+        auto& seq = *state_->sequences[index];
 
         seq.cache_len = cache_len;
 
         // update token IDs
-        seq.token_ids.resize(output_len);
-        check_cuda_error(cudaMemcpyAsync(
-            seq.token_ids.data(), output_ids_data, sizeof(int) * output_len, cudaMemcpyDefault, stream_));
+        seq.tokens.resize(output_len);
+
+        const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>();
+        Copy(output_ids_data, output_len, seq.tokens.data());
 
         // update random states
-        seq.random_state_.resize(sizeof(curandState_t) * 2);
-        check_cuda_error(cudaMemcpyAsync(seq.random_state_.data(),
-                                         llama_->dynamic_decode_layer_->topk_curandstate_buf() + index,
-                                         sizeof(curandState_t),
-                                         cudaMemcpyDefault,
-                                         stream_));
-        check_cuda_error(cudaMemcpyAsync(seq.random_state_.data() + sizeof(curandState_t),
-                                         llama_->dynamic_decode_layer_->topp_curandstate_buf() + index,
-                                         sizeof(curandState_t),
-                                         cudaMemcpyDefault,
-                                         stream_));
+        seq.random_state.resize(sizeof(curandState_t) * 2);
+
+        // save random state in host memory
+        if (auto ptr = (curandState_t*)seq.random_state.data()) {
+            Copy(llama_->GetTopKState(index), 1, ptr++);
+            Copy(llama_->GetTopPState(index), 1, ptr++);
+        }
 
         check_cuda_error(cudaStreamSynchronize(stream_));
 
-        llama_->kv_cache_mgr_->update(cached_seq_[index], stream_);
+        sequence_manager_->Update(seq);
     }
 
+    // Notify request completion
     if (rank_ == 0) {
-        requests_[index]->signal.set_value(0);
+        state_->requests[index]->signal.set_value(0);
+    }
+
+    state_->requests[index]  = nullptr;
+    state_->sequences[index] = nullptr;
+}
+
+template<typename T>
+void LlamaBatch<T>::InternalThreadEntry(int device_id)
+{
+    TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
+    check_cuda_error(cudaSetDevice(device_id));
+
+    auto& shared_state = llama_->shared_state_;
+
+    auto& request_queue  = shared_state->request_queue;
+    auto& infer_requests = shared_state->infer_requests;
+    auto& stop_requests  = shared_state->stop_requests;
+
+    int finished_count = 0;
+
+    while (1) {
+        if (rank_ == 0) {
+            const int  free_slot_count = max_batch_size_ - state_->size + finished_count;
+            const bool is_empty        = (free_slot_count == max_batch_size_);
+
+            // will block if state is empty
+            request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
+
+            if (!shared_state->abort) {
+                RejectInvalidRequests(stop_requests, infer_requests);
+            }
+        }
+
+        // wait while rank-0 is dequeueing
+        shared_state->barrier->wait();
+
+        if (shared_state->abort) {
+            if (state_->size && rank_ == 0) {
+                TM_LOG_WARNING("Active request(s) present (%d) while aborting.", state_->size);
+            }
+            return;
+        }
+
+        ProcessStopRequests(stop_requests);
+
+        ProcessInferRequests(infer_requests);
+
+        // wait while shared stop/infer_requests is being used
+        shared_state->barrier->wait();
+
+        auto modified = Initialize();
+
+        ContextDecode();
+
+        if (state_->active_size) {
+            if (modified) {
+                InitializeGeneration();
+                InitializeSampling();
+            }
+            for (int i = 0; i < step_length_; ++i) {
+                if (!Generate()) {
+                    break;
+                }
+            }
+            finished_count = Finish();
+        }
     }
 
-    requests_[index] = nullptr;
+    FT_CHECK(0);
+}
+
+template<typename T>
+void LlamaBatch<T>::Start()
+{
+    int device_id = -1;
+    check_cuda_error(cudaGetDevice(&device_id));
+    internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
 }
 
 template class LlamaBatch<half>;
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 280562ffb1..a8bfad6729 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -2,66 +2,108 @@
 
 #pragma once
 
-#include "src/turbomind/models/llama/LlamaCacheManager.h"
+// #include "src/turbomind/models/llama/LlamaCacheManager.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/cublasMMWrapper.h"
 
 namespace turbomind {
 
+struct BatchState {
+    int*  h_context_length;
+    bool* h_finished;
+
+    void* top_k_curand_state;
+    void* top_p_curand_state;
+    int*  output_ids;  // output ids in [B, S]
+
+    std::vector<int> seq_len_limit;
+
+    std::vector<const Sequence*>          sequences;
+    std::vector<std::shared_ptr<Request>> requests;
+
+    // |<-- existing -->|<-- swap-in -->|<-- inactive -->|
+    int size;
+    int active_size;
+};
+
 template<typename T>
 class LlamaV2;
 
 template<typename T>
 class LlamaBatch {
 public:
-    int size() const noexcept
-    {
-        return batch_size_;
-    };
+    void AllocateBuffer(size_t batch_size, size_t session_len);
+    void AllocatePersistantBuffer(size_t max_batch_size);
+    void FreeBuffer();
 
-    int maxSize() const noexcept
-    {
-        return max_batch_size_;
-    }
+    using Requests = std::vector<std::shared_ptr<Request>>;
 
-    int finishedCount() const noexcept
-    {
-        return finished_count_;
-    }
+    void RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs);
 
-    void verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs,
-                        std::vector<std::shared_ptr<Request>>& infer_reqs);
-    void handleStopRequests(const std::vector<std::shared_ptr<Request>>& requests);
+    void ProcessStopRequests(const Requests& requests);
 
-    void allocateBuffer(size_t batch_size, size_t session_len);
-    void allocatePersistantBuffer(size_t max_batch_size);
-    void freeBuffer();
+    void ProcessInferRequests(const Requests& requests);
 
-    void initializeSampling(int infer_request_count);
+    bool Initialize();
 
-    void initialize(const std::vector<std::shared_ptr<Request>>& infer_requests);
-    void contextDecode();
+    void ContextDecode();
 
-    void initializeGeneration();
-    bool generate();
+    void InitializeSampling();
+    void InitializeGeneration();
+    bool Generate();
 
-    void finish();
-    void finishRequest(int index, bool force_end);
+    int  Finish();
+    void FinishRequest(int index, bool force_end);
 
-    void synchronize();
-
-    void setOutputTensors(int max_gen_step);
+    void SetOutputTensors(int max_gen_step);
 
     void
-    outputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
+    OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
 
-    explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama);
+    explicit LlamaBatch(int                              max_batch_size,
+                        int                              max_context_token_num,
+                        int                              session_len,
+                        std::unique_ptr<SequenceManager> sequence_manager,
+                        LlamaV2<T>*                      llama);
 
     ~LlamaBatch()
     {
-        freeBuffer();
+        llama_->shared_state_->request_queue.Abort();
+
+        internal_thread_.join();
+
+        FreeBuffer();
+    }
+
+    void Start();
+
+private:
+    void InternalThreadEntry(int device_id);
+
+    void UpdateSequenceStates(BatchState& state, int index);
+
+    void CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst);
+
+    void SaveRandomState(BatchState& state, int idx);
+
+    void LoadRandomState(BatchState& state, int idx);
+
+    // analogs to `std::copy_n`
+    template<typename U>
+    U* Copy(const U* src, size_t count, U* dst)
+    {
+        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(T) * count, cudaMemcpyDefault, stream_));
+        return dst += count;
+    }
+
+    template<typename U>
+    U* Clear(U* data, size_t count)
+    {
+        check_cuda_error(cudaMemsetAsync(data, 0, sizeof(U) * count, stream_));
+        return data += count;
     }
 
 private:
@@ -70,11 +112,11 @@ class LlamaBatch {
     const int  session_len_;
     const int  rank_;
     const bool debug_;
+    const int  step_length_;
 
     LlamaV2<T>* const llama_;
 
-    // active requests
-    std::vector<std::shared_ptr<Request>> requests_;
+    std::unique_ptr<SequenceManager> sequence_manager_;
 
     T*   context_decoder_input_buf_{};   // CTXDEC
     T*   context_decoder_output_buf_{};  // CTXDEC
@@ -83,16 +125,14 @@ class LlamaBatch {
     T* decoder_input_buf_{};   // CTXDEC, GENERATE
     T* decoder_output_buf_{};  // CTXDEC, GENERATE
 
-    int* input_ids_buf_{};       // input token ids + cache missed token ids, CTXDEC
-    int* input_length_buf_{};    // input + cache missed length, CTXDEC, GENERATE
-    int* history_length_buf_{};  // history length, CTXDEC
-    int* context_length_buf_{};  // history length + input_length, CTXDEC, GENERATE
-
-    int* total_padding_count_{};  // GENERATE
-    int* sequence_lengths_{};     // current sequence length
-
-    uint64_t* k_cache_ptr_buf_{};
-    uint64_t* v_cache_ptr_buf_{};
+    int*       input_ids_buf_{};       // input token ids + cache missed token ids, CTXDEC
+    int*       input_length_buf_{};    // input + cache missed length, CTXDEC, GENERATE
+    int*       history_length_buf_{};  // history length, CTXDEC
+    int*       context_length_buf_{};  // history length + input_length, CTXDEC, GENERATE
+    int*       sequence_lengths_{};    // current sequence length
+    int*       cu_block_counts_{};
+    uintptr_t* k_block_ptrs_{};
+    uintptr_t* v_block_ptrs_{};
 
     float* logits_buf_{};        // combined logits
     float* local_logits_buf_{};  // tensor parallel local logits
@@ -100,8 +140,7 @@ class LlamaBatch {
     float* local_context_logits_buf_{};
 
     // used by dynamic decoder
-    int*      token_ids_buf_{};   // all token IDs in [S, B], indexed using `step`
-    int*      output_ids_buf_{};  // output ids in [B, S]
+    int*      token_ids_buf_{};  // all token IDs in [S, B], indexed using `step`
     int*      end_ids_buf_{};
     bool*     finished_buf_{};
     uint32_t* seq_limit_len_{};
@@ -110,12 +149,11 @@ class LlamaBatch {
     int*       h_input_ids_buf_{};
     int*       h_input_length_buf_{};
     int*       h_history_length_buf_{};
-    int*       h_context_length_buf_{};
     int*       h_sequence_lengths_{};
-    bool*      h_finished_buf_{};
-    uintptr_t* h_k_cache_ptr_buf_{};
-    uintptr_t* h_v_cache_ptr_buf_{};
     uint32_t*  h_seq_limit_len_{};
+    int*       h_cu_block_counts_{};
+    uintptr_t* h_k_block_ptrs_{};
+    uintptr_t* h_v_block_ptrs_{};
 
     int*      stop_words_buf_{};  // [batch_size, 2, kMaxStopWordsLen]
     int*      bad_words_buf_{};
@@ -125,23 +163,21 @@ class LlamaBatch {
     float*    h_repetition_penalty_{};
     uint64_t* h_random_seed_{};
 
-    void* topk_curandstate_buf_{};
-    void* topp_curandstate_buf_{};
+    BatchState states_[3];
 
-    // hard limits for persistent buffers
-    static constexpr int kMaxStopBadWordsLen = 32;
+    BatchState* state_{};
+    BatchState* back_{};
+    BatchState* incoming_{};
 
-    using CachedSeq = LlamaCacheManager::Sequence;
+    uint64_t request_count_{0};
 
-    std::vector<CachedSeq> cached_seq_;
-    std::vector<int>       request_seq_len_limit_;
+    // hard limits for persistent buffers
+    static constexpr int kMaxStopBadWordsLen = 32;
 
     const DataType data_type_{};
 
-    int batch_size_{};
     int max_context_len_{};
     int step_{};
-    int finished_count_{};
 
     bool is_allocate_persistant_buffer_ = false;
     bool is_allocate_buffer_            = false;
@@ -154,6 +190,8 @@ class LlamaBatch {
     cudaStream_t     stream_{};
     cublasMMWrapper* cublas_wrapper_{};
     IAllocator*      allocator_{};
+
+    std::thread internal_thread_;
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 66bcf7570f..912589df09 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -116,6 +116,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
      *   \param history_lengths [batch_size], int
      *   \param context_lengths [batch_size], int
      *   \param cu_seqlens [batch_size+1], int
+     *   \param cu_block_counts [batch_size+1], int
      *   \param max_seq_len [1], int on cpu
      *   \param is_final_layer [1], bool on cpu
      *   \param layer_id [1], int on cpu
@@ -141,10 +142,11 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
     T* attention_input = input_tensors->at("input_query").getPtr<T>();
     T* attention_mask  = input_tensors->at("attention_mask").getPtr<T>();
 
-    const auto input_length   = input_tensors->at("input_lengths").getPtr<const int>();
-    const auto history_length = input_tensors->at("history_lengths").getPtr<const int>();
-    const auto context_length = input_tensors->at("context_lengths").getPtr<const int>();
-    int*       cu_seqlens     = input_tensors->at("cu_seqlens").getPtr<int>();
+    const auto input_length    = input_tensors->at("input_lengths").getPtr<const int>();
+    const auto history_length  = input_tensors->at("history_lengths").getPtr<const int>();
+    const auto context_length  = input_tensors->at("context_lengths").getPtr<const int>();
+    int*       cu_seqlens      = input_tensors->at("cu_seqlens").getPtr<int>();
+    int*       cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
 
     const auto padding_offset = input_tensors->at("padding_offset").getPtr<int>();
 
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index f914063a70..def55f41a8 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -109,6 +109,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
         {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
         {"history_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.history_length}},
         {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
+        {"cu_block_counts", input_tensors->at("cu_block_counts")},
         {"max_seq_len", input_tensors->at("max_seq_len")}};
 
     auto& k_cache = *sess.k_cache;
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index 73e95b1353..88dd76b935 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -124,6 +124,7 @@ void LlamaDecoder<T>::forwardSelfAttn(const LlamaDecoder::Session&
                                         {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io});
     const int layer_id = layer;
     self_attention_input_tensors.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id});
+    self_attention_input_tensors.insert("cu_block_counts", input_tensors->at("cu_block_counts"));
     auto& k_cache = *sess.k_cache;
     auto& v_cache = *sess.v_cache;
 
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index beaf3c3f6d..a1893736b6 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -28,6 +28,7 @@
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/LlamaWeight.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/models/llama/llama_params.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
@@ -86,14 +87,15 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
     cuda_device_prop_(cuda_device_prop),
     debug_(isDebug()),
     step_length_(step_length),
-    batch_(max_batch_size, max_context_token_num, session_len, this),
+    // batch_(max_batch_size, max_context_token_num, session_len, this),
     shared_state_(shared_state)
 
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
 
-    vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
+    vocab_size_padded_ =
+        (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
 
     size_t elem_bits = 0;
     if (quant_policy & QuantPolicy::kCacheKVInt8) {
@@ -109,23 +111,36 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
 
     const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
 
-    kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_,
-                                                        local_kv_head_num,
-                                                        size_per_head_,
-                                                        session_len,
-                                                        elem_bits,
-                                                        cache_max_entry_count,
-                                                        cache_chunk_size,
-                                                        tensor_para.rank_,
-                                                        allocator);
+    // kv_cache_mgr_     = std::make_unique<LlamaCacheManager>(num_layer_,
+    //                                                     local_kv_head_num,
+    //                                                     size_per_head_,
+    //                                                     session_len,
+    //                                                     elem_bits,
+    //                                                     cache_max_entry_count,
+    //                                                     cache_chunk_size,
+    //                                                     tensor_para.rank_,
+    //                                                     allocator);
+    auto sequence_manager = std::make_unique<SequenceManager>(num_layer,
+                                                              local_kv_head_num,
+                                                              size_per_head_,
+                                                              128,
+                                                              cache_max_entry_count,
+                                                              cache_chunk_size,
+                                                              elem_bits,
+                                                              tensor_para_.rank_,
+                                                              allocator);
+    batch_                = std::make_unique<LlamaBatch<T>>(
+        max_batch_size, max_context_token_num, session_len, std::move(sequence_manager), this);
+
     initialize(attn_params, kv_head_num, use_context_fmha, quant_policy);
-    start();
+
+    /// TODO: decouple Llama model and batch inference
+    batch_->Start();
 }
 
 template<typename T>
 LlamaV2<T>::~LlamaV2()
 {
-    internal_thread_.join();
 
     delete decoder_;
     delete dynamic_decode_layer_;
@@ -171,7 +186,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
 
     dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
                                                           vocab_size_padded_,
-                                                          0,            // end_id, deprecated
+                                                          0,  // end_id, deprecated
                                                           stream_,
                                                           cublas_wrapper_,
                                                           allocator_,
@@ -209,6 +224,7 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
                                const int* input_length,
                                const int* history_length,
                                const int* context_length,
+                               const int* cu_block_counts,
                                size_t     token_num,
                                size_t     max_input_len,
                                size_t     max_context_len,
@@ -251,7 +267,7 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
         {"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}},
         {"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}},
         {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
-    };
+        {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}}};
 
     std::unordered_map<std::string, Tensor> decoder_output_tensors{
         {"decoder_output", {MEMORY_GPU, dtype, {bsz, max_input_len, hidden_units_}, context_decoder_output_buf}},
@@ -267,17 +283,17 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
 }
 
 template<typename T>
-void LlamaV2<T>::decoderForward(T*         decoder_output,
-                                uintptr_t* k_cache_ptr,
-                                uintptr_t* v_cache_ptr,
-                                T*         decoder_input,
-                                const int* sequence_length,
-                                const int* total_padding_count,
-                                bool*      finished,
-                                int        step,
-                                int        ite,
-                                size_t     session_len,
-                                size_t     batch_size)
+void LlamaV2<T>::decoderForward(T*          decoder_output,
+                                uintptr_t*  k_cache_ptr,
+                                uintptr_t*  v_cache_ptr,
+                                T*          decoder_input,
+                                const int*  sequence_length,
+                                const bool* finished,
+                                const int*  cu_block_counts,
+                                int         step,
+                                int         ite,
+                                size_t      session_len,
+                                size_t      batch_size)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
@@ -289,7 +305,7 @@ void LlamaV2<T>::decoderForward(T*         decoder_output,
     std::unordered_map<std::string, Tensor> decoder_input_tensors{
         {"decoder_input", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_input}},
         {"sequence_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
-        {"total_padding_tokens", {MEMORY_GPU, TYPE_INT32, {batch_size}, total_padding_count}},
+        {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}},
         {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
         {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
         {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
@@ -430,73 +446,6 @@ void LlamaV2<T>::dynamicDecode(int*            token_ids,
     dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
 }
 
-template<typename T>
-void LlamaV2<T>::internalThreadEntry(int device_id)
-{
-    TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_);
-    check_cuda_error(cudaSetDevice(device_id));
-
-    auto& request_queue  = shared_state_->request_queue;
-    auto& infer_requests = shared_state_->infer_requests;
-    auto& stop_requests  = shared_state_->stop_requests;
-
-    while (1) {
-        if (tensor_para_.rank_ == 0) {
-            const int  free_slot_count = batch_.maxSize() - batch_.size() + batch_.finishedCount();
-            const bool is_empty        = free_slot_count == batch_.maxSize();
-
-            request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty);
-
-            batch_.verifyRequests(stop_requests, infer_requests);
-        }
-
-        // wait while rank-0 is dequeueing
-        shared_state_->barrier->wait();
-
-        bool modified = false;
-
-        if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) {
-            batch_.handleStopRequests(stop_requests);
-            batch_.synchronize();
-            modified = true;
-        }
-
-        const int infer_request_count = infer_requests.size();
-
-        if (!infer_requests.empty()) {
-            batch_.initialize(infer_requests);  // reinitialize when new requests come, possible buffer allocation
-            batch_.contextDecode();
-            modified = true;
-        }
-
-        // wait while shared stop/infer_requests is being used
-        shared_state_->barrier->wait();
-
-        if (batch_.size()) {
-            if (modified) {
-                batch_.initializeGeneration();
-                batch_.initializeSampling(infer_request_count);
-            }
-            for (int i = 0; i < step_length_; ++i) {
-                if (!batch_.generate()) {
-                    break;
-                }
-            }
-            batch_.finish();
-        }
-    }
-
-    FT_CHECK(0);
-}
-
-template<typename T>
-void LlamaV2<T>::start()
-{
-    int device_id = -1;
-    check_cuda_error(cudaGetDevice(&device_id));
-    internal_thread_ = std::thread(&LlamaV2<T>::internalThreadEntry, this, device_id);
-}
-
 static inline Tensor slice(const Tensor& tensor, int index)
 {
     auto shape = tensor.shape;
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index ed13aa40f4..31e4bf42d7 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -28,6 +28,7 @@
 #include "src/turbomind/models/llama/LlamaDecoder.h"
 #include "src/turbomind/models/llama/LlamaWeight.h"
 #include "src/turbomind/models/llama/Request.h"
+#include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/cublasMMWrapper.h"
 #include "src/turbomind/utils/instance_comm.h"
@@ -44,6 +45,7 @@ class LlamaV2 {
         std::vector<std::shared_ptr<Request>> stop_requests;
         RequestQueue                          request_queue;
         std::shared_ptr<Barrier>              barrier;
+        bool                                  abort;
     };
 
     ~LlamaV2();
@@ -94,8 +96,6 @@ class LlamaV2 {
 private:
     friend class Batch;
 
-    void internalThreadEntry(int device_id);
-
     void
     initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_context_fmha, int quant_policy);
 
@@ -110,23 +110,24 @@ class LlamaV2 {
                        const int* input_length,
                        const int* history_length,
                        const int* context_length,
+                       const int* cu_block_counts,
                        size_t     token_num,
                        size_t     max_input_len,
                        size_t     max_context_len,
                        size_t     session_len,
                        size_t     batch_size);
 
-    void decoderForward(T*         decoder_output,
-                        uintptr_t* k_cache_ptr,
-                        uintptr_t* v_cache_ptr,
-                        T*         decoder_input,
-                        const int* sequence_length,
-                        const int* total_padding_count,
-                        bool*      finished,
-                        int        step,
-                        int        ite,
-                        size_t     session_len,
-                        size_t     batch_size);
+    void decoderForward(T*          decoder_output,
+                        uintptr_t*  k_cache_ptr,
+                        uintptr_t*  v_cache_ptr,
+                        T*          decoder_input,
+                        const int*  sequence_length,
+                        const bool* finished,
+                        const int*  cu_block_counts,
+                        int         step,
+                        int         ite,
+                        size_t      session_len,
+                        size_t      batch_size);
 
     void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);
 
@@ -146,7 +147,15 @@ class LlamaV2 {
                        size_t          token_ids_len,
                        size_t          batch_size);
 
-    void start();
+    curandState_t* GetTopKState(int index)
+    {
+        return dynamic_decode_layer_->topk_curandstate_buf() + index;
+    }
+
+    curandState_t* GetTopPState(int index)
+    {
+        return dynamic_decode_layer_->topp_curandstate_buf() + index;
+    }
 
 private:
     friend class LlamaBatch<T>;
@@ -176,18 +185,14 @@ class LlamaV2 {
 
     const bool debug_{false};
 
-    std::unique_ptr<LlamaCacheManager> kv_cache_mgr_;
-
     LlamaWeight<T>*            weights_{};
     LlamaDecoder<T>*           decoder_{};
     LlamaContextDecoder<T>*    context_decoder_{};
     DynamicDecodeLayer<float>* dynamic_decode_layer_{};
 
-    const int                    step_length_;
-    LlamaBatch<T>                batch_;
-    std::shared_ptr<SharedState> shared_state_;
-
-    std::thread internal_thread_;
+    const int                      step_length_;
+    std::unique_ptr<LlamaBatch<T>> batch_;
+    std::shared_ptr<SharedState>   shared_state_;
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/Request.h b/src/turbomind/models/llama/Request.h
index cb2d1858a3..46badf98a5 100644
--- a/src/turbomind/models/llama/Request.h
+++ b/src/turbomind/models/llama/Request.h
@@ -14,9 +14,11 @@ namespace turbomind {
 
 struct Request {
     uint64_t id;
-    bool     start_flag;
-    bool     end_flag;
-    bool     stop_flag;
+    uint64_t priority;
+
+    bool start_flag;
+    bool end_flag;
+    bool stop_flag;
 
     // per rank inputs/outputs
     std::vector<TensorMap> inputs;
@@ -25,8 +27,7 @@ struct Request {
     using Callback = std::function<void(std::unordered_map<std::string, Tensor>*)>;
     Callback stream_cb;
 
-    enum
-    {
+    enum {
         kInvalid  = 1,
         kConflict = 2,
         kBusy     = 3,
@@ -61,11 +62,16 @@ class RequestQueue {
     void dequeue(std::vector<std::shared_ptr<Request>>& stop_requests,
                  std::vector<std::shared_ptr<Request>>& infer_requests,
                  unsigned                               max_infer_count,
-                 bool                                   blocking)
+                 bool                                   blocking,
+                 bool&                                  abort)
     {
         std::unique_lock<std::mutex> lock(mutex_);
         if (blocking) {
-            cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()); });
+            cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()) || abort_; });
+            if (abort_) {
+                abort = true;
+                return;
+            }
         }
 
         stop_requests.clear();
@@ -81,11 +87,18 @@ class RequestQueue {
         }
     }
 
+    void Abort()
+    {
+        std::lock_guard<std::mutex> lock(mutex_);
+        abort_ = true;
+    }
+
 private:
     std::queue<std::shared_ptr<Request>> stop_queue_;
     std::queue<std::shared_ptr<Request>> infer_queue_;
     std::mutex                           mutex_;
     std::condition_variable              cv_;
+    bool                                 abort_{false};
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
new file mode 100644
index 0000000000..7c03a5e0e5
--- /dev/null
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -0,0 +1,368 @@
+#include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/utils/logger.h"
+#include <ctime>
+
+namespace turbomind {
+
+SequenceManager::SequenceManager(size_t      layer_num,
+                                 size_t      head_num,
+                                 size_t      head_dim,
+                                 size_t      block_len,
+                                 double      block_count,
+                                 int         chunk_size,
+                                 size_t      elem_bits,
+                                 int         rank,
+                                 IAllocator* allocator):
+    block_len_(block_len), rank_(rank)
+{
+    constexpr int kBitsPerByte = 8;
+
+    size_t block_size = layer_num * head_num * block_len * head_dim * elem_bits / kBitsPerByte * 2;
+
+    block_manager_ = std::make_unique<BlockManager>(block_size, block_count, chunk_size, allocator);
+
+    val_offset_ = block_size / 2;
+}
+
+const Sequence* SequenceManager::Create(uint64_t id)
+{
+    Sequence sequence{id, {}, {}, {}, {}, {}};
+
+    auto it = sequences_.find(id);
+    if (it != sequences_.end()) {
+        if (rank_ == 0) {
+            TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id);
+        }
+        block_manager_->Release(it->second.blocks);
+        it->second = std::move(sequence);
+    }
+    else {
+        it = sequences_.emplace_hint(it, id, std::move(sequence));
+    }
+
+    return &it->second;
+}
+
+void SequenceManager::VerifyBlocks(Sequence& seq)
+{
+    FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
+    for (int i = 0; i < seq.blocks.size(); ++i) {
+        if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
+            seq.blocks.resize(i);
+            seq.block_unique_ids.resize(i);
+            break;
+        }
+    }
+    seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_len_);
+}
+
+const Sequence* SequenceManager::Fetch(uint64_t id)
+{
+    if (auto it = sequences_.find(id); it != sequences_.end()) {
+        auto& sequence = it->second;
+        return &it->second;
+    }
+
+    return nullptr;
+}
+
+bool SequenceManager::Erase(uint64_t id)
+{
+    if (auto it = sequences_.find(id); it != sequences_.end()) {
+        auto& seq = it->second;
+        if (seq.status != Sequence::kCached) {
+            if (released_.empty()) {
+                released_ = std::move(seq.blocks);
+            }
+            else {
+                released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
+            }
+        }
+        sequences_.erase(it);
+    }
+
+    return false;
+}
+
+void SequenceManager::Update(const Sequence& sequence)
+{
+    block_manager_->Touch(sequence.blocks);
+}
+
+bool SequenceManager::Contains(uint64_t id)
+{
+    return sequences_.find(id) != sequences_.end();
+}
+
+namespace {
+
+struct Schedule {
+    int free;
+    int cached;
+
+    int allocate;
+    int evict;
+
+    std::vector<int> victims;
+
+    std::vector<int> active;
+    std::vector<int> block_counts;
+
+    std::vector<int> inactive;
+};
+
+class Simulator {
+public:
+    explicit Simulator(const std::vector<const Sequence*>& seqs,
+                       const std::vector<int>&             idxs,
+                       std::vector<int>&                   ref_count):
+        seqs_(seqs), idxs_(idxs), ref_count_(ref_count)
+    {
+        released_.resize(seqs.size());
+        ptr_ = released_.size();
+    }
+
+    int Release(int order)
+    {
+        while (order < ptr_) {
+            --ptr_;
+            int count = 0;
+            for (const auto& p : seqs_[idxs_[ptr_]]->blocks) {
+                if (--ref_count_[p->id] == 0) {
+                    ++count;
+                }
+            }
+            released_[ptr_] = count;
+        }
+
+        return released_[order];
+    }
+
+private:
+    const std::vector<const Sequence*>& seqs_;
+    const std::vector<int>&             idxs_;
+
+    std::vector<int>& ref_count_;
+
+    std::vector<int> released_;
+    int              ptr_;
+};
+
+struct Transaction {
+    int index_;
+    int block_count_;
+
+    int allocate_{};
+    int evict_{};
+    int preempt_{};
+
+    std::vector<int> victims_;
+
+    Schedule&  sched_;
+    Simulator& simulator_;
+
+    explicit Transaction(Schedule& sched, int index, int block_count, Simulator& simulator):
+        sched_(sched), index_(index), block_count_(block_count), simulator_(simulator)
+    {
+    }
+
+    int Allocate(int count)
+    {
+        allocate_ += count;
+        return count;
+    }
+
+    int Evict(int count)
+    {
+        evict_ += count;
+        return count;
+    }
+
+    int Preempt(int order, int idx)
+    {
+        victims_.push_back(idx);
+        preempt_ += simulator_.Release(order);
+        return preempt_;
+    }
+
+    void Commit()
+    {
+        sched_.free -= allocate_;
+        sched_.cached += preempt_ - evict_;
+
+        sched_.allocate += allocate_;
+        sched_.evict += evict_;
+
+        sched_.victims.insert(sched_.victims.end(), victims_.begin(), victims_.end());
+
+        sched_.active.push_back(index_);
+        sched_.block_counts.push_back(block_count_);
+    }
+};
+
+}  // namespace
+
+std::ostream& operator<<(std::ostream& os, const Sequence& seq)
+{
+    os << "Sequence[id=" << seq.id << ",status=" << seq.status << ",size(blocks)=" << seq.blocks.size()
+       << ",cache_len=" << seq.cache_len << ",size(random_state)=" << seq.random_state.size() << "]";
+    return os;
+}
+
+bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
+                                  const std::vector<int>&             context_lengths,
+                                  const std::vector<uint64_t>&        priorities,
+                                  int                                 step_length)
+{
+    ////////////////////////////////////////////////////////////////////////////////
+    /// Schedule the assignment of blocks to sequences
+    auto seqs = const_cast<Sequence* const*>(sequences.data());
+
+    // check validity of of cached blocks (blocks of active & locked seqs are always valid)
+    if (need_verification_) {
+        for (int i = 0; i < sequences.size(); ++i) {
+            if (seqs[i]->status == Sequence::kCached) {
+                VerifyBlocks(*seqs[i]);
+            }
+        }
+        need_verification_ = false;
+    }
+
+    // count required blocks based on block validity
+    std::vector<int> required(sequences.size());
+    int              total_required{};
+    for (int i = 0; i < sequences.size(); ++i) {
+        int seq_len = context_lengths[i] + step_length;
+        int count   = (seq_len + block_len_ - 1) / block_len_ - static_cast<int>(seqs[i]->blocks.size());
+        required.push_back(std::max(0, count));
+        total_required += required.back();
+    }
+
+    // no new blocks required, exit early
+    if (total_required == 0) {
+        return false;
+    }
+
+    /// TODO: more early exit heuristics
+
+    // sort according to priority
+    std::vector<int> idxs(sequences.size());
+    std::iota(idxs.begin(), idxs.end(), 0);
+    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return priorities[i] < priorities[j]; });
+
+    Snapshot snapshot = block_manager_->TakeSnapshot();
+
+    Schedule schedule{snapshot.free, snapshot.cached};
+    schedule.cached += released_.size();
+
+    Simulator simulator(sequences, idxs, snapshot.ref_count);
+
+    bool modified = false;
+
+    for (int i = 0, j = idxs.size(); i < j; ++i) {
+        const int idx = idxs[i];
+
+        const auto& seq         = *sequences[idx];
+        auto        block_count = required[idx];
+
+        Transaction trans{schedule, idx, block_count, simulator};
+
+        // allocate from free blocks
+        if (block_count) {
+            block_count -= trans.Allocate(std::min(block_count, schedule.free));
+        }
+        // evict cached blocks
+        if (block_count) {
+            block_count -= trans.Evict(std::min(block_count, schedule.free));
+        }
+
+        for (int v = j - 1; block_count && v > i; --v) {
+            if (sequences[idxs[v]]->status == Sequence::kCached) {
+                continue;
+            }
+            int preempt = trans.Preempt(v, idxs[v]);
+            // Commit only when preemption actually free enough blocks for the sequence to run
+            if (block_count <= preempt) {
+                // preempted blocks are in cached state
+                block_count -= trans.Evict(block_count);
+                j = v + 1;
+                break;
+            }
+        }
+
+        if (block_count == 0) {
+            trans.Commit();
+            if (seq.status != Sequence::kActive) {
+                modified = true;
+            }
+        }
+        else {
+            // failed to collect enough block for the sequence, transaction aborted. Active sequence will be kept
+            // locked if not preempted by seq with higher priority
+            schedule.inactive.push_back(idx);
+            if (seq.status == Sequence::kActive) {
+                modified = true;
+            }
+        }
+    }
+
+    // Verify the schedule
+    FT_CHECK(schedule.allocate <= snapshot.free);
+    FT_CHECK(schedule.evict <= snapshot.cached);
+    // FT_CHECK(schedule.allocate + schedule.evict + schedule.preempt == total_block_count);
+
+    ////////////////////////////////////////////////////////////////////////////////
+    /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)
+    schedule.allocate += schedule.evict;
+
+    // release preempted blocks -> cached
+    {
+        std::vector<const Block*> blocks;
+        for (const auto& v : schedule.victims) {
+            auto& seq = *seqs[v];
+            block_manager_->Touch(seq.blocks);
+            seq.status = Sequence::kCached;
+            blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
+        }
+        block_manager_->Release(blocks);
+    }
+
+    // evict cached blocks -> free
+    if (schedule.evict) {
+        need_verification_ = true;
+        block_manager_->Evict(schedule.evict);
+    }
+
+    // allocate & assign blocks
+    auto blocks = block_manager_->Allocate(schedule.allocate + schedule.evict);
+    auto first  = blocks.begin();
+
+    for (const auto& idx : schedule.active) {
+        auto& sequence  = *seqs[idx];
+        sequence.status = Sequence::kActive;
+
+        auto last = first + required[idx];
+        std::for_each(first, last, [&sequence](const Block* b) {
+            sequence.blocks.push_back(b);
+            sequence.block_unique_ids.push_back(b->unique_id);
+        });
+
+        first = last;
+    }
+
+    block_manager_->Touch(blocks);
+
+    for (const auto& idx : schedule.inactive) {
+        if (seqs[idx]->status == Sequence::kActive) {
+            seqs[idx]->status = Sequence::kLocked;
+        }
+    }
+
+    for (const auto& idx : schedule.victims) {
+        seqs[idx]->status = Sequence::kCached;
+    }
+
+    return modified;
+}
+
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
new file mode 100644
index 0000000000..0ce057b95e
--- /dev/null
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -0,0 +1,96 @@
+#pragma once
+
+#include "src/turbomind/models/llama/BlockManager.h"
+
+namespace turbomind {
+
+// |<-- active -->|<-- pending -->|<-- new -->|
+
+struct Sequence {
+
+    enum Status {
+        kCached = 0,
+        kLocked,
+        kActive
+    };
+
+    uint64_t id;
+    Status   status;
+
+    std::vector<const Block*> blocks;
+    std::vector<uint64_t>     block_unique_ids;
+
+    mutable std::vector<int> tokens;  // update by user
+
+    mutable int cache_len;
+
+    // additional data kept round-to-round
+    mutable std::vector<std::byte> random_state;  // update by user
+
+    friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
+};
+
+class SequenceManager {
+public:
+    // allocate slack blocks to reduce block manager overhead
+    static constexpr int kSlackBlockNum = 1;
+
+    explicit SequenceManager(size_t      layer_num,
+                             size_t      head_num,
+                             size_t      head_dim,
+                             size_t      block_len,
+                             double      block_count,
+                             int         chunk_size,
+                             size_t      elem_bits,
+                             int         rank,
+                             IAllocator* allocator);
+
+    const Sequence* Create(uint64_t id);
+
+    const Sequence* Fetch(uint64_t id);
+
+    void Update(const Sequence& seq);
+
+    bool Erase(uint64_t id);
+
+    bool Contains(uint64_t id);
+
+    bool Materialize(const std::vector<const Sequence*>& sequences,
+                     const std::vector<int>&             context_lengths,
+                     const std::vector<uint64_t>&        priorities,
+                     int                                 step_length);
+
+    void* OffsetKey(void* block_ptr)
+    {
+        return block_ptr;
+    }
+
+    void* OffsetVal(void* block_ptr)
+    {
+        return (std::byte*)block_ptr + val_offset_;
+    }
+
+    int max_block_count() const noexcept
+    {
+        return block_manager_->max_block_count();
+    }
+
+private:
+    void VerifyBlocks(Sequence& seq);
+
+private:
+    int    block_len_;
+    int    rank_;
+    size_t val_offset_{};
+
+    bool need_verification_{};
+
+    // Use `std::map` to avoid reference invalidation
+    std::map<uint64_t, Sequence> sequences_;
+
+    std::unique_ptr<BlockManager> block_manager_;
+
+    std::vector<const Block*> released_;
+};
+
+}  // namespace turbomind
\ No newline at end of file

From a7e31c59a7fc0c4d70e4947a2ae2ea6c931e1137 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 25 Sep 2023 02:02:35 +0000
Subject: [PATCH 04/56] update

---
 src/turbomind/kernels/decoder_mha/kv_cache.cu | 342 +++++++++++-----
 src/turbomind/kernels/decoder_mha/kv_cache.h  |  48 ++-
 .../test_decoder_multihead_attention.cu       | 108 ++++--
 .../llama/LlamaContextAttentionLayer.cc       |  18 +-
 .../models/llama/LlamaContextAttentionLayer.h |   2 +
 src/turbomind/models/llama/SequenceManager.cc |   4 +-
 src/turbomind/models/llama/SequenceManager.h  |   8 +
 src/turbomind/models/llama/llama_kernels.cu   | 365 +++++-------------
 src/turbomind/models/llama/llama_kernels.h    |  21 +-
 9 files changed, 502 insertions(+), 414 deletions(-)

diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.cu b/src/turbomind/kernels/decoder_mha/kv_cache.cu
index 4e73d04b26..0fdad8e063 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.cu
@@ -1,148 +1,312 @@
 #include "../gemm_s_f16/common.h"
 // #include "cute/tensor.hpp"
 #include <cuda_fp16.h>
+#include <type_traits>
 
 namespace turbomind {
 
 // [S/x, H, x, D] <-> [S/y, H, y, D]
 
-template<typename T>
-__device__ void ConvertBlockSize(const T** src_block_ptrs,
-                                 T**       dst_block_ptrs,
-                                 int       src_block_size,
-                                 int       dst_block_size,
-                                 int       heads,
-                                 int       dims,
-                                 int       seq_len)
+// [S, H, 1, D] <-> [1, H, S, D]
+
+namespace {
+
+struct ThreadMap {};
+
+}  // namespace
+
+template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
+__inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptrs,
+                                            T** __restrict__ dst_block_ptrs,
+                                            const int* __restrict__ src_cu_block_cnts,
+                                            const int* __restrict__ dst_cu_block_cnts,
+                                            const int* __restrict__ seq_lens,
+                                            SrcBlockLen src_block_len,
+                                            DstBlockLen dst_block_len,
+                                            HeadDim     head_dim)
 {
     constexpr int kVecSize = sizeof(uint4) / sizeof(T);
 
-    size_t count = (size_t)heads * seq_len * dims;
+    const int hi = blockIdx.y;
+    const int bi = blockIdx.z;
 
-    for (size_t i = (threadIdx.x + blockIdx.x * blockDim.x) * kVecSize; i < count;
-         i += blockDim.x * gridDim.x * kVecSize) {
-        // get coords from [H, S, D]
-        int di = i % dims;
-        int ii = i / dims;
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int di  = idx * kVecSize % head_dim;
+    const int si  = idx * kVecSize / head_dim;
 
-        int si = ii % seq_len;
-        int hi = ii / seq_len;
+    if (si >= seq_lens[bi]) {
+        return;
+    }
 
-        // compute indices into src
-        int src_block_index  = si / src_block_size;
-        int src_block_offset = hi * src_block_size * dims + si % src_block_size * dims + di;
+    // compute indices into src
+    int src_block_index  = si / src_block_len + src_cu_block_cnts[bi];
+    int src_block_offset = hi * src_block_len * head_dim + si % src_block_len * head_dim + di;
 
-        // compute indices into dst
-        int dst_block_index  = si / dst_block_size;
-        int dst_block_offset = hi * dst_block_size * dims + si % dst_block_size * dims + di;
+    // compute indices into dst
+    int dst_block_index  = si / dst_block_len + dst_cu_block_cnts[bi];
+    int dst_block_offset = hi * dst_block_len * head_dim + si % dst_block_len * head_dim + di;
 
-        const T* src_block = src_block_ptrs[src_block_index];
-        T*       dst_block = dst_block_ptrs[dst_block_index];
+    // printf("%d %d\n", src_block_index, dst_block_index);
 
-        uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
+    const T* __restrict__ src_block = src_block_ptrs[src_block_index];
+    T* __restrict__ dst_block       = dst_block_ptrs[dst_block_index];
 
-        *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
-    }
+    uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
+
+    *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
+}
+
+static inline size_t get_helper_smem_size(int batch_size)
+{
+    return (sizeof(void*) + sizeof(int)) * batch_size;
 }
 
 template<typename T>
-__global__ void
-LinearToBlocksKernel(const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len)
+__global__ void LinearToBlocksKernel(const T*   src,
+                                     T**        dst_block_ptrs,
+                                     const int* dst_cu_block_cnts,
+                                     const int* seq_lens,
+                                     int        src_block_len,
+                                     int        dst_block_len,
+                                     int        head_num,
+                                     int        head_dim,
+                                     int        batch_size)
 {
-    __shared__ const T* src_block_ptr[1];
+    extern __shared__ void* smem[];
+
+    const T** src_block_ptrs    = (const T**)smem;
+    int*      src_cu_block_cnts = (int*)(src_block_ptrs + batch_size);
 
-    if (threadIdx.x == 0) {
-        src_block_ptr[0] = src;
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        src_cu_block_cnts[i] = i;
+        src_block_ptrs[i]    = src + blockIdx.z * head_num * src_block_len * head_dim;
     }
 
     __syncthreads();
 
-    ConvertBlockSize(src_block_ptr, dst_block_ptrs, seq_len, dst_block_size, heads, dims, seq_len);
+    ConvertBlockSize(src_block_ptrs,
+                     dst_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
 }
 
-template<typename T>
-__global__ void
-BlocksToLinearKernel(const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len)
+template<typename T, typename HeadDim>
+__global__ void BlocksToLinearKernel(const T**  src_block_ptrs,
+                                     T*         dst,
+                                     const int* src_cu_block_cnts,
+                                     const int* seq_lens,
+                                     int        src_block_len,
+                                     int        dst_block_len,
+                                     int        head_num,
+                                     HeadDim    head_dim,
+                                     int        batch_size)
 {
-    __shared__ T* dst_block_ptr[1];
+    extern __shared__ void* smem[];
+
+    T**  dst_block_ptrs    = (T**)smem;
+    int* dst_cu_block_cnts = (int*)(dst_block_ptrs + batch_size);
 
-    if (threadIdx.x == 0) {
-        dst_block_ptr[0] = dst;
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+        dst_block_ptrs[i]    = dst + blockIdx.z * head_num * dst_block_len * head_dim;
     }
 
     __syncthreads();
 
-    ConvertBlockSize(src_block_ptrs, dst_block_ptr, src_block_size, seq_len, heads, dims, seq_len);
+    ConvertBlockSize(src_block_ptrs,
+                     dst_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
 }
 
-template<typename T>
-__global__ void BlocksToBlocksKernel(const T** src_block_ptrs,
-                                     T**       dst_block_ptrs,
-                                     int       src_block_size,
-                                     int       dst_block_size,
-                                     int       heads,
-                                     int       dims,
-                                     int       seq_len)
+template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
+__global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
+                                            const T**   src_v_block_ptrs,
+                                            T*          dst_k,
+                                            T*          dst_v,
+                                            const int*  src_cu_block_cnts,
+                                            const int*  seq_lens,
+                                            SrcBlockLen src_block_len,
+                                            DstBlockLen dst_block_len,
+                                            int         head_num,
+                                            HeadDim     head_dim,
+                                            int         batch_size)
 {
-    ConvertBlockSize(src_block_ptrs, dst_block_ptrs, src_block_size, dst_block_size, heads, dims, seq_len);
-}
+    extern __shared__ void* smem[];
 
-template<typename T>
-void ConvertLinearToBlocks(
-    const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st)
-{
-    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+    T**  dst_k_block_ptrs  = (T**)smem;
+    T**  dst_v_block_ptrs  = dst_k_block_ptrs + batch_size;
+    int* dst_cu_block_cnts = (int*)(dst_v_block_ptrs + batch_size);
 
-    int threads = 512;
-    int blocks  = std::min(512, (heads * seq_len * dims / kVecSize + threads - 1) / threads);
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+        dst_k_block_ptrs[i]  = dst_k + blockIdx.z * head_num * dst_block_len * head_dim;
+        dst_v_block_ptrs[i]  = dst_v + blockIdx.z * head_num * dst_block_len * head_dim;
+    }
 
-    LinearToBlocksKernel<<<blocks, threads, 0, st>>>(src, dst_block_ptrs, dst_block_size, heads, dims, seq_len);
+    __syncthreads();
+
+    ConvertBlockSize(src_k_block_ptrs,
+                     dst_k_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+
+    ConvertBlockSize(src_v_block_ptrs,
+                     dst_v_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
 }
 
 template<typename T>
-void ConvertBlocksToLinear(
-    const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st)
+void ConvertLinearToBlocks(const T*     src,
+                           T**          dst_block_ptrs,
+                           const int*   dst_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          src_max_len,
+                           int          dst_block_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
+                           cudaStream_t st)
 {
     constexpr int kVecSize = sizeof(uint4) / sizeof(T);
 
-    int threads = 256;
-    int blocks  = (heads * seq_len * dims / kVecSize + threads - 1) / threads;
-
-    BlocksToLinearKernel<<<blocks, threads, 0, st>>>(src_block_ptrs, dst, src_block_size, heads, dims, seq_len);
+    constexpr int threads = 128;
+    const dim3    blocks((src_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
+
+    const auto smem_sz = get_helper_smem_size(batch_size);
+
+    auto fn = [&](auto head_dim) {
+        LinearToBlocksKernel<<<blocks, threads, smem_sz, st>>>(src,
+                                                               dst_block_ptrs,
+                                                               dst_cu_block_cnts,
+                                                               seq_lens,
+                                                               src_max_len,
+                                                               dst_block_len,
+                                                               head_num,
+                                                               head_dim,
+                                                               batch_size);
+    };
+
+    switch (head_dim) {
+        case 128:
+            fn(std::integral_constant<int, 128>{});
+            break;
+        default:
+            fn(head_dim);
+    }
 }
 
+template void ConvertLinearToBlocks(const half*  src,
+                                    half**       dst_block_ptrs,
+                                    const int*   dst_cu_block_cnts,
+                                    const int*   seq_lens,
+                                    int          src_seq_len,
+                                    int          dst_block_len,
+                                    int          head_num,
+                                    int          head_dim,
+                                    int          batch_size,
+                                    cudaStream_t st);
+
 template<typename T>
-void ConvertBlocksToBlocks(const T**    src_block_ptrs,
-                           T**          dst_block_ptrs,
-                           int          src_block_size,
-                           int          dst_block_size,
-                           int          heads,
-                           int          dims,
-                           int          seq_len,
+void ConvertBlocksToLinear(const T**    src_block_ptrs,
+                           T*           dst,
+                           const int*   src_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          src_block_len,
+                           int          dst_max_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
                            cudaStream_t st)
 {
     constexpr int kVecSize = sizeof(uint4) / sizeof(T);
 
-    int threads = 512;
-    int blocks  = std::min(512, (heads * seq_len * dims / kVecSize + threads - 1) / threads);
-
-    BlocksToBlocksKernel<<<blocks, threads, 0, st>>>(
-        src_block_ptrs, dst_block_ptrs, src_block_size, dst_block_size, heads, dims, seq_len);
+    constexpr int threads = 256;
+    const dim3    blocks((dst_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
+
+    const auto smem_sz = get_helper_smem_size(batch_size);
+
+    auto fn = [&](auto head_dim) {
+        BlocksToLinearKernel<<<blocks, threads, smem_sz, st>>>(src_block_ptrs,
+                                                               dst,
+                                                               src_cu_block_cnts,
+                                                               seq_lens,
+                                                               std::integral_constant<int, 128>{},
+                                                               dst_max_len,
+                                                               head_num,
+                                                               head_dim,
+                                                               batch_size);
+    };
+
+    switch (head_dim) {
+        case 128:
+            fn(std::integral_constant<int, 128>{});
+            break;
+        default:
+            fn(head_dim);
+    }
 }
 
-template void ConvertLinearToBlocks(
-    const half* src, half** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st);
-
-template void ConvertBlocksToLinear(
-    const half** src_block_ptrs, half* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st);
-
-template void ConvertBlocksToBlocks(const half** src_block_ptrs,
-                                    half**       dst_block_ptrs,
-                                    int          src_block_size,
-                                    int          dst_block_size,
-                                    int          heads,
-                                    int          dims,
-                                    int          seq_len,
+template void ConvertBlocksToLinear(const half** src_block_ptrs,
+                                    half*        dst,
+                                    const int*   src_cu_block_cnts,
+                                    const int*   seq_lens,
+                                    int          src_block_len,
+                                    int          dst_max_seq_len,
+                                    int          head_num,
+                                    int          head_dim,
+                                    int          batch_size,
                                     cudaStream_t st);
 
+// template<typename T>
+// void ConvertBlocksToBlocks(const T**    src_block_ptrs,
+//                            T**          dst_block_ptrs,
+//                            int          src_block_len,
+//                            int          dst_block_len,
+//                            int          heads,
+//                            int          dims,
+//                            int          seq_len,
+//                            cudaStream_t st)
+// {
+//     constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+//     int threads = 512;
+//     int blocks  = std::min(512, (heads * seq_len * dims / kVecSize + threads - 1) / threads);
+
+//     BlocksToBlocksKernel<<<blocks, threads, 0, st>>>(
+//         src_block_ptrs, dst_block_ptrs, src_block_len, dst_block_len, heads, dims, seq_len);
+// }
+
+// template void ConvertLinearToBlocks(
+//     const half* src, half** dst_block_ptrs, int dst_block_len, int heads, int dims, int seq_len, cudaStream_t st);
+
+// template void ConvertBlocksToLinear(
+//     const half** src_block_ptrs, half* dst, int src_block_len, int heads, int dims, int seq_len, cudaStream_t st);
+
+// template void ConvertBlocksToBlocks(const half** src_block_ptrs,
+//                                     half**       dst_block_ptrs,
+//                                     int          src_block_len,
+//                                     int          dst_block_len,
+//                                     int          heads,
+//                                     int          dims,
+//                                     int          seq_len,
+//                                     cudaStream_t st);
+
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.h b/src/turbomind/kernels/decoder_mha/kv_cache.h
index 72758e4b08..6b29c53ae7 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.h
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.h
@@ -4,22 +4,46 @@
 
 namespace turbomind {
 
-template<typename T>
-void ConvertLinearToBlocks(
-    const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+// template<typename T>
+// void ConvertLinearToBlocks(
+//     const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st);
 
-template<typename T>
-void ConvertBlocksToLinear(
-    const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+// template<typename T>
+// void ConvertBlocksToLinear(
+//     const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st);
+
+// template<typename T>
+// void ConvertBlocksToBlocks(const T**    src_block_ptrs,
+//                            T**          dst_block_ptrs,
+//                            int          src_block_size,
+//                            int          dst_block_size,
+//                            int          heads,
+//                            int          dims,
+//                            int          seq_len,
+//                            cudaStream_t st);
 
 template<typename T>
-void ConvertBlocksToBlocks(const T**    src_block_ptrs,
+void ConvertLinearToBlocks(const T*     src,
                            T**          dst_block_ptrs,
-                           int          src_block_size,
-                           int          dst_block_size,
-                           int          heads,
-                           int          dims,
-                           int          seq_len,
+                           const int*   dst_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          src_seq_len,
+                           int          dst_block_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
+                           cudaStream_t st);
+
+template<typename T>
+void ConvertBlocksToLinear(const T**    src_block_ptrs,
+                           T*           dst,
+                           const int*   src_cu_block_cnts,
+                           const int*   seq_lens,
+                           int          src_block_len,
+                           int          dst_max_seq_len,
+                           int          head_num,
+                           int          head_dim,
+                           int          batch_size,
                            cudaStream_t st);
 
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
index 8a5bc46d5f..7d71718bf7 100644
--- a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
@@ -4,9 +4,11 @@
 #include "kv_cache.h"
 #include "test_utils.h"
 #include <cmath>
+#include <ios>
 #include <iostream>
 #include <thrust/universal_vector.h>
 
+#include <iomanip>
 #include <numeric>
 
 using namespace turbomind;
@@ -29,38 +31,65 @@ void TestBlocks(thrust::universal_vector<half>&  linear,
                 thrust::universal_vector<half*>& _ptrs,
                 int                              head_num,
                 int                              head_dim,
-                int                              block_size)
+                int                              block_size,
+                int                              batch_size)
 {
-    int seq_len  = linear.size() / head_num / head_dim;
+    int seq_len  = linear.size() / (head_dim * head_num * batch_size);
     int n_blocks = (seq_len + block_size - 1) / block_size;
 
-    std::cout << "seq_len = " << seq_len << ", block_num = " << n_blocks << ", block_size = " << block_size << "\n";
+    std::cout << "batch_size = " << batch_size << ", seq_len = " << seq_len << ", block_num = " << n_blocks
+              << ", block_size = " << block_size << "\n";
 
-    thrust::universal_vector<half>  blocks(n_blocks * head_num * block_size * head_dim);
-    thrust::universal_vector<half*> ptrs(n_blocks);
+    thrust::universal_vector<half>  blocks(batch_size * n_blocks * head_num * block_size * head_dim);
+    thrust::universal_vector<half*> ptrs(batch_size * n_blocks);
 
-    std::vector<size_t> idxs(n_blocks);
+    std::vector<size_t> idxs(batch_size * n_blocks);
     std::iota(idxs.begin(), idxs.end(), 0);
 
     std::random_shuffle(idxs.begin(), idxs.end());
 
-    for (int i = 0; i < n_blocks; ++i) {
+    for (int i = 0; i < idxs.size(); ++i) {
         ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
     }
 
+    thrust::universal_vector<int> seq_lens(batch_size);
+    thrust::fill(seq_lens.begin(), seq_lens.end(), seq_len);
+
+    std::vector<int>              n_blocks_vec(batch_size + 1, n_blocks);
+    thrust::universal_vector<int> cu_block_cnts(batch_size + 1);
+    std::exclusive_scan(n_blocks_vec.begin(), n_blocks_vec.end(), cu_block_cnts.begin(), 0);
+
     for (int i = 0; i < 10; ++i) {
-        ConvertLinearToBlocks(
-            (const half*)linear.data().get(), ptrs.data().get(), block_size, head_num, head_dim, seq_len, 0);
+        ConvertLinearToBlocks((const half*)linear.data().get(),
+                              ptrs.data().get(),
+                              cu_block_cnts.data().get(),
+                              seq_lens.data().get(),
+                              seq_len,
+                              block_size,
+                              head_num,
+                              head_dim,
+                              batch_size,
+                              0);
     }
     thrust::universal_vector<half> _linear(linear.size());
 
     for (int i = 0; i < 10; ++i) {
-        ConvertBlocksToLinear(
-            (const half**)ptrs.data().get(), _linear.data().get(), block_size, head_num, head_dim, seq_len, 0);
+        ConvertBlocksToLinear((const half**)ptrs.data().get(),
+                              _linear.data().get(),
+                              cu_block_cnts.data().get(),
+                              seq_lens.data().get(),
+                              block_size,
+                              seq_len,
+                              head_num,
+                              head_dim,
+                              batch_size,
+                              0);
     }
     cudaDeviceSynchronize();
-
-    // Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, head_num * seq_len);
+    std::cout << ">>> Compare\n";
+    Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
+    std::cout << "<<< Compare\n";
+    std::exit(0);
 
     _blocks.swap(blocks);
     _ptrs.swap(ptrs);
@@ -70,15 +99,20 @@ int main(int argc, char* argv[])
 {
     DecoderMultiHeadAttentionParams<half> params{};
 
-    constexpr int kHeadNum = 108 * 4;
-    // constexpr int kHeadNum     = 32 * 4;
+    // constexpr int kHeadNum = 108 * 4;
+    constexpr int kHeadNum     = 32;
     constexpr int kHeadDim     = 128;
-    constexpr int kBatchSize   = 1;
-    constexpr int kContextLen  = 8192;
+    constexpr int kBatchSize   = 64;
+    constexpr int kContextLen  = 511;
     constexpr int kSequenceLen = kContextLen + 1;
-    constexpr int kBlockSz     = 256;
+    constexpr int kBlockSz     = 128;
     constexpr int kTestIter    = 1;
 
+    // constexpr int kHeadNum     = 3;
+    // constexpr int kHeadDim     = 4;
+    // constexpr int kContextLen  = 7;
+    // constexpr int kSequenceLen = kContextLen + 1;
+
     RNG rng{};
 
     thrust::universal_vector<half>  output(kBatchSize * kHeadNum * kHeadDim);
@@ -94,31 +128,49 @@ int main(int argc, char* argv[])
     rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
 
     if (kContextLen) {
-        rng.GenerateNormal(k_cache.data().get(), kHeadNum * kSequenceLen * kHeadDim);
-        rng.GenerateNormal(v_cache.data().get(), kHeadNum * kSequenceLen * kHeadDim);
+        rng.GenerateNormal(k_cache.data().get(), kBatchSize * kHeadNum * kSequenceLen * kHeadDim);
+        rng.GenerateNormal(v_cache.data().get(), kBatchSize * kHeadNum * kSequenceLen * kHeadDim);
 
         cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim,
                           sizeof(half) * kSequenceLen * kHeadDim,
                           0,
                           sizeof(half) * kHeadDim,
-                          kHeadNum);
+                          kBatchSize * kHeadNum);
+        if constexpr (0) {
+            for (int b = 0; b < kBatchSize; ++b) {
+                for (int h = 0; h < kHeadNum; ++h) {
+                    for (int s = 0; s < kSequenceLen; ++s) {
+                        for (int d = 0; d < kHeadDim; ++d) {
+                            std::cout << std::setw(7) << std::setprecision(4) << std::fixed
+                                      << (float)k_cache[b * kHeadNum * kSequenceLen * kHeadDim
+                                                        + h * kSequenceLen * kHeadDim + s * kHeadDim + d]
+                                      << " ";
+                        }
+                        std::cout << "\n";
+                    }
+                    std::cout << "\n";
+                }
+                std::cout << "\n";
+            }
+            std::exit(0);
+        }
 
         cudaMemset2DAsync(v_cache.data().get() + kContextLen * kHeadDim,
                           sizeof(half) * kSequenceLen * kHeadDim,
                           0,
                           sizeof(half) * kHeadDim,
-                          kHeadNum);
+                          kBatchSize * kHeadNum);
     }
 
     thrust::universal_vector<half>  k_blocks;
     thrust::universal_vector<half*> k_ptrs;
 
-    TestBlocks(k_cache, k_blocks, k_ptrs, kHeadNum, kHeadDim, kBlockSz);
+    TestBlocks(k_cache, k_blocks, k_ptrs, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
 
     thrust::universal_vector<half>  v_blocks;
     thrust::universal_vector<half*> v_ptrs;
 
-    TestBlocks(v_cache, v_blocks, v_ptrs, kHeadNum, kHeadDim, kBlockSz);
+    TestBlocks(v_cache, v_blocks, v_ptrs, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
 
     thrust::universal_vector<half>  k_cache_ref = k_cache;
     thrust::universal_vector<half>  v_cache_ref = v_cache;
@@ -200,10 +252,10 @@ int main(int argc, char* argv[])
     }
 
     if (1) {
-        ConvertBlocksToLinear(
-            (const half**)k_ptrs.data().get(), k_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
-        ConvertBlocksToLinear(
-            (const half**)v_ptrs.data().get(), v_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
+        // ConvertBlocksToLinear(
+        //     (const half**)k_ptrs.data().get(), k_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
+        // ConvertBlocksToLinear(
+        //     (const half**)v_ptrs.data().get(), v_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
     }
 
     cudaDeviceSynchronize();
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 912589df09..f985cff5a5 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -26,6 +26,7 @@
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
 #include "src/turbomind/models/llama/llama_utils.h"
+#include "src/turbomind/kernels/decoder_mha/kv_cache.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
@@ -183,7 +184,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                    stream_);
     sync_check_cuda_error();
 
-    const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
+    const size_t layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
 
     auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
     auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
@@ -195,19 +196,22 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
     // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x]
     invokeExtendKVCache(k_cache_ptrs,
                         v_cache_ptrs,
-                        layer_offset,
                         k_buf_2_,
                         v_buf_2_,
-                        batch_size,
+                        cu_block_counts,
                         input_length,
-                        max_q_len,
                         history_length,
-                        max_seq_len,
+                        batch_size,
+                        kv_cache_block_len_,
+                        layer_offset,
+                        max_q_len,
                         size_per_head_,
                         local_kv_head_num_,
-                        stream_,
                         quant_policy_,
-                        weights->past_kv_scale.data());
+                        weights->past_kv_scale.data(),
+                        stream_);
+
+    ConvertBlocksToLinear(k_cache_ptrs, k_cache_, cu_block_counts, max_seq_len, kv_cache_block_len_, int dst_max_seq_len, int head_num, int head_dim, int batch_size, cudaStream_t st)
 
     sync_check_cuda_error();
     if (use_fmha_) {
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.h b/src/turbomind/models/llama/LlamaContextAttentionLayer.h
index 235b575b8e..6d3363e8aa 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.h
@@ -58,6 +58,7 @@ class LlamaContextAttentionLayer {
         cublas_wrapper_(cublas_wrapper),
         linear_(cublas_wrapper, stream),
         allocator_(allocator),
+        kv_cache_block_len_(128), /// 
         is_free_buffer_after_forward_(is_free_buffer_after_forward),
         use_fmha_(use_fmha),
         quant_policy_(quant_policy)
@@ -98,6 +99,7 @@ class LlamaContextAttentionLayer {
     const size_t local_kv_head_num_;
     const size_t local_head_num_;
     const size_t head_n_rep_;
+    const size_t kv_cache_block_len_;
     const bool   is_free_buffer_after_forward_;
 
     const LlamaAttentionParams params_;
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 7c03a5e0e5..1c05340048 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -17,7 +17,8 @@ SequenceManager::SequenceManager(size_t      layer_num,
 {
     constexpr int kBitsPerByte = 8;
 
-    size_t block_size = layer_num * head_num * block_len * head_dim * elem_bits / kBitsPerByte * 2;
+    // [2, L, H, block_len, D]
+    size_t block_size = 2UL * layer_num * head_num * block_len * head_dim * elem_bits / kBitsPerByte;
 
     block_manager_ = std::make_unique<BlockManager>(block_size, block_count, chunk_size, allocator);
 
@@ -216,6 +217,7 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
 {
     ////////////////////////////////////////////////////////////////////////////////
     /// Schedule the assignment of blocks to sequences
+
     auto seqs = const_cast<Sequence* const*>(sequences.data());
 
     // check validity of of cached blocks (blocks of active & locked seqs are always valid)
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 0ce057b95e..331d79990c 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -93,4 +93,12 @@ class SequenceManager {
     std::vector<const Block*> released_;
 };
 
+// cu_block_cnts(seq_idx) -> block_idx_offset
+// block_idxs(block_idx_offset) -> (seq_idx, seq_offset)
+
+
+inline void func(const std::vector<int>& block_cnts, std::vector<int>& cu_block_cnts, std::vector<int>& inv_block_idxs) {
+
+}
+
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu
index ebbfa7ee26..66c1202918 100644
--- a/src/turbomind/models/llama/llama_kernels.cu
+++ b/src/turbomind/models/llama/llama_kernels.cu
@@ -200,300 +200,131 @@ template void invokeCreateCausalMasks(float* mask, const int*, const int*, int,
 template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t);
 
 template<typename T>
-__global__ void extend_key_cache(T**          k_dst,
-                                 const size_t dst_offset,
-                                 const T*     k_src,
-                                 const int    head_num,
-                                 const int    size_per_head,
-                                 const int*   query_length,
-                                 const int*   history_length,
-                                 const int    max_q_len,
-                                 const int    max_seq_len)
+__global__ void extend_kv_cache(T**          k_dst_ptrs,
+                                T**          v_dst_ptrs,
+                                const T*     k_src,
+                                const T*     v_src,
+                                const int*   cu_block_counts,
+                                const int*   query_length,
+                                const int*   history_length,
+                                const int    block_length,
+                                const size_t dst_layer_offset,
+                                const int    max_q_len,
+                                const int    head_num,
+                                const int    head_dim)
 {
-    const int     batch_id = blockIdx.y;
-    const int     head_id  = blockIdx.z;
-    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;
-
-    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
-    int       size_per_head_div_x = size_per_head / X_ELEMS;
-
-    // x dim is now handled by uint4 type
-    const auto key_src = reinterpret_cast<const uint4*>(k_src);
-    const auto key_dst = reinterpret_cast<uint4*>(k_dst[batch_id] + dst_offset);
-
-    const auto seq_len  = query_length[batch_id];
-    const auto t_offset = history_length[batch_id];
-
-    const int k_head_size_id = idx % size_per_head_div_x;
-    const int k_seq_len_id   = idx / size_per_head_div_x;
+    const int batch_id     = blockIdx.y;
+    const int query_len    = query_length[batch_id];
+    const int history_len  = history_length[batch_id];
+    const int cu_block_cnt = cu_block_counts[batch_id];
 
-    if (k_seq_len_id < seq_len) {
-        // [B, H, s, D/x] -> [H, D/x, S[t:t+s]]
-
-        const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len +  // H
-                                k_head_size_id * max_seq_len +                 // D/x
-                                t_offset + k_seq_len_id;                       // s + offset
-
-        const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
-                                head_id * size_per_head_div_x * max_q_len +              // H
-                                k_seq_len_id * size_per_head_div_x +                     // s
-                                k_head_size_id;                                          // D/x
-
-        key_dst[dst_idx] = key_src[src_idx];
-    }
-}
-
-template<typename T>
-__global__ void extend_value_cache(T**          v_dst,
-                                   const size_t dst_offset,
-                                   const T*     v_src,
-                                   const int    head_num,
-                                   const int    size_per_head,
-                                   const int*   query_length,
-                                   const int*   history_length,
-                                   const int    max_q_len,
-                                   const int    max_seq_len)
-{
-    const int     batch_id = blockIdx.y;
-    const int     head_id  = blockIdx.z;
-    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;
+    const int     head_id = blockIdx.z;
+    constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8;
 
+    const int size_per_head_div_x = head_dim / X_ELEMS;
     const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
-    int       size_per_head_div_x = size_per_head / X_ELEMS;
+    const int head_size_id        = idx % size_per_head_div_x;
+    const int seq_len_id          = idx / size_per_head_div_x;
 
-    // x dim is now handled by uint4 type
-    const auto val_src = reinterpret_cast<const uint4*>(v_src);
-    const auto val_dst = reinterpret_cast<uint4*>(v_dst[batch_id] + dst_offset);
-
-    const auto seq_len  = query_length[batch_id];
-    const auto t_offset = history_length[batch_id];
-
-    const int v_head_size_id = idx % size_per_head_div_x;
-    const int v_seq_len_id   = idx / size_per_head_div_x;
-
-    if (v_seq_len_id < seq_len) {
-        // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
-        const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len +      // H
-                                (v_seq_len_id + t_offset) * size_per_head_div_x +  // s + offset
-                                v_head_size_id;                                    // D/x
-
-        const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
-                                head_id * size_per_head_div_x * max_q_len +              // H
-                                v_seq_len_id * size_per_head_div_x +                     // s
-                                v_head_size_id;                                          // D/x
-
-        val_dst[dst_idx] = val_src[src_idx];
-    }
-}
-
-inline __device__ float2 float2div(float a, float2 b)
-{
-    float2 c;
-    c.x = b.x / a;
-    c.y = b.y / a;
-    return c;
-}
-
-inline __device__ float2 float2sub(float zp, float2 val)
-{
-    float2 ret;
-    ret.x = val.x - zp;
-    ret.y = val.y - zp;
-    return ret;
-}
-
-static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale, const float zp)
-{
-    half4 dst;
-    dst.x = __float2half(value.x * scale + zp);
-    dst.y = __float2half(value.y * scale + zp);
-    dst.z = __float2half(value.z * scale + zp);
-    dst.w = __float2half(value.w * scale + zp);
-    return dst;
-}
-
-static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w)
-{
-    uint32_t dst;
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720
-    uint32_t a;
-    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x));
-    uint32_t b;
-    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y));
-    uint32_t c;
-    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z));
-    uint32_t d;
-    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w));
-
-    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2,  0;\n" : "=r"(dst) : "r"(d), "r"(c));
-    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a));
-#else
-    char4 tmp;
-    tmp.x       = x;
-    tmp.y       = y;
-    tmp.z       = z;
-    tmp.w       = w;
-    dst         = reinterpret_cast<const uint32_t&>(tmp);
-#endif
-    return dst;
-}
-
-template<typename T>
-__global__ void extend_value_cache_int8(int8_t**     v_dst,
-                                        const size_t dst_offset,
-                                        const T*     v_src,
-                                        const int    head_num,
-                                        const int    size_per_head,
-                                        const int*   query_length,
-                                        const int*   history_length,
-                                        const int    max_q_len,
-                                        const int    max_seq_len,
-                                        const float  v_scale,
-                                        const float  v_zp)
-{
-    const int     batch_id = blockIdx.y;
-    const int     head_id  = blockIdx.z;
-    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;
-
-    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
-    int       size_per_head_div_x = size_per_head / X_ELEMS;
+    const int cache_block_index  = (seq_len_id + history_len) / block_length;
+    const int cache_block_offset = (seq_len_id + history_len) % block_length;
 
     // x dim is now handled by uint4 type
-    const auto val_src = reinterpret_cast<const uint4*>(v_src);
-    const auto val_dst = reinterpret_cast<uint2*>(v_dst[batch_id] + dst_offset);
+    const auto k_val_src = reinterpret_cast<const uint4*>(k_src);
+    const auto v_val_src = reinterpret_cast<const uint4*>(v_src);
 
-    const auto seq_len  = query_length[batch_id];
-    const auto t_offset = history_length[batch_id];
+    // const auto val_dst = reinterpret_cast<uint4*>(v_dst[batch_id] + dst_layer_offset);
+    const auto k_val_dst = (uint4*)((k_dst_ptrs + cu_block_cnt)[cache_block_index] + dst_layer_offset);
+    const auto v_val_dst = (uint4*)((v_dst_ptrs + cu_block_cnt)[cache_block_index] + dst_layer_offset);
 
-    const int v_head_size_id = idx % size_per_head_div_x;
-    const int v_seq_len_id   = idx / size_per_head_div_x;
-
-    if (v_seq_len_id < seq_len) {
+    if (seq_len_id < query_len) {
         // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
-        const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len +      // H
-                                (v_seq_len_id + t_offset) * size_per_head_div_x +  // s + offset
-                                v_head_size_id;                                    // D/x
+        const int64_t dst_idx = head_id * size_per_head_div_x * block_length +  // H
+                                cache_block_offset * size_per_head_div_x +      // s + offset
+                                head_size_id;                                   // D/x
 
         const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
                                 head_id * size_per_head_div_x * max_q_len +              // H
-                                v_seq_len_id * size_per_head_div_x +                     // s
-                                v_head_size_id;                                          // D/x
+                                seq_len_id * size_per_head_div_x +                       // s
+                                head_size_id;                                            // D/x
 
-        // scale to int8 and write
-        const auto value  = val_src[src_idx];
-        auto       to_ptr = reinterpret_cast<uint32_t*>(val_dst + dst_idx);
-
-        float2 float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.x)));
-        float2 float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.y)));
-        to_ptr[0]       = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
-
-        float2_0  = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.z)));
-        float2_1  = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.w)));
-        to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
+        k_val_dst[dst_idx] = k_val_src[src_idx];
+        v_val_dst[dst_idx] = v_val_src[src_idx];
     }
 }
 
 template<typename T>
-void invokeExtendKVCache(T**          k_dst,
-                         T**          v_dst,
-                         size_t       dst_offset,
+void invokeExtendKVCache(T**          k_dst_ptrs,
+                         T**          v_dst_ptrs,
                          const T*     k_src,
                          const T*     v_src,
-                         int          local_batch_size,
+                         const int*   cu_block_counts,
                          const int*   query_length,
-                         int          max_q_len,
                          const int*   history_length,
-                         int          max_seq_len,
-                         int          size_per_head,
-                         int          local_head_num,
-                         cudaStream_t stream,
+                         int          batch_size,
+                         int          block_length,
+                         size_t       dst_layer_offset,
+                         int          max_q_len,
+                         int          head_dim,
+                         int          head_num,
                          int          quant,
-                         const float* kv_scale)
+                         const float* kv_scale,
+                         cudaStream_t stream)
 {
     constexpr int block_sz = 128;
     constexpr int x        = (sizeof(T) == 4) ? 4 : 8;
 
-    dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num);
-
-    if (quant & QuantPolicy::kCacheKVInt8) {
-        extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(k_dst),
-                                                               dst_offset,
-                                                               k_src,
-                                                               local_head_num,
-                                                               size_per_head,
-                                                               query_length,
-                                                               history_length,
-                                                               max_q_len,
-                                                               max_seq_len,
-                                                               kv_scale[0],
-                                                               kv_scale[1]);
-
-        extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(v_dst),
-                                                               dst_offset,
-                                                               v_src,
-                                                               local_head_num,
-                                                               size_per_head,
-                                                               query_length,
-                                                               history_length,
-                                                               max_q_len,
-                                                               max_seq_len,
-                                                               kv_scale[2],
-                                                               kv_scale[3]);
-    }
-    else {
-        extend_value_cache<<<grid, block_sz, 0, stream>>>(k_dst,
-                                                          dst_offset,
-                                                          k_src,
-                                                          local_head_num,
-                                                          size_per_head,
-                                                          query_length,
-                                                          history_length,
-                                                          max_q_len,
-                                                          max_seq_len);
-
-        extend_value_cache<<<grid, block_sz, 0, stream>>>(v_dst,
-                                                          dst_offset,
-                                                          v_src,
-                                                          local_head_num,
-                                                          size_per_head,
-                                                          query_length,
-                                                          history_length,
-                                                          max_q_len,
-                                                          max_seq_len);
-    }
+    dim3 grid((max_q_len * head_dim / x + block_sz - 1) / block_sz, batch_size, head_num);
+
+    FT_CHECK(quant == 0);
+
+    extend_kv_cache<<<grid, block_sz, 0, stream>>>(k_dst_ptrs,
+                                                   v_dst_ptrs,
+                                                   k_src,
+                                                   v_src,
+                                                   cu_block_counts,
+                                                   query_length,
+                                                   history_length,
+                                                   block_length,
+                                                   dst_layer_offset,
+                                                   max_q_len,
+                                                   head_num,
+                                                   head_dim);
 }
 
-template void invokeExtendKVCache(float**,
-                                  float**,
-                                  size_t,
-                                  const float*,
-                                  const float*,
-                                  int,
-                                  const int*,
-                                  int,
-                                  const int*,
-                                  int,
-                                  int,
-                                  int,
-                                  cudaStream_t stream,
-                                  int,
-                                  const float*);
-
-template void invokeExtendKVCache(half**,
-                                  half**,
-                                  size_t,
-                                  const half*,
-                                  const half*,
-                                  int,
-                                  const int*,
-                                  int,
-                                  const int*,
-                                  int,
-                                  int,
-                                  int,
-                                  cudaStream_t stream,
-                                  int,
-                                  const float*);
+template void invokeExtendKVCache(float**      k_dst_ptrs,
+                                  float**      v_dst_ptrs,
+                                  const float* k_src,
+                                  const float* v_src,
+                                  const int*   cu_block_counts,
+                                  const int*   query_length,
+                                  const int*   history_length,
+                                  int          batch_size,
+                                  int          block_length,
+                                  size_t       dst_layer_offset,
+                                  int          max_q_len,
+                                  int          head_dim,
+                                  int          head_num,
+                                  int          quant,
+                                  const float* kv_scale,
+                                  cudaStream_t stream);
+
+template void invokeExtendKVCache(half**       k_dst_ptrs,
+                                  half**       v_dst_ptrs,
+                                  const half*  k_src,
+                                  const half*  v_src,
+                                  const int*   cu_block_counts,
+                                  const int*   query_length,
+                                  const int*   history_length,
+                                  int          batch_size,
+                                  int          block_length,
+                                  size_t       dst_layer_offset,
+                                  int          max_q_len,
+                                  int          head_dim,
+                                  int          head_num,
+                                  int          quant,
+                                  const float* kv_scale,
+                                  cudaStream_t stream);
 
 template<typename T>
 __global__ void transpose_value_cache(T*           v_dst,  //
@@ -581,8 +412,8 @@ __global__ void transpose_value_cache_int8(T*             v_dst,  //
         const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx);
         auto       to_ptr   = reinterpret_cast<half4*>(val_dst + dst_idx);
 
-        to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale, v_zp);
-        to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale, v_zp);
+        // to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale, v_zp);
+        // to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale, v_zp);
     }
 }
 
diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h
index 6bd4644f0d..96385e5763 100644
--- a/src/turbomind/models/llama/llama_kernels.h
+++ b/src/turbomind/models/llama/llama_kernels.h
@@ -34,21 +34,22 @@ void invokeCreateCausalMasks(
     T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream);
 
 template<typename T>
-void invokeExtendKVCache(T**          k_dst,
-                         T**          v_dst,
-                         size_t       layer_offset,
+void invokeExtendKVCache(T**          k_dst_ptrs,
+                         T**          v_dst_ptrs,
                          const T*     k_src,
                          const T*     v_src,
-                         int          batch_size,
+                         const int*   cu_block_counts,
                          const int*   query_length,
-                         int          max_q_len,
                          const int*   history_length,
-                         int          max_seq_len,
-                         int          size_per_head,
-                         int          local_head_num,
-                         cudaStream_t stream,
+                         int          batch_size,
+                         int          block_length,
+                         size_t       dst_layer_offset,
+                         int          max_q_len,
+                         int          head_dim,
+                         int          head_num,
                          int          quant,
-                         const float* kv_scale);
+                         const float* kv_scale,
+                         cudaStream_t stream);
 
 template<typename T>
 void invokeTransposeKVCache(T*           key_cache_trans,

From 3ed61766009e123d38671ce0f4ba105df065d947 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 25 Sep 2023 14:10:49 +0000
Subject: [PATCH 05/56] update

---
 src/turbomind/kernels/decoder_mha/kv_cache.cu | 264 ++++++++-------
 src/turbomind/kernels/decoder_mha/kv_cache.h  |  32 +-
 src/turbomind/models/llama/BlockManager.cc    |  49 ++-
 src/turbomind/models/llama/BlockManager.h     |  20 +-
 src/turbomind/models/llama/CMakeLists.txt     |   6 +
 src/turbomind/models/llama/LlamaBatch.cc      | 300 ++++++++++--------
 src/turbomind/models/llama/LlamaBatch.h       |  26 +-
 .../llama/LlamaContextAttentionLayer.cc       |  30 +-
 src/turbomind/models/llama/LlamaV2.cc         |   5 +
 src/turbomind/models/llama/LlamaV2.h          |   7 +-
 src/turbomind/models/llama/SequenceManager.cc |  97 ++++--
 src/turbomind/models/llama/SequenceManager.h  |  31 +-
 .../models/llama/test_cache_manager.cc        |  94 ++++++
 13 files changed, 624 insertions(+), 337 deletions(-)
 create mode 100644 src/turbomind/models/llama/test_cache_manager.cc

diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.cu b/src/turbomind/kernels/decoder_mha/kv_cache.cu
index 0fdad8e063..7ebe25271b 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.cu
@@ -7,14 +7,6 @@ namespace turbomind {
 
 // [S/x, H, x, D] <-> [S/y, H, y, D]
 
-// [S, H, 1, D] <-> [1, H, S, D]
-
-namespace {
-
-struct ThreadMap {};
-
-}  // namespace
-
 template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
 __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptrs,
                                             T** __restrict__ dst_block_ptrs,
@@ -56,10 +48,10 @@ __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptr
     *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
 }
 
-static inline size_t get_helper_smem_size(int batch_size)
-{
-    return (sizeof(void*) + sizeof(int)) * batch_size;
-}
+// static inline size_t get_helper_smem_size(int batch_size)
+// {
+//     return (sizeof(void*) + sizeof(int)) * batch_size;
+// }
 
 template<typename T>
 __global__ void LinearToBlocksKernel(const T*   src,
@@ -94,85 +86,6 @@ __global__ void LinearToBlocksKernel(const T*   src,
                      head_dim);
 }
 
-template<typename T, typename HeadDim>
-__global__ void BlocksToLinearKernel(const T**  src_block_ptrs,
-                                     T*         dst,
-                                     const int* src_cu_block_cnts,
-                                     const int* seq_lens,
-                                     int        src_block_len,
-                                     int        dst_block_len,
-                                     int        head_num,
-                                     HeadDim    head_dim,
-                                     int        batch_size)
-{
-    extern __shared__ void* smem[];
-
-    T**  dst_block_ptrs    = (T**)smem;
-    int* dst_cu_block_cnts = (int*)(dst_block_ptrs + batch_size);
-
-    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
-        dst_cu_block_cnts[i] = i;
-        dst_block_ptrs[i]    = dst + blockIdx.z * head_num * dst_block_len * head_dim;
-    }
-
-    __syncthreads();
-
-    ConvertBlockSize(src_block_ptrs,
-                     dst_block_ptrs,
-                     src_cu_block_cnts,
-                     dst_cu_block_cnts,
-                     seq_lens,
-                     src_block_len,
-                     dst_block_len,
-                     head_dim);
-}
-
-template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
-__global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
-                                            const T**   src_v_block_ptrs,
-                                            T*          dst_k,
-                                            T*          dst_v,
-                                            const int*  src_cu_block_cnts,
-                                            const int*  seq_lens,
-                                            SrcBlockLen src_block_len,
-                                            DstBlockLen dst_block_len,
-                                            int         head_num,
-                                            HeadDim     head_dim,
-                                            int         batch_size)
-{
-    extern __shared__ void* smem[];
-
-    T**  dst_k_block_ptrs  = (T**)smem;
-    T**  dst_v_block_ptrs  = dst_k_block_ptrs + batch_size;
-    int* dst_cu_block_cnts = (int*)(dst_v_block_ptrs + batch_size);
-
-    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
-        dst_cu_block_cnts[i] = i;
-        dst_k_block_ptrs[i]  = dst_k + blockIdx.z * head_num * dst_block_len * head_dim;
-        dst_v_block_ptrs[i]  = dst_v + blockIdx.z * head_num * dst_block_len * head_dim;
-    }
-
-    __syncthreads();
-
-    ConvertBlockSize(src_k_block_ptrs,
-                     dst_k_block_ptrs,
-                     src_cu_block_cnts,
-                     dst_cu_block_cnts,
-                     seq_lens,
-                     src_block_len,
-                     dst_block_len,
-                     head_dim);
-
-    ConvertBlockSize(src_v_block_ptrs,
-                     dst_v_block_ptrs,
-                     src_cu_block_cnts,
-                     dst_cu_block_cnts,
-                     seq_lens,
-                     src_block_len,
-                     dst_block_len,
-                     head_dim);
-}
-
 template<typename T>
 void ConvertLinearToBlocks(const T*     src,
                            T**          dst_block_ptrs,
@@ -190,7 +103,7 @@ void ConvertLinearToBlocks(const T*     src,
     constexpr int threads = 128;
     const dim3    blocks((src_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
 
-    const auto smem_sz = get_helper_smem_size(batch_size);
+    const auto smem_sz = (sizeof(void*) + sizeof(int)) * batch_size;
 
     auto fn = [&](auto head_dim) {
         LinearToBlocksKernel<<<blocks, threads, smem_sz, st>>>(src,
@@ -224,6 +137,39 @@ template void ConvertLinearToBlocks(const half*  src,
                                     int          batch_size,
                                     cudaStream_t st);
 
+template<typename T, typename HeadDim>
+__global__ void BlocksToLinearKernel(const T**  src_block_ptrs,
+                                     T*         dst,
+                                     const int* src_cu_block_cnts,
+                                     const int* seq_lens,
+                                     int        src_block_len,
+                                     int        dst_block_len,
+                                     int        head_num,
+                                     HeadDim    head_dim,
+                                     int        batch_size)
+{
+    extern __shared__ void* smem[];
+
+    T**  dst_block_ptrs    = (T**)smem;
+    int* dst_cu_block_cnts = (int*)(dst_block_ptrs + batch_size);
+
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+        dst_block_ptrs[i]    = dst + blockIdx.z * head_num * dst_block_len * head_dim;
+    }
+
+    __syncthreads();
+
+    ConvertBlockSize(src_block_ptrs,
+                     dst_block_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+}
+
 template<typename T>
 void ConvertBlocksToLinear(const T**    src_block_ptrs,
                            T*           dst,
@@ -241,7 +187,7 @@ void ConvertBlocksToLinear(const T**    src_block_ptrs,
     constexpr int threads = 256;
     const dim3    blocks((dst_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
 
-    const auto smem_sz = get_helper_smem_size(batch_size);
+    const auto smem_sz = (sizeof(void*) + sizeof(int)) * batch_size;
 
     auto fn = [&](auto head_dim) {
         BlocksToLinearKernel<<<blocks, threads, smem_sz, st>>>(src_block_ptrs,
@@ -275,38 +221,114 @@ template void ConvertBlocksToLinear(const half** src_block_ptrs,
                                     int          batch_size,
                                     cudaStream_t st);
 
-// template<typename T>
-// void ConvertBlocksToBlocks(const T**    src_block_ptrs,
-//                            T**          dst_block_ptrs,
-//                            int          src_block_len,
-//                            int          dst_block_len,
-//                            int          heads,
-//                            int          dims,
-//                            int          seq_len,
-//                            cudaStream_t st)
-// {
-//     constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
+__global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
+                                            const T**   src_v_block_ptrs,
+                                            T**         dst_k_ptrs,
+                                            T**         dst_v_ptrs,
+                                            const int*  src_cu_block_cnts,
+                                            const int*  seq_lens,
+                                            SrcBlockLen src_block_len,
+                                            DstBlockLen dst_block_len,
+                                            int         head_num,
+                                            HeadDim     head_dim,
+                                            int         batch_size)
+{
+    extern __shared__ int dst_cu_block_cnts[];
 
-//     int threads = 512;
-//     int blocks  = std::min(512, (heads * seq_len * dims / kVecSize + threads - 1) / threads);
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+    }
 
-//     BlocksToBlocksKernel<<<blocks, threads, 0, st>>>(
-//         src_block_ptrs, dst_block_ptrs, src_block_len, dst_block_len, heads, dims, seq_len);
-// }
+    __syncthreads();
+
+    ConvertBlockSize(src_k_block_ptrs,
+                     dst_k_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+
+    ConvertBlockSize(src_v_block_ptrs,
+                     dst_v_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim);
+}
+
+template<typename T>
+void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
+                                  const T**    src_v_block_ptrs,
+                                  T**          dst_k_ptrs,
+                                  T**          dst_v_ptrs,
+                                  const int*   src_cu_block_cnts,
+                                  const int*   seq_lens,
+                                  int          src_block_len,
+                                  int          dst_block_len,
+                                  int          head_num,
+                                  int          head_dim,
+                                  int          batch_size,
+                                  cudaStream_t st)
+{
+    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+
+    constexpr int threads = 256;
+    const dim3    blocks((dst_block_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
+
+    const auto smem_sz = sizeof(int) * batch_size;
 
-// template void ConvertLinearToBlocks(
-//     const half* src, half** dst_block_ptrs, int dst_block_len, int heads, int dims, int seq_len, cudaStream_t st);
+    auto fn = [&](auto head_dim) {
+        KvCacheBlocksToLinearKernel<<<blocks, threads, smem_sz, st>>>(src_k_block_ptrs,
+                                                                      src_v_block_ptrs,
+                                                                      dst_k_ptrs,
+                                                                      dst_v_ptrs,
+                                                                      src_cu_block_cnts,
+                                                                      seq_lens,
+                                                                      src_block_len,
+                                                                      dst_block_len,
+                                                                      head_num,
+                                                                      head_dim,
+                                                                      batch_size);
+    };
 
-// template void ConvertBlocksToLinear(
-//     const half** src_block_ptrs, half* dst, int src_block_len, int heads, int dims, int seq_len, cudaStream_t st);
+    switch (head_dim) {
+        case 128:
+            fn(std::integral_constant<int, 128>{});
+            break;
+        default:
+            fn(head_dim);
+    }
+}
 
-// template void ConvertBlocksToBlocks(const half** src_block_ptrs,
-//                                     half**       dst_block_ptrs,
-//                                     int          src_block_len,
-//                                     int          dst_block_len,
-//                                     int          heads,
-//                                     int          dims,
-//                                     int          seq_len,
-//                                     cudaStream_t st);
+template void ConvertKvCacheBlocksToLinear(const half** src_k_block_ptrs,
+                                           const half** src_v_block_ptrs,
+                                           half**       dst_k_ptrs,
+                                           half**       dst_v_ptrs,
+                                           const int*   src_cu_block_cnts,
+                                           const int*   seq_lens,
+                                           int          src_block_len,
+                                           int          dst_block_len,
+                                           int          head_num,
+                                           int          head_dim,
+                                           int          batch_size,
+                                           cudaStream_t st);
+
+template void ConvertKvCacheBlocksToLinear(const float** src_k_block_ptrs,
+                                           const float** src_v_block_ptrs,
+                                           float**       dst_k_ptrs,
+                                           float**       dst_v_ptrs,
+                                           const int*    src_cu_block_cnts,
+                                           const int*    seq_lens,
+                                           int           src_block_len,
+                                           int           dst_block_len,
+                                           int           head_num,
+                                           int           head_dim,
+                                           int           batch_size,
+                                           cudaStream_t  st);
 
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.h b/src/turbomind/kernels/decoder_mha/kv_cache.h
index 6b29c53ae7..305971479b 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.h
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.h
@@ -4,24 +4,6 @@
 
 namespace turbomind {
 
-// template<typename T>
-// void ConvertLinearToBlocks(
-//     const T* src, T** dst_block_ptrs, int dst_block_size, int heads, int dims, int seq_len, cudaStream_t st);
-
-// template<typename T>
-// void ConvertBlocksToLinear(
-//     const T** src_block_ptrs, T* dst, int src_block_size, int heads, int dims, int seq_len, cudaStream_t st);
-
-// template<typename T>
-// void ConvertBlocksToBlocks(const T**    src_block_ptrs,
-//                            T**          dst_block_ptrs,
-//                            int          src_block_size,
-//                            int          dst_block_size,
-//                            int          heads,
-//                            int          dims,
-//                            int          seq_len,
-//                            cudaStream_t st);
-
 template<typename T>
 void ConvertLinearToBlocks(const T*     src,
                            T**          dst_block_ptrs,
@@ -46,4 +28,18 @@ void ConvertBlocksToLinear(const T**    src_block_ptrs,
                            int          batch_size,
                            cudaStream_t st);
 
+template<typename T>
+void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
+                                  const T**    src_v_block_ptrs,
+                                  T**          dst_k_ptrs,
+                                  T**          dst_v_ptrs,
+                                  const int*   src_cu_block_cnts,
+                                  const int*   seq_lens,
+                                  int          src_block_len,
+                                  int          dst_block_len,  // max{seq_lens}
+                                  int          head_num,
+                                  int          head_dim,
+                                  int          batch_size,
+                                  cudaStream_t st);
+
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index e46c0fc073..e87e868509 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -1,5 +1,7 @@
 #include "src/turbomind/models/llama/BlockManager.h"
-// #include "src/turbomind/models/llama/utility.h"
+#include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/logger.h"
 #include <algorithm>
 #include <iterator>
 #include <stdexcept>
@@ -22,6 +24,13 @@ BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size
     else if (chunk_size < 0) {
         chunk_size_ = max_block_count_;
     }
+    else {
+        chunk_size_ = chunk_size;
+    }
+
+    TM_LOG_INFO("[BlockManager] block_size = %d", block_size_);
+    TM_LOG_INFO("[BlockManager] max_block_count = %d", max_block_count_);
+    TM_LOG_INFO("[BlockManager] chunk_size = %d", chunk_size_);
 
     blocks_.reserve(max_block_count_);
 
@@ -97,15 +106,13 @@ std::vector<const Block*> BlockManager::Allocate(int count)
 
     std::vector<const Block*> ret;
 
-    std::vector<int> idxs;
-    idxs.reserve(count);
+    std::vector<int> idxs(count);
 
     for (int i = 0; i < count; ++i) {
         int idx     = free_ids_[i];
         idxs[i]     = idx;
         auto& block = blocks_[idx];
-        FT_CHECK(block.ref_count == 0);
-        FT_CHECK(block.timestamp == 0);
+        FT_CHECK(is_free(block));
         block.ref_count = 1;
         block.unique_id = unique_id_++;
         ret.push_back(&block);
@@ -113,6 +120,10 @@ std::vector<const Block*> BlockManager::Allocate(int count)
 
     Move(free_ids_, idxs, active_ids_);
 
+    Touch(ret);
+
+    dbg("[Allocate]", free_ids_, active_ids_);
+
     return ret;
 }
 
@@ -130,18 +141,22 @@ void BlockManager::Evict(int count)
 
     // set as free
     for (const auto& idx : idxs) {
+        FT_CHECK(is_cached(blocks_[idx]));
         blocks_[idx].timestamp = 0;
     }
 
     Move(cached_ids_, idxs, free_ids_);
+
+    dbg("[Evict]", free_ids_);
 }
 
-void BlockManager::Release(const std::vector<const Block*>& bs)
+int BlockManager::Release(const std::vector<const Block*>& bs)
 {
     std::vector<int> cached;
 
     for (const auto& p : bs) {
         auto& block = blocks_[p->id];
+        FT_CHECK(is_active(block));
         if (--block.ref_count == 0) {
             cached.push_back(block.id);
         }
@@ -150,6 +165,10 @@ void BlockManager::Release(const std::vector<const Block*>& bs)
     std::sort(cached.begin(), cached.end());
 
     Move(active_ids_, cached, cached_ids_);
+
+    dbg("[Release]", cached_ids_);
+
+    return cached.size();
 }
 
 void BlockManager::Retain(const std::vector<const Block*>& bs)
@@ -174,7 +193,23 @@ Snapshot BlockManager::TakeSnapshot()
     for (const auto& idx : active_ids_) {
         ref_count[idx] = blocks_[idx].ref_count;
     }
-    return {(int)active_ids_.size(), (int)cached_ids_.size(), (int)free_ids_.size(), std::move(ref_count)};
+    return {active_count(), cached_count(), free_count(), std::move(ref_count)};
+}
+
+std::ostream& operator<<(std::ostream& os, const BlockManager& manager)
+{
+    os << "block_size: " << manager.block_size_ << "\n";
+    os << "max_block_count: " << manager.max_block_count_ << "\n";
+    os << "chunk_size: " << manager.chunk_size_ << "\n";
+    os << "allocator: " << manager.allocator_ << "\n";
+    os << "chunks: " << manager.chunks_.size() << "\n";
+    os << "active_ids: " << manager.active_ids_.size() << "\n";
+    os << "cached_ids: " << manager.cached_ids_.size() << "\n";
+    os << "free_ids: " << manager.free_ids_.size() << "\n";
+    os << "blocks: " << manager.blocks_.size() << "\n";
+    os << "unique_id: " << manager.unique_id_ << "\n";
+    os << "timestamp: " << manager.timestamp_ << "\n";
+    return os;
 }
 
 std::ostream& operator<<(std::ostream& os, const Block& block)
diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h
index 362b4cfb16..a6f7dc50a5 100644
--- a/src/turbomind/models/llama/BlockManager.h
+++ b/src/turbomind/models/llama/BlockManager.h
@@ -59,10 +59,10 @@ class BlockManager {
     ~BlockManager();
 
     // free -> active
-    std::vector<const Block*> Allocate(int count);
+    [[nodiscard]] std::vector<const Block*> Allocate(int count);
 
     // active -> cached
-    void Release(const std::vector<const Block*>& bs);
+    [[maybe_unused]] int Release(const std::vector<const Block*>& bs);
 
     // cached -> free
     void Evict(int count);
@@ -80,18 +80,23 @@ class BlockManager {
         return max_block_count_;
     }
 
-    int active_count() const noexcept {
+    int active_count() const noexcept
+    {
         return active_ids_.size();
     }
 
-    int cached_count() const noexcept {
+    int cached_count() const noexcept
+    {
         return cached_ids_.size();
     }
 
-    int free_count() const noexcept {
-        return free_ids_.size();
+    int free_count() const noexcept
+    {
+        return (max_block_count_ - blocks_.size()) + free_ids_.size();
     }
 
+    friend std::ostream& operator<<(std::ostream& os, const BlockManager&);
+
 private:
     static size_t GetBlockCount(size_t block_size, double ratio);
 
@@ -115,7 +120,8 @@ class BlockManager {
 
     std::vector<Block> blocks_;  // < 100k
 
-    uint64_t unique_id_{1UL << 63};
+    // uint64_t unique_id_{1UL << 63};
+    uint64_t unique_id_{0};
     uint64_t timestamp_{1};
 };
 
diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt
index 1da6fcec16..0d3ea9ba7a 100644
--- a/src/turbomind/models/llama/CMakeLists.txt
+++ b/src/turbomind/models/llama/CMakeLists.txt
@@ -30,6 +30,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
         DynamicDecodeLayer
         activation_kernels
         decoder_masked_multihead_attention
+        decoder_multihead_attention
         bert_preprocess_kernels
         decoding_kernels
         unfused_attention_kernels
@@ -50,3 +51,8 @@ endif()
 
 add_executable(llama_gemm llama_gemm.cc)
 target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
+
+find_package(Catch2 3 REQUIRED)
+
+add_executable(test_cache_manager test_cache_manager.cc)
+target_link_libraries(test_cache_manager PRIVATE Llama Catch2::Catch2WithMain)
\ No newline at end of file
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index b105be6d56..77b9e26f65 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -115,8 +115,9 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
 }
 
 template<typename T>
-void LlamaBatch<T>::ProcessStopRequests(const Requests& requests)
+auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal>
 {
+    std::vector<Signal> signals;
     for (const auto& r : requests) {
         int ec = Request::kFail;
         // find matching active sequence
@@ -141,10 +142,9 @@ void LlamaBatch<T>::ProcessStopRequests(const Requests& requests)
             Clear(sequence_length.getPtr<int>(), 1);
             check_cuda_error(cudaStreamSynchronize(stream_));
         }
-        if (rank_ == 0) {
-            r->signal.set_value(ec);
-        }
+        signals.push_back([=] { r->signal.set_value(ec); });
     }
+    return signals;
 }
 
 template<typename T>
@@ -270,17 +270,14 @@ bool LlamaBatch<T>::Initialize()
     add(state_);
     add(incoming_);
 
-    bool modified = sequence_manager_->Materialize(sequences, context_lengths, priorities, llama_->step_length_);
+    auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, llama_->step_length_);
 
-    // no swap-in/swap-out & no holes in the buffers & no new requests -> nothing changed
-    if (!modified && !holes && !incoming_->size) {
-        return false;
-    }
+    bool exchange = outcome.swap_in + outcome.swap_out > 0;
 
     std::vector<int> idxs(sequences.size());
     std::iota(idxs.begin(), idxs.end(), 0);
 
-    if (modified) {
+    if (exchange) {
         // put active ones first
         auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
             return sequences[idx]->status == Sequence::kActive;  // present status
@@ -301,59 +298,63 @@ bool LlamaBatch<T>::Initialize()
         }
     }
 
-    // Copy sequence states to the back state buffer
-    back_->size = back_->active_size = 0;
-    for (const auto& i : idxs) {
-        auto& s = *sequences[i];
-        if (modified) {
-            // backup random states from dynamic decode layers for swap-outs
-            if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
-                SaveRandomState(*coords[i].first, coords[i].second);
+    if (exchange || holes) {
+        // Copy sequence states to the back state buffer
+        back_->size = back_->active_size = 0;
+        for (const auto& i : idxs) {
+            auto& s = *sequences[i];
+            if (exchange) {
+                // backup random states from dynamic decode layers for swap-outs
+                if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
+                    SaveRandomState(*coords[i].first, coords[i].second);
+                }
+                // restore random states to dynamic decode layers for swap-ins
+                if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
+                    LoadRandomState(*coords[i].first, coords[i].second);
+                }
             }
-            // restore random states to dynamic decode layers for swap-ins
-            if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
-                LoadRandomState(*coords[i].first, coords[i].second);
+            if (s.status == Sequence::kActive) {
+                ++back_->active_size;
             }
+            CopyState(coords[i], {back_, back_->size++});
         }
-        if (s.status == Sequence::kActive) {
-            ++back_->active_size;
-        }
-        CopyState(coords[i], {back_, back_->size++});
+        // Swap the buffers
+        std::swap(state_, back_);
     }
-    // Swap the buffers
-    std::swap(state_, back_);
 
     const int batch_size = state_->active_size;
 
-    // Prepare intermediate buffers
-    h_cu_block_counts_[0] = 0;
+    if (exchange || outcome.allocation) {
+        // Prepare intermediate buffers
+        h_cu_block_counts_[0] = 0;
 
-    auto k_ptrs = h_k_block_ptrs_;
-    auto v_ptrs = h_v_block_ptrs_;
+        auto k_ptrs = h_k_block_ptrs_;
+        auto v_ptrs = h_v_block_ptrs_;
 
-    for (int i = 0; i < batch_size; ++i) {
-        const auto& seq = *state_->sequences[i];
+        for (int i = 0; i < batch_size; ++i) {
+            const auto& seq = *state_->sequences[i];
 
-        // cumulative num of blocks
-        h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
+            // cumulative num of blocks
+            h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
 
-        k_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), k_ptrs, [&](const Block* p) {
-            return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
-        });
-        v_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), v_ptrs, [&](auto p) {
-            return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
-        });
-    }
+            k_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), k_ptrs, [&](const Block* p) {
+                return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
+            });
+            v_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), v_ptrs, [&](auto p) {
+                return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
+            });
+        }
 
-    Copy(state_->h_context_length, batch_size, context_length_buf_);
+        Copy(state_->h_context_length, batch_size, context_length_buf_);
 
-    Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
-    Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
-    Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
+        Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
+        Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
+        Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
+    }
 
     // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
     // generation & sampling need to be re-initialized for correctness
-    return modified || active_holes;
+    return exchange || active_holes;
 }
 
 template<typename T>
@@ -397,9 +398,11 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     const size_t batchxbeam = batch_size;
 
-    const size_t hidden_units    = llama_->hidden_units_;
-    const size_t vocab_size      = llama_->vocab_size_padded_;
-    const size_t max_block_count = sequence_manager_->max_block_count();
+    const size_t hidden_units      = llama_->hidden_units_;
+    const size_t vocab_size        = llama_->vocab_size_padded_;
+    const size_t head_dim          = llama_->size_per_head_;
+    const size_t local_kv_head_num = llama_->local_kv_head_num_;
+    const size_t max_block_count   = sequence_manager_->max_block_count();
 
     context_decoder_input_buf_ =
         (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
@@ -408,6 +411,14 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
     context_decoder_ids_buf_ =
         (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);
 
+    tmp_k_cache_buf_ = (T*)allocator_->reMalloc(
+        tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
+    tmp_v_cache_buf_ = (T*)allocator_->reMalloc(
+        tmp_v_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
+
+    tmp_k_ptrs_ = (void**)allocator_->reMalloc(tmp_k_ptrs_, sizeof(void*) * batch_size, false);
+    tmp_v_ptrs_ = (void**)allocator_->reMalloc(tmp_v_ptrs_, sizeof(void*) * batch_size, false);
+
     decoder_input_buf_  = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false);
     decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false);
 
@@ -474,6 +485,9 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
         h_history_length_buf_ =
             (int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true);
 
+        h_tmp_k_ptrs_ = (void**)allocator_->reMalloc(h_tmp_k_ptrs_, sizeof(void*) * max_batch_size, false, true);
+        h_tmp_v_ptrs_ = (void**)allocator_->reMalloc(h_tmp_v_ptrs_, sizeof(void*) * max_batch_size, false, true);
+
         h_cu_block_counts_ =
             (int*)allocator_->reMalloc(h_cu_block_counts_, sizeof(int) * (max_batch_size + 1), false, true);
         h_k_block_ptrs_ =
@@ -503,6 +517,11 @@ void LlamaBatch<T>::FreeBuffer()
         allocator_->free((void**)&context_decoder_output_buf_);
         allocator_->free((void**)&context_decoder_ids_buf_);
 
+        allocator_->free((void**)&tmp_k_cache_buf_);
+        allocator_->free((void**)&tmp_v_cache_buf_);
+        allocator_->free((void**)&tmp_k_ptrs_);
+        allocator_->free((void**)&tmp_v_ptrs_);
+
         allocator_->free((void**)&decoder_input_buf_);
         allocator_->free((void**)&decoder_output_buf_);
 
@@ -542,6 +561,8 @@ void LlamaBatch<T>::FreeBuffer()
             allocator_->free((void**)&s.h_finished, true);
             allocator_->free((void**)&s.output_ids);
         }
+        allocator_->free((void**)&h_tmp_k_ptrs_, true);
+        allocator_->free((void**)&h_tmp_v_ptrs_, true);
         allocator_->free((void**)&h_cu_block_counts_, true);
         allocator_->free((void**)&h_k_block_ptrs_, true);
         allocator_->free((void**)&h_v_block_ptrs_, true);
@@ -590,7 +611,7 @@ LlamaBatch<T>::LlamaBatch(int                              max_batch_size,
 template<typename T>
 void LlamaBatch<T>::InitializeSampling()
 {
-    const int batch_size = state_->size;
+    const int batch_size = state_->active_size;
     TensorMap inputs;
     for (const auto& param : sampling_params_) {
         // find an exemplar that matches the param name
@@ -826,86 +847,93 @@ void LlamaBatch<T>::ContextDecode()
     if (rank_ == 0) {
         TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
     }
+    // subtract input/context len by 1 to skip last input token (will process with decoder later)
     invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_);
     invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
 
-    auto get_input_len   = [this](int index) { return h_input_length_buf_[index] - 1; };
-    auto get_context_len = [this](int index) { return state_->h_context_length[index] - 1; };
-
-    std::vector<int> decode_indices{base};
-    std::vector<int> decode_lengths{get_input_len(base)};
-
-    auto token_num       = get_input_len(base);
-    auto max_input_len   = get_input_len(base);
-    auto max_context_len = get_context_len(base);
-    auto offset          = base;
-    for (int i = offset + 1; i <= batch_size; ++i) {
-        if (i == batch_size || token_num + state_->h_context_length[i] > max_context_token_num_) {
-            const int context_decode_batch_size = i - offset;
-            if (rank_ == 0) {
-                TM_LOG_INFO(
-                    "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
-                    base,
-                    context_decode_batch_size,
-                    token_num,
-                    max_input_len,
-                    max_context_len);
-            }
-            // construct context_decoder_ids w/o padding
-            // aaaa____
-            // bb______ -> aaaabbcccccccc
-            // cccccccc
-            auto context_decoder_ids = context_decoder_ids_buf_;
-            for (int j = offset; j < i; ++j) {
-                context_decoder_ids = Copy(input_ids_buf_ + j * session_len_, get_input_len(j), context_decoder_ids);
-            }
-            llama_->contextDecode(nullptr,
-                                  k_block_ptrs_,
-                                  v_block_ptrs_,
-                                  context_decoder_input_buf_,
-                                  context_decoder_output_buf_,
-                                  context_decoder_ids_buf_,
-                                  input_length_buf_ + offset,
-                                  history_length_buf_ + offset,
-                                  context_length_buf_ + offset,
-                                  cu_block_counts_ + offset,
-                                  token_num,
-                                  max_input_len,
-                                  max_context_len,
-                                  session_len_,
-                                  context_decode_batch_size);
-
-            // compute logits of inputs if requested
-            OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
-
-            if (i < batch_size) {
-                // initialize next sub-batch
-                token_num       = get_input_len(i);
-                max_input_len   = get_input_len(i);
-                max_context_len = get_context_len(i);
-                offset          = i;
-
-                decode_indices = {i};
-                decode_lengths = {get_input_len(i)};
-            }
+    // find sub-batch offsets
+    std::vector<int> offsets{base};
+    int              accum_input_count   = 0;
+    int              accum_context_count = 0;
+    for (int i = base; i < batch_size; ++i) {
+        int input_count   = accum_input_count + h_input_length_buf_[i] - 1;
+        int context_count = accum_context_count + state_->h_context_length[i] - 1;
+        if (input_count <= max_context_token_num_ && context_count <= max_context_token_num_) {
+            accum_input_count   = input_count;
+            accum_context_count = context_count;
         }
         else {
-            // add to current sub-batch
-            token_num += get_input_len(i);
-            max_input_len   = std::max(max_input_len, get_input_len(i));
-            max_context_len = std::max(max_context_len, get_context_len(i));
-
+            offsets.push_back(i);
+            accum_input_count   = 0;
+            accum_context_count = 0;
+        }
+    }
+    offsets.push_back(batch_size);
+
+    // context decode on sub-batches
+    for (int k = 0; k < offsets.size() - 1; ++k) {
+        int              first          = offsets[k];
+        int              last           = offsets[k + 1];
+        int              sub_batch_szie = last - first;
+        T*               k_ptr          = tmp_k_cache_buf_;
+        T*               v_ptr          = tmp_v_cache_buf_;
+        std::vector<int> decode_indices{};
+        std::vector<int> decode_lengths{};
+        int              max_input_len{};
+        int              max_context_len{};
+        auto             input_ids = context_decoder_ids_buf_;
+        for (int i = first; i < last; ++i) {
+            input_ids        = Copy(input_ids_buf_ + i * session_len_, h_input_ids_buf_[i] - 1, input_ids);
+            h_tmp_k_ptrs_[i] = k_ptr;
+            h_tmp_v_ptrs_[i] = v_ptr;
+            k_ptr += (state_->h_context_length[i] - 1) * llama_->local_kv_head_num_ * llama_->size_per_head_;
+            v_ptr += (state_->h_context_length[i] - 1) * llama_->local_kv_head_num_ * llama_->size_per_head_;
             decode_indices.push_back(i);
-            decode_lengths.push_back(get_input_len(i));
+            decode_lengths.push_back(h_input_length_buf_[i] - 1);
+            max_input_len   = std::max(max_input_len, h_input_length_buf_[i] - 1);
+            max_context_len = std::max(max_context_len, state_->h_context_length[i] - 1);
         }
+        int token_count = input_ids - context_decoder_ids_buf_;
+
+        Copy(h_tmp_k_ptrs_ + first, sub_batch_szie, tmp_k_ptrs_ + first);
+        Copy(h_tmp_v_ptrs_ + first, sub_batch_szie, tmp_v_ptrs_ + first);
+
+        if (rank_ == 0) {
+            TM_LOG_INFO(
+                "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
+                base,
+                sub_batch_szie,
+                token_count,
+                max_input_len,
+                max_context_len);
+        }
+
+        llama_->contextDecode(nullptr,
+                              k_block_ptrs_,
+                              v_block_ptrs_,
+                              tmp_k_ptrs_ + first,
+                              tmp_v_ptrs_ + first,
+                              context_decoder_input_buf_,
+                              context_decoder_output_buf_,
+                              context_decoder_ids_buf_,
+                              input_length_buf_ + first,
+                              history_length_buf_ + first,
+                              context_length_buf_ + first,
+                              cu_block_counts_ + first,
+                              token_count,
+                              max_input_len,
+                              max_context_len,
+                              session_len_,
+                              sub_batch_szie);
+
+        // compute logits of inputs if requested
+        OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
     }
 
     invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
     invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_);
 
-    for (int i = offset; i < batch_size; ++i) {
-        h_input_length_buf_[i] = 0;
-    }
+    std::fill(h_input_length_buf_ + base, h_input_length_buf_ + batch_size, 0);
 
     check_cuda_error(cudaStreamSynchronize(stream_));
     const auto tock = std::chrono::high_resolution_clock::now();
@@ -962,7 +990,7 @@ void LlamaBatch<T>::OutputContextLogits(T*                      context_decoder_
 }
 
 template<typename T>
-int LlamaBatch<T>::Finish()
+auto LlamaBatch<T>::Finish() -> std::vector<Signal>
 {
     const int batch_size = state_->active_size;
 
@@ -989,14 +1017,14 @@ int LlamaBatch<T>::Finish()
         TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
     }
 
-    int finished_count{};
+    std::vector<Signal> signals;
     for (int i = 0; i < batch_size; ++i) {
         if (state_->requests[i] && state_->h_finished[i]) {
             FinishRequest(i, false);
-            ++finished_count;
+            signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
         }
     }
-    return finished_count;
+    return signals;
 }
 
 template<typename T>
@@ -1079,12 +1107,6 @@ void LlamaBatch<T>::FinishRequest(int index, bool force_end)
         sequence_manager_->Update(seq);
     }
 
-    // Notify request completion
-    if (rank_ == 0) {
-        state_->requests[index]->signal.set_value(0);
-    }
-
-    state_->requests[index]  = nullptr;
     state_->sequences[index] = nullptr;
 }
 
@@ -1125,7 +1147,8 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
             return;
         }
 
-        ProcessStopRequests(stop_requests);
+        auto signals = ProcessStopRequests(stop_requests);
+        BarrierSignalRequests(*shared_state->barrier, signals);
 
         ProcessInferRequests(infer_requests);
 
@@ -1146,13 +1169,28 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
                     break;
                 }
             }
-            finished_count = Finish();
+            auto signals = Finish();
+            BarrierSignalRequests(*shared_state->barrier, signals);
         }
     }
 
     FT_CHECK(0);
 }
 
+template<typename T>
+void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals)
+{
+    if (!signals.empty()) {
+        barrier.wait();
+        if (rank_ == 0) {
+            for (const auto& s : signals) {
+                s();
+            }
+        }
+        barrier.wait();
+    }
+}
+
 template<typename T>
 void LlamaBatch<T>::Start()
 {
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index a8bfad6729..7cc430eb9d 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -3,6 +3,7 @@
 #pragma once
 
 // #include "src/turbomind/models/llama/LlamaCacheManager.h"
+#include "src/turbomind/models/llama/Barrier.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/Request.h"
 #include "src/turbomind/models/llama/SequenceManager.h"
@@ -24,9 +25,10 @@ struct BatchState {
     std::vector<const Sequence*>          sequences;
     std::vector<std::shared_ptr<Request>> requests;
 
-    // |<-- existing -->|<-- swap-in -->|<-- inactive -->|
-    int size;
+    // |<-- existing -->|<-- swap-in -->|
+    // |<----------- active ----------->|<-- inactive -->|
     int active_size;
+    int size;
 };
 
 template<typename T>
@@ -40,10 +42,11 @@ class LlamaBatch {
     void FreeBuffer();
 
     using Requests = std::vector<std::shared_ptr<Request>>;
+    using Signal   = std::function<void()>;
 
     void RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs);
 
-    void ProcessStopRequests(const Requests& requests);
+    [[nodiscard]] auto ProcessStopRequests(const Requests& requests) -> std::vector<Signal>;
 
     void ProcessInferRequests(const Requests& requests);
 
@@ -53,10 +56,11 @@ class LlamaBatch {
 
     void InitializeSampling();
     void InitializeGeneration();
-    bool Generate();
 
-    int  Finish();
-    void FinishRequest(int index, bool force_end);
+    [[nodiscard]] bool Generate();
+
+    [[nodiscard]] auto Finish() -> std::vector<Signal>;
+    void               FinishRequest(int index, bool force_end);
 
     void SetOutputTensors(int max_gen_step);
 
@@ -91,6 +95,8 @@ class LlamaBatch {
 
     void LoadRandomState(BatchState& state, int idx);
 
+    void BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals);
+
     // analogs to `std::copy_n`
     template<typename U>
     U* Copy(const U* src, size_t count, U* dst)
@@ -125,6 +131,14 @@ class LlamaBatch {
     T* decoder_input_buf_{};   // CTXDEC, GENERATE
     T* decoder_output_buf_{};  // CTXDEC, GENERATE
 
+    // temp buffers used for block->linear kv-cache conversion
+    T*     tmp_k_cache_buf_{};
+    T*     tmp_v_cache_buf_{};
+    void** tmp_k_ptrs_{};
+    void** tmp_v_ptrs_{};
+    void** h_tmp_k_ptrs_{};
+    void** h_tmp_v_ptrs_{};
+
     int*       input_ids_buf_{};       // input token ids + cache missed token ids, CTXDEC
     int*       input_length_buf_{};    // input + cache missed length, CTXDEC, GENERATE
     int*       history_length_buf_{};  // history length, CTXDEC
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index f985cff5a5..8fc54b9922 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -21,12 +21,12 @@
 
 #include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
 #include "src/turbomind/kernels/bert_preprocess_kernels.h"
+#include "src/turbomind/kernels/decoder_mha/kv_cache.h"
 #include "src/turbomind/kernels/unfused_attention_kernels.h"
 #include "src/turbomind/macro.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
 #include "src/turbomind/models/llama/llama_utils.h"
-#include "src/turbomind/kernels/decoder_mha/kv_cache.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
@@ -188,6 +188,10 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
 
     auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
     auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
+
+    auto tmp_k_ptrs = output_tensors->getPtr<T*>("tmp_k");
+    auto tmp_v_ptrs = output_tensors->getPtr<T*>("tmp_v");
+
     //////////////////////////////////////////////////////////
     /// insert the k/v computed from inputs into k/v cache
     /// transpose kv -> kv cache
@@ -210,13 +214,25 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                         quant_policy_,
                         weights->past_kv_scale.data(),
                         stream_);
+    sync_check_cuda_error();
 
-    ConvertBlocksToLinear(k_cache_ptrs, k_cache_, cu_block_counts, max_seq_len, kv_cache_block_len_, int dst_max_seq_len, int head_num, int head_dim, int batch_size, cudaStream_t st)
-
+    ConvertKvCacheBlocksToLinear((const T**)k_cache_ptrs,
+                                 (const T**)v_cache_ptrs,
+                                 tmp_k_ptrs,
+                                 tmp_k_ptrs,
+                                 cu_block_counts,
+                                 history_length,
+                                 kv_cache_block_len_,
+                                 max_seq_len,
+                                 local_kv_head_num_,
+                                 size_per_head_,
+                                 batch_size,
+                                 stream_);
     sync_check_cuda_error();
+
     if (use_fmha_) {
-        fusedMultiHeadAttention(k_cache_ptrs,
-                                v_cache_ptrs,
+        fusedMultiHeadAttention(tmp_k_ptrs,
+                                tmp_k_ptrs,
                                 layer_offset,
                                 attention_mask,
                                 cu_seqlens,
@@ -226,8 +242,8 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                 max_seq_len);
     }
     else {
-        unfusedMultiHeadAttention(k_cache_ptrs,
-                                  v_cache_ptrs,
+        unfusedMultiHeadAttention(tmp_k_ptrs,
+                                  tmp_k_ptrs,
                                   layer_offset,
                                   attention_mask,
                                   padding_offset,
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index a1893736b6..cc8f64f7fa 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -78,6 +78,7 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
     end_id_(end_id),
     hidden_units_(head_num * size_per_head),
     local_head_num_(head_num / tensor_para.world_size_),
+    local_kv_head_num_(head_num / tensor_para.world_size_),
     weights_(weights),
     tensor_para_(tensor_para),
     stream_(stream),
@@ -218,6 +219,8 @@ template<typename T>
 void LlamaV2<T>::contextDecode(T*         deocder_output,
                                uintptr_t* k_cache_ptr,
                                uintptr_t* v_cache_ptr,
+                               void**     tmp_k_ptrs,
+                               void**     tmp_v_ptrs,
                                T*         context_decoder_input_buf,
                                T*         context_decoder_output_buf,
                                const int* input_ids,
@@ -273,6 +276,8 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
         {"decoder_output", {MEMORY_GPU, dtype, {bsz, max_input_len, hidden_units_}, context_decoder_output_buf}},
         {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}},
         {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}},
+        {"tmp_k_ptrs", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_k_ptrs}},
+        {"tmp_v_ptrs", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_v_ptrs}},
         {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}};
 
     context_decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index 31e4bf42d7..09f49ceaaf 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -102,8 +102,10 @@ class LlamaV2 {
     void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
 
     void contextDecode(T*         deocder_output,
-                       uintptr_t* k_cache_ptr,
-                       uintptr_t* v_cache_ptr,
+                       uintptr_t* k_block_ptrs,
+                       uintptr_t* v_block_ptrs,
+                       void**     k_tmp_ptrs,
+                       void**     v_tmp_ptrs,
                        T*         context_decoder_input_buf,
                        T*         context_decoder_output_buf,
                        const int* input_ids,
@@ -175,6 +177,7 @@ class LlamaV2 {
     const size_t hidden_units_;
 
     const size_t local_head_num_;
+    const size_t local_kv_head_num_;
     NcclParam    tensor_para_;
 
     cudaStream_t     stream_;
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 1c05340048..808dcfd754 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -1,6 +1,9 @@
 #include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/utils/allocator.h"
+#include "src/turbomind/utils/dbg.h"
 #include "src/turbomind/utils/logger.h"
 #include <ctime>
+#include <stdexcept>
 
 namespace turbomind {
 
@@ -81,6 +84,9 @@ bool SequenceManager::Erase(uint64_t id)
         }
         sequences_.erase(it);
     }
+    else {
+        throw std::out_of_range(std::to_string(id));
+    }
 
     return false;
 }
@@ -103,6 +109,7 @@ struct Schedule {
 
     int allocate;
     int evict;
+    int preempt;
 
     std::vector<int> victims;
 
@@ -112,6 +119,25 @@ struct Schedule {
     std::vector<int> inactive;
 };
 
+template<typename T>
+std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
+{
+    os << "[";
+    for (int i = 0; i < v.size(); ++i) {
+        os << (i ? "," : "") << v[i];
+    }
+    os << "]";
+    return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const Schedule& s)
+{
+    os << "Schedule { free=" << s.free << ", cached=" << s.cached << ", allocate=" << s.allocate
+       << ", evict=" << s.evict << ", preempt=" << s.preempt << ", active=" << s.active << ", victims=" << s.victims
+       << ", block_counts=" << s.block_counts << ", inactive=" << s.inactive << " }";
+    return os;
+}
+
 class Simulator {
 public:
     explicit Simulator(const std::vector<const Sequence*>& seqs,
@@ -119,6 +145,7 @@ class Simulator {
                        std::vector<int>&                   ref_count):
         seqs_(seqs), idxs_(idxs), ref_count_(ref_count)
     {
+        dbg(seqs.size());
         released_.resize(seqs.size());
         ptr_ = released_.size();
     }
@@ -189,10 +216,15 @@ struct Transaction {
     void Commit()
     {
         sched_.free -= allocate_;
-        sched_.cached += preempt_ - evict_;
+        FT_CHECK(sched_.free >= 0);
+
+        sched_.cached += preempt_;
+        sched_.cached -= evict_;
+        FT_CHECK(sched_.cached >= 0);
 
         sched_.allocate += allocate_;
         sched_.evict += evict_;
+        sched_.preempt += preempt_;
 
         sched_.victims.insert(sched_.victims.end(), victims_.begin(), victims_.end());
 
@@ -201,24 +233,32 @@ struct Transaction {
     }
 };
 
+std::ostream& operator<<(std::ostream& os, const Transaction& trans)
+{
+    os << "Transaction { index=" << trans.index_ << ", block_count=" << trans.block_count_
+       << ", allocate=" << trans.allocate_ << ", evict=" << trans.evict_ << ", preempt=" << trans.preempt_
+       << ", victims=" << trans.victims_ << " }";
+    return os;
+}
+
 }  // namespace
 
 std::ostream& operator<<(std::ostream& os, const Sequence& seq)
 {
-    os << "Sequence[id=" << seq.id << ",status=" << seq.status << ",size(blocks)=" << seq.blocks.size()
-       << ",cache_len=" << seq.cache_len << ",size(random_state)=" << seq.random_state.size() << "]";
+    os << "Sequence { id=" << seq.id << ", status=" << seq.status << ", size(blocks)=" << seq.blocks.size()
+       << ", cache_len=" << seq.cache_len << ", size(random_state)=" << seq.random_state.size() << " }";
     return os;
 }
 
-bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
+auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
                                   const std::vector<int>&             context_lengths,
                                   const std::vector<uint64_t>&        priorities,
-                                  int                                 step_length)
+                                  int                                 step_length) -> Outcome
 {
     ////////////////////////////////////////////////////////////////////////////////
     /// Schedule the assignment of blocks to sequences
-
-    auto seqs = const_cast<Sequence* const*>(sequences.data());
+    auto    seqs = const_cast<Sequence* const*>(sequences.data());
+    Outcome outcome{};
 
     // check validity of of cached blocks (blocks of active & locked seqs are always valid)
     if (need_verification_) {
@@ -236,13 +276,15 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
     for (int i = 0; i < sequences.size(); ++i) {
         int seq_len = context_lengths[i] + step_length;
         int count   = (seq_len + block_len_ - 1) / block_len_ - static_cast<int>(seqs[i]->blocks.size());
-        required.push_back(std::max(0, count));
-        total_required += required.back();
+        required[i] = std::max(0, count);
+        total_required += required[i];
     }
 
+    dbg(required);
+
     // no new blocks required, exit early
     if (total_required == 0) {
-        return false;
+        return outcome;
     }
 
     /// TODO: more early exit heuristics
@@ -259,7 +301,8 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
 
     Simulator simulator(sequences, idxs, snapshot.ref_count);
 
-    bool modified = false;
+    std::vector<int> active(idxs.size());
+    std::vector<int> victim(idxs.size());
 
     for (int i = 0, j = idxs.size(); i < j; ++i) {
         const int idx = idxs[i];
@@ -275,7 +318,7 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
         }
         // evict cached blocks
         if (block_count) {
-            block_count -= trans.Evict(std::min(block_count, schedule.free));
+            block_count -= trans.Evict(std::min(block_count, schedule.cached));
         }
 
         for (int v = j - 1; block_count && v > i; --v) {
@@ -283,6 +326,7 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
                 continue;
             }
             int preempt = trans.Preempt(v, idxs[v]);
+            dbg(preempt);
             // Commit only when preemption actually free enough blocks for the sequence to run
             if (block_count <= preempt) {
                 // preempted blocks are in cached state
@@ -292,31 +336,34 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
             }
         }
 
+        dbg(block_count, trans);
+
         if (block_count == 0) {
             trans.Commit();
+            active[i] = 1;
             if (seq.status != Sequence::kActive) {
-                modified = true;
+                ++outcome.swap_in;
             }
         }
-        else {
-            // failed to collect enough block for the sequence, transaction aborted. Active sequence will be kept
-            // locked if not preempted by seq with higher priority
-            schedule.inactive.push_back(idx);
-            if (seq.status == Sequence::kActive) {
-                modified = true;
+    }
+
+    for (const auto& i : idxs) {
+        if (!active[i]) {
+            schedule.inactive.push_back(i);
+            if (seqs[i]->status == Sequence::kActive) {
+                ++outcome.swap_out;
             }
         }
     }
 
-    // Verify the schedule
-    FT_CHECK(schedule.allocate <= snapshot.free);
-    FT_CHECK(schedule.evict <= snapshot.cached);
-    // FT_CHECK(schedule.allocate + schedule.evict + schedule.preempt == total_block_count);
+    dbg(schedule);
 
     ////////////////////////////////////////////////////////////////////////////////
     /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)
     schedule.allocate += schedule.evict;
 
+    outcome.allocation = schedule.allocate;
+
     // release preempted blocks -> cached
     {
         std::vector<const Block*> blocks;
@@ -352,8 +399,6 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
         first = last;
     }
 
-    block_manager_->Touch(blocks);
-
     for (const auto& idx : schedule.inactive) {
         if (seqs[idx]->status == Sequence::kActive) {
             seqs[idx]->status = Sequence::kLocked;
@@ -364,7 +409,7 @@ bool SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
         seqs[idx]->status = Sequence::kCached;
     }
 
-    return modified;
+    return outcome;
 }
 
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 331d79990c..1b3b680784 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -45,6 +45,9 @@ class SequenceManager {
                              int         rank,
                              IAllocator* allocator);
 
+    SequenceManager(const SequenceManager&)     = delete;
+    SequenceManager(SequenceManager&&) noexcept = default;
+
     const Sequence* Create(uint64_t id);
 
     const Sequence* Fetch(uint64_t id);
@@ -55,10 +58,16 @@ class SequenceManager {
 
     bool Contains(uint64_t id);
 
-    bool Materialize(const std::vector<const Sequence*>& sequences,
-                     const std::vector<int>&             context_lengths,
-                     const std::vector<uint64_t>&        priorities,
-                     int                                 step_length);
+    struct Outcome {
+        int allocation;
+        int swap_in;
+        int swap_out;
+    };
+
+    Outcome Materialize(const std::vector<const Sequence*>& sequences,
+                        const std::vector<int>&             context_lengths,
+                        const std::vector<uint64_t>&        priorities,
+                        int                                 step_length);
 
     void* OffsetKey(void* block_ptr)
     {
@@ -75,6 +84,12 @@ class SequenceManager {
         return block_manager_->max_block_count();
     }
 
+    friend std::ostream& operator<<(std::ostream& os, const Outcome& oc)
+    {
+        os << "allocation: " << oc.allocation << ", swap-in: " << oc.swap_in << ", swap-out: " << oc.swap_out;
+        return os;
+    }
+
 private:
     void VerifyBlocks(Sequence& seq);
 
@@ -93,12 +108,4 @@ class SequenceManager {
     std::vector<const Block*> released_;
 };
 
-// cu_block_cnts(seq_idx) -> block_idx_offset
-// block_idxs(block_idx_offset) -> (seq_idx, seq_offset)
-
-
-inline void func(const std::vector<int>& block_cnts, std::vector<int>& cu_block_cnts, std::vector<int>& inv_block_idxs) {
-
-}
-
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc
new file mode 100644
index 0000000000..c306b1e7cc
--- /dev/null
+++ b/src/turbomind/models/llama/test_cache_manager.cc
@@ -0,0 +1,94 @@
+#include "BlockManager.h"
+#include "SequenceManager.h"
+
+#include "src/turbomind/utils/allocator.h"
+
+#include "src/turbomind/utils/dbg.h"
+#include <catch2/catch_test_macros.hpp>
+#include <iterator>
+
+using namespace turbomind;
+
+std::ostream& operator<<(std::ostream& os, const Block* b)
+{
+    os << "(" << b->id << "," << b->timestamp << ")";
+    return os;
+}
+
+TEST_CASE("BlockManager")
+{
+    Allocator<AllocatorType::CUDA> allocator(0);
+
+    BlockManager m(1024, 32, 8, &allocator);
+    REQUIRE(m.max_block_count() == 32);
+    REQUIRE(m.free_count() == 32);
+
+    auto blocks1 = m.Allocate(10);
+
+    dbg(blocks1);
+
+    REQUIRE(blocks1.size() == 10);
+    REQUIRE(m.active_count() == blocks1.size());
+    REQUIRE(m.free_count() == 22);
+
+    auto blocks2 = m.Allocate(6);
+    REQUIRE(blocks2.size() == 6);
+    REQUIRE(m.active_count() == blocks1.size() + blocks2.size());
+    REQUIRE(m.free_count() == 16);
+
+    auto blocks3 = m.Allocate(16);
+    REQUIRE(blocks3.size() == 16);
+    REQUIRE(m.active_count() == 32);
+    REQUIRE(m.free_count() == 0);
+
+    std::copy(blocks3.begin(), blocks3.end(), std::back_inserter(blocks1));
+    std::copy(blocks2.begin(), blocks2.end(), std::back_inserter(blocks1));
+
+    REQUIRE(m.Release(blocks1) == 32);
+    REQUIRE(m.active_count() == 0);
+    REQUIRE(m.free_count() == 0);
+    REQUIRE(m.cached_count() == 32);
+
+    m.Evict(16);
+    REQUIRE(m.active_count() == 0);
+    REQUIRE(m.free_count() == 16);
+    REQUIRE(m.cached_count() == 16);
+
+    auto blocks4 = m.Allocate(14);
+    REQUIRE(m.active_count() == 14);
+    REQUIRE(m.free_count() == 2);
+    REQUIRE(m.cached_count() == 16);
+}
+
+TEST_CASE("SequenceManager")
+{
+    Allocator<AllocatorType::CUDA> allocator(0);
+
+    SequenceManager manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator);
+
+    REQUIRE(manager.max_block_count() == 20);
+    REQUIRE(manager.Contains(1) == false);
+
+    auto s1 = manager.Create(1);
+    dbg(*s1);
+    REQUIRE(manager.Contains(1) == true);
+
+    manager.Erase(1);
+    REQUIRE(manager.Contains(1) == false);
+
+    s1 = manager.Create(1);
+    REQUIRE(manager.Contains(1) == true);
+
+    auto outcome = manager.Materialize({s1}, {128}, {100}, 1);
+    dbg(s1->blocks);
+    REQUIRE(s1->blocks.size() == 2);
+
+    auto s2 = manager.Create(2);
+    REQUIRE(manager.Contains(2));
+
+    outcome = manager.Materialize({s1, s2}, {128, 2560-1}, {2, 1}, 1);
+    dbg(outcome);
+
+    // outcome = manager.Materialize({s1, s2}, {128, 12800}, {1, 2}, 1);
+    // dbg(outcome);
+}
\ No newline at end of file

From 79dab4ca4cd38e091fe0b09145dfcfcc7c7209e8 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 28 Sep 2023 08:22:56 +0000
Subject: [PATCH 06/56] update

---
 CMakeLists.txt                                |  16 ++
 examples/cpp/llama/llama_triton_example.cc    |   4 +-
 .../decoder_multihead_attention.cu            |   3 +-
 .../decoder_multihead_attention_params.h      |   8 +-
 .../decoder_multihead_attention_template.h    |  22 ++-
 src/turbomind/kernels/decoder_mha/iterator.h  |  17 +-
 src/turbomind/kernels/decoder_mha/kv_cache.cu |  37 ++++-
 src/turbomind/kernels/decoder_mha/kv_cache.h  |   3 +
 .../test_decoder_multihead_attention.cu       |  78 ++++++---
 .../kernels/decoder_mha/test_utils.cu         |   6 +-
 src/turbomind/models/llama/BlockManager.cc    |  29 ++--
 src/turbomind/models/llama/BlockManager.h     |   8 +-
 src/turbomind/models/llama/LlamaBatch.cc      | 148 ++++++++++++------
 src/turbomind/models/llama/LlamaBatch.h       |   5 +-
 .../llama/LlamaContextAttentionLayer.cc       |  46 ++++--
 .../models/llama/LlamaContextDecoder.cc       |  26 +--
 .../models/llama/LlamaContextDecoder.h        |  17 +-
 src/turbomind/models/llama/LlamaDecoder.cc    |   1 -
 .../llama/LlamaDecoderSelfAttentionLayer.cc   | 132 ++++++++++------
 .../llama/LlamaDecoderSelfAttentionLayer.h    |   2 +
 src/turbomind/models/llama/LlamaV2.cc         |  14 +-
 src/turbomind/models/llama/LlamaWeight.cc     |  10 +-
 src/turbomind/models/llama/Request.h          |   7 +-
 src/turbomind/models/llama/SequenceManager.cc | 115 +++++++-------
 src/turbomind/models/llama/SequenceManager.h  |  25 ++-
 src/turbomind/models/llama/llama_kernels.cu   |   7 +-
 .../models/llama/test_cache_manager.cc        |  26 ++-
 .../triton_backend/llama/LlamaTritonModel.cc  |   6 +-
 .../llama/LlamaTritonModelInstance.h          |   4 +-
 .../transformer_triton_backend.hpp            |   2 +
 src/turbomind/utils/allocator.h               |  33 ++--
 src/turbomind/utils/cuda_utils.h              |   2 +-
 32 files changed, 544 insertions(+), 315 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9b1979ffce..80c767d1a2 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -61,6 +61,22 @@ option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF
 
 option(BUILD_FAST_MATH "Build in fast math mode" ON)
 
+# the environment variable 
+#   ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
+# must be set at runtime
+# https://github.com/google/sanitizers/issues/1322
+if (LMDEPLOY_ASAN_ENABLE)
+    add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>)
+    add_link_options(-fsanitize=address)
+endif ()
+
+# notice that ubsan has linker issues for ubuntu < 18.04, see
+# https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed
+if (LMDEPLOY_UBSAN_ENABLE)
+    add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=undefined>)
+    add_link_options(-fsanitize=undefined)
+endif ()
+
 if(BUILD_MULTI_GPU)
   message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL")
   add_definitions("-DBUILD_MULTI_GPU")
diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc
index 2f50ae19a0..07e88a508d 100644
--- a/examples/cpp/llama/llama_triton_example.cc
+++ b/examples/cpp/llama/llama_triton_example.cc
@@ -80,7 +80,9 @@ broadCastRequest(const std::vector<int>& v_start_ids,
     if (node_id == 0) {
         memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int));
         memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int));
-        memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int));
+        if (!v_input_bad_words.empty()) {
+            memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int));
+        }
     }
     if (kUSE_MPI) {
         ft::mpi::barrier();
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
index cc6edaf230..113780d65b 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
@@ -34,7 +34,7 @@ void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& p
     dim3 grid(params.num_kv_heads, params.batch_size);
 
     const size_t kDynamicSmemSize = MHAType::GetDynamicSmemSize(0);
-    std::cout << "dynamic shared memory size: " << kDynamicSmemSize << "\n";
+    // std::cout << "dynamic shared memory size: " << kDynamicSmemSize << "\n";
 
     cudaFuncSetAttribute(
         decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynamicSmemSize);
@@ -43,5 +43,6 @@ void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& p
 }
 
 template void LaunchDecoderMultiheadAttention<half, 128>(const DecoderMultiHeadAttentionParams<half>& params);
+template void LaunchDecoderMultiheadAttention<float, 128>(const DecoderMultiHeadAttentionParams<float>& params);
 
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
index 2cb2184643..b055961289 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
@@ -17,20 +17,20 @@ struct DecoderMultiHeadAttentionParams {
     T* v_bias;
 
     // sequence-level buffers
-    int*  per_sample_length;
-    bool* finished;
+    const int*  per_sample_length;
+    const bool* finished;
 
     // kv cache
     void** per_sample_k_cache;  // [H, S, D]
     void** per_sample_v_cache;  // [H, S, D]
-    size_t per_sample_kv_cache_offset;
+    size_t layer_offset;
 
     /// cache layout M,[N,H,x,D]
     /// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
     /// 1. [L,sum(S),H,x,D]
     void** k_cache_block_ptrs;  // X,[H,x,D]
     void** v_cache_block_ptrs;  // X,[H,x,D]
-    int*   cu_ctxlens;          // [B+1]
+    int*   cu_block_cnts;       // [B+1]
     int    kv_cache_block_size;
 
     // batch-level params
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
index 3dabd9042a..a13925818b 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
@@ -113,19 +113,13 @@ struct DecoderMultiHeadAttentionKernel {
         timestep_ = params_.per_sample_length[batch_idx_];
 
         if constexpr (kUseBlockIter) {
-            k_cache_ptrs_ = params_.k_cache_block_ptrs + params_.cu_ctxlens[batch_idx_];
-            v_cache_ptrs_ = params_.v_cache_block_ptrs + params_.cu_ctxlens[batch_idx_];
-            // if (thread0()) {
-            //     printf("%d %p %p\n",
-            //            params_.cu_ctxlens[batch_idx_],
-            //            params_.k_cache_block_ptrs,
-            //            params_.v_cache_block_ptrs);
-            // }
+            k_cache_ptrs_ = params_.k_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
+            v_cache_ptrs_ = params_.v_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
         }
         else {
-            k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.per_sample_kv_cache_offset
+            k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.layer_offset
                        + head_idx_ * params_.max_seq_len * params_.size_per_head;
-            v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.per_sample_kv_cache_offset
+            v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.layer_offset
                        + head_idx_ * params_.max_seq_len * params_.size_per_head;
         }
     }
@@ -235,8 +229,10 @@ struct DecoderMultiHeadAttentionKernel {
                 // if (thread0()) {
                 //     printf("%d %d %p %p\n", block_index, block_offset, k_cache_ptrs_, v_cache_ptrs_);
                 // }
-                k_cache_ = (T*)k_cache_ptrs_[block_index] + head_idx_ * params_.kv_cache_block_size * kHeadDim;
-                v_cache_ = (T*)v_cache_ptrs_[block_index] + head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                k_cache_ = (T*)k_cache_ptrs_[block_index] + params_.layer_offset
+                           + head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                v_cache_ = (T*)v_cache_ptrs_[block_index] + params_.layer_offset
+                           + head_idx_ * params_.kv_cache_block_size * kHeadDim;
                 Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K);
                 Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V);
             }
@@ -302,6 +298,7 @@ struct DecoderMultiHeadAttentionKernel {
         if constexpr (kUseBlockIter) {
             iter_K = {k_cache_ptrs_,
                       params_.kv_cache_block_size,
+                      params_.layer_offset,
                       head_idx_,
                       smem_Kv_,
                       step,
@@ -490,6 +487,7 @@ struct DecoderMultiHeadAttentionKernel {
         if constexpr (kUseBlockIter) {
             iter_V = {v_cache_ptrs_,
                       params_.kv_cache_block_size,
+                      params_.layer_offset,
                       head_idx_,
                       smem_Kv_,
                       step,
diff --git a/src/turbomind/kernels/decoder_mha/iterator.h b/src/turbomind/kernels/decoder_mha/iterator.h
index 532cc23a22..08f939827c 100644
--- a/src/turbomind/kernels/decoder_mha/iterator.h
+++ b/src/turbomind/kernels/decoder_mha/iterator.h
@@ -66,6 +66,7 @@ struct Iterator {
 
     int block_size_;
     int block_k_;
+    int layer_offset_;
 
     int head_idx_;
 
@@ -102,18 +103,26 @@ struct Iterator {
         is_valid_s_ = offset_s_ < seq_len;
     }
 
-    __device__ Iterator(
-        const void** block_ptrs, int block_size, int head_idx, T* smem, int step, int seqlen, int warp_id, int lane_id)
+    __device__ Iterator(const void** block_ptrs,
+                        int          block_size,
+                        int          layer_offset,
+                        int          head_idx,
+                        T*           smem,
+                        int          step,
+                        int          seqlen,
+                        int          warp_id,
+                        int          lane_id)
     {
         // src_  = src;
         int block_index = step / block_size;
         block_size_     = block_size;
         block_k_        = (block_index + 1) * block_size - step;  // offset to next block
+        layer_offset_   = layer_offset;
         head_idx_       = head_idx;
 
         block_iterator_ = BlockIterator(block_ptrs + block_index);
 
-        src_ = (const T*)block_iterator_.Next() + head_idx_ * block_size_ * ThreadMap::kC;
+        src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
 
         smem_ = smem;
 
@@ -186,7 +195,7 @@ struct Iterator {
             if (is_valid_s_) {
                 block_k_ -= ThreadMap::kS;
                 if (block_k_ == 0) {
-                    src_        = (const T*)block_iterator_.Next() + head_idx_ * block_size_ * ThreadMap::kC;
+                    src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
                     block_k_    = block_size_;
                     src_offset_ = init_offset_;
                 }
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.cu b/src/turbomind/kernels/decoder_mha/kv_cache.cu
index 7ebe25271b..18a0754260 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.cu
@@ -1,5 +1,6 @@
 #include "../gemm_s_f16/common.h"
 // #include "cute/tensor.hpp"
+#include "src/turbomind/utils/dbg.h"
 #include <cuda_fp16.h>
 #include <type_traits>
 
@@ -13,6 +14,8 @@ __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptr
                                             const int* __restrict__ src_cu_block_cnts,
                                             const int* __restrict__ dst_cu_block_cnts,
                                             const int* __restrict__ seq_lens,
+                                            int         src_offset,
+                                            int         dst_offset,
                                             SrcBlockLen src_block_len,
                                             DstBlockLen dst_block_len,
                                             HeadDim     head_dim)
@@ -32,11 +35,11 @@ __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptr
 
     // compute indices into src
     int src_block_index  = si / src_block_len + src_cu_block_cnts[bi];
-    int src_block_offset = hi * src_block_len * head_dim + si % src_block_len * head_dim + di;
+    int src_block_offset = src_offset + hi * src_block_len * head_dim + si % src_block_len * head_dim + di;
 
     // compute indices into dst
     int dst_block_index  = si / dst_block_len + dst_cu_block_cnts[bi];
-    int dst_block_offset = hi * dst_block_len * head_dim + si % dst_block_len * head_dim + di;
+    int dst_block_offset = dst_offset + hi * dst_block_len * head_dim + si % dst_block_len * head_dim + di;
 
     // printf("%d %d\n", src_block_index, dst_block_index);
 
@@ -48,16 +51,12 @@ __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptr
     *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
 }
 
-// static inline size_t get_helper_smem_size(int batch_size)
-// {
-//     return (sizeof(void*) + sizeof(int)) * batch_size;
-// }
-
 template<typename T>
 __global__ void LinearToBlocksKernel(const T*   src,
                                      T**        dst_block_ptrs,
                                      const int* dst_cu_block_cnts,
                                      const int* seq_lens,
+                                     int        dst_offset,
                                      int        src_block_len,
                                      int        dst_block_len,
                                      int        head_num,
@@ -81,6 +80,8 @@ __global__ void LinearToBlocksKernel(const T*   src,
                      src_cu_block_cnts,
                      dst_cu_block_cnts,
                      seq_lens,
+                     0,
+                     dst_offset,
                      src_block_len,
                      dst_block_len,
                      head_dim);
@@ -91,6 +92,7 @@ void ConvertLinearToBlocks(const T*     src,
                            T**          dst_block_ptrs,
                            const int*   dst_cu_block_cnts,
                            const int*   seq_lens,
+                           int          dst_offset,
                            int          src_max_len,
                            int          dst_block_len,
                            int          head_num,
@@ -110,6 +112,7 @@ void ConvertLinearToBlocks(const T*     src,
                                                                dst_block_ptrs,
                                                                dst_cu_block_cnts,
                                                                seq_lens,
+                                                               dst_offset,
                                                                src_max_len,
                                                                dst_block_len,
                                                                head_num,
@@ -130,6 +133,7 @@ template void ConvertLinearToBlocks(const half*  src,
                                     half**       dst_block_ptrs,
                                     const int*   dst_cu_block_cnts,
                                     const int*   seq_lens,
+                                    int          dst_offset,
                                     int          src_seq_len,
                                     int          dst_block_len,
                                     int          head_num,
@@ -142,6 +146,7 @@ __global__ void BlocksToLinearKernel(const T**  src_block_ptrs,
                                      T*         dst,
                                      const int* src_cu_block_cnts,
                                      const int* seq_lens,
+                                     int        src_offset,
                                      int        src_block_len,
                                      int        dst_block_len,
                                      int        head_num,
@@ -165,6 +170,8 @@ __global__ void BlocksToLinearKernel(const T**  src_block_ptrs,
                      src_cu_block_cnts,
                      dst_cu_block_cnts,
                      seq_lens,
+                     src_offset,
+                     0,
                      src_block_len,
                      dst_block_len,
                      head_dim);
@@ -175,6 +182,7 @@ void ConvertBlocksToLinear(const T**    src_block_ptrs,
                            T*           dst,
                            const int*   src_cu_block_cnts,
                            const int*   seq_lens,
+                           int          src_offset,
                            int          src_block_len,
                            int          dst_max_len,
                            int          head_num,
@@ -194,7 +202,8 @@ void ConvertBlocksToLinear(const T**    src_block_ptrs,
                                                                dst,
                                                                src_cu_block_cnts,
                                                                seq_lens,
-                                                               std::integral_constant<int, 128>{},
+                                                               src_offset,
+                                                               src_block_len,
                                                                dst_max_len,
                                                                head_num,
                                                                head_dim,
@@ -214,6 +223,7 @@ template void ConvertBlocksToLinear(const half** src_block_ptrs,
                                     half*        dst,
                                     const int*   src_cu_block_cnts,
                                     const int*   seq_lens,
+                                    int          src_offset,
                                     int          src_block_len,
                                     int          dst_max_seq_len,
                                     int          head_num,
@@ -228,6 +238,7 @@ __global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
                                             T**         dst_v_ptrs,
                                             const int*  src_cu_block_cnts,
                                             const int*  seq_lens,
+                                            int         src_offset,
                                             SrcBlockLen src_block_len,
                                             DstBlockLen dst_block_len,
                                             int         head_num,
@@ -247,6 +258,8 @@ __global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
                      src_cu_block_cnts,
                      dst_cu_block_cnts,
                      seq_lens,
+                     src_offset,
+                     0,
                      src_block_len,
                      dst_block_len,
                      head_dim);
@@ -256,6 +269,8 @@ __global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
                      src_cu_block_cnts,
                      dst_cu_block_cnts,
                      seq_lens,
+                     src_offset,
+                     0,
                      src_block_len,
                      dst_block_len,
                      head_dim);
@@ -268,6 +283,7 @@ void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
                                   T**          dst_v_ptrs,
                                   const int*   src_cu_block_cnts,
                                   const int*   seq_lens,
+                                  int          src_offset,
                                   int          src_block_len,
                                   int          dst_block_len,
                                   int          head_num,
@@ -282,6 +298,8 @@ void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
 
     const auto smem_sz = sizeof(int) * batch_size;
 
+    // dbg(src_block_len, dst_block_len, head_num, head_dim, batch_size);
+
     auto fn = [&](auto head_dim) {
         KvCacheBlocksToLinearKernel<<<blocks, threads, smem_sz, st>>>(src_k_block_ptrs,
                                                                       src_v_block_ptrs,
@@ -289,6 +307,7 @@ void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
                                                                       dst_v_ptrs,
                                                                       src_cu_block_cnts,
                                                                       seq_lens,
+                                                                      src_offset,
                                                                       src_block_len,
                                                                       dst_block_len,
                                                                       head_num,
@@ -311,6 +330,7 @@ template void ConvertKvCacheBlocksToLinear(const half** src_k_block_ptrs,
                                            half**       dst_v_ptrs,
                                            const int*   src_cu_block_cnts,
                                            const int*   seq_lens,
+                                           int          src_offset,
                                            int          src_block_len,
                                            int          dst_block_len,
                                            int          head_num,
@@ -324,6 +344,7 @@ template void ConvertKvCacheBlocksToLinear(const float** src_k_block_ptrs,
                                            float**       dst_v_ptrs,
                                            const int*    src_cu_block_cnts,
                                            const int*    seq_lens,
+                                           int           src_offset,
                                            int           src_block_len,
                                            int           dst_block_len,
                                            int           head_num,
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.h b/src/turbomind/kernels/decoder_mha/kv_cache.h
index 305971479b..798851fded 100644
--- a/src/turbomind/kernels/decoder_mha/kv_cache.h
+++ b/src/turbomind/kernels/decoder_mha/kv_cache.h
@@ -9,6 +9,7 @@ void ConvertLinearToBlocks(const T*     src,
                            T**          dst_block_ptrs,
                            const int*   dst_cu_block_cnts,
                            const int*   seq_lens,
+                           int          dst_offset,
                            int          src_seq_len,
                            int          dst_block_len,
                            int          head_num,
@@ -21,6 +22,7 @@ void ConvertBlocksToLinear(const T**    src_block_ptrs,
                            T*           dst,
                            const int*   src_cu_block_cnts,
                            const int*   seq_lens,
+                           int          src_offset,
                            int          src_block_len,
                            int          dst_max_seq_len,
                            int          head_num,
@@ -35,6 +37,7 @@ void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
                                   T**          dst_v_ptrs,
                                   const int*   src_cu_block_cnts,
                                   const int*   seq_lens,
+                                  int          src_offset,
                                   int          src_block_len,
                                   int          dst_block_len,  // max{seq_lens}
                                   int          head_num,
diff --git a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
index 7d71718bf7..7397443301 100644
--- a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
@@ -26,9 +26,10 @@ T* align(T* ptr, size_t alignment)
 
 // [S/S, H, S, D] <-> [S/b, H, b, D]
 
-void TestBlocks(thrust::universal_vector<half>&  linear,
-                thrust::universal_vector<half>&  _blocks,
-                thrust::universal_vector<half*>& _ptrs,
+void TestBlocks(thrust::universal_vector<half>&  linear,          // linear data
+                thrust::universal_vector<half>&  _blocks,         // block data
+                thrust::universal_vector<half*>& _ptrs,           // block ptrs
+                thrust::universal_vector<int>&   _cu_block_cnts,  // cumulative block counts
                 int                              head_num,
                 int                              head_dim,
                 int                              block_size,
@@ -41,7 +42,7 @@ void TestBlocks(thrust::universal_vector<half>&  linear,
               << ", block_size = " << block_size << "\n";
 
     thrust::universal_vector<half>  blocks(batch_size * n_blocks * head_num * block_size * head_dim);
-    thrust::universal_vector<half*> ptrs(batch_size * n_blocks);
+    thrust::universal_vector<half*> ptrs(batch_size * n_blocks + 1);  // +1 padding
 
     std::vector<size_t> idxs(batch_size * n_blocks);
     std::iota(idxs.begin(), idxs.end(), 0);
@@ -64,6 +65,7 @@ void TestBlocks(thrust::universal_vector<half>&  linear,
                               ptrs.data().get(),
                               cu_block_cnts.data().get(),
                               seq_lens.data().get(),
+                              0,
                               seq_len,
                               block_size,
                               head_num,
@@ -78,6 +80,7 @@ void TestBlocks(thrust::universal_vector<half>&  linear,
                               _linear.data().get(),
                               cu_block_cnts.data().get(),
                               seq_lens.data().get(),
+                              0,
                               block_size,
                               seq_len,
                               head_num,
@@ -89,10 +92,11 @@ void TestBlocks(thrust::universal_vector<half>&  linear,
     std::cout << ">>> Compare\n";
     Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
     std::cout << "<<< Compare\n";
-    std::exit(0);
+    // std::exit(0);
 
     _blocks.swap(blocks);
     _ptrs.swap(ptrs);
+    _cu_block_cnts.swap(cu_block_cnts);
 }
 
 int main(int argc, char* argv[])
@@ -102,7 +106,7 @@ int main(int argc, char* argv[])
     // constexpr int kHeadNum = 108 * 4;
     constexpr int kHeadNum     = 32;
     constexpr int kHeadDim     = 128;
-    constexpr int kBatchSize   = 64;
+    constexpr int kBatchSize   = 1;
     constexpr int kContextLen  = 511;
     constexpr int kSequenceLen = kContextLen + 1;
     constexpr int kBlockSz     = 128;
@@ -123,7 +127,6 @@ int main(int argc, char* argv[])
     thrust::universal_vector<int>   sequence_lengths(kBatchSize);
     thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
     thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
-    thrust::universal_vector<int>   cu_ctxlens(kBatchSize + 1);
 
     rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
 
@@ -164,13 +167,14 @@ int main(int argc, char* argv[])
 
     thrust::universal_vector<half>  k_blocks;
     thrust::universal_vector<half*> k_ptrs;
+    thrust::universal_vector<int>   cu_block_cnts;
 
-    TestBlocks(k_cache, k_blocks, k_ptrs, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
+    TestBlocks(k_cache, k_blocks, k_ptrs, cu_block_cnts, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
 
     thrust::universal_vector<half>  v_blocks;
     thrust::universal_vector<half*> v_ptrs;
 
-    TestBlocks(v_cache, v_blocks, v_ptrs, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
+    TestBlocks(v_cache, v_blocks, v_ptrs, cu_block_cnts, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
 
     thrust::universal_vector<half>  k_cache_ref = k_cache;
     thrust::universal_vector<half>  v_cache_ref = v_cache;
@@ -186,7 +190,6 @@ int main(int argc, char* argv[])
         v_cache_ptrs[i]     = v_cache.data().get() + i * v_cache.size() / kBatchSize;
         k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
         v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
-        cu_ctxlens[i + 1]   = cu_ctxlens[i] + kContextLen;
 
         // align(k_cache_ptrs[i], 256);
         // align(v_cache_ptrs[i], 256);
@@ -200,20 +203,20 @@ int main(int argc, char* argv[])
     params.v      = params.k + kHeadNum * kHeadDim;
     params.stride = 3 * kHeadNum * kHeadDim;
 
-    params.batch_size  = kBatchSize;
-    params.max_seq_len = kContextLen + 1;
-    params.cu_ctxlens  = cu_ctxlens.data().get();
+    params.batch_size    = kBatchSize;
+    params.max_seq_len   = kContextLen + 1;
+    params.cu_block_cnts = cu_block_cnts.data().get();
 
     printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
     params.k_cache_block_ptrs  = (void**)k_ptrs.data().get();
     params.v_cache_block_ptrs  = (void**)v_ptrs.data().get();
     params.kv_cache_block_size = kBlockSz;
 
-    params.finished                   = finished.data().get();
-    params.per_sample_length          = sequence_lengths.data().get();
-    params.per_sample_k_cache         = k_cache_ref_ptrs.data().get();
-    params.per_sample_v_cache         = v_cache_ref_ptrs.data().get();
-    params.per_sample_kv_cache_offset = 0;
+    params.finished           = finished.data().get();
+    params.per_sample_length  = sequence_lengths.data().get();
+    params.per_sample_k_cache = k_cache_ref_ptrs.data().get();
+    params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
+    params.layer_offset       = 0;
 
     params.num_heads     = kHeadNum;
     params.num_kv_heads  = kHeadNum;
@@ -228,10 +231,10 @@ int main(int argc, char* argv[])
     }
 
     cudaDeviceSynchronize();
-    // if (auto err = cudaGetLastError(); err != cudaSuccess) {
-    //     std::cout << cudaGetErrorString(err) << "\n";
-    //     return -1;
-    // }
+    if (auto err = cudaGetLastError(); err != cudaSuccess) {
+        std::cout << cudaGetErrorString(err) << "\n";
+        return -1;
+    }
     std::cout << "---------------------------------------------------\n";
 
     params.out                = output.data().get();
@@ -251,11 +254,34 @@ int main(int argc, char* argv[])
         }
     }
 
+    thrust::universal_vector<int> seq_lens(kBatchSize);
+    for (auto& x : seq_lens) {
+        x = kContextLen + 1;
+    }
+
     if (1) {
-        // ConvertBlocksToLinear(
-        //     (const half**)k_ptrs.data().get(), k_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
-        // ConvertBlocksToLinear(
-        //     (const half**)v_ptrs.data().get(), v_cache.data().get(), kBlockSz, kHeadNum, kHeadDim, kSequenceLen, 0);
+        ConvertBlocksToLinear((const half**)k_ptrs.data().get(),
+                              k_cache.data().get(),
+                              cu_block_cnts.data().get(),
+                              seq_lens.data().get(),
+                              0,
+                              kBlockSz,
+                              kSequenceLen,
+                              kHeadNum,
+                              kHeadDim,
+                              kBatchSize,
+                              0);
+        ConvertBlocksToLinear((const half**)v_ptrs.data().get(),
+                              v_cache.data().get(),
+                              cu_block_cnts.data().get(),
+                              seq_lens.data().get(),
+                              0,
+                              kBlockSz,
+                              kSequenceLen,
+                              kHeadNum,
+                              kHeadDim,
+                              kBatchSize,
+                              0);
     }
 
     cudaDeviceSynchronize();
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.cu b/src/turbomind/kernels/decoder_mha/test_utils.cu
index 46f582bc0a..883f0fc3d0 100644
--- a/src/turbomind/kernels/decoder_mha/test_utils.cu
+++ b/src/turbomind/kernels/decoder_mha/test_utils.cu
@@ -202,11 +202,11 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
     params.v = reinterpret_cast<const DataType*>(p.v);
 
     params.stride   = p.stride;
-    params.finished = p.finished;
+    params.finished = (bool*)p.finished;
 
     params.k_cache_per_sample         = reinterpret_cast<DataType**>(p.per_sample_k_cache);
     params.v_cache_per_sample         = reinterpret_cast<DataType**>(p.per_sample_v_cache);
-    params.kv_cache_per_sample_offset = p.per_sample_kv_cache_offset;
+    params.kv_cache_per_sample_offset = p.layer_offset;
     params.batch_size                 = p.batch_size;
     params.beam_width                 = 1;
     params.memory_max_len             = p.max_seq_len;
@@ -215,7 +215,7 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
     params.length_per_sample          = p.per_sample_length;  // max_input_length + current output length
 
     for (int i = 0; i < p.batch_size; ++i) {
-        params.timestep = std::max(params.timestep, p.cu_ctxlens[i + 1] - p.cu_ctxlens[i]);
+        params.timestep = std::max(p.per_sample_length[i], params.timestep);
     }
 
     std::cout << "timestep = " << params.timestep << "\n";
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index e87e868509..bda8cb98cf 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -152,31 +152,42 @@ void BlockManager::Evict(int count)
 
 int BlockManager::Release(const std::vector<const Block*>& bs)
 {
-    std::vector<int> cached;
+    std::vector<int> idxs;
 
     for (const auto& p : bs) {
         auto& block = blocks_[p->id];
         FT_CHECK(is_active(block));
         if (--block.ref_count == 0) {
-            cached.push_back(block.id);
+            idxs.push_back(block.id);
         }
     }
 
-    std::sort(cached.begin(), cached.end());
+    std::sort(idxs.begin(), idxs.end());
 
-    Move(active_ids_, cached, cached_ids_);
+    Move(active_ids_, idxs, cached_ids_);
 
     dbg("[Release]", cached_ids_);
 
-    return cached.size();
+    return idxs.size();
 }
 
 void BlockManager::Retain(const std::vector<const Block*>& bs)
 {
+    std::vector<int> idxs;
+
     for (const auto& p : bs) {
-        FT_CHECK(is_active(*p));
-        ++const_cast<Block*>(p)->ref_count;
+        auto& block = blocks_[p->id];
+        FT_CHECK(is_cached(block));
+        if (++block.ref_count == 1) {
+            idxs.push_back(p->id);
+        }
     }
+
+    std::sort(idxs.begin(), idxs.end());
+
+    Move(cached_ids_, idxs, active_ids_);
+
+    dbg("[Retain]", active_ids_);
 }
 
 void BlockManager::Touch(const std::vector<const Block*>& bs)
@@ -214,8 +225,8 @@ std::ostream& operator<<(std::ostream& os, const BlockManager& manager)
 
 std::ostream& operator<<(std::ostream& os, const Block& block)
 {
-    os << "Block[id=" << block.id << ",ref_count=" << block.ref_count << ",unique_id=" << block.unique_id
-       << ",timestamp=" << block.timestamp << ",data=" << block.data << "]";
+    os << "id=" << block.id << ", ref_count=" << block.ref_count << ", unique_id=" << block.unique_id
+       << ", timestamp=" << block.timestamp << ", data=" << block.data;
     return os;
 }
 
diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h
index a6f7dc50a5..dc14b32a15 100644
--- a/src/turbomind/models/llama/BlockManager.h
+++ b/src/turbomind/models/llama/BlockManager.h
@@ -61,15 +61,17 @@ class BlockManager {
     // free -> active
     [[nodiscard]] std::vector<const Block*> Allocate(int count);
 
+    // decrease ref count
     // active -> cached
     [[maybe_unused]] int Release(const std::vector<const Block*>& bs);
 
+    // increase ref count
+    // cached -> active
+    void Retain(const std::vector<const Block*>& bs);
+
     // cached -> free
     void Evict(int count);
 
-    // active -> active
-    void Retain(const std::vector<const Block*>& bs);
-
     // increase timestamp in reversed order
     void Touch(const std::vector<const Block*>& bs);
 
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 77b9e26f65..35e89235a9 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -10,6 +10,7 @@
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/dbg.h"
 #include "src/turbomind/utils/logger.h"
 #include <algorithm>
 #include <cstdint>
@@ -158,12 +159,12 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
     for (const auto& r : requests) {
 
         // sanity check, incoming request in previous iter should have been moved to `state_`
-        FT_CHECK(state.sequences[i] == nullptr);
+        FT_CHECK(!state.requests[i]);
 
         state.requests[i] = r;
 
         // get sequence for the request
-        state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Fetch(r->id);
+        state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
 
         auto& seq = *state.sequences[i];
 
@@ -256,21 +257,33 @@ bool LlamaBatch<T>::Initialize()
         }
     }
 
-    auto add = [&](BatchState* state) {
+    auto process = [&](BatchState* state) {
+        // dbg(state->size);
         for (int i = 0; i < state->size; ++i) {
             if (auto& r = state->requests[i]) {
                 sequences.push_back(state->sequences[i]);
                 status.push_back(state->sequences[i]->status);
                 priorities.push_back(r->priority);
+                context_lengths.push_back(state->h_context_length[i]);
                 coords.emplace_back(state, i);
+                // clear swap-in flags
+                state->is_swap_in[i] = 0;
             }
         }
     };
 
-    add(state_);
-    add(incoming_);
+    process(state_);
+    process(incoming_);
 
-    auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, llama_->step_length_);
+    // dbg(sequences);
+    // dbg(context_lengths);
+    // dbg(priorities);
+    // dbg(step_length_);
+
+    auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);
+    if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
+        dbg(outcome);
+    }
 
     bool exchange = outcome.swap_in + outcome.swap_out > 0;
 
@@ -304,13 +317,13 @@ bool LlamaBatch<T>::Initialize()
         for (const auto& i : idxs) {
             auto& s = *sequences[i];
             if (exchange) {
+                const auto& [state, idx] = coords[i];
                 // backup random states from dynamic decode layers for swap-outs
                 if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
-                    SaveRandomState(*coords[i].first, coords[i].second);
+                    SaveRandomState(*state, idx);
                 }
-                // restore random states to dynamic decode layers for swap-ins
                 if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
-                    LoadRandomState(*coords[i].first, coords[i].second);
+                    state->is_swap_in[idx] = 1;
                 }
             }
             if (s.status == Sequence::kActive) {
@@ -347,9 +360,22 @@ bool LlamaBatch<T>::Initialize()
 
         Copy(state_->h_context_length, batch_size, context_length_buf_);
 
+        dbg(std::vector(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1));
+        dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size]));
+        dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size]));
+        dbg(h_cu_block_counts_[batch_size]);
+
         Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
         Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
         Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
+
+        static_assert(sizeof(uintptr_t) == sizeof(void*));
+
+        std::vector<void*> fuck(h_cu_block_counts_[batch_size]);
+        Copy((void**)k_block_ptrs_, fuck.size(), fuck.data());
+        cudaStreamSynchronize(stream_);
+
+        dbg(fuck);
     }
 
     // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
@@ -370,6 +396,7 @@ void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std:
     dst->h_finished[j]       = src->h_finished[i];
     dst->seq_len_limit[j]    = src->seq_len_limit[i];
     dst->sequences[j]        = src->sequences[i];
+    dst->is_swap_in[i]       = src->is_swap_in[i];
     dst->requests[j]         = std::move(src->requests[i]);
 
     Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_);
@@ -388,6 +415,7 @@ void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx)
 template<typename T>
 void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
 {
+    dbg(idx);
     Copy((curandState_t*)state.top_k_curand_state + idx, 1, llama_->GetTopKState(idx));
     Copy((curandState_t*)state.top_p_curand_state + idx, 1, llama_->GetTopPState(idx));
 }
@@ -402,7 +430,8 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
     const size_t vocab_size        = llama_->vocab_size_padded_;
     const size_t head_dim          = llama_->size_per_head_;
     const size_t local_kv_head_num = llama_->local_kv_head_num_;
-    const size_t max_block_count   = sequence_manager_->max_block_count();
+    // +1 padding, BlockIterator does not use predicate
+    const size_t max_block_count = sequence_manager_->max_block_count() + 1;
 
     context_decoder_input_buf_ =
         (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
@@ -484,6 +513,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
             (int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
         h_history_length_buf_ =
             (int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true);
+        h_sequence_lengths_ =
+            (int*)allocator_->reMalloc(h_sequence_lengths_, sizeof(int) * max_batch_size, false, true);
 
         h_tmp_k_ptrs_ = (void**)allocator_->reMalloc(h_tmp_k_ptrs_, sizeof(void*) * max_batch_size, false, true);
         h_tmp_v_ptrs_ = (void**)allocator_->reMalloc(h_tmp_v_ptrs_, sizeof(void*) * max_batch_size, false, true);
@@ -569,6 +600,7 @@ void LlamaBatch<T>::FreeBuffer()
         allocator_->free((void**)&h_input_ids_buf_, true);
         allocator_->free((void**)&h_input_length_buf_, true);
         allocator_->free((void**)&h_history_length_buf_, true);
+        allocator_->free((void**)&h_sequence_lengths_, true);
         allocator_->free((void**)&h_seq_limit_len_, true);
         is_allocate_persistant_buffer_ = false;
     }
@@ -598,6 +630,7 @@ LlamaBatch<T>::LlamaBatch(int                              max_batch_size,
         s.requests.resize(max_batch_size);
         s.sequences.resize(max_batch_size);
         s.seq_len_limit.resize(max_batch_size);
+        s.is_swap_in.resize(max_batch_size);
     }
 
     state_    = &states_[0];
@@ -650,7 +683,7 @@ void LlamaBatch<T>::InitializeSampling()
 
     // recover random states if not a new request
     for (int i = 0; i < batch_size; ++i) {
-        if (!state_->requests[i]->start_flag) {
+        if (!state_->requests[i]->start_flag && state_->is_swap_in[i]) {
             LoadRandomState(*state_, i);
         }
     }
@@ -693,7 +726,8 @@ void LlamaBatch<T>::InitializeGeneration()
     invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
     sync_check_cuda_error();
 
-    // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for
+    // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted
+    // for
     for (int i = 0; i < batch_size; ++i) {
         h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len_ - state_->h_context_length[i]);
         // mask finished sequences
@@ -818,24 +852,24 @@ void LlamaBatch<T>::ContextDecode()
 
     int base = -1;
     for (int i = 0; i < batch_size; ++i) {
-        if (h_input_length_buf_[i] > 1) {
-            base = i;
-            break;
+        if (state_->is_swap_in[i]) {
+            const auto& seq = *state_->sequences[i];
+            dbg(state_->h_context_length[i], seq.cache_len);
+            if (const int missing = state_->h_context_length[i] - seq.cache_len; missing > 1) {
+                base = base < 0 ? i : base;
+                Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
+                // subtract input/context len by 1 to skip last input token (will process with decoder later)
+                h_input_length_buf_[i]   = missing - 1;
+                h_history_length_buf_[i] = seq.cache_len;
+            }
         }
     }
-    if (base == -1) {
+    if (base < 0) {
         TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
         return;
     }
 
-    for (int i = base; i < batch_size; ++i) {
-        const auto& seq     = *state_->sequences[i];
-        const int   missing = state_->h_context_length[i] - seq.cache_len;
-        FT_CHECK(missing > 1);
-        Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
-        h_input_length_buf_[i]   = missing;
-        h_history_length_buf_[i] = seq.cache_len;
-    }
+    const int context_decode_count = batch_size - base;
 
     Copy(h_input_length_buf_, batch_size, input_length_buf_);
     Copy(h_history_length_buf_, batch_size, history_length_buf_);
@@ -843,32 +877,39 @@ void LlamaBatch<T>::ContextDecode()
     check_cuda_error(cudaStreamSynchronize(stream_));
     const auto tick = std::chrono::high_resolution_clock::now();
 
-    const int context_decode_count = batch_size - base;
     if (rank_ == 0) {
         TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
     }
     // subtract input/context len by 1 to skip last input token (will process with decoder later)
-    invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_);
     invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
 
     // find sub-batch offsets
     std::vector<int> offsets{base};
-    int              accum_input_count   = 0;
-    int              accum_context_count = 0;
+    std::vector<int> max_context_cnts;
+    int              accum_size        = 0;
+    int              accum_input_count = 0;
+    int              max_context_count = 0;
     for (int i = base; i < batch_size; ++i) {
-        int input_count   = accum_input_count + h_input_length_buf_[i] - 1;
-        int context_count = accum_context_count + state_->h_context_length[i] - 1;
-        if (input_count <= max_context_token_num_ && context_count <= max_context_token_num_) {
-            accum_input_count   = input_count;
-            accum_context_count = context_count;
+        int size          = accum_size + 1;
+        int input_count   = accum_input_count + h_input_length_buf_[i];
+        int context_count = std::max(max_context_count, state_->h_context_length[i] - 1);
+        // we have `cu_seqlens` on q so no padding for input is needed
+        // kernels are expecting uniform k/v cache length -> `max_context_count * size <= max_context_token_num_`
+        if (input_count <= max_context_token_num_ && context_count * size <= max_context_token_num_) {
+            accum_size        = size;
+            accum_input_count = input_count;
+            max_context_count = context_count;
         }
         else {
             offsets.push_back(i);
-            accum_input_count   = 0;
-            accum_context_count = 0;
+            max_context_cnts.push_back(max_context_count);
+            accum_size        = 0;
+            accum_input_count = 0;
+            max_context_count = 0;
         }
     }
     offsets.push_back(batch_size);
+    max_context_cnts.push_back(max_context_count);
 
     // context decode on sub-batches
     for (int k = 0; k < offsets.size() - 1; ++k) {
@@ -880,20 +921,19 @@ void LlamaBatch<T>::ContextDecode()
         std::vector<int> decode_indices{};
         std::vector<int> decode_lengths{};
         int              max_input_len{};
-        int              max_context_len{};
         auto             input_ids = context_decoder_ids_buf_;
         for (int i = first; i < last; ++i) {
-            input_ids        = Copy(input_ids_buf_ + i * session_len_, h_input_ids_buf_[i] - 1, input_ids);
+            input_ids        = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
             h_tmp_k_ptrs_[i] = k_ptr;
             h_tmp_v_ptrs_[i] = v_ptr;
-            k_ptr += (state_->h_context_length[i] - 1) * llama_->local_kv_head_num_ * llama_->size_per_head_;
-            v_ptr += (state_->h_context_length[i] - 1) * llama_->local_kv_head_num_ * llama_->size_per_head_;
+            k_ptr += llama_->local_kv_head_num_ * max_context_cnts[k] * llama_->size_per_head_;
+            v_ptr += llama_->local_kv_head_num_ * max_context_cnts[k] * llama_->size_per_head_;
             decode_indices.push_back(i);
-            decode_lengths.push_back(h_input_length_buf_[i] - 1);
-            max_input_len   = std::max(max_input_len, h_input_length_buf_[i] - 1);
-            max_context_len = std::max(max_context_len, state_->h_context_length[i] - 1);
+            decode_lengths.push_back(h_input_length_buf_[i]);
+            max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
         }
         int token_count = input_ids - context_decoder_ids_buf_;
+        dbg(token_count, max_input_len, max_context_cnts[k]);
 
         Copy(h_tmp_k_ptrs_ + first, sub_batch_szie, tmp_k_ptrs_ + first);
         Copy(h_tmp_v_ptrs_ + first, sub_batch_szie, tmp_v_ptrs_ + first);
@@ -905,7 +945,19 @@ void LlamaBatch<T>::ContextDecode()
                 sub_batch_szie,
                 token_count,
                 max_input_len,
-                max_context_len);
+                max_context_cnts[k]);
+        }
+
+        dbg(first, last);
+        dbg(k_block_ptrs_, v_block_ptrs_);
+
+        if (1) {
+            int a, b, c;
+            Copy(input_length_buf_, 1, &a);
+            Copy(history_length_buf_, 1, &b);
+            Copy(context_length_buf_, 1, &c);
+            cudaStreamSynchronize(stream_);
+            dbg(a, b, c);
         }
 
         llama_->contextDecode(nullptr,
@@ -922,8 +974,8 @@ void LlamaBatch<T>::ContextDecode()
                               cu_block_counts_ + first,
                               token_count,
                               max_input_len,
-                              max_context_len,
-                              session_len_,
+                              max_context_cnts[k],
+                              max_context_cnts[k],
                               sub_batch_szie);
 
         // compute logits of inputs if requested
@@ -931,7 +983,6 @@ void LlamaBatch<T>::ContextDecode()
     }
 
     invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
-    invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_);
 
     std::fill(h_input_length_buf_ + base, h_input_length_buf_ + batch_size, 0);
 
@@ -996,7 +1047,8 @@ auto LlamaBatch<T>::Finish() -> std::vector<Signal>
 
     // secure info needed by `synchronize()`
     Copy(finished_buf_, batch_size, state_->h_finished);
-    Copy(sequence_lengths_, batch_size, h_sequence_lengths_);
+    Copy(sequence_lengths_, batch_size, state_->h_context_length);
+    Copy(sequence_lengths_, batch_size, context_length_buf_);
 
     SetOutputTensors(step_);
 
@@ -1104,7 +1156,7 @@ void LlamaBatch<T>::FinishRequest(int index, bool force_end)
 
         check_cuda_error(cudaStreamSynchronize(stream_));
 
-        sequence_manager_->Update(seq);
+        sequence_manager_->Release(seq);
     }
 
     state_->sequences[index] = nullptr;
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 7cc430eb9d..1faa7004e2 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -21,6 +21,7 @@ struct BatchState {
     int*  output_ids;  // output ids in [B, S]
 
     std::vector<int> seq_len_limit;
+    std::vector<int> is_swap_in;
 
     std::vector<const Sequence*>          sequences;
     std::vector<std::shared_ptr<Request>> requests;
@@ -101,7 +102,7 @@ class LlamaBatch {
     template<typename U>
     U* Copy(const U* src, size_t count, U* dst)
     {
-        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(T) * count, cudaMemcpyDefault, stream_));
+        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(U) * count, cudaMemcpyDefault, stream_));
         return dst += count;
     }
 
@@ -177,7 +178,7 @@ class LlamaBatch {
     float*    h_repetition_penalty_{};
     uint64_t* h_random_seed_{};
 
-    BatchState states_[3];
+    std::array<BatchState, 3> states_{};
 
     BatchState* state_{};
     BatchState* back_{};
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 8fc54b9922..2505669dec 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -29,6 +29,7 @@
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/dbg.h"
 #include "src/turbomind/utils/logger.h"
 
 namespace turbomind {
@@ -151,6 +152,14 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
 
     const auto padding_offset = input_tensors->at("padding_offset").getPtr<int>();
 
+    auto Show = [&](const T* x, size_t n) {
+        std::vector<T> vec(n);
+        cudaMemcpyAsync(vec.data(), x, sizeof(T) * n, cudaMemcpyDefault, stream_);
+        cudaStreamSynchronize(stream_);
+        std::vector<float> float_vec(vec.begin(), vec.end());
+        dbg(float_vec);
+    };
+
     /////////////////////////////////////////////
     /// allocate buffers
     allocateBuffer(batch_size, num_token, max_q_len, max_k_len);
@@ -184,6 +193,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                    stream_);
     sync_check_cuda_error();
 
+    // [2, L, H, s, D]
     const size_t layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
 
     auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
@@ -219,9 +229,10 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
     ConvertKvCacheBlocksToLinear((const T**)k_cache_ptrs,
                                  (const T**)v_cache_ptrs,
                                  tmp_k_ptrs,
-                                 tmp_k_ptrs,
+                                 tmp_v_ptrs,
                                  cu_block_counts,
-                                 history_length,
+                                 context_length,
+                                 layer_offset,
                                  kv_cache_block_len_,
                                  max_seq_len,
                                  local_kv_head_num_,
@@ -230,21 +241,22 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                  stream_);
     sync_check_cuda_error();
 
+    // dbg(kv_cache_block_len_, max_seq_len, local_kv_head_num_, size_per_head_, batch_size);
+    // void *kk, *vv;
+    // cudaMemcpyAsync(&kk, tmp_k_ptrs, sizeof(void*), cudaMemcpyDefault, stream_);
+    // cudaMemcpyAsync(&vv, tmp_v_ptrs, sizeof(void*), cudaMemcpyDefault, stream_);
+    // cudaStreamSynchronize(stream_);
+    // Show((const T*)kk, local_kv_head_num_ * max_seq_len * size_per_head_);
+    // Show((const T*)vv, local_kv_head_num_ * max_seq_len * size_per_head_);
+
     if (use_fmha_) {
-        fusedMultiHeadAttention(tmp_k_ptrs,
-                                tmp_k_ptrs,
-                                layer_offset,
-                                attention_mask,
-                                cu_seqlens,
-                                batch_size,
-                                max_q_len,
-                                max_k_len,
-                                max_seq_len);
+        fusedMultiHeadAttention(
+            tmp_k_ptrs, tmp_v_ptrs, 0, attention_mask, cu_seqlens, batch_size, max_q_len, max_k_len, max_seq_len);
     }
     else {
         unfusedMultiHeadAttention(tmp_k_ptrs,
-                                  tmp_k_ptrs,
-                                  layer_offset,
+                                  tmp_v_ptrs,
+                                  0,
                                   attention_mask,
                                   padding_offset,
                                   context_length,
@@ -257,6 +269,14 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                   weights->past_kv_scale.data());
     }
 
+    Compare(qkv_buf_3_, num_token * hidden_units_, Concat("qkv_buf_3", layer_id), kCmpRead, stream_);
+
+    // dbg(max_seq_len);
+
+    if (0) {
+        Show(qkv_buf_3_, num_token * hidden_units_);
+    }
+
     //////////////////////////////////////////////
     /// output gemm <Bs,HD> -> <Bs,HD>
     linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output);
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index def55f41a8..e763082733 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -25,7 +25,9 @@
 #include "src/turbomind/models/llama/LlamaContextDecoder.h"
 #include "src/turbomind/models/llama/llama_decoder_kernels.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
+#include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/dbg.h"
 
 namespace turbomind {
 
@@ -93,6 +95,7 @@ void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
 template<typename T>
 void LlamaContextDecoder<T>::forwardSelfAttn(const Session&                                 sess,
                                              T*                                             attn_io,
+                                             std::unordered_map<std::string, Tensor>*       output_tensors,
                                              const std::unordered_map<std::string, Tensor>* input_tensors,
                                              int                                            layer,
                                              bool                                           is_final)
@@ -112,14 +115,15 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
         {"cu_block_counts", input_tensors->at("cu_block_counts")},
         {"max_seq_len", input_tensors->at("max_seq_len")}};
 
-    auto& k_cache = *sess.k_cache;
-    auto& v_cache = *sess.v_cache;
+    // auto& k_cache = *sess.k_cache;
+    // auto& v_cache = *sess.v_cache;
 
     TensorMap self_attention_output_tensors{
         {"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
-        {"key_cache", k_cache},
-        {"value_cache", v_cache},
-    };
+        {"key_cache", output_tensors->at("key_cache")},
+        {"value_cache", output_tensors->at("value_cache")},
+        {"tmp_k", output_tensors->at("tmp_k")},
+        {"tmp_v", output_tensors->at("tmp_v")}};
 
     context_attention_layer_->forward(&self_attention_output_tensors,  //
                                       &self_attention_input_tensors,
@@ -208,11 +212,11 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
     T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>();
     T* decoder_output       = output_tensors->at("decoder_output").getPtr<T>();
 
-    sess.k_cache = &output_tensors->at("key_cache");
-    sess.v_cache = &output_tensors->at("value_cache");
-
     allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len);
 
+    FT_CHECK(padding_offset_);
+    dbg(padding_offset_);
+
     size_t tmp_token_num{};
     invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_,
                                        &tmp_token_num,  // updated token num
@@ -223,6 +227,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
                                        sess.max_query_len,
                                        stream_);
     sync_check_cuda_error();
+    dbg(tmp_token_num, sess.token_num);
     FT_CHECK(tmp_token_num == sess.token_num);
 
     invokeCreateCausalMasks(attention_mask_,
@@ -234,6 +239,9 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
                             stream_);
     sync_check_cuda_error();
 
+    Compare(
+        decoder_input_output, sess.token_num * hidden_units_, Concat("context_decoder_input", 0), kCmpRead, stream_);
+
     /////////////////////////////////////////////
     /// RMSNorm
     invokeRootMeanSquareNorm(decoder_output,
@@ -248,7 +256,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
     for (size_t layer = 0; layer < num_layer_; ++layer) {
         /////////////////////////////////////////////
         /// self-attention
-        forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);
+        forwardSelfAttn(sess, decoder_output, output_tensors, input_tensors, layer, false);
 
         invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
                                           decoder_output,
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.h b/src/turbomind/models/llama/LlamaContextDecoder.h
index da6264176f..4f3613c38c 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.h
+++ b/src/turbomind/models/llama/LlamaContextDecoder.h
@@ -63,21 +63,20 @@ class LlamaContextDecoder: public BaseLayer {
     const DataType data_type_;
 
     struct Session {
-        size_t  batch_size;
-        size_t  token_num;
-        size_t  max_query_len;
-        size_t  max_key_len;
-        Tensor* k_cache;
-        Tensor* v_cache;
-        int*    input_length{};
-        int*    history_length{};
-        int*    context_length{};
+        size_t batch_size;
+        size_t token_num;
+        size_t max_query_len;
+        size_t max_key_len;
+        int*   input_length{};
+        int*   history_length{};
+        int*   context_length{};
 
         const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
     };
 
     void forwardSelfAttn(const Session&                                 sess,
                          T*                                             attn_io,
+                         std::unordered_map<std::string, Tensor>*       output_tensors,
                          const std::unordered_map<std::string, Tensor>* input_tensors,
                          int                                            layer,
                          bool                                           is_final);
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index 88dd76b935..73e95b1353 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -124,7 +124,6 @@ void LlamaDecoder<T>::forwardSelfAttn(const LlamaDecoder::Session&
                                         {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io});
     const int layer_id = layer;
     self_attention_input_tensors.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id});
-    self_attention_input_tensors.insert("cu_block_counts", input_tensors->at("cu_block_counts"));
     auto& k_cache = *sess.k_cache;
     auto& v_cache = *sess.v_cache;
 
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index eec9a7fbd4..9ad3908ea6 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -19,6 +19,7 @@
 // https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.cc
 #include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
 #include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
+#include "src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h"
 #include "src/turbomind/macro.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
@@ -211,17 +212,19 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
      *    \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
      */
 
-    const T*    input_query_data      = input_tensors->getPtr<T>("input_query");
-    const int*  sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
-    const int*  total_padding_len     = input_tensors->getPtr<int>("total_padding_tokens");
-    const bool* finished_data         = input_tensors->getPtr<bool>("finished", nullptr);
-    const bool* masked_tokens_data    = input_tensors->getPtr<bool>("masked_tokens", nullptr);
-    const int*  cache_indir           = input_tensors->getPtr<int>("cache_indirection", nullptr);
+    const T*   input_query_data      = input_tensors->getPtr<T>("input_query");
+    const int* sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
+    // const int*  total_padding_len     = input_tensors->getPtr<int>("total_padding_tokens");
+    const bool* finished_data      = input_tensors->getPtr<bool>("finished", nullptr);
+    const bool* masked_tokens_data = input_tensors->getPtr<bool>("masked_tokens", nullptr);
+    const int*  cache_indir        = input_tensors->getPtr<int>("cache_indirection", nullptr);
 
     T*  hidden_features_data = output_tensors->getPtr<T>("attention_output");
     T** key_cache_ptrs       = output_tensors->getPtr<T*>("key_cache");
     T** value_cache_ptrs     = output_tensors->getPtr<T*>("value_cache");
 
+    int* cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
+
     const int layer_id = input_tensors->getVal<int>("layer_id");
 
     const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
@@ -238,51 +241,78 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
     POP_RANGE;
 
-    const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
-    const int  memory_len            = max_seq_len;
-
-    fusedQKV_masked_attention_dispatch<T>(
-        qkv_buf_,
-        weights->qkv.bias,  // query_weight.bias,
-        nullptr,            // relative_attention_bias,
-        nullptr,
-        nullptr,
-        key_cache_ptrs,
-        value_cache_ptrs,
-        kv_cache_layer_offset,
-        cache_indir,
-        context_buf_,
-        finished_data,
-        sequence_lengths_data,  // NOTE: current seq len including padding (fixed after meeting the finished id)
-        batch_size,
-        batch_size,
-        beam_width,
-        local_head_num_,
-        local_kv_head_num_,
-        size_per_head_,
-        params_.rotray_embedding_dim,
-        params_.max_position_embeddings,
-        params_.use_dynamic_ntk,
-        params_.use_logn_attn,
-        memory_len,
-        nullptr,  // prefix_prompt_lengths
-        0,        // max_prefix_prompt_length
-        0,        // max_input_length, not used w/o linear_bias_slopes
-        input_tensors->getPtr<int>("total_padding_tokens", nullptr),
-        step,
-        1.f,                            // q_scaling
-        0,                              // relative_attention_bias_stride
-        nullptr,                        // linear_bias_slopes
-        nullptr,                        //  masked_tokens_data,
-        nullptr,                        // ia3_tasks
-        nullptr,                        // ia3_key_weights
-        nullptr,                        // ia3_value_weights
-        nullptr,                        // qkv_scale_out
-        nullptr,                        // attention_out_scale
-        quant_policy_,                  // int8_mode
-        weights->past_kv_scale.data(),  // attention kv scale
-        stream_);
-    sync_check_cuda_error();
+    const auto layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
+    // const int  memory_len   = max_seq_len;
+
+    DecoderMultiHeadAttentionParams<T> params{};
+
+    params.out    = context_buf_;
+    params.q      = qkv_buf_;
+    params.k      = params.q + local_head_num_ * size_per_head_;
+    params.v      = params.k + local_kv_head_num_ * size_per_head_;
+    params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
+
+    params.batch_size    = batch_size;
+    params.cu_block_cnts = cu_block_counts;  /// TODO
+
+    params.k_cache_block_ptrs  = (void**)key_cache_ptrs;
+    params.v_cache_block_ptrs  = (void**)value_cache_ptrs;
+    params.kv_cache_block_size = kv_cache_block_len_;
+
+    params.finished          = finished_data;
+    params.per_sample_length = sequence_lengths_data;
+
+    params.layer_offset = layer_offset;
+
+    params.num_heads     = local_head_num_;
+    params.num_kv_heads  = local_kv_head_num_;
+    params.size_per_head = size_per_head_;
+    params.inv_sqrt_dh   = 1.f / std::sqrt((float)params.size_per_head);
+
+    params.rotary_embedding_dim  = size_per_head_;
+    params.rotary_embedding_base = 10000.f;
+
+    LaunchDecoderMultiheadAttention<T, 128>(params);
+
+    // fusedQKV_masked_attention_dispatch<T>(
+    //     qkv_buf_,
+    //     weights->qkv.bias,  // query_weight.bias,
+    //     nullptr,            // relative_attention_bias,
+    //     nullptr,
+    //     nullptr,
+    //     key_cache_ptrs,
+    //     value_cache_ptrs,
+    //     kv_cache_layer_offset,
+    //     cache_indir,
+    //     context_buf_,
+    //     finished_data,
+    //     sequence_lengths_data,  // NOTE: current seq len including padding (fixed after
+    //     meeting the finished id) batch_size, batch_size, beam_width, local_head_num_,
+    //     local_kv_head_num_,
+    //     size_per_head_,
+    //     params_.rotray_embedding_dim,
+    //     params_.max_position_embeddings,
+    //     params_.use_dynamic_ntk,
+    //     params_.use_logn_attn,
+    //     memory_len,
+    //     nullptr,  // prefix_prompt_lengths
+    //     0,        // max_prefix_prompt_length
+    //     0,        // max_input_length, not used w/o linear_bias_slopes
+    //     input_tensors->getPtr<int>("total_padding_tokens", nullptr),
+    //     step,
+    //     1.f,                            // q_scaling
+    //     0,                              // relative_attention_bias_stride
+    //     nullptr,                        // linear_bias_slopes
+    //     nullptr,                        //  masked_tokens_data,
+    //     nullptr,                        // ia3_tasks
+    //     nullptr,                        // ia3_key_weights
+    //     nullptr,                        // ia3_value_weights
+    //     nullptr,                        // qkv_scale_out
+    //     nullptr,                        // attention_out_scale
+    //     quant_policy_,                  // int8_mode
+    //     weights->past_kv_scale.data(),  // attention kv scale
+    //     stream_);
+    // sync_check_cuda_error();
 
     linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
 
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
index 89afe3f964..ac1c02caac 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
@@ -56,6 +56,7 @@ class LlamaDecoderSelfAttentionLayer {
         stream_(stream),
         linear_(cublas_wrapper, stream),
         allocator_(allocator),
+        kv_cache_block_len_(128),  ///
         is_free_buffer_after_forward_(is_free_buffer_after_forward),
         quant_policy_(quant_policy)
     {
@@ -76,6 +77,7 @@ class LlamaDecoderSelfAttentionLayer {
     const size_t local_head_num_;
     const size_t local_kv_head_num_;
     const size_t local_hidden_units_;
+    const size_t kv_cache_block_len_;
     const bool   is_free_buffer_after_forward_;
     const int    quant_policy_;
 
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index cc8f64f7fa..3532bf4216 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -112,15 +112,6 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
 
     const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
 
-    // kv_cache_mgr_     = std::make_unique<LlamaCacheManager>(num_layer_,
-    //                                                     local_kv_head_num,
-    //                                                     size_per_head_,
-    //                                                     session_len,
-    //                                                     elem_bits,
-    //                                                     cache_max_entry_count,
-    //                                                     cache_chunk_size,
-    //                                                     tensor_para.rank_,
-    //                                                     allocator);
     auto sequence_manager = std::make_unique<SequenceManager>(num_layer,
                                                               local_kv_head_num,
                                                               size_per_head_,
@@ -142,7 +133,6 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
 template<typename T>
 LlamaV2<T>::~LlamaV2()
 {
-
     delete decoder_;
     delete dynamic_decode_layer_;
     delete context_decoder_;
@@ -276,8 +266,8 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
         {"decoder_output", {MEMORY_GPU, dtype, {bsz, max_input_len, hidden_units_}, context_decoder_output_buf}},
         {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}},
         {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}},
-        {"tmp_k_ptrs", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_k_ptrs}},
-        {"tmp_v_ptrs", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_v_ptrs}},
+        {"tmp_k", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_k_ptrs}},
+        {"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_v_ptrs}},
         {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}};
 
     context_decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc
index 511cbe5bbf..e1287f471b 100644
--- a/src/turbomind/models/llama/LlamaWeight.cc
+++ b/src/turbomind/models/llama/LlamaWeight.cc
@@ -72,6 +72,10 @@ LlamaWeight<T>::~LlamaWeight()
 
     pre_decoder_embedding_table   = nullptr;
     post_decoder_embedding_kernel = nullptr;
+
+    for (auto& p : decoder_layer_weights) {
+        delete p;
+    }
 }
 
 template<typename T>
@@ -95,8 +99,10 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
 
     loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);
 
-    loadWeightFromBin(
-        (T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type);
+    loadWeightFromBin((T*)post_decoder_embedding_kernel,
+                      {hidden_units_ * vocab_size_padded_},
+                      dir_path + "output.weight",
+                      model_file_type);
 
     for (unsigned layer = 0; layer < num_layer_; ++layer) {
         decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
diff --git a/src/turbomind/models/llama/Request.h b/src/turbomind/models/llama/Request.h
index 46badf98a5..9e5e5dbea9 100644
--- a/src/turbomind/models/llama/Request.h
+++ b/src/turbomind/models/llama/Request.h
@@ -89,8 +89,11 @@ class RequestQueue {
 
     void Abort()
     {
-        std::lock_guard<std::mutex> lock(mutex_);
-        abort_ = true;
+        {
+            std::lock_guard<std::mutex> lock(mutex_);
+            abort_ = true;
+        }
+        cv_.notify_all();
     }
 
 private:
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 808dcfd754..34aa240524 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -16,7 +16,7 @@ SequenceManager::SequenceManager(size_t      layer_num,
                                  size_t      elem_bits,
                                  int         rank,
                                  IAllocator* allocator):
-    block_len_(block_len), rank_(rank)
+    block_len_(block_len)
 {
     constexpr int kBitsPerByte = 8;
 
@@ -37,8 +37,11 @@ const Sequence* SequenceManager::Create(uint64_t id)
         if (rank_ == 0) {
             TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id);
         }
-        block_manager_->Release(it->second.blocks);
-        it->second = std::move(sequence);
+        auto& seq = it->second;
+        if (seq.status != Sequence::kCached) {
+            released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
+        }
+        seq = std::move(sequence);
     }
     else {
         it = sequences_.emplace_hint(it, id, std::move(sequence));
@@ -47,58 +50,60 @@ const Sequence* SequenceManager::Create(uint64_t id)
     return &it->second;
 }
 
-void SequenceManager::VerifyBlocks(Sequence& seq)
-{
-    FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
-    for (int i = 0; i < seq.blocks.size(); ++i) {
-        if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
-            seq.blocks.resize(i);
-            seq.block_unique_ids.resize(i);
-            break;
-        }
-    }
-    seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_len_);
-}
-
-const Sequence* SequenceManager::Fetch(uint64_t id)
+const Sequence* SequenceManager::Get(uint64_t id)
 {
     if (auto it = sequences_.find(id); it != sequences_.end()) {
         auto& sequence = it->second;
         return &it->second;
     }
-
     return nullptr;
 }
 
+bool SequenceManager::Contains(uint64_t id)
+{
+    return sequences_.find(id) != sequences_.end();
+}
+
 bool SequenceManager::Erase(uint64_t id)
 {
     if (auto it = sequences_.find(id); it != sequences_.end()) {
         auto& seq = it->second;
         if (seq.status != Sequence::kCached) {
-            if (released_.empty()) {
-                released_ = std::move(seq.blocks);
-            }
-            else {
-                released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
-            }
+            released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
         }
         sequences_.erase(it);
     }
     else {
         throw std::out_of_range(std::to_string(id));
     }
-
     return false;
 }
 
-void SequenceManager::Update(const Sequence& sequence)
+void SequenceManager::Verify(Sequence& seq, std::vector<const Block*>& retain)
 {
-    block_manager_->Touch(sequence.blocks);
+    FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
+    for (int i = 0; i < seq.blocks.size(); ++i) {
+        if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
+            seq.blocks.resize(i);
+            seq.block_unique_ids.resize(i);
+            break;
+        }
+    }
+    retain.insert(retain.end(), seq.blocks.begin(), seq.blocks.end());
+    seq.status    = Sequence::kLocked;
+    seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_len_);
 }
 
-bool SequenceManager::Contains(uint64_t id)
+void SequenceManager::Release(const Sequence& sequence)
 {
-    return sequences_.find(id) != sequences_.end();
+    auto& seq = const_cast<Sequence&>(sequence);
+    if (seq.status == Sequence::kActive) {
+        block_manager_->Touch(seq.blocks);
+    }
+    if (seq.status != Sequence::kCached) {
+        released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
+    }
+    seq.status = Sequence::kCached;
 }
 
 namespace {
@@ -132,9 +137,9 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
 
 std::ostream& operator<<(std::ostream& os, const Schedule& s)
 {
-    os << "Schedule { free=" << s.free << ", cached=" << s.cached << ", allocate=" << s.allocate
-       << ", evict=" << s.evict << ", preempt=" << s.preempt << ", active=" << s.active << ", victims=" << s.victims
-       << ", block_counts=" << s.block_counts << ", inactive=" << s.inactive << " }";
+    os << "free=" << s.free << ", cached=" << s.cached << ", allocate=" << s.allocate << ", evict=" << s.evict
+       << ", preempt=" << s.preempt << ", active=" << s.active << ", victims=" << s.victims
+       << ", block_counts=" << s.block_counts << ", inactive=" << s.inactive;
     return os;
 }
 
@@ -235,9 +240,8 @@ struct Transaction {
 
 std::ostream& operator<<(std::ostream& os, const Transaction& trans)
 {
-    os << "Transaction { index=" << trans.index_ << ", block_count=" << trans.block_count_
-       << ", allocate=" << trans.allocate_ << ", evict=" << trans.evict_ << ", preempt=" << trans.preempt_
-       << ", victims=" << trans.victims_ << " }";
+    os << "index=" << trans.index_ << ", block_count=" << trans.block_count_ << ", allocate=" << trans.allocate_
+       << ", evict=" << trans.evict_ << ", preempt=" << trans.preempt_ << ", victims=" << trans.victims_;
     return os;
 }
 
@@ -245,8 +249,8 @@ std::ostream& operator<<(std::ostream& os, const Transaction& trans)
 
 std::ostream& operator<<(std::ostream& os, const Sequence& seq)
 {
-    os << "Sequence { id=" << seq.id << ", status=" << seq.status << ", size(blocks)=" << seq.blocks.size()
-       << ", cache_len=" << seq.cache_len << ", size(random_state)=" << seq.random_state.size() << " }";
+    os << "id=" << seq.id << ", status=" << seq.status << ", size(blocks)=" << seq.blocks.size()
+       << ", cache_len=" << seq.cache_len << ", size(random_state)=" << seq.random_state.size();
     return os;
 }
 
@@ -260,14 +264,21 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
     auto    seqs = const_cast<Sequence* const*>(sequences.data());
     Outcome outcome{};
 
+    if (!released_.empty()) {
+        block_manager_->Release(released_);
+        released_.clear();
+    }
+
     // check validity of of cached blocks (blocks of active & locked seqs are always valid)
     if (need_verification_) {
+        need_verification_ = false;
+        std::vector<const Block*> retain;
         for (int i = 0; i < sequences.size(); ++i) {
             if (seqs[i]->status == Sequence::kCached) {
-                VerifyBlocks(*seqs[i]);
+                Verify(*seqs[i], retain);
             }
         }
-        need_verification_ = false;
+        block_manager_->Retain(retain);
     }
 
     // count required blocks based on block validity
@@ -280,7 +291,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
         total_required += required[i];
     }
 
-    dbg(required);
+    // dbg(required);
 
     // no new blocks required, exit early
     if (total_required == 0) {
@@ -325,13 +336,14 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
             if (sequences[idxs[v]]->status == Sequence::kCached) {
                 continue;
             }
+            dbg(v, idxs[v]);
             int preempt = trans.Preempt(v, idxs[v]);
             dbg(preempt);
             // Commit only when preemption actually free enough blocks for the sequence to run
             if (block_count <= preempt) {
                 // preempted blocks are in cached state
                 block_count -= trans.Evict(block_count);
-                j = v + 1;
+                j = v;
                 break;
             }
         }
@@ -340,7 +352,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
 
         if (block_count == 0) {
             trans.Commit();
-            active[i] = 1;
+            active[idx] = 1;
             if (seq.status != Sequence::kActive) {
                 ++outcome.swap_in;
             }
@@ -365,25 +377,20 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
     outcome.allocation = schedule.allocate;
 
     // release preempted blocks -> cached
-    {
-        std::vector<const Block*> blocks;
-        for (const auto& v : schedule.victims) {
-            auto& seq = *seqs[v];
-            block_manager_->Touch(seq.blocks);
-            seq.status = Sequence::kCached;
-            blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
-        }
-        block_manager_->Release(blocks);
+    for (const auto& v : schedule.victims) {
+        Release(*sequences[v]);
     }
+    block_manager_->Release(released_);
+    released_.clear();
 
     // evict cached blocks -> free
     if (schedule.evict) {
-        need_verification_ = true;
         block_manager_->Evict(schedule.evict);
+        need_verification_ = true;
     }
 
     // allocate & assign blocks
-    auto blocks = block_manager_->Allocate(schedule.allocate + schedule.evict);
+    auto blocks = block_manager_->Allocate(schedule.allocate);
     auto first  = blocks.begin();
 
     for (const auto& idx : schedule.active) {
@@ -391,7 +398,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
         sequence.status = Sequence::kActive;
 
         auto last = first + required[idx];
-        std::for_each(first, last, [&sequence](const Block* b) {
+        std::for_each(first, last, [&](const Block* b) {
             sequence.blocks.push_back(b);
             sequence.block_unique_ids.push_back(b->unique_id);
         });
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 1b3b680784..5fdbf2054d 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -4,8 +4,6 @@
 
 namespace turbomind {
 
-// |<-- active -->|<-- pending -->|<-- new -->|
-
 struct Sequence {
 
     enum Status {
@@ -32,9 +30,6 @@ struct Sequence {
 
 class SequenceManager {
 public:
-    // allocate slack blocks to reduce block manager overhead
-    static constexpr int kSlackBlockNum = 1;
-
     explicit SequenceManager(size_t      layer_num,
                              size_t      head_num,
                              size_t      head_dim,
@@ -50,13 +45,13 @@ class SequenceManager {
 
     const Sequence* Create(uint64_t id);
 
-    const Sequence* Fetch(uint64_t id);
+    const Sequence* Get(uint64_t id);
 
-    void Update(const Sequence& seq);
+    bool Contains(uint64_t id);
 
     bool Erase(uint64_t id);
 
-    bool Contains(uint64_t id);
+    void Release(const Sequence& seq);
 
     struct Outcome {
         int allocation;
@@ -84,14 +79,8 @@ class SequenceManager {
         return block_manager_->max_block_count();
     }
 
-    friend std::ostream& operator<<(std::ostream& os, const Outcome& oc)
-    {
-        os << "allocation: " << oc.allocation << ", swap-in: " << oc.swap_in << ", swap-out: " << oc.swap_out;
-        return os;
-    }
-
 private:
-    void VerifyBlocks(Sequence& seq);
+    void Verify(Sequence& seq, std::vector<const Block*>& retain);
 
 private:
     int    block_len_;
@@ -108,4 +97,10 @@ class SequenceManager {
     std::vector<const Block*> released_;
 };
 
+inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc)
+{
+    os << "allocation: " << oc.allocation << ", swap-in: " << oc.swap_in << ", swap-out: " << oc.swap_out;
+    return os;
+}
+
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu
index 66c1202918..822fd1bc0d 100644
--- a/src/turbomind/models/llama/llama_kernels.cu
+++ b/src/turbomind/models/llama/llama_kernels.cu
@@ -233,18 +233,17 @@ __global__ void extend_kv_cache(T**          k_dst_ptrs,
     const auto k_val_src = reinterpret_cast<const uint4*>(k_src);
     const auto v_val_src = reinterpret_cast<const uint4*>(v_src);
 
-    // const auto val_dst = reinterpret_cast<uint4*>(v_dst[batch_id] + dst_layer_offset);
     const auto k_val_dst = (uint4*)((k_dst_ptrs + cu_block_cnt)[cache_block_index] + dst_layer_offset);
     const auto v_val_dst = (uint4*)((v_dst_ptrs + cu_block_cnt)[cache_block_index] + dst_layer_offset);
 
     if (seq_len_id < query_len) {
         // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
-        const int64_t dst_idx = head_id * size_per_head_div_x * block_length +  // H
+        const int64_t dst_idx = head_id * block_length * size_per_head_div_x +  // H
                                 cache_block_offset * size_per_head_div_x +      // s + offset
                                 head_size_id;                                   // D/x
 
-        const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
-                                head_id * size_per_head_div_x * max_q_len +              // H
+        const int64_t src_idx = batch_id * head_num * max_q_len * size_per_head_div_x +  // B
+                                head_id * max_q_len * size_per_head_div_x +              // H
                                 seq_len_id * size_per_head_div_x +                       // s
                                 head_size_id;                                            // D/x
 
diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc
index c306b1e7cc..184f1f2bf8 100644
--- a/src/turbomind/models/llama/test_cache_manager.cc
+++ b/src/turbomind/models/llama/test_cache_manager.cc
@@ -60,7 +60,7 @@ TEST_CASE("BlockManager")
     REQUIRE(m.cached_count() == 16);
 }
 
-TEST_CASE("SequenceManager")
+TEST_CASE("SequenceManager basic test")
 {
     Allocator<AllocatorType::CUDA> allocator(0);
 
@@ -86,9 +86,27 @@ TEST_CASE("SequenceManager")
     auto s2 = manager.Create(2);
     REQUIRE(manager.Contains(2));
 
-    outcome = manager.Materialize({s1, s2}, {128, 2560-1}, {2, 1}, 1);
+    outcome = manager.Materialize({s1, s2}, {128, 2559}, {2, 1}, 1);
     dbg(outcome);
+    REQUIRE(outcome.allocation == 20);
+    REQUIRE(outcome.swap_in == 1);
+    REQUIRE(outcome.swap_out == 1);
 
-    // outcome = manager.Materialize({s1, s2}, {128, 12800}, {1, 2}, 1);
-    // dbg(outcome);
+    auto s3 = manager.Create(3);
+    outcome = manager.Materialize({s1, s2, s3}, {127, 2559, 255}, {1, 100, 2}, 1);
+    dbg(outcome);
+}
+
+TEST_CASE("SequenceManager functional test")
+{
+    Allocator<AllocatorType::CUDA> allocator(0);
+    SequenceManager                manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator);
+
+    auto seq = manager.Create(1);
+    for (int i = 0; i < 1024; ++i) {
+        auto outcome = manager.Materialize({seq}, {i}, {0}, 1);
+        if (outcome.allocation) {
+            dbg(i, outcome);
+        }
+    }
 }
\ No newline at end of file
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index 169d6cbdba..0f552f6063 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -248,13 +248,13 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
                                                   cuda_device_prop_ptr.get());
 
     return std::make_unique<LlamaTritonSharedModelInstance<T>>(
-        LlamaTritonSharedModelInstance<T>{std::move(llama),
-                                          shared_weights_[device_id],
-                                          std::move(allocator),
+        LlamaTritonSharedModelInstance<T>{std::move(allocator),
                                           std::move(cublas_algo_map),
                                           std::move(cublas_wrapper_mutex),
                                           std::move(cublas_wrapper),
                                           std::move(cuda_device_prop_ptr),
+                                          shared_weights_[device_id],
+                                          std::move(llama),
                                           session_len_});
 }
 
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h
index 1713d96bef..4dff6eb24c 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h
@@ -29,13 +29,13 @@ namespace ft = turbomind;
 
 template<typename T>
 struct LlamaTritonSharedModelInstance {
-    std::unique_ptr<ft::LlamaV2<T>>                         llm;
-    std::shared_ptr<ft::LlamaWeight<T>>                     llm_weight;
     std::unique_ptr<ft::Allocator<ft::AllocatorType::CUDA>> allocator;
     std::unique_ptr<ft::cublasAlgoMap>                      cublas_algo_map;
     std::unique_ptr<std::mutex>                             cublas_wrapper_mutex;
     std::unique_ptr<ft::cublasMMWrapper>                    cublas_wrapper;
     std::unique_ptr<cudaDeviceProp>                         cuda_device_prop_ptr;
+    std::shared_ptr<ft::LlamaWeight<T>>                     llm_weight;
+    std::unique_ptr<ft::LlamaV2<T>>                         llm;
     const int                                               session_len;
 };
 
diff --git a/src/turbomind/triton_backend/transformer_triton_backend.hpp b/src/turbomind/triton_backend/transformer_triton_backend.hpp
index 4026048e31..483651b8db 100644
--- a/src/turbomind/triton_backend/transformer_triton_backend.hpp
+++ b/src/turbomind/triton_backend/transformer_triton_backend.hpp
@@ -271,6 +271,8 @@ struct AbstractTransformerModel;
 struct AbstractTransformerModelInstance;
 
 struct AbstractTransformerModelInstance {
+    virtual ~AbstractTransformerModelInstance() = default;
+
     virtual std::shared_ptr<std::vector<triton::Tensor>>
     forward(std::shared_ptr<std::vector<triton::Tensor>> input_tensors) = 0;
 
diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h
index a87efcd73b..1cebb33a00 100644
--- a/src/turbomind/utils/allocator.h
+++ b/src/turbomind/utils/allocator.h
@@ -125,9 +125,15 @@ class Allocator;
 template<>
 class Allocator<AllocatorType::CUDA>: public IAllocator {
 private:
-    const int                          device_id_;
-    cudaStream_t                       stream_ = 0;  // initialize as default stream
-    std::unordered_map<void*, size_t>* pointer_mapping_;
+    enum class MemoryType
+    {
+        HOST,
+        DEVICE
+    };
+
+    const int                                                 device_id_;
+    cudaStream_t                                              stream_ = 0;  // initialize as default stream
+    std::unordered_map<void*, std::pair<size_t, MemoryType>>* pointer_mapping_;
 
     bool isExist(void* address) const
     {
@@ -136,10 +142,10 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
     ReallocType isReMalloc(void* address, size_t size) const
     {
         FT_CHECK(isExist(address));
-        if (pointer_mapping_->at(address) < size) {
+        if (pointer_mapping_->at(address).first < size) {
             return ReallocType::INCREASE;
         }
-        else if (pointer_mapping_->at(address) == size) {
+        else if (pointer_mapping_->at(address).first == size) {
             return ReallocType::REUSE;
         }
         else {
@@ -151,7 +157,7 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
     Allocator(int device_id): device_id_(device_id)
     {
         TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-        pointer_mapping_ = new std::unordered_map<void*, size_t>();
+        pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>();
 #if defined(CUDA_MEMORY_POOL_DISABLED)
         TM_LOG_WARNING(
             "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
@@ -188,7 +194,9 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
     {
         TM_LOG_DEBUG(__PRETTY_FUNCTION__);
         while (!pointer_mapping_->empty()) {
-            free((void**)(&pointer_mapping_->begin()->first));
+            auto ptr           = pointer_mapping_->begin()->first;
+            auto size_and_type = pointer_mapping_->begin()->second;
+            free(&ptr, size_and_type.second == MemoryType::HOST);
         }
         delete pointer_mapping_;
     }
@@ -229,18 +237,19 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
         check_cuda_error(getSetDevice(o_device));
         TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size);
 
-        pointer_mapping_->insert({getAddress(ptr), size});
+        pointer_mapping_->insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}});
 
         return ptr;
     }
 
-    void free(void** ptr, bool is_host = false) const
+    void free(void** ptr, bool _ = false) const
     {
         TM_LOG_DEBUG(__PRETTY_FUNCTION__);
         void* address = getAddress(*ptr);
         if (*ptr != nullptr) {
             int o_device = 0;
             if (pointer_mapping_->count(address)) {
+                const auto is_host = pointer_mapping_->at(address).second == MemoryType::HOST;
                 TM_LOG_DEBUG("Free buffer %p", address);
                 check_cuda_error(getSetDevice(device_id_, &o_device));
                 if (is_host) {
@@ -361,7 +370,7 @@ class Allocator<AllocatorType::TF>: public IAllocator {
     {
         while (!pointer_mapping_->empty()) {
             void* ptr = pointer_mapping_->begin()->second.flat<uint8>().data();
-            free((void**)(&ptr));
+            free(&ptr);
         }
         pointer_mapping_->clear();
         delete pointer_mapping_;
@@ -454,7 +463,7 @@ class Allocator<AllocatorType::TH>: public IAllocator {
         TM_LOG_DEBUG(__PRETTY_FUNCTION__);
         while (!pointer_mapping_->empty()) {
             void* ptr = pointer_mapping_->begin()->second.data_ptr();
-            free((void**)(&ptr));
+            free(&ptr);
         }
         pointer_mapping_->clear();
         delete pointer_mapping_;
@@ -466,4 +475,4 @@ class Allocator<AllocatorType::TH>: public IAllocator {
     }
 };
 #endif
-}  // namespace turbomind
+}  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/utils/cuda_utils.h b/src/turbomind/utils/cuda_utils.h
index be0b85d69a..f066a0c25b 100644
--- a/src/turbomind/utils/cuda_utils.h
+++ b/src/turbomind/utils/cuda_utils.h
@@ -131,7 +131,7 @@ void check(T result, char const* const func, const char* const file, int const l
 inline void syncAndCheck(const char* const file, int const line)
 {
     // When FT_DEBUG_LEVEL=DEBUG, must check error
-    static char* level_name = std::getenv("FT_DEBUG_LEVEL");
+    static char* level_name = std::getenv("TM_DEBUG_LEVEL");
     if (level_name != nullptr) {
         static std::string level = std::string(level_name);
         if (level == "DEBUG") {

From ac8a50bdb5ba2c074dbbd2ad9bdfff590d4b7106 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 9 Oct 2023 06:06:09 +0000
Subject: [PATCH 07/56] update

---
 examples/cpp/llama/llama_triton_example.cc    |  31 +--
 .../decoder_multihead_attention_template.h    |   5 +
 .../kernels/unfused_attention_kernels.cu      |  10 +-
 .../kernels/unfused_attention_kernels.h       |   2 +-
 src/turbomind/models/llama/LlamaBatch.cc      | 232 ++++++++----------
 src/turbomind/models/llama/LlamaBatch.h       |  66 ++---
 .../llama/LlamaContextAttentionLayer.cc       |   5 +-
 .../models/llama/LlamaContextDecoder.cc       |   5 -
 .../models/llama/LlamaContextDecoder.h        |   1 -
 src/turbomind/models/llama/LlamaDecoder.cc    |   3 +
 .../llama/LlamaDecoderSelfAttentionLayer.cc   | 192 +--------------
 .../llama/LlamaDecoderSelfAttentionLayer.h    |   2 +-
 src/turbomind/models/llama/LlamaV2.cc         |   3 -
 src/turbomind/models/llama/LlamaV2.h          |   1 -
 src/turbomind/models/llama/llama_kernels.cu   |   8 +-
 src/turbomind/models/llama/llama_kernels.h    |   2 +-
 16 files changed, 187 insertions(+), 381 deletions(-)

diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc
index 07e88a508d..51d5a2182a 100644
--- a/examples/cpp/llama/llama_triton_example.cc
+++ b/examples/cpp/llama/llama_triton_example.cc
@@ -433,6 +433,8 @@ int main(int argc, char* argv[])
     const int  beam_width   = output_tensors_lists[0].get()->at("output_ids").shape[1];
     const int  seq_len      = output_tensors_lists[0].get()->at("output_ids").shape[2];
 
+    ft::FT_CHECK(beam_width == 1);
+
     std::vector<int> seq_lens(batch_size);
     // step 6: check results
     if (node_id == 0) {
@@ -442,32 +444,25 @@ int main(int argc, char* argv[])
             printf("[WARNING] Cannot write results into output file %s \n", fName.c_str());
         }
         else {
-            size_t outCount = batch_size * beam_width * seq_len;
-            // int*   hBuf     = new int[outCount];
+            const size_t outCount = batch_size * beam_width * seq_len;
+
             std::vector<int> hBuf(outCount);
+
             ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
             ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
+
             std::cout << "sequence length: ";
             for (int i = 0; i < batch_size; ++i) {
                 std::cout << (i ? ", " : "") << seq_lens[i];
             }
             std::cout << "\n";
-            {
-                std::cout << "Writing " << outCount << " elements\n";
-                int zeroCount = 0;
-                for (size_t i = 0; i < outCount; i++) {
-                    if (hBuf[i] == int(0))
-                        zeroCount++;
-                    outFile << hBuf[i] << " ";
-                    if ((i + 1) % (seq_len) == 0)
-                        outFile << std::endl;
-
-                    if (i < 10)
-                        printf("%5d ", hBuf[i]);
-                    if ((i + 1) % (seq_len) == 0 && i < 10)
-                        std::cout << std::endl;
+
+            for (int i = 0; i < batch_size; ++i) {
+                outFile << (i ? "\n" : "");
+                auto buf = hBuf.data() + seq_len * i;
+                for (int j = 0; j < seq_lens[i]; ++j) {
+                    outFile << buf[j] << " ";
                 }
-                std::cout << std::endl << "zeroCount = " << zeroCount << std::endl;
             }
         }
     }
@@ -477,7 +472,7 @@ int main(int argc, char* argv[])
     }
     cudaDeviceSynchronize();
 
-    if (1) {
+    if (0) {
         // test time
         auto start = std::chrono::high_resolution_clock::now();
 
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
index a13925818b..20669dad91 100644
--- a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
@@ -651,6 +651,11 @@ struct DecoderMultiHeadAttentionKernel {
             __syncthreads();
         }
 
+        // early exit if finished flag is set
+        if (params_.finished[batch_idx_]) {
+            return;
+        }
+
         // Compute attention for current step
         Prolugue();
 
diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu
index 536175ccf8..a81b4b0c5c 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.cu
+++ b/src/turbomind/kernels/unfused_attention_kernels.cu
@@ -855,7 +855,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                    T* QKV,
                                                    const T* __restrict qkv_bias,
                                                    const int* padding_offset,
-                                                   const int* history_length,
+                                                   const int* context_length,
                                                    const int* input_length,
                                                    int        batch_size,
                                                    int        seq_len,
@@ -927,8 +927,8 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
         }
     }
 
-    const int history_len = history_length[batch_idx];
-    const int context_len = history_len + input_length[batch_idx];
+    const int context_len = context_length[batch_idx];
+    const int history_len = context_len - input_length[batch_idx];
     const int timestep    = history_len + seq_idx;
 
     float rotary_emb_base = 10000.f;
@@ -982,7 +982,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                                                              QKV,                      \
                                                                                              qkv_bias,                 \
                                                                                              padding_offset,           \
-                                                                                             history_length,           \
+                                                                                             context_length,           \
                                                                                              input_length,             \
                                                                                              batch_size,               \
                                                                                              seq_len,                  \
@@ -1001,7 +1001,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     T*           QKV,
                                     const T*     qkv_bias,
                                     const int*   padding_offset,
-                                    const int*   history_length,
+                                    const int*   context_length,
                                     const int*   input_length,
                                     const int    batch_size,
                                     const int    seq_len,
diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h
index 50069fc33a..479b55ebe6 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.h
+++ b/src/turbomind/kernels/unfused_attention_kernels.h
@@ -70,7 +70,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     T*           QKV,
                                     const T*     qkv_bias,
                                     const int*   padding_offset,
-                                    const int*   history_length,
+                                    const int*   context_length,
                                     const int*   input_length,
                                     const int    batch_size,
                                     const int    seq_len,
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 35e89235a9..882da30e22 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -126,7 +126,7 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
             // stop & optionally erase active sequence
             if (state_->requests[i] && state_->requests[i]->id == r->id) {
                 ec = 0;
-                FinishRequest(i, r->end_flag);
+                CompleteRequest(i, true, r->end_flag);
                 break;
             }
         }
@@ -305,7 +305,7 @@ bool LlamaBatch<T>::Initialize()
         if (swapin_beg != active_end) {
             std::vector<int> missing_len(sequences.size());
             for (int i = 0; i < sequences.size(); ++i) {
-                missing_len[i] = (int)sequences[i]->tokens.size() - sequences[i]->cache_len;
+                missing_len[i] = context_lengths[i] - sequences[i]->cache_len;
             }
             std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; });
         }
@@ -350,32 +350,29 @@ bool LlamaBatch<T>::Initialize()
             // cumulative num of blocks
             h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
 
-            k_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), k_ptrs, [&](const Block* p) {
+            k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
                 return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
             });
-            v_ptrs = std::transform(seq.blocks.begin(), seq.blocks.end(), v_ptrs, [&](auto p) {
+            v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) {
                 return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
             });
         }
 
         Copy(state_->h_context_length, batch_size, context_length_buf_);
 
-        dbg(std::vector(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1));
-        dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size]));
-        dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size]));
-        dbg(h_cu_block_counts_[batch_size]);
+        if (1) {
+            std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1);
+            dbg(cu_block_cnts);
+        }
+        // dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size]));
+        // dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size]));
+        // dbg(h_cu_block_counts_[batch_size]);
 
         Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
         Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
         Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
 
         static_assert(sizeof(uintptr_t) == sizeof(void*));
-
-        std::vector<void*> fuck(h_cu_block_counts_[batch_size]);
-        Copy((void**)k_block_ptrs_, fuck.size(), fuck.data());
-        cudaStreamSynchronize(stream_);
-
-        dbg(fuck);
     }
 
     // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
@@ -408,16 +405,16 @@ void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std:
 template<typename T>
 void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx)
 {
-    Copy(llama_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
-    Copy(llama_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
+    Copy(model_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
+    Copy(model_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
 }
 
 template<typename T>
 void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
 {
     dbg(idx);
-    Copy((curandState_t*)state.top_k_curand_state + idx, 1, llama_->GetTopKState(idx));
-    Copy((curandState_t*)state.top_p_curand_state + idx, 1, llama_->GetTopPState(idx));
+    Copy((curandState_t*)state.top_k_curand_state + idx, 1, model_->GetTopKState(idx));
+    Copy((curandState_t*)state.top_p_curand_state + idx, 1, model_->GetTopPState(idx));
 }
 
 template<typename T>
@@ -426,10 +423,10 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     const size_t batchxbeam = batch_size;
 
-    const size_t hidden_units      = llama_->hidden_units_;
-    const size_t vocab_size        = llama_->vocab_size_padded_;
-    const size_t head_dim          = llama_->size_per_head_;
-    const size_t local_kv_head_num = llama_->local_kv_head_num_;
+    const size_t hidden_units      = model_->hidden_units_;
+    const size_t vocab_size        = model_->vocab_size_padded_;
+    const size_t head_dim          = model_->size_per_head_;
+    const size_t local_kv_head_num = model_->local_kv_head_num_;
     // +1 padding, BlockIterator does not use predicate
     const size_t max_block_count = sequence_manager_->max_block_count() + 1;
 
@@ -453,7 +450,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
 
     input_ids_buf_      = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
     input_length_buf_   = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
-    history_length_buf_ = (int*)allocator_->reMalloc(history_length_buf_, sizeof(int) * batchxbeam);
     context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
 
     sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
@@ -506,15 +502,11 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
     const size_t max_block_count = sequence_manager_->max_block_count();
 
     {
-        NcclGuard barrier(llama_->tensor_para_, stream_, true);
+        NcclGuard barrier(model_->tensor_para_, stream_, true);
         h_input_ids_buf_ =
             (int*)allocator_->reMalloc(h_input_ids_buf_, sizeof(int) * max_batch_size * session_len_, false, true);
         h_input_length_buf_ =
             (int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_history_length_buf_ =
-            (int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true);
-        h_sequence_lengths_ =
-            (int*)allocator_->reMalloc(h_sequence_lengths_, sizeof(int) * max_batch_size, false, true);
 
         h_tmp_k_ptrs_ = (void**)allocator_->reMalloc(h_tmp_k_ptrs_, sizeof(void*) * max_batch_size, false, true);
         h_tmp_v_ptrs_ = (void**)allocator_->reMalloc(h_tmp_v_ptrs_, sizeof(void*) * max_batch_size, false, true);
@@ -558,7 +550,6 @@ void LlamaBatch<T>::FreeBuffer()
 
         allocator_->free((void**)&input_ids_buf_);
         allocator_->free((void**)&input_length_buf_);
-        allocator_->free((void**)&history_length_buf_);
         allocator_->free((void**)&context_length_buf_);
 
         allocator_->free((void**)&sequence_lengths_);
@@ -599,8 +590,6 @@ void LlamaBatch<T>::FreeBuffer()
         allocator_->free((void**)&h_v_block_ptrs_, true);
         allocator_->free((void**)&h_input_ids_buf_, true);
         allocator_->free((void**)&h_input_length_buf_, true);
-        allocator_->free((void**)&h_history_length_buf_, true);
-        allocator_->free((void**)&h_sequence_lengths_, true);
         allocator_->free((void**)&h_seq_limit_len_, true);
         is_allocate_persistant_buffer_ = false;
     }
@@ -611,20 +600,20 @@ LlamaBatch<T>::LlamaBatch(int                              max_batch_size,
                           int                              max_context_token_num,
                           int                              session_len,
                           std::unique_ptr<SequenceManager> sequence_manager,
-                          LlamaV2<T>*                      llama):
+                          LlamaV2<T>*                      model):
     max_batch_size_(max_batch_size),
     max_context_token_num_(max_context_token_num),
     session_len_(session_len),
-    rank_(llama->tensor_para_.rank_),
-    debug_(llama->debug_),
-    step_length_(llama->step_length_),
+    rank_(model->tensor_para_.rank_),
+    debug_(model->debug_),
+    step_length_(model->step_length_),
     sequence_manager_(std::move(sequence_manager)),
-    llama_(llama),
+    model_(model),
     data_type_(getTensorType<T>())
 {
-    stream_         = llama_->stream_;
-    allocator_      = llama_->allocator_;
-    cublas_wrapper_ = llama_->cublas_wrapper_;
+    stream_         = model_->stream_;
+    allocator_      = model_->allocator_;
+    cublas_wrapper_ = model_->cublas_wrapper_;
 
     for (auto& s : states_) {
         s.requests.resize(max_batch_size);
@@ -679,7 +668,7 @@ void LlamaBatch<T>::InitializeSampling()
 
     inputs_ = std::move(inputs);
 
-    llama_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
+    model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
 
     // recover random states if not a new request
     for (int i = 0; i < batch_size; ++i) {
@@ -688,16 +677,15 @@ void LlamaBatch<T>::InitializeSampling()
         }
     }
 
-    handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size);
+    handleOptArg(&inputs_, "end_id", end_ids_buf_, model_->end_id_, batch_size);
     cudaStreamSynchronize(0);
 }
 
 template<typename T>
-void LlamaBatch<T>::InitializeGeneration()
+auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
 {
-    const int batch_size = state_->size;
-
-    max_context_len_ = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
+    const int batch_size      = state_->active_size;
+    const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
 
     Clear(token_ids_buf_, batch_size * session_len_);
     invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
@@ -712,7 +700,7 @@ void LlamaBatch<T>::InitializeGeneration()
     for (int i = 0; i < batch_size; ++i) {
         auto token_ids = token_ids_buf_ + i;
         auto p_src     = state_->h_context_length[i] - 1;
-        auto p_dst     = max_context_len_ - 1;
+        auto p_dst     = max_context_len - 1;
         if (p_src != p_dst) {  // dst and src of `cudaMemcpyAsync` must not overlap
             Copy(token_ids + p_src * batch_size, 1, token_ids + p_dst * batch_size);
         }
@@ -729,9 +717,9 @@ void LlamaBatch<T>::InitializeGeneration()
     // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted
     // for
     for (int i = 0; i < batch_size; ++i) {
-        h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len_ - state_->h_context_length[i]);
+        h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
         // mask finished sequences
-        state_->h_finished[i] = max_context_len_ >= h_seq_limit_len_[i];
+        state_->h_finished[i] = max_context_len >= h_seq_limit_len_[i];
     }
     Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
     Copy(state_->h_finished, batch_size, finished_buf_);
@@ -739,11 +727,11 @@ void LlamaBatch<T>::InitializeGeneration()
     // ! range of step_ [1, 2 * session_len]
     // consider a sequence with context_len == session_len and another sequence with context_len == 1 and
     // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
-    step_ = max_context_len_;
+    const int start_step = max_context_len;
 
     if (rank_ == 0) {
         TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
-        TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len_);
+        TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len);
 
         TM_LOG_INFO("[initGen] slot  sequence_id  context_len  seq_limit_len  finished");
         for (int i = 0; i < batch_size; ++i) {
@@ -755,45 +743,46 @@ void LlamaBatch<T>::InitializeGeneration()
                         (int)state_->h_finished[i]);
         }
     }
+    return GenerationState{max_context_len, start_step};
 }
 
 template<typename T>
-bool LlamaBatch<T>::Generate()
+bool LlamaBatch<T>::Generate(GenerationState& g)
 {
     const int batch_size = state_->active_size;
 
     constexpr int kLogInterval = 10;
-    if (rank_ == 0 && (step_ - 1) % kLogInterval == 0) {
-        TM_LOG_INFO("------------------------- step = %d -------------------------", step_ - 1);
+    if (rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
+        TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
     }
 
-    const bool is_first_step = step_ == max_context_len_;
+    const bool is_first_step = (g.step == g.max_init_ctx_len);
 
     std::vector<int> prev;
     if (debug_ && rank_ == 0 && is_first_step) {
         prev.resize(batch_size);
-        Copy(token_ids_buf_ + (step_ - 1) * batch_size, batch_size, prev.data());
+        Copy(token_ids_buf_ + (g.step - 1) * batch_size, batch_size, prev.data());
     }
 
     // embeddingLookup(step_ - 1);
-    llama_->embeddingLookup(decoder_input_buf_,  //
+    model_->embeddingLookup(decoder_input_buf_,  //
                             token_ids_buf_,
                             batch_size,
-                            step_ - 1);
+                            g.step - 1);
 
-    llama_->decoderForward(decoder_output_buf_,
+    model_->decoderForward(decoder_output_buf_,
                            k_block_ptrs_,
                            v_block_ptrs_,
                            decoder_input_buf_,
                            sequence_lengths_,
                            finished_buf_,
                            cu_block_counts_,
-                           step_,
+                           g.step,
                            0,
                            session_len_,
                            batch_size);
 
-    llama_->postDecodeEmbedding(logits_buf_,  //
+    model_->postDecodeEmbedding(logits_buf_,  //
                                 local_logits_buf_,
                                 decoder_output_buf_,
                                 batch_size);
@@ -801,7 +790,7 @@ bool LlamaBatch<T>::Generate()
     // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
     // not supported yet.
     bool should_stop{};
-    llama_->dynamicDecode(token_ids_buf_,
+    model_->dynamicDecode(token_ids_buf_,
                           finished_buf_,
                           sequence_lengths_,
                           &should_stop,
@@ -811,16 +800,16 @@ bool LlamaBatch<T>::Generate()
                           seq_limit_len_,
                           context_length_buf_,
                           end_ids_buf_,
-                          step_,
+                          g.step,
                           0,
-                          max_context_len_,
+                          g.max_init_ctx_len,
                           session_len_ * 2,
                           batch_size);
 
     if (debug_ && rank_ == 0) {
         std::vector<int> curr(batch_size);
 
-        Copy(token_ids_buf_ + step_ * batch_size, batch_size, curr.data());
+        Copy(token_ids_buf_ + g.step * batch_size, batch_size, curr.data());
         cudaStreamSynchronize(stream_);
 
         if (is_first_step) {
@@ -828,19 +817,19 @@ bool LlamaBatch<T>::Generate()
             for (int k = 0; k < prev.size(); ++k) {
                 sprev << std::setw(6) << prev[k];
             }
-            TM_LOG_INFO("[ lookup ] step = %d, [%s]", step_ - 1, sprev.str().c_str());
+            TM_LOG_INFO("[ lookup ] step = %d, [%s]", g.step - 1, sprev.str().c_str());
         }
 
         std::stringstream scurr;
         for (int k = 0; k < curr.size(); ++k) {
             scurr << std::setw(6) << curr[k];
         }
-        TM_LOG_INFO("[generate] step = %d, [%s]", step_ - 1, scurr.str().c_str());
+        TM_LOG_INFO("[generate] step = %d, [%s]", g.step - 1, scurr.str().c_str());
     }
 
     ////////////////////////////////////////////////
     /// ! increase the step counter
-    ++step_;
+    ++g.step;
 
     return !should_stop;
 }
@@ -857,22 +846,21 @@ void LlamaBatch<T>::ContextDecode()
             dbg(state_->h_context_length[i], seq.cache_len);
             if (const int missing = state_->h_context_length[i] - seq.cache_len; missing > 1) {
                 base = base < 0 ? i : base;
+                dbg(seq.tokens, seq.cache_len);
                 Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
                 // subtract input/context len by 1 to skip last input token (will process with decoder later)
-                h_input_length_buf_[i]   = missing - 1;
-                h_history_length_buf_[i] = seq.cache_len;
+                h_input_length_buf_[i] = missing - 1;
             }
         }
     }
     if (base < 0) {
-        TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
+        // TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
         return;
     }
 
     const int context_decode_count = batch_size - base;
 
     Copy(h_input_length_buf_, batch_size, input_length_buf_);
-    Copy(h_history_length_buf_, batch_size, history_length_buf_);
 
     check_cuda_error(cudaStreamSynchronize(stream_));
     const auto tick = std::chrono::high_resolution_clock::now();
@@ -915,7 +903,7 @@ void LlamaBatch<T>::ContextDecode()
     for (int k = 0; k < offsets.size() - 1; ++k) {
         int              first          = offsets[k];
         int              last           = offsets[k + 1];
-        int              sub_batch_szie = last - first;
+        int              sub_batch_size = last - first;
         T*               k_ptr          = tmp_k_cache_buf_;
         T*               v_ptr          = tmp_v_cache_buf_;
         std::vector<int> decode_indices{};
@@ -926,8 +914,8 @@ void LlamaBatch<T>::ContextDecode()
             input_ids        = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
             h_tmp_k_ptrs_[i] = k_ptr;
             h_tmp_v_ptrs_[i] = v_ptr;
-            k_ptr += llama_->local_kv_head_num_ * max_context_cnts[k] * llama_->size_per_head_;
-            v_ptr += llama_->local_kv_head_num_ * max_context_cnts[k] * llama_->size_per_head_;
+            k_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
+            v_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
             decode_indices.push_back(i);
             decode_lengths.push_back(h_input_length_buf_[i]);
             max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
@@ -935,14 +923,14 @@ void LlamaBatch<T>::ContextDecode()
         int token_count = input_ids - context_decoder_ids_buf_;
         dbg(token_count, max_input_len, max_context_cnts[k]);
 
-        Copy(h_tmp_k_ptrs_ + first, sub_batch_szie, tmp_k_ptrs_ + first);
-        Copy(h_tmp_v_ptrs_ + first, sub_batch_szie, tmp_v_ptrs_ + first);
+        Copy(h_tmp_k_ptrs_ + first, sub_batch_size, tmp_k_ptrs_ + first);
+        Copy(h_tmp_v_ptrs_ + first, sub_batch_size, tmp_v_ptrs_ + first);
 
         if (rank_ == 0) {
             TM_LOG_INFO(
                 "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
                 base,
-                sub_batch_szie,
+                sub_batch_size,
                 token_count,
                 max_input_len,
                 max_context_cnts[k]);
@@ -952,15 +940,15 @@ void LlamaBatch<T>::ContextDecode()
         dbg(k_block_ptrs_, v_block_ptrs_);
 
         if (1) {
-            int a, b, c;
-            Copy(input_length_buf_, 1, &a);
-            Copy(history_length_buf_, 1, &b);
-            Copy(context_length_buf_, 1, &c);
+            std::vector<int> input_len(sub_batch_size);
+            std::vector<int> context_len(sub_batch_size);
+            Copy(input_length_buf_ + first, sub_batch_size, input_len.data());
+            Copy(context_length_buf_ + first, sub_batch_size, context_len.data());
             cudaStreamSynchronize(stream_);
-            dbg(a, b, c);
+            dbg(input_len, context_len);
         }
 
-        llama_->contextDecode(nullptr,
+        model_->contextDecode(nullptr,
                               k_block_ptrs_,
                               v_block_ptrs_,
                               tmp_k_ptrs_ + first,
@@ -969,14 +957,13 @@ void LlamaBatch<T>::ContextDecode()
                               context_decoder_output_buf_,
                               context_decoder_ids_buf_,
                               input_length_buf_ + first,
-                              history_length_buf_ + first,
                               context_length_buf_ + first,
                               cu_block_counts_ + first,
                               token_count,
                               max_input_len,
                               max_context_cnts[k],
                               max_context_cnts[k],
-                              sub_batch_szie);
+                              sub_batch_size);
 
         // compute logits of inputs if requested
         OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
@@ -1016,41 +1003,40 @@ void LlamaBatch<T>::OutputContextLogits(T*                      context_decoder_
     }
 
     if (context_logits_buf_ == nullptr) {
-        NcclGuard guard(llama_->tensor_para_, stream_, true);
+        NcclGuard guard(model_->tensor_para_, stream_, true);
         context_logits_buf_ =
-            (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
-        const auto tp = llama_->tensor_para_.world_size_;
+            (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * max_context_token_num_);
+        const auto tp = model_->tensor_para_.world_size_;
         if (tp > 1) {
-            FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
-            const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
+            FT_CHECK(model_->vocab_size_padded_ % tp == 0);
+            const auto local_vocab_size = model_->vocab_size_padded_ / tp;
             local_context_logits_buf_ =
                 (float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
         }
     }
 
-    llama_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
+    model_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
 
     auto logits = context_logits_buf_;
 
     for (int k = 0; k < indices.size(); ++k) {
         if (output_logits[k]) {
-            Copy(logits, llama_->vocab_size_ * lengths[k], output_logits[k]);
+            Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]);
         }
-        logits += llama_->vocab_size_padded_ * lengths[k];
+        logits += model_->vocab_size_padded_ * lengths[k];
     }
 }
 
 template<typename T>
-auto LlamaBatch<T>::Finish() -> std::vector<Signal>
+auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
 {
     const int batch_size = state_->active_size;
 
-    // secure info needed by `synchronize()`
+    // secure info needed by `Initialize()`
     Copy(finished_buf_, batch_size, state_->h_finished);
     Copy(sequence_lengths_, batch_size, state_->h_context_length);
-    Copy(sequence_lengths_, batch_size, context_length_buf_);
 
-    SetOutputTensors(step_);
+    SetOutputTensors(g);
 
     check_cuda_error(cudaStreamSynchronize(stream_));
 
@@ -1064,7 +1050,7 @@ auto LlamaBatch<T>::Finish() -> std::vector<Signal>
     if (debug_ && rank_ == 0) {
         std::stringstream ss;
         for (int i = 0; i < batch_size; ++i) {
-            ss << (i ? ", " : "") << "(" << h_sequence_lengths_[i] << "," << state_->h_finished[i] << ")";
+            ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
         }
         TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
     }
@@ -1072,7 +1058,7 @@ auto LlamaBatch<T>::Finish() -> std::vector<Signal>
     std::vector<Signal> signals;
     for (int i = 0; i < batch_size; ++i) {
         if (state_->requests[i] && state_->h_finished[i]) {
-            FinishRequest(i, false);
+            CompleteRequest(i, false, false);
             signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
         }
     }
@@ -1080,15 +1066,16 @@ auto LlamaBatch<T>::Finish() -> std::vector<Signal>
 }
 
 template<typename T>
-void LlamaBatch<T>::SetOutputTensors(int max_gen_step)
+void LlamaBatch<T>::SetOutputTensors(const GenerationState& g)
 {
+    // dbg(g.max_init_ctx_len);
     const auto batch_size = state_->active_size;
     // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
     invokeGatherOutput(state_->output_ids,
                        token_ids_buf_,
                        context_length_buf_,
-                       max_context_len_,
-                       max_gen_step,
+                       g.max_init_ctx_len,
+                       g.step,
                        session_len_,
                        batch_size,
                        stream_);
@@ -1101,7 +1088,7 @@ void LlamaBatch<T>::SetOutputTensors(int max_gen_step)
             auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
             Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
             Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
-            if (max_gen_step > max_context_len_) {  // +1 for newly generated token
+            if (g.step > g.max_init_ctx_len) {  // +1 for newly generated token
                 invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
             }
         }
@@ -1109,31 +1096,29 @@ void LlamaBatch<T>::SetOutputTensors(int max_gen_step)
 }
 
 template<typename T>
-void LlamaBatch<T>::FinishRequest(int index, bool force_end)
+void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_force_end)
 {
     if (rank_ == 0) {
-        TM_LOG_INFO("[finishRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
+        TM_LOG_INFO("[CompleteRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
     }
 
     if (debug_ && rank_ == 0) {
-        std::vector<int> tokens(h_sequence_lengths_[index] + 1);
+        std::vector<int> tokens(state_->h_context_length[index] + 1);
         Copy(state_->output_ids + index * session_len_, tokens.size(), tokens.data());
         cudaStreamSynchronize(stream_);
         std::stringstream ss;
         for (const auto& t : tokens) {
             ss << " " << t;
         }
-        TM_LOG_INFO("[finishRequest] slot %d, tokens [%s]", index, ss.str().c_str());
+        TM_LOG_INFO("[CompleteRequest] slot %d, tokens [%s]", index, ss.str().c_str());
     }
 
-    if (state_->requests[index]->end_flag || force_end) {
+    if (state_->requests[index]->end_flag || is_force_end) {
         sequence_manager_->Erase(state_->requests[index]->id);
     }
     else {
-        // the last generated token is not processed by decoder thus dont have k/v cache
-        const int n_steps    = step_ - max_context_len_;
-        const int cache_len  = h_sequence_lengths_[index];
-        const int output_len = n_steps > 0 ? cache_len + 1 : cache_len;
+        const int cache_len  = state_->h_context_length[index];
+        const int output_len = !is_stop_request ? cache_len + 1 : cache_len;
 
         auto& seq = *state_->sequences[index];
 
@@ -1150,8 +1135,8 @@ void LlamaBatch<T>::FinishRequest(int index, bool force_end)
 
         // save random state in host memory
         if (auto ptr = (curandState_t*)seq.random_state.data()) {
-            Copy(llama_->GetTopKState(index), 1, ptr++);
-            Copy(llama_->GetTopPState(index), 1, ptr++);
+            ptr = Copy(model_->GetTopKState(index), 1, ptr);
+            ptr = Copy(model_->GetTopPState(index), 1, ptr);
         }
 
         check_cuda_error(cudaStreamSynchronize(stream_));
@@ -1168,7 +1153,7 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
     TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
     check_cuda_error(cudaSetDevice(device_id));
 
-    auto& shared_state = llama_->shared_state_;
+    auto& shared_state = model_->shared_state_;
 
     auto& request_queue  = shared_state->request_queue;
     auto& infer_requests = shared_state->infer_requests;
@@ -1176,12 +1161,14 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
 
     int finished_count = 0;
 
+    GenerationState g{};
+
     while (1) {
         if (rank_ == 0) {
             const int  free_slot_count = max_batch_size_ - state_->size + finished_count;
             const bool is_empty        = (free_slot_count == max_batch_size_);
 
-            // will block if state is empty
+            // will block if batch is empty
             request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
 
             if (!shared_state->abort) {
@@ -1194,7 +1181,7 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
 
         if (shared_state->abort) {
             if (state_->size && rank_ == 0) {
-                TM_LOG_WARNING("Active request(s) present (%d) while aborting.", state_->size);
+                TM_LOG_WARNING("Active request(s) present (%d) while exiting.", state_->size);
             }
             return;
         }
@@ -1213,15 +1200,16 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
 
         if (state_->active_size) {
             if (modified) {
-                InitializeGeneration();
+                g = InitializeGeneration();
                 InitializeSampling();
             }
             for (int i = 0; i < step_length_; ++i) {
-                if (!Generate()) {
+                if (!Generate(g)) {
                     break;
                 }
             }
-            auto signals = Finish();
+            auto signals   = Finish(g);
+            finished_count = signals.size();
             BarrierSignalRequests(*shared_state->barrier, signals);
         }
     }
@@ -1235,9 +1223,7 @@ void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Si
     if (!signals.empty()) {
         barrier.wait();
         if (rank_ == 0) {
-            for (const auto& s : signals) {
-                s();
-            }
+            std::for_each(signals.cbegin(), signals.cend(), [](auto& s) { s(); });
         }
         barrier.wait();
     }
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 1faa7004e2..8e1c1f5d36 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -51,19 +51,25 @@ class LlamaBatch {
 
     void ProcessInferRequests(const Requests& requests);
 
-    bool Initialize();
+    [[nodiscard]] bool Initialize();
 
     void ContextDecode();
 
-    void InitializeSampling();
-    void InitializeGeneration();
+    struct GenerationState {
+        int max_init_ctx_len;
+        int step;
+    };
 
-    [[nodiscard]] bool Generate();
+    void            InitializeSampling();
+    GenerationState InitializeGeneration();
 
-    [[nodiscard]] auto Finish() -> std::vector<Signal>;
-    void               FinishRequest(int index, bool force_end);
+    [[nodiscard]] bool Generate(GenerationState& g);
 
-    void SetOutputTensors(int max_gen_step);
+    [[nodiscard]] auto Finish(GenerationState& g) -> std::vector<Signal>;
+
+    void CompleteRequest(int index, bool is_stop_request, bool is_force_end);
+
+    void SetOutputTensors(const GenerationState& g);
 
     void
     OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
@@ -76,7 +82,7 @@ class LlamaBatch {
 
     ~LlamaBatch()
     {
-        llama_->shared_state_->request_queue.Abort();
+        model_->shared_state_->request_queue.Abort();
 
         internal_thread_.join();
 
@@ -121,17 +127,25 @@ class LlamaBatch {
     const bool debug_;
     const int  step_length_;
 
-    LlamaV2<T>* const llama_;
+    LlamaV2<T>* const model_;
 
     std::unique_ptr<SequenceManager> sequence_manager_;
 
-    T*   context_decoder_input_buf_{};   // CTXDEC
-    T*   context_decoder_output_buf_{};  // CTXDEC
-    int* context_decoder_ids_buf_{};
-
-    T* decoder_input_buf_{};   // CTXDEC, GENERATE
-    T* decoder_output_buf_{};  // CTXDEC, GENERATE
+    ///////////////////////////////////////////////////////////////////
+    // k/v cache block buffers
+    int*       cu_block_counts_{};
+    uintptr_t* k_block_ptrs_{};
+    uintptr_t* v_block_ptrs_{};
 
+    ////////////////////////////////////////////////////////////////////
+    // context decoding temp buffers
+    T*   context_decoder_input_buf_{};
+    T*   context_decoder_output_buf_{};
+    int* context_decoder_ids_buf_{};
+    int* input_ids_buf_{};
+    // lengths
+    int* input_length_buf_{};    // input + cache missed length
+    int* context_length_buf_{};  // history length + input_length
     // temp buffers used for block->linear kv-cache conversion
     T*     tmp_k_cache_buf_{};
     T*     tmp_v_cache_buf_{};
@@ -140,14 +154,10 @@ class LlamaBatch {
     void** h_tmp_k_ptrs_{};
     void** h_tmp_v_ptrs_{};
 
-    int*       input_ids_buf_{};       // input token ids + cache missed token ids, CTXDEC
-    int*       input_length_buf_{};    // input + cache missed length, CTXDEC, GENERATE
-    int*       history_length_buf_{};  // history length, CTXDEC
-    int*       context_length_buf_{};  // history length + input_length, CTXDEC, GENERATE
-    int*       sequence_lengths_{};    // current sequence length
-    int*       cu_block_counts_{};
-    uintptr_t* k_block_ptrs_{};
-    uintptr_t* v_block_ptrs_{};
+    T*   decoder_input_buf_{};
+    T*   decoder_output_buf_{};
+    int* sequence_lengths_{};  // current sequence length
+    int* init_ctx_lens_{};
 
     float* logits_buf_{};        // combined logits
     float* local_logits_buf_{};  // tensor parallel local logits
@@ -161,10 +171,9 @@ class LlamaBatch {
     uint32_t* seq_limit_len_{};
 
     // pinned buffers
-    int*       h_input_ids_buf_{};
-    int*       h_input_length_buf_{};
-    int*       h_history_length_buf_{};
-    int*       h_sequence_lengths_{};
+    int* h_input_ids_buf_{};
+    int* h_input_length_buf_{};
+    // int*       h_sequence_lengths_{};
     uint32_t*  h_seq_limit_len_{};
     int*       h_cu_block_counts_{};
     uintptr_t* h_k_block_ptrs_{};
@@ -191,9 +200,6 @@ class LlamaBatch {
 
     const DataType data_type_{};
 
-    int max_context_len_{};
-    int step_{};
-
     bool is_allocate_persistant_buffer_ = false;
     bool is_allocate_buffer_            = false;
 
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 2505669dec..a3b358d9ed 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -145,7 +145,6 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
     T* attention_mask  = input_tensors->at("attention_mask").getPtr<T>();
 
     const auto input_length    = input_tensors->at("input_lengths").getPtr<const int>();
-    const auto history_length  = input_tensors->at("history_lengths").getPtr<const int>();
     const auto context_length  = input_tensors->at("context_lengths").getPtr<const int>();
     int*       cu_seqlens      = input_tensors->at("cu_seqlens").getPtr<int>();
     int*       cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
@@ -178,7 +177,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                    qkv_buf_,
                                    weights->qkv.bias,
                                    padding_offset,  // padding_offset,
-                                   history_length,  // used for applying rotary embedding
+                                   context_length,  // used for applying rotary embedding
                                    input_length,
                                    batch_size,
                                    max_q_len,  // seq_len
@@ -214,7 +213,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                         v_buf_2_,
                         cu_block_counts,
                         input_length,
-                        history_length,
+                        context_length,
                         batch_size,
                         kv_cache_block_len_,
                         layer_offset,
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index e763082733..c7231c1aa8 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -110,14 +110,10 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
         {"padding_offset", {MEMORY_GPU, TYPE_INT32, {sess.token_num}, padding_offset_}},
         {"cu_seqlens", {MEMORY_GPU, TYPE_INT32, {sess.batch_size + 1}, cu_seqlens_}},
         {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
-        {"history_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.history_length}},
         {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
         {"cu_block_counts", input_tensors->at("cu_block_counts")},
         {"max_seq_len", input_tensors->at("max_seq_len")}};
 
-    // auto& k_cache = *sess.k_cache;
-    // auto& v_cache = *sess.v_cache;
-
     TensorMap self_attention_output_tensors{
         {"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
         {"key_cache", output_tensors->at("key_cache")},
@@ -206,7 +202,6 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
     sess.weights       = decoder_layer_weights;
 
     sess.input_length   = input_tensors->at("input_lengths").getPtr<int>();
-    sess.history_length = input_tensors->at("history_lengths").getPtr<int>();
     sess.context_length = input_tensors->at("context_lengths").getPtr<int>();
 
     T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>();
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.h b/src/turbomind/models/llama/LlamaContextDecoder.h
index 4f3613c38c..6750614c5e 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.h
+++ b/src/turbomind/models/llama/LlamaContextDecoder.h
@@ -68,7 +68,6 @@ class LlamaContextDecoder: public BaseLayer {
         size_t max_query_len;
         size_t max_key_len;
         int*   input_length{};
-        int*   history_length{};
         int*   context_length{};
 
         const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index 73e95b1353..926b442429 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -195,6 +195,9 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
     T* decoder_input  = input_tensors->at("decoder_input").getPtr<T>();
     T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
 
+    int step = input_tensors->at("step").getVal<int>();
+    // Compare(decoder_input, sess.batch_size * hidden_units_, Concat("decoder_input", 0, step), kCmpRead, stream_);
+
     ////////////////////////////////////////////
     /// RMSNorm
     invokeRootMeanSquareNorm(decoder_output,
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index 9ad3908ea6..135ed8f074 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -33,139 +33,7 @@
 namespace turbomind {
 
 template<typename T>
-struct SATypeConverter {
-    using Type = T;
-};
-
-template<>
-struct SATypeConverter<half> {
-    using Type = uint16_t;
-};
-
-template<typename T>
-static inline void fusedQKV_masked_attention_dispatch(const T*     qkv_buf,
-                                                      const T*     qkv_bias,
-                                                      const T*     relative_attention_bias,
-                                                      T*           key_cache,
-                                                      T*           value_cache,
-                                                      T**          k_cache_per_sample,
-                                                      T**          v_cache_per_sample,
-                                                      size_t       kv_cache_per_sample_offset,
-                                                      const int*   cache_indir,
-                                                      T*           context_buf,
-                                                      const bool*  finished,
-                                                      const int*   sequence_lengths,
-                                                      const int    max_batch_size,
-                                                      const int    inference_batch_size,
-                                                      const int    beam_width,
-                                                      const int    head_num,
-                                                      const int    kv_head_num,
-                                                      const int    size_per_head,
-                                                      const int    rotary_embedding_dim,
-                                                      const int    max_position_embeddings,
-                                                      const bool   use_dynamic_ntk,
-                                                      const bool   use_logn_attn,
-                                                      const int    memory_max_len,
-                                                      const int*   prefix_prompt_lengths,
-                                                      const int    max_prefix_prompt_length,
-                                                      const int    max_input_len,
-                                                      const int*   total_padding_tokens,
-                                                      const int    step,
-                                                      const float  q_scaling,
-                                                      const int    relative_attention_bias_stride,
-                                                      const T*     linear_bias_slopes,
-                                                      const bool*  masked_tokens,
-                                                      const int*   ia3_tasks,
-                                                      const T*     ia3_key_weights,
-                                                      const T*     ia3_value_weights,
-                                                      const float* qkv_scale_out,
-                                                      const float* attention_out_scale,
-                                                      const int    int8_mode,
-                                                      const float* attention_kv_scale,
-                                                      cudaStream_t stream)
-{
-    using DataType = typename SATypeConverter<T>::Type;
-    // Prepare the parameters.
-    Masked_multihead_attention_params<DataType> params;
-    memset(&params, 0, sizeof(params));
-    // int hidden_units = head_num * size_per_head;
-    if (qkv_bias != nullptr) {
-        params.q_bias = reinterpret_cast<const DataType*>(qkv_bias);
-        params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + head_num * size_per_head;
-        params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + (head_num + kv_head_num) * size_per_head;
-    }
-    else {
-        params.q_bias = nullptr;
-        params.k_bias = nullptr;
-        params.v_bias = nullptr;
-    }
-
-    // Set the output buffer.
-    params.out = reinterpret_cast<DataType*>(context_buf);
-
-    // Set the input buffers.
-    // [B, nH + kvH, D]
-    params.q = reinterpret_cast<const DataType*>(qkv_buf);
-    params.k = reinterpret_cast<const DataType*>(qkv_buf) + head_num * size_per_head;
-    params.v = reinterpret_cast<const DataType*>(qkv_buf) + (head_num + kv_head_num) * size_per_head;
-
-    params.stride   = (head_num + 2 * kv_head_num) * size_per_head;
-    params.finished = const_cast<bool*>(finished);
-
-    FT_CHECK(k_cache_per_sample && v_cache_per_sample);
-
-    params.k_cache_per_sample         = reinterpret_cast<DataType**>(k_cache_per_sample);
-    params.v_cache_per_sample         = reinterpret_cast<DataType**>(v_cache_per_sample);
-    params.kv_cache_per_sample_offset = kv_cache_per_sample_offset;
-    params.batch_size                 = inference_batch_size;
-    params.beam_width                 = beam_width;
-    params.memory_max_len             = memory_max_len;
-    params.prefix_prompt_lengths      = prefix_prompt_lengths;
-    params.max_prefix_prompt_length   = max_prefix_prompt_length;
-    params.length_per_sample          = sequence_lengths;  // max_input_length + current output length
-    // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
-    params.timestep     = step + max_prefix_prompt_length - 1;
-    params.num_heads    = head_num;
-    params.num_kv_heads = kv_head_num;
-
-    params.hidden_size_per_head    = size_per_head;
-    params.rotary_embedding_dim    = rotary_embedding_dim;
-    params.max_position_embeddings = max_position_embeddings;
-    params.use_dynamic_ntk         = use_dynamic_ntk;
-    params.use_logn_attn           = use_logn_attn;
-
-    // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
-    params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling);
-
-    params.total_padding_tokens = total_padding_tokens;
-    if (relative_attention_bias != nullptr) {
-        params.relative_attention_bias = reinterpret_cast<const DataType*>(relative_attention_bias);
-    }
-    params.relative_attention_bias_stride = relative_attention_bias_stride;
-    params.masked_tokens                  = masked_tokens;
-
-    // The slope of linear position bias per head, e.g., ALiBi.
-    if (linear_bias_slopes != nullptr) {
-        params.linear_bias_slopes = reinterpret_cast<const DataType*>(linear_bias_slopes);
-    }
-    params.max_input_length = max_input_len;
-
-    params.int8_mode = int8_mode;
-
-    if (int8_mode & QuantPolicy::kCacheKVInt8) {
-        params.attention_k_scale = attention_kv_scale[0];
-        params.attention_k_zp    = attention_kv_scale[1];
-        params.attention_v_scale = attention_kv_scale[2];
-        params.attention_v_zp    = attention_kv_scale[3];
-    }
-
-    PUSH_RANGE("scaled dot-product fusion");
-    masked_multihead_attention(params, stream);
-    POP_RANGE;
-}
-
-template<typename T>
-void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int key_len, int max_memory_len)
+void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
@@ -212,12 +80,9 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
      *    \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
      */
 
-    const T*   input_query_data      = input_tensors->getPtr<T>("input_query");
-    const int* sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
-    // const int*  total_padding_len     = input_tensors->getPtr<int>("total_padding_tokens");
-    const bool* finished_data      = input_tensors->getPtr<bool>("finished", nullptr);
-    const bool* masked_tokens_data = input_tensors->getPtr<bool>("masked_tokens", nullptr);
-    const int*  cache_indir        = input_tensors->getPtr<int>("cache_indirection", nullptr);
+    const T*    input_query_data      = input_tensors->getPtr<T>("input_query");
+    const int*  sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
+    const bool* finished_data         = input_tensors->getPtr<bool>("finished");
 
     T*  hidden_features_data = output_tensors->getPtr<T>("attention_output");
     T** key_cache_ptrs       = output_tensors->getPtr<T*>("key_cache");
@@ -227,15 +92,12 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     const int layer_id = input_tensors->getVal<int>("layer_id");
 
-    const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
-    const int step        = input_tensors->getVal<int>("step");
-
-    const int step_1 = step - 1;
+    // const int step        = input_tensors->getVal<int>("step");
+    // const int step_1 = step - 1;
 
     const int batch_size = input_tensors->at("input_query").shape[0];
-    const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1;
 
-    allocateBuffer(batch_size, step, max_seq_len);
+    allocateBuffer(batch_size);
 
     PUSH_RANGE("qkv_gemm");
     linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
@@ -274,46 +136,6 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     LaunchDecoderMultiheadAttention<T, 128>(params);
 
-    // fusedQKV_masked_attention_dispatch<T>(
-    //     qkv_buf_,
-    //     weights->qkv.bias,  // query_weight.bias,
-    //     nullptr,            // relative_attention_bias,
-    //     nullptr,
-    //     nullptr,
-    //     key_cache_ptrs,
-    //     value_cache_ptrs,
-    //     kv_cache_layer_offset,
-    //     cache_indir,
-    //     context_buf_,
-    //     finished_data,
-    //     sequence_lengths_data,  // NOTE: current seq len including padding (fixed after
-    //     meeting the finished id) batch_size, batch_size, beam_width, local_head_num_,
-    //     local_kv_head_num_,
-    //     size_per_head_,
-    //     params_.rotray_embedding_dim,
-    //     params_.max_position_embeddings,
-    //     params_.use_dynamic_ntk,
-    //     params_.use_logn_attn,
-    //     memory_len,
-    //     nullptr,  // prefix_prompt_lengths
-    //     0,        // max_prefix_prompt_length
-    //     0,        // max_input_length, not used w/o linear_bias_slopes
-    //     input_tensors->getPtr<int>("total_padding_tokens", nullptr),
-    //     step,
-    //     1.f,                            // q_scaling
-    //     0,                              // relative_attention_bias_stride
-    //     nullptr,                        // linear_bias_slopes
-    //     nullptr,                        //  masked_tokens_data,
-    //     nullptr,                        // ia3_tasks
-    //     nullptr,                        // ia3_key_weights
-    //     nullptr,                        // ia3_value_weights
-    //     nullptr,                        // qkv_scale_out
-    //     nullptr,                        // attention_out_scale
-    //     quant_policy_,                  // int8_mode
-    //     weights->past_kv_scale.data(),  // attention kv scale
-    //     stream_);
-    // sync_check_cuda_error();
-
     linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
 
     if (tensor_para_.world_size_ > 1) {
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
index ac1c02caac..73c9674d23 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
@@ -32,7 +32,7 @@ template<typename T>
 class LlamaDecoderSelfAttentionLayer {
 public:
     void freeBuffer();
-    void allocateBuffer(size_t batch_size, int key_len, int max_memory_len);
+    void allocateBuffer(size_t batch_size);
 
     LlamaDecoderSelfAttentionLayer(size_t                      head_num,
                                    size_t                      kv_head_num,
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index 3532bf4216..01e10ab260 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -88,7 +88,6 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
     cuda_device_prop_(cuda_device_prop),
     debug_(isDebug()),
     step_length_(step_length),
-    // batch_(max_batch_size, max_context_token_num, session_len, this),
     shared_state_(shared_state)
 
 {
@@ -215,7 +214,6 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
                                T*         context_decoder_output_buf,
                                const int* input_ids,
                                const int* input_length,
-                               const int* history_length,
                                const int* context_length,
                                const int* cu_block_counts,
                                size_t     token_num,
@@ -255,7 +253,6 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
         {"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_input_buf}},
         {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
         {"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, input_length}},
-        {"history_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, history_length}},
         {"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, context_length}},
         {"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}},
         {"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}},
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index 09f49ceaaf..c2b2b34eea 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -110,7 +110,6 @@ class LlamaV2 {
                        T*         context_decoder_output_buf,
                        const int* input_ids,
                        const int* input_length,
-                       const int* history_length,
                        const int* context_length,
                        const int* cu_block_counts,
                        size_t     token_num,
diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu
index 822fd1bc0d..fe5dc2f44d 100644
--- a/src/turbomind/models/llama/llama_kernels.cu
+++ b/src/turbomind/models/llama/llama_kernels.cu
@@ -206,7 +206,7 @@ __global__ void extend_kv_cache(T**          k_dst_ptrs,
                                 const T*     v_src,
                                 const int*   cu_block_counts,
                                 const int*   query_length,
-                                const int*   history_length,
+                                const int*   context_length,
                                 const int    block_length,
                                 const size_t dst_layer_offset,
                                 const int    max_q_len,
@@ -215,7 +215,7 @@ __global__ void extend_kv_cache(T**          k_dst_ptrs,
 {
     const int batch_id     = blockIdx.y;
     const int query_len    = query_length[batch_id];
-    const int history_len  = history_length[batch_id];
+    const int history_len  = context_length[batch_id] - query_len;
     const int cu_block_cnt = cu_block_counts[batch_id];
 
     const int     head_id = blockIdx.z;
@@ -259,7 +259,7 @@ void invokeExtendKVCache(T**          k_dst_ptrs,
                          const T*     v_src,
                          const int*   cu_block_counts,
                          const int*   query_length,
-                         const int*   history_length,
+                         const int*   context_length,
                          int          batch_size,
                          int          block_length,
                          size_t       dst_layer_offset,
@@ -283,7 +283,7 @@ void invokeExtendKVCache(T**          k_dst_ptrs,
                                                    v_src,
                                                    cu_block_counts,
                                                    query_length,
-                                                   history_length,
+                                                   context_length,
                                                    block_length,
                                                    dst_layer_offset,
                                                    max_q_len,
diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h
index 96385e5763..d6226baf50 100644
--- a/src/turbomind/models/llama/llama_kernels.h
+++ b/src/turbomind/models/llama/llama_kernels.h
@@ -40,7 +40,7 @@ void invokeExtendKVCache(T**          k_dst_ptrs,
                          const T*     v_src,
                          const int*   cu_block_counts,
                          const int*   query_length,
-                         const int*   history_length,
+                         const int*   context_length,
                          int          batch_size,
                          int          block_length,
                          size_t       dst_layer_offset,

From fac21cd8b1337e0b2799345f26d59a3e4d2718f2 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 9 Oct 2023 06:17:15 +0000
Subject: [PATCH 08/56] rename

---
 .../{decoder_mha => decoder_multihead_attention}/array_ops.h      | 0
 .../decoder_multihead_attention.cu                                | 0
 .../decoder_multihead_attention.h                                 | 0
 .../decoder_multihead_attention_params.h                          | 0
 .../decoder_multihead_attention_template.h                        | 0
 .../{decoder_mha => decoder_multihead_attention}/iterator.h       | 0
 .../{decoder_mha => decoder_multihead_attention}/kv_cache.cu      | 0
 .../{decoder_mha => decoder_multihead_attention}/kv_cache.h       | 0
 .../test_decoder_multihead_attention.cu                           | 0
 .../{decoder_mha => decoder_multihead_attention}/test_utils.cu    | 0
 .../{decoder_mha => decoder_multihead_attention}/test_utils.h     | 0
 .../{decoder_mha => decoder_multihead_attention}/thread_map.h     | 0
 12 files changed, 0 insertions(+), 0 deletions(-)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/array_ops.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/decoder_multihead_attention.cu (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/decoder_multihead_attention.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/decoder_multihead_attention_params.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/decoder_multihead_attention_template.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/iterator.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/kv_cache.cu (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/kv_cache.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/test_decoder_multihead_attention.cu (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/test_utils.cu (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/test_utils.h (100%)
 rename src/turbomind/kernels/{decoder_mha => decoder_multihead_attention}/thread_map.h (100%)

diff --git a/src/turbomind/kernels/decoder_mha/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/array_ops.h
rename to src/turbomind/kernels/decoder_multihead_attention/array_ops.h
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/decoder_multihead_attention.cu
rename to src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h
rename to src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/decoder_multihead_attention_params.h
rename to src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
diff --git a/src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/decoder_multihead_attention_template.h
rename to src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
diff --git a/src/turbomind/kernels/decoder_mha/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/iterator.h
rename to src/turbomind/kernels/decoder_multihead_attention/iterator.h
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.cu b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/kv_cache.cu
rename to src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
diff --git a/src/turbomind/kernels/decoder_mha/kv_cache.h b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/kv_cache.h
rename to src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
diff --git a/src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/test_decoder_multihead_attention.cu
rename to src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.cu b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/test_utils.cu
rename to src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
diff --git a/src/turbomind/kernels/decoder_mha/test_utils.h b/src/turbomind/kernels/decoder_multihead_attention/test_utils.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/test_utils.h
rename to src/turbomind/kernels/decoder_multihead_attention/test_utils.h
diff --git a/src/turbomind/kernels/decoder_mha/thread_map.h b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
similarity index 100%
rename from src/turbomind/kernels/decoder_mha/thread_map.h
rename to src/turbomind/kernels/decoder_multihead_attention/thread_map.h

From a0f24509e663559b38006bc6c1119d173cce5d5b Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 9 Oct 2023 10:31:03 +0000
Subject: [PATCH 09/56] GQA support

---
 src/turbomind/kernels/CMakeLists.txt          |  2 +-
 .../decoder_multihead_attention.cu            | 41 +++++++++++----
 .../decoder_multihead_attention.h             |  4 +-
 .../decoder_multihead_attention_template.h    | 52 ++++++++-----------
 .../decoder_multihead_attention/iterator.h    |  1 -
 .../test_decoder_multihead_attention.cu       |  2 +-
 .../llama/LlamaContextAttentionLayer.cc       |  2 +-
 .../llama/LlamaDecoderSelfAttentionLayer.cc   |  4 +-
 8 files changed, 61 insertions(+), 47 deletions(-)

diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt
index 473e579c45..f96da72200 100644
--- a/src/turbomind/kernels/CMakeLists.txt
+++ b/src/turbomind/kernels/CMakeLists.txt
@@ -71,4 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
 set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
 
 add_subdirectory(gemm_s_f16)
-add_subdirectory(decoder_mha)
\ No newline at end of file
+add_subdirectory(decoder_multihead_attention)
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index 113780d65b..ad9acf38ea 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -23,26 +23,49 @@ bool Dump()
     return true;
 }
 
-template<typename T, int HeadDim>
-void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
+template<typename T, int HeadDim, int HeadPerCta>
+void InvokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
 {
-    using MHAType = DecoderMultiHeadAttentionKernel<T, 1, HeadDim, 16, HeadDim, 2048, 6>;
+    using MHAType = DecoderMultiHeadAttentionKernel<T, HeadPerCta, HeadDim, 16, HeadDim, 2048, 6>;
 
     [[maybe_unused]] static const bool init = Dump<MHAType>();
 
     dim3 block(MHAType::kWarpCount * WARP_SIZE);
-    dim3 grid(params.num_kv_heads, params.batch_size);
+    dim3 grid(params.num_heads / HeadPerCta, params.batch_size);
 
-    const size_t kDynamicSmemSize = MHAType::GetDynamicSmemSize(0);
+    static const size_t kDynSmemSize = MHAType::GetDynamicSmemSize();
     // std::cout << "dynamic shared memory size: " << kDynamicSmemSize << "\n";
 
     cudaFuncSetAttribute(
-        decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynamicSmemSize);
+        decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
 
-    decoder_multihead_attention<MHAType><<<grid, block, kDynamicSmemSize>>>(params);
+    decoder_multihead_attention<MHAType><<<grid, block, kDynSmemSize>>>(params);
 }
 
-template void LaunchDecoderMultiheadAttention<half, 128>(const DecoderMultiHeadAttentionParams<half>& params);
-template void LaunchDecoderMultiheadAttention<float, 128>(const DecoderMultiHeadAttentionParams<float>& params);
+template<typename T>
+void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
+{
+    static constexpr int HeadDim = 128;
+
+    FT_CHECK(params.size_per_head == HeadDim);
+
+    int group_size = params.num_heads / params.num_kv_heads;
+
+    if (group_size % 8 == 0) {
+        InvokeDecoderMultiheadAttention<T, HeadDim, 8>(params);
+    }
+    else if (group_size % 4 == 0) {
+        InvokeDecoderMultiheadAttention<T, HeadDim, 4>(params);
+    }
+    else if (group_size % 2 == 0) {
+        InvokeDecoderMultiheadAttention<T, HeadDim, 2>(params);
+    }
+    else {
+        InvokeDecoderMultiheadAttention<T, HeadDim, 1>(params);
+    }
+}
+
+template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<half>& params);
+template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<float>& params);
 
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
index cdee4af1cd..f4eca0617c 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
@@ -3,7 +3,7 @@
 
 namespace turbomind {
 
-template<typename T, int HeadDim>
-void LaunchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params);
+template<typename T>
+void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params);
 
 }
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index 20669dad91..9ca22226c4 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -36,7 +36,7 @@ struct DecoderMultiHeadAttentionKernel {
     using MapKv  = ThreadMapKv<kMaxHeadDim, kKeyPerIter, kVecKvSize, kThreadPerKey, kWarpCount>;
     using IterKv = turbomind::Iterator<T, MapKv, SliceLen, kStages, kUseBlockIter>;
 
-    static size_t GetDynamicSmemSize(int)
+    static constexpr size_t GetDynamicSmemSize()
     {
         size_t smem_kv_cache = IterKv::kSmemByteSize;
         size_t smem_kv_align = 128;
@@ -45,9 +45,6 @@ struct DecoderMultiHeadAttentionKernel {
         return smem_kv_align + smem_kv_cache + std::max(smem_qk, smem_pr);
     }
 
-    // using AccumType   = float;
-    // using ComputeType = float;
-
     using QkAccumType   = float;
     using QkComputeType = float;
 
@@ -70,6 +67,9 @@ struct DecoderMultiHeadAttentionKernel {
     int warp_id_;
     int lane_id_;
 
+    int  kv_head_idx_;
+    bool is_gqa_leader_;
+
     int timestep_;
     T*  k_cache_;  // [S, D]
     T*  v_cache_;  // [S, D]
@@ -105,11 +105,15 @@ struct DecoderMultiHeadAttentionKernel {
         smem_red_max_ = smem.red_max;
         smem_red_sum_ = smem.red_sum;
 
-        head_idx_  = blockIdx.x;
+        head_idx_  = blockIdx.x * kHeadPerCta;
         batch_idx_ = blockIdx.y;
         warp_id_   = threadIdx.x / WARP_SIZE;
         lane_id_   = threadIdx.x % WARP_SIZE;
 
+        const int gqa_group_size = params.num_heads / params.num_kv_heads;
+        kv_head_idx_             = head_idx_ / gqa_group_size;
+        is_gqa_leader_           = head_idx_ % gqa_group_size == 0;
+
         timestep_ = params_.per_sample_length[batch_idx_];
 
         if constexpr (kUseBlockIter) {
@@ -118,23 +122,12 @@ struct DecoderMultiHeadAttentionKernel {
         }
         else {
             k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.layer_offset
-                       + head_idx_ * params_.max_seq_len * params_.size_per_head;
+                       + kv_head_idx_ * params_.max_seq_len * params_.size_per_head;
             v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.layer_offset
-                       + head_idx_ * params_.max_seq_len * params_.size_per_head;
+                       + kv_head_idx_ * params_.max_seq_len * params_.size_per_head;
         }
     }
 
-    // [kkkk][vvvv][kkkk][vvvv][kkkk][vvvv][k][v]
-    // __device__ int is_last_iter_of_slice(int iter, int full, int partial)
-    // {
-    //     if (iter < full) {
-    //         return (iter + 1) % kIterPerSlice == 0;
-    //     }
-    //     else {
-    //         return (iter - full + 1) % partial == 0;
-    //     }
-    // }
-
     __device__ void Prolugue()
     {
         // - Each warp is handling a row of Q
@@ -147,7 +140,7 @@ struct DecoderMultiHeadAttentionKernel {
         using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
 
         static constexpr int kQVecPerThread  = MapQ::kIterC;
-        static constexpr int kQHeadPerThread = MapQ::kIterS;  // > 1 when #warp < #head
+        static constexpr int kQHeadPerThread = MapQ::kIterS;  // > 1 when #warp < kCtaPerHead
 
         static_assert(kQVecPerThread == 1);
 
@@ -169,8 +162,8 @@ struct DecoderMultiHeadAttentionKernel {
             int qi = offset.y + s;
             Ldg(frag_Q[s], &params_.q[batch_idx_ * params_.stride + (head_idx_ + qi) * kHeadDim + di]);
         }
-        Ldg(frag_K, &params_.k[batch_idx_ * params_.stride + head_idx_ * kHeadDim + offset.x]);
-        Ldg(frag_V, &params_.v[batch_idx_ * params_.stride + head_idx_ * kHeadDim + offset.x]);
+        Ldg(frag_K, &params_.k[batch_idx_ * params_.stride + kv_head_idx_ * kHeadDim + offset.x]);
+        Ldg(frag_V, &params_.v[batch_idx_ * params_.stride + kv_head_idx_ * kHeadDim + offset.x]);
 
         if (params_.q_bias) {
             // load biases
@@ -183,8 +176,8 @@ struct DecoderMultiHeadAttentionKernel {
             }
             VecQ bias_K;
             VecQ bias_V;
-            Ldg(bias_K, &params_.k_bias[head_idx_ * kHeadDim + offset.x]);
-            Ldg(bias_V, &params_.v_bias[head_idx_ * kHeadDim + offset.x]);
+            Ldg(bias_K, &params_.k_bias[kv_head_idx_ * kHeadDim + offset.x]);
+            Ldg(bias_V, &params_.v_bias[kv_head_idx_ * kHeadDim + offset.x]);
 
             using namespace ops;
             // apply biases
@@ -222,7 +215,7 @@ struct DecoderMultiHeadAttentionKernel {
         }
 
         // store
-        if (warp_id_ == 0) {
+        if (warp_id_ == 0 && is_gqa_leader_) {
             if constexpr (kUseBlockIter) {
                 int block_index  = timestep_ / params_.kv_cache_block_size;
                 int block_offset = timestep_ % params_.kv_cache_block_size;
@@ -230,9 +223,9 @@ struct DecoderMultiHeadAttentionKernel {
                 //     printf("%d %d %p %p\n", block_index, block_offset, k_cache_ptrs_, v_cache_ptrs_);
                 // }
                 k_cache_ = (T*)k_cache_ptrs_[block_index] + params_.layer_offset
-                           + head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                           + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
                 v_cache_ = (T*)v_cache_ptrs_[block_index] + params_.layer_offset
-                           + head_idx_ * params_.kv_cache_block_size * kHeadDim;
+                           + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
                 Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K);
                 Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V);
             }
@@ -299,7 +292,7 @@ struct DecoderMultiHeadAttentionKernel {
             iter_K = {k_cache_ptrs_,
                       params_.kv_cache_block_size,
                       params_.layer_offset,
-                      head_idx_,
+                      kv_head_idx_,
                       smem_Kv_,
                       step,
                       step + iter_length,
@@ -488,7 +481,7 @@ struct DecoderMultiHeadAttentionKernel {
             iter_V = {v_cache_ptrs_,
                       params_.kv_cache_block_size,
                       params_.layer_offset,
-                      head_idx_,
+                      kv_head_idx_,
                       smem_Kv_,
                       step,
                       step + iter_length,
@@ -680,7 +673,6 @@ struct DecoderMultiHeadAttentionKernel {
         using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
 
         static constexpr int kQkvHeadPerThread = MapQ::kIterS;
-        static_assert(kQkvHeadPerThread == 1);
 
         int2 offset = MapQ::get_offset(warp_id_, lane_id_);
 
@@ -697,7 +689,7 @@ struct DecoderMultiHeadAttentionKernel {
             // float scale = 1.f;
             using namespace ops;
             VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
-            /// FIXME: `(head_idx_ + qi)` doesn't look right
+
             Store(&params_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
                   cast<Dtype>(frag_O));
         }
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
index 08f939827c..deb2939488 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -255,7 +255,6 @@ struct Iterator {
         const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);
         constexpr int cp_size      = sizeof(AccessType);
         static_assert(cp_size == 16);
-        // cp.async.cg.shared.global.L2::256B
         asm volatile("{\n"
                      "  .reg .pred p;\n"
                      "  setp.ne.b32 p, %0, 0;\n"
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index 7397443301..67f2984f18 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -244,7 +244,7 @@ int main(int argc, char* argv[])
     std::vector<thrust::universal_vector<half>> outputs;
 
     for (int i = 0; i < std::max(kTestIter, 10); ++i) {
-        LaunchDecoderMultiheadAttention<half, 128>(params);
+        DispatchDecoderMultiheadAttention<half>(params);
         if (auto err = cudaGetLastError(); err != cudaSuccess) {
             std::cout << cudaGetErrorString(err) << "\n";
             return -1;
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index a3b358d9ed..a71d4e7b7c 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -21,7 +21,7 @@
 
 #include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
 #include "src/turbomind/kernels/bert_preprocess_kernels.h"
-#include "src/turbomind/kernels/decoder_mha/kv_cache.h"
+#include "src/turbomind/kernels/decoder_multihead_attention/kv_cache.h"
 #include "src/turbomind/kernels/unfused_attention_kernels.h"
 #include "src/turbomind/macro.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index 135ed8f074..ad4d4fafdc 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -19,7 +19,7 @@
 // https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.cc
 #include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
 #include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
-#include "src/turbomind/kernels/decoder_mha/decoder_multihead_attention.h"
+#include "src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h"
 #include "src/turbomind/macro.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
@@ -134,7 +134,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.rotary_embedding_dim  = size_per_head_;
     params.rotary_embedding_base = 10000.f;
 
-    LaunchDecoderMultiheadAttention<T, 128>(params);
+    DispatchDecoderMultiheadAttention<T>(params);
 
     linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
 

From 139f71d2b4a18dc04802ab2aa91563e86b2a596e Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 9 Oct 2023 11:23:25 +0000
Subject: [PATCH 10/56] fix context length

---
 src/turbomind/models/llama/LlamaBatch.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 882da30e22..e828869eae 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -358,8 +358,6 @@ bool LlamaBatch<T>::Initialize()
             });
         }
 
-        Copy(state_->h_context_length, batch_size, context_length_buf_);
-
         if (1) {
             std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1);
             dbg(cu_block_cnts);
@@ -706,6 +704,7 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
         }
     }
 
+    Copy(state_->h_context_length, batch_size, context_length_buf_);  // also referenced in `SetOutputTensors`
     Copy(context_length_buf_, batch_size, sequence_lengths_);
     // `sequence_lengths_` will be increased by dynamic decode
     // note that in decoder and in output "sequence length" has different semantic
@@ -860,6 +859,7 @@ void LlamaBatch<T>::ContextDecode()
 
     const int context_decode_count = batch_size - base;
 
+    Copy(state_->h_context_length, batch_size, context_length_buf_);
     Copy(h_input_length_buf_, batch_size, input_length_buf_);
 
     check_cuda_error(cudaStreamSynchronize(stream_));

From 94a5d4a7e94ce07f91f910ef0ff4508a0c84835c Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 9 Oct 2023 12:29:02 +0000
Subject: [PATCH 11/56] GQA dispatch

---
 .../decoder_multihead_attention.cu            | 26 ++++++++++++-------
 1 file changed, 16 insertions(+), 10 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index ad9acf38ea..de0b9dd248 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -49,16 +49,22 @@ void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>&
 
     FT_CHECK(params.size_per_head == HeadDim);
 
-    int group_size = params.num_heads / params.num_kv_heads;
-
-    if (group_size % 8 == 0) {
-        InvokeDecoderMultiheadAttention<T, HeadDim, 8>(params);
-    }
-    else if (group_size % 4 == 0) {
-        InvokeDecoderMultiheadAttention<T, HeadDim, 4>(params);
-    }
-    else if (group_size % 2 == 0) {
-        InvokeDecoderMultiheadAttention<T, HeadDim, 2>(params);
+    if constexpr (std::is_same_v<T, half>) {
+
+        int group_size = params.num_heads / params.num_kv_heads;
+
+        if (group_size % 8 == 0) {
+            InvokeDecoderMultiheadAttention<T, HeadDim, 8>(params);
+        }
+        else if (group_size % 4 == 0) {
+            InvokeDecoderMultiheadAttention<T, HeadDim, 4>(params);
+        }
+        else if (group_size % 2 == 0) {
+            InvokeDecoderMultiheadAttention<T, HeadDim, 2>(params);
+        }
+        else {
+            InvokeDecoderMultiheadAttention<T, HeadDim, 1>(params);
+        }
     }
     else {
         InvokeDecoderMultiheadAttention<T, HeadDim, 1>(params);

From 68aa13560460dfd92b418128bd42b14390fbf57e Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Wed, 11 Oct 2023 10:24:07 +0000
Subject: [PATCH 12/56] kv8

---
 .../decoder_multihead_attention/array_ops.h   | 155 ++++++-
 .../decoder_multihead_attention.cu            |  41 +-
 .../decoder_multihead_attention_params.h      |  33 +-
 .../decoder_multihead_attention_template.h    | 131 ++++--
 .../decoder_multihead_attention/iterator.h    |  10 +-
 .../decoder_multihead_attention/kv_cache.cu   | 261 ++++++++---
 .../decoder_multihead_attention/kv_cache.h    |  27 +-
 .../test_decoder_multihead_attention.cu       |  24 +-
 .../llama/LlamaContextAttentionLayer.cc       |  36 +-
 .../models/llama/LlamaDecoderLayerWeight.cc   |   2 +-
 .../llama/LlamaDecoderSelfAttentionLayer.cc   |   3 +
 src/turbomind/models/llama/llama_kernels.cu   | 437 +++++++++---------
 src/turbomind/models/llama/llama_kernels.h    |   4 +-
 13 files changed, 759 insertions(+), 405 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index a157d15cac..99d6135b82 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -1,6 +1,6 @@
 #pragma once
 
-#include "../gemm_s_f16/common.h"
+#include "src/turbomind/kernels/gemm_s_f16/common.h"
 #include <cfloat>
 #include <limits>
 
@@ -200,8 +200,8 @@ inline __device__ void Lds(Array<T, N>& dst, const T* src)
     }
 }
 
-template<typename Accum, typename Compute, int kThreadGroupSize, typename T, int N, int V>
-inline __device__ Accum qk_dot(const Array<T, N> (&q)[V], const Array<T, N> (&k)[V])
+template<typename Accum, typename Compute, int kThreadGroupSize, typename Tq, typename Tk, int N, int V>
+inline __device__ Accum qk_dot(const Array<Tq, N> (&q)[V], const Array<Tk, N> (&k)[V])
 {
     Accum accum{};
 
@@ -221,8 +221,8 @@ inline __device__ Accum qk_dot(const Array<T, N> (&q)[V], const Array<T, N> (&k)
     return accum;
 }
 
-template<typename Accum, typename Compute, int kThreadGroupSize, typename T, int N>
-inline __device__ Accum qk_dot(const Array<T, N>& q, const Array<T, N>& k)
+template<typename Accum, typename Compute, int kThreadGroupSize, typename Tq, typename Tk, int N>
+inline __device__ Accum qk_dot(const Array<Tq, N>& q, const Array<Tk, N>& k)
 {
     Accum accum{};
 
@@ -314,4 +314,149 @@ inline __device__ Array<T, N> blockSum(Array<T, N> val, T* smem_red, int warp_id
     return val;
 }
 
+//////////////////////////////////////////////////////////////////////////////////////////////////
+
+// generic case for floating point -> floating point / integer -> integer conversion
+template<typename Ti, typename To, typename = void>
+struct ConvertKvCache {
+    __device__ __host__ ConvertKvCache(float, float) {}
+    template<int N>
+    inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<To, N>
+    {
+        Array<To, N> vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            vo[i] = (To)vi[i];
+        }
+        return vo;
+    }
+};
+
+// generic case for converting to same type, bypass
+template<typename T>
+struct ConvertKvCache<T, T> {
+    __device__ __host__ ConvertKvCache(float, float) {}
+    template<int N>
+    inline __device__ auto operator()(const Array<T, N>& v) const -> Array<T, N>
+    {
+        return v;
+    }
+};
+
+template<typename Ti>
+struct ConvertKvCache<Ti, int8_t> {
+
+    float scale_;
+    float zero_;
+
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
+
+    inline __device__ uint8_t round(float x) const
+    {
+        uint32_t y;
+        asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
+        return y;
+    }
+
+    template<int N>
+    inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<int8_t, N>
+    {
+        Array<int8_t, N> vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            // convert to unsigned int by offseting +128
+            (uint8_t&)vo[i] = round(((float)vi[i] - zero_) / scale_ + 128.f);
+        }
+        return vo;
+    }
+};
+
+inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
+{
+    union {
+        Array<float, 4>    f32x4;
+        Array<uint32_t, 4> u32x4;
+    };
+
+    auto& i8s = (const uint32_t&)x;
+
+    // 00000000111111112222222233333333
+    // 01234567012345670123456701234567
+    // SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
+    // 0????????_______XXXXXXXX________
+    // (1 + x / 2^15) * 2^(e - 127) -> e - 127 == 15 -> e = 142
+    //                                       7 6 5 4
+    static constexpr uint32_t f32_magic = 0x47000000;  // 2^15 = 32768
+    static constexpr uint32_t m0        = 0x7604;
+    static constexpr uint32_t m1        = 0x7614;
+    static constexpr uint32_t m2        = 0x7624;
+    static constexpr uint32_t m3        = 0x7634;
+
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
+    asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
+
+    if (0) {  // fused with dequantization
+        PRAGMA_UNROLL
+        for (int i = 0; i < 4; ++i) {
+            f32x4[i] -= 32896.f;  // 32768 + 128
+        }
+    }
+
+    return f32x4;
+}
+
+template<>
+struct ConvertKvCache<int8_t, float> {
+
+    float scale_;
+    float zero_;
+
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
+
+    template<int N>
+    inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<float, N>
+    {
+        Array<float, N> vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 4) {
+            auto& vec = (Array<float, 4>&)vo[i];
+            vec       = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
+            PRAGMA_UNROLL
+            for (int j = 0; j < 4; ++j) {
+                // vec[j] = vec[j] * scale + zero;
+                vec[j] = vec[j] * scale_ + (zero_ - 32896.f * scale_);
+            }
+        }
+        return vo;
+    }
+};
+
+template<>
+struct ConvertKvCache<int8_t, half> {
+
+    float scale_;
+    float zero_;
+
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
+
+    template<int N>
+    inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<half, N>
+    {
+        Array<half, N> vo;
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; i += 4) {
+            auto& vec = (Array<half, 4>&)vo[i];
+            auto  tmp = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
+            PRAGMA_UNROLL
+            for (int j = 0; j < 4; ++j) {
+                // vec[j] = half(tmp[j] * scale + zero);
+                vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
+            }
+        }
+        return vo;
+    }
+};
+
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index de0b9dd248..91dc207876 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -1,4 +1,5 @@
 #include "decoder_multihead_attention_template.h"
+#include "src/turbomind/models/llama/llama_utils.h"
 
 #include <iostream>
 
@@ -23,10 +24,11 @@ bool Dump()
     return true;
 }
 
-template<typename T, int HeadDim, int HeadPerCta>
-void InvokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
+template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
+void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
 {
-    using MHAType = DecoderMultiHeadAttentionKernel<T, HeadPerCta, HeadDim, 16, HeadDim, 2048, 6>;
+    // 2048_32x6 ~ 64k smem
+    using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;
 
     [[maybe_unused]] static const bool init = Dump<MHAType>();
 
@@ -51,24 +53,29 @@ void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>&
 
     if constexpr (std::is_same_v<T, half>) {
 
-        int group_size = params.num_heads / params.num_kv_heads;
-
-        if (group_size % 8 == 0) {
-            InvokeDecoderMultiheadAttention<T, HeadDim, 8>(params);
-        }
-        else if (group_size % 4 == 0) {
-            InvokeDecoderMultiheadAttention<T, HeadDim, 4>(params);
-        }
-        else if (group_size % 2 == 0) {
-            InvokeDecoderMultiheadAttention<T, HeadDim, 2>(params);
+        //     int group_size = params.num_heads / params.num_kv_heads;
+
+        //     if (group_size % 8 == 0) {
+        //         invokeDecoderMultiheadAttention<T, HeadDim, 8>(params);
+        //     }
+        //     else if (group_size % 4 == 0) {
+        //         invokeDecoderMultiheadAttention<T, HeadDim, 4>(params);
+        //     }
+        //     else if (group_size % 2 == 0) {
+        //         invokeDecoderMultiheadAttention<T, HeadDim, 2>(params);
+        //     }
+        //     else {
+        //         invokeDecoderMultiheadAttention<T, HeadDim, 1>(params);
+        //     }
+        // }
+        // else {
+        if (params.quant_policy & QuantPolicy::kCacheKVInt8) {
+            invokeDecoderMultiheadAttention<T, int8_t, HeadDim, 1>(params);
         }
         else {
-            InvokeDecoderMultiheadAttention<T, HeadDim, 1>(params);
+            invokeDecoderMultiheadAttention<T, T, HeadDim, 1>(params);
         }
     }
-    else {
-        InvokeDecoderMultiheadAttention<T, HeadDim, 1>(params);
-    }
 }
 
 template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<half>& params);
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index b055961289..199d0f4667 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -5,33 +5,33 @@ namespace turbomind {
 template<typename T>
 struct DecoderMultiHeadAttentionParams {
     // token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]
-    T*  out;
-    T*  q;
-    T*  k;
-    T*  v;
+    T* __restrict__ out;
+    T* __restrict__ q;
+    T* __restrict__ k;
+    T* __restrict__ v;
     int stride;
 
     // bias, [qH, D] or [kvH, D]
-    T* q_bias;
-    T* k_bias;
-    T* v_bias;
+    T* __restrict__ q_bias;
+    T* __restrict__ k_bias;
+    T* __restrict__ v_bias;
 
     // sequence-level buffers
-    const int*  per_sample_length;
-    const bool* finished;
+    const int* __restrict__ per_sample_length;
+    const bool* __restrict__ finished;
 
     // kv cache
-    void** per_sample_k_cache;  // [H, S, D]
-    void** per_sample_v_cache;  // [H, S, D]
+    void** __restrict__ per_sample_k_cache;  // [H, S, D]
+    void** __restrict__ per_sample_v_cache;  // [H, S, D]
     size_t layer_offset;
 
     /// cache layout M,[N,H,x,D]
     /// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
     /// 1. [L,sum(S),H,x,D]
-    void** k_cache_block_ptrs;  // X,[H,x,D]
-    void** v_cache_block_ptrs;  // X,[H,x,D]
-    int*   cu_block_cnts;       // [B+1]
-    int    kv_cache_block_size;
+    void** __restrict__ k_cache_block_ptrs;  // X,[H,x,D]
+    void** __restrict__ v_cache_block_ptrs;  // X,[H,x,D]
+    int* __restrict__ cu_block_cnts;         // [B+1]
+    int kv_cache_block_size;
 
     // batch-level params
     int batch_size;
@@ -51,6 +51,9 @@ struct DecoderMultiHeadAttentionParams {
 
     // log(n) attention
     bool use_logn_attn;
+
+    int   quant_policy;
+    float kv_quant_params[4];
 };
 
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index 9ca22226c4..a569e95144 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -4,15 +4,23 @@
 #include "iterator.h"
 #include "src/turbomind/kernels/gemm_s_f16/common.h"
 #include "thread_map.h"
+#include <cstdint>
 #include <cuda_pipeline_primitives.h>
+#include <type_traits>
 
 #include "decoder_multihead_attention_params.h"
 
 namespace turbomind {
 
-template<typename T, int HeadPerCta, int MaxHeadDim, int KeyPerIter, int HeadDim, int SliceLen, int Stages>
+template<typename T,
+         typename Tkv,
+         int HeadPerCta,
+         int MaxHeadDim,
+         int KeyPerIter,
+         int HeadDim,
+         int SliceLen,
+         int Stages>
 struct DecoderMultiHeadAttentionKernel {
-    using Dtype     = T;
     using ParamType = DecoderMultiHeadAttentionParams<T>;
 
     static constexpr int kWarpCount  = 4;
@@ -25,16 +33,16 @@ struct DecoderMultiHeadAttentionKernel {
     static constexpr int kSliceLen     = SliceLen;
     static constexpr int kIterPerSlice = kSliceLen / kKeyPerIter;
 
-    static constexpr int kVecKvSize    = sizeof(uint4) / sizeof(T);
+    static constexpr int kVecKvSize    = sizeof(uint4) / sizeof(Tkv);
     static constexpr int kThreadPerKey = 8;
 
-    using VecKv      = Array<Dtype, kVecKvSize>;
+    using VecKv      = Array<T, kVecKvSize>;
     using VecKvFloat = Array<float, kVecKvSize>;
 
     static constexpr bool kUseBlockIter = true;
 
     using MapKv  = ThreadMapKv<kMaxHeadDim, kKeyPerIter, kVecKvSize, kThreadPerKey, kWarpCount>;
-    using IterKv = turbomind::Iterator<T, MapKv, SliceLen, kStages, kUseBlockIter>;
+    using IterKv = turbomind::Iterator<Tkv, MapKv, SliceLen, kStages, kUseBlockIter>;
 
     static constexpr size_t GetDynamicSmemSize()
     {
@@ -52,7 +60,7 @@ struct DecoderMultiHeadAttentionKernel {
     using PvComputeType = float;
 
     struct SharedStorage {
-        __align__(16) Dtype Q[kHeadPerCta * kMaxHeadDim];
+        __align__(16) T Q[kHeadPerCta * kMaxHeadDim];
         __align__(16) float O[kHeadPerCta * kMaxHeadDim];
         float M[kHeadPerCta];  // max{dot(Q,  K^T  )}
         float L[kHeadPerCta];  // sum{exp(s - S_max)}
@@ -71,21 +79,30 @@ struct DecoderMultiHeadAttentionKernel {
     bool is_gqa_leader_;
 
     int timestep_;
-    T*  k_cache_;  // [S, D]
-    T*  v_cache_;  // [S, D]
-
-    const void** k_cache_ptrs_;
-    const void** v_cache_ptrs_;
-
-    Dtype* smem_Kv_;
-    float* smem_S_;
-    float* smem_P_;
-    Dtype* smem_Q_;
-    float* smem_M_;
-    float* smem_L_;
-    float* smem_O_;
-    float* smem_red_max_;
-    float* smem_red_sum_;
+    Tkv* __restrict__ k_cache_;  // [S, D]
+    Tkv* __restrict__ v_cache_;  // [S, D]
+
+    const void** __restrict__ k_cache_ptrs_;
+    const void** __restrict__ v_cache_ptrs_;
+
+    Tkv* __restrict__ smem_Kv_;
+    float* __restrict__ smem_S_;
+    float* __restrict__ smem_P_;
+    T* __restrict__ smem_Q_;
+    float* __restrict__ smem_M_;
+    float* __restrict__ smem_L_;
+    float* __restrict__ smem_O_;
+    float* __restrict__ smem_red_max_;
+    float* __restrict__ smem_red_sum_;
+
+    // avoid redundant type cast for KV8
+    using KLoadType = std::conditional_t<std::is_same_v<Tkv, int8_t>, float, T>;
+    using VLoadType = std::conditional_t<std::is_same_v<Tkv, int8_t>, float, T>;
+
+    ConvertKvCache<T, Tkv>         conv_k_store_;
+    ConvertKvCache<T, Tkv>         conv_v_store_;
+    ConvertKvCache<Tkv, KLoadType> conv_k_;
+    ConvertKvCache<Tkv, VLoadType> conv_v_;
 
     __device__ bool thread0()
     {
@@ -93,9 +110,13 @@ struct DecoderMultiHeadAttentionKernel {
     }
 
     __device__ DecoderMultiHeadAttentionKernel(const ParamType& params, SharedStorage& smem, uint8_t* dsmem):
-        params_(params)
+        params_(params),
+        conv_k_store_{params_.kv_quant_params[0], params_.kv_quant_params[1]},
+        conv_v_store_{params_.kv_quant_params[2], params_.kv_quant_params[3]},
+        conv_k_{params_.kv_quant_params[0], params_.kv_quant_params[1]},
+        conv_v_{params_.kv_quant_params[2], params_.kv_quant_params[3]}
     {
-        smem_Kv_      = (Dtype*)dsmem;
+        smem_Kv_      = (Tkv*)dsmem;
         smem_S_       = (float*)(smem_Kv_ + IterKv::kSizePerTile * kStages);  // [HeadPerCta * kSliceLen]
         smem_P_       = smem_S_;  // ! reusing only works when S and P has same dtype
         smem_Q_       = smem.Q;
@@ -214,6 +235,9 @@ struct DecoderMultiHeadAttentionKernel {
             Store(&smem_O_[qi * kMaxHeadDim + offset.x], cast<float>(frag_V));
         }
 
+        auto farg_K_store = conv_k_store_(frag_K);
+        auto farg_V_store = conv_v_store_(frag_V);
+
         // store
         if (warp_id_ == 0 && is_gqa_leader_) {
             if constexpr (kUseBlockIter) {
@@ -222,16 +246,16 @@ struct DecoderMultiHeadAttentionKernel {
                 // if (thread0()) {
                 //     printf("%d %d %p %p\n", block_index, block_offset, k_cache_ptrs_, v_cache_ptrs_);
                 // }
-                k_cache_ = (T*)k_cache_ptrs_[block_index] + params_.layer_offset
+                k_cache_ = (Tkv*)k_cache_ptrs_[block_index] + params_.layer_offset
                            + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
-                v_cache_ = (T*)v_cache_ptrs_[block_index] + params_.layer_offset
+                v_cache_ = (Tkv*)v_cache_ptrs_[block_index] + params_.layer_offset
                            + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
-                Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K);
-                Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V);
+                Store(&k_cache_[block_offset * kHeadDim + offset.x], farg_K_store);
+                Store(&v_cache_[block_offset * kHeadDim + offset.x], farg_V_store);
             }
             else {
-                Store(&k_cache_[timestep_ * kHeadDim + offset.x], frag_K);
-                Store(&v_cache_[timestep_ * kHeadDim + offset.x], frag_V);
+                Store(&k_cache_[timestep_ * kHeadDim + offset.x], farg_K_store);
+                Store(&v_cache_[timestep_ * kHeadDim + offset.x], farg_V_store);
             }
         }
     }
@@ -272,7 +296,10 @@ struct DecoderMultiHeadAttentionKernel {
 
     struct State {
         // Double buffering to hide smem/dequant latency
-        VecKv frag_Kv_buf[2][kKvVecPerThread];
+        Array<KLoadType, kVecKvSize> frag_K_buf[2][kKvVecPerThread];
+        Array<VLoadType, kVecKvSize> frag_V_buf[2][kKvVecPerThread];
+
+        Array<Tkv, kVecKvSize> frag_Kv_tmp_buf[2][kKvVecPerThread];
     };
 
     static constexpr int kPrefetchCount = (IterKv::kIterCount + MapKv::kIterS - 1) / MapKv::kIterS;
@@ -306,7 +333,12 @@ struct DecoderMultiHeadAttentionKernel {
         PrefetchKvCache(iter_K);
         CpAsyncWait();
 
-        iter_K.Load(state.frag_Kv_buf[0]);
+        iter_K.Load(state.frag_Kv_tmp_buf[0]);
+        PRAGMA_UNROLL
+        for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+            state.frag_K_buf[0][vi] = conv_k_(state.frag_Kv_tmp_buf[0][vi]);
+        }
+
         iter_K.PrefetchBatch(0, kPrefetchCount);
         if (kKvKeyPerThread == 1) {
             CpAsyncCommit();
@@ -322,11 +354,16 @@ struct DecoderMultiHeadAttentionKernel {
         for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
             PRAGMA_UNROLL
             for (int si = 0; si < kKvKeyPerThread; ++si) {
+                const int next = (si + 1) % 2;
                 // smem -> rmem for next iter
-                iter_K.Load(state.frag_Kv_buf[(si + 1) % 2]);
+                iter_K.Load(state.frag_Kv_tmp_buf[next]);
+                PRAGMA_UNROLL
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_K_buf[next][vi] = conv_k_(state.frag_Kv_tmp_buf[next][vi]);
+                }
 
                 // current iter's K fragment
-                auto& frag_K = state.frag_Kv_buf[si % 2];
+                auto& frag_K = state.frag_K_buf[si % 2];
 
                 const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
 
@@ -375,7 +412,7 @@ struct DecoderMultiHeadAttentionKernel {
             // handle special case
             if (kKvKeyPerThread == 1) {
                 for (int vi = 0; vi < kKvVecPerThread; ++vi) {
-                    state.frag_Kv_buf[0][vi] = state.frag_Kv_buf[1][vi];
+                    state.frag_K_buf[0][vi] = state.frag_K_buf[1][vi];
                 }
             }
         }
@@ -495,7 +532,12 @@ struct DecoderMultiHeadAttentionKernel {
         PrefetchKvCache(iter_V);
         CpAsyncWait();
 
-        iter_V.Load(state.frag_Kv_buf[0]);
+        iter_V.Load(state.frag_Kv_tmp_buf[0]);
+        PRAGMA_UNROLL
+        for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+            state.frag_V_buf[0][vi] = conv_v_(state.frag_Kv_tmp_buf[0][vi]);
+        }
+
         iter_V.PrefetchBatch(0, kPrefetchCount);
         if (kKvKeyPerThread == 1) {
             CpAsyncCommit();
@@ -508,8 +550,13 @@ struct DecoderMultiHeadAttentionKernel {
         for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
             PRAGMA_UNROLL
             for (int si = 0; si < kKvKeyPerThread; ++si) {
+                const int next = (si + 1) % 2;
                 // Load value cache for next warp iter
-                iter_V.Load(state.frag_Kv_buf[(si + 1) % 2]);
+                iter_V.Load(state.frag_Kv_tmp_buf[next]);
+                PRAGMA_UNROLL
+                for (int vi = 0; vi < kKvVecPerThread; ++vi) {
+                    state.frag_V_buf[next][vi] = conv_v_(state.frag_Kv_tmp_buf[next][vi]);
+                }
 
                 // Load Pr for next warp iter
                 // PRAGMA_UNROLL
@@ -517,7 +564,7 @@ struct DecoderMultiHeadAttentionKernel {
                 //     frag_Pr_buf[(si + 1) % 2][qi] = smem_P_[qi * kSliceLen + (ti + MapKv::kWarpAccessS)];
                 // }
 
-                auto& frag_V = state.frag_Kv_buf[si % 2];
+                auto& frag_V = state.frag_V_buf[si % 2];
                 // auto& frag_P = frag_Pr_buf[si % 2];
 
                 const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
@@ -556,7 +603,7 @@ struct DecoderMultiHeadAttentionKernel {
             // handle special case
             if (kKvKeyPerThread == 1) {
                 for (int vi = 0; vi < kKvVecPerThread; ++vi) {
-                    state.frag_Kv_buf[0][vi] = state.frag_Kv_buf[1][vi];
+                    state.frag_V_buf[0][vi] = state.frag_V_buf[1][vi];
                 }
                 // PRAGMA_UNROLL
                 // for (int qi = 0; qi < kHeadPerCta; ++qi) {
@@ -590,14 +637,14 @@ struct DecoderMultiHeadAttentionKernel {
                 PRAGMA_UNROLL
                 for (int vi = 0; vi < kKvVecPerThread; ++vi) {
                     if (offset.y == gi) {
-                        // ! 2-way bank conflict
+                        // bank conflict
                         auto& smem_O = (VecKvFloat&)smem_O_[qi * kMaxHeadDim + offset.x + vi * MapKv::kDeltaC];
                         using namespace ops;
                         auto tmp_O = smem_O;
                         if (offset.y == 0) {
                             tmp_O = tmp_O * exp_M_diff[qi];
                         }
-                        // ! 2-way bank conflict
+                        // bank conflict
                         smem_O = tmp_O + frag_O[qi][vi];
                     }
                 }
@@ -639,7 +686,7 @@ struct DecoderMultiHeadAttentionKernel {
     {
         if constexpr (0) {
             for (int i = threadIdx.x; i < kStages * IterKv::kSizePerTile; i += blockDim.x) {
-                smem_Kv_[i] = Dtype(0);
+                smem_Kv_[i] = T(0);
             }
             __syncthreads();
         }
@@ -691,7 +738,7 @@ struct DecoderMultiHeadAttentionKernel {
             VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
 
             Store(&params_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
-                  cast<Dtype>(frag_O));
+                  cast<T>(frag_O));
         }
     }
 };
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
index deb2939488..0e5158283f 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -70,8 +70,8 @@ struct Iterator {
 
     int head_idx_;
 
-    const T* src_;
-    T*       smem_;
+    const T* __restrict__ src_;
+    T* __restrict__ smem_;
 
     int smem_read_offset_;
 
@@ -250,11 +250,11 @@ struct Iterator {
     }
 #endif
 
-    static __device__ void CpAsync(T* dst, const T* src, bool mask)
+    static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask)
     {
         const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);
         constexpr int cp_size      = sizeof(AccessType);
-        static_assert(cp_size == 16);
+        // static_assert(cp_size == 16);
         asm volatile("{\n"
                      "  .reg .pred p;\n"
                      "  setp.ne.b32 p, %0, 0;\n"
@@ -265,7 +265,7 @@ struct Iterator {
                      "n"(cp_size));
     }
 
-    static __device__ void Copy(T* dst, const T* src, bool mask)
+    static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
     {
         if (mask) {
             Ldg(*(AccessType*)dst, src);
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
index 18a0754260..64fcb26ce4 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
@@ -1,5 +1,7 @@
 #include "../gemm_s_f16/common.h"
 // #include "cute/tensor.hpp"
+#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
+#include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/dbg.h"
 #include <cuda_fp16.h>
 #include <type_traits>
@@ -8,9 +10,14 @@ namespace turbomind {
 
 // [S/x, H, x, D] <-> [S/y, H, y, D]
 
-template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
-__inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptrs,
-                                            T** __restrict__ dst_block_ptrs,
+template<typename Tin,
+         typename Tout,
+         typename SrcBlockLen,
+         typename DstBlockLen,
+         typename HeadDim,
+         typename Transform = ConvertKvCache<Tin, Tout>>
+__inline__ __device__ void ConvertBlockSize(const Tin** __restrict__ src_block_ptrs,
+                                            Tout** __restrict__ dst_block_ptrs,
                                             const int* __restrict__ src_cu_block_cnts,
                                             const int* __restrict__ dst_cu_block_cnts,
                                             const int* __restrict__ seq_lens,
@@ -18,16 +25,18 @@ __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptr
                                             int         dst_offset,
                                             SrcBlockLen src_block_len,
                                             DstBlockLen dst_block_len,
-                                            HeadDim     head_dim)
+                                            HeadDim     head_dim,
+                                            Transform   transform = {1.f, 0.f})
 {
-    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+    constexpr int kVecSize = sizeof(uint4) / std::max(sizeof(Tin), sizeof(Tout));
 
     const int hi = blockIdx.y;
     const int bi = blockIdx.z;
 
     const int idx = blockIdx.x * blockDim.x + threadIdx.x;
-    const int di  = idx * kVecSize % head_dim;
-    const int si  = idx * kVecSize / head_dim;
+    /// TODO: use cutlass fast div/mod
+    const int di = idx * kVecSize % head_dim;
+    const int si = idx * kVecSize / head_dim;
 
     if (si >= seq_lens[bi]) {
         return;
@@ -43,12 +52,18 @@ __inline__ __device__ void ConvertBlockSize(const T** __restrict__ src_block_ptr
 
     // printf("%d %d\n", src_block_index, dst_block_index);
 
-    const T* __restrict__ src_block = src_block_ptrs[src_block_index];
-    T* __restrict__ dst_block       = dst_block_ptrs[dst_block_index];
+    const Tin* __restrict__ src_block = src_block_ptrs[src_block_index];
+    Tout* __restrict__ dst_block      = dst_block_ptrs[dst_block_index];
+
+    // uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
 
-    uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
+    Array<Tin, kVecSize> src_vec;
+    Ldg(src_vec, src_block + src_block_offset);
 
-    *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
+    Array<Tout, kVecSize> dst_vec = transform(src_vec);
+    Store(dst_block + dst_block_offset, dst_vec);
+
+    // *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
 }
 
 template<typename T>
@@ -276,11 +291,10 @@ __global__ void KvCacheBlocksToLinearKernel(const T**   src_k_block_ptrs,
                      head_dim);
 }
 
-template<typename T>
-void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
-                                  const T**    src_v_block_ptrs,
-                                  T**          dst_k_ptrs,
-                                  T**          dst_v_ptrs,
+void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
+                                  const void** src_v_block_ptrs,
+                                  void**       dst_k_ptrs,
+                                  void**       dst_v_ptrs,
                                   const int*   src_cu_block_cnts,
                                   const int*   seq_lens,
                                   int          src_offset,
@@ -289,67 +303,178 @@ void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
                                   int          head_num,
                                   int          head_dim,
                                   int          batch_size,
+                                  int          elem_bits,
                                   cudaStream_t st)
 {
-    constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+    auto fn = [&](auto value) {
+        using T = decltype(value);
+
+        constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+        constexpr int kThreads = 256;
+
+        const dim3 blocks((dst_block_len * head_dim / kVecSize + kThreads - 1) / kThreads, head_num, batch_size);
+        const auto smem_sz = sizeof(int) * batch_size;
+
+        KvCacheBlocksToLinearKernel<<<blocks, kThreads, smem_sz, st>>>((const T**)src_k_block_ptrs,
+                                                                       (const T**)src_v_block_ptrs,
+                                                                       (T**)dst_k_ptrs,
+                                                                       (T**)dst_v_ptrs,
+                                                                       src_cu_block_cnts,
+                                                                       seq_lens,
+                                                                       src_offset,
+                                                                       src_block_len,
+                                                                       dst_block_len,
+                                                                       head_num,
+                                                                       head_dim,
+                                                                       batch_size);
+    };
 
-    constexpr int threads = 256;
-    const dim3    blocks((dst_block_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
+    switch (elem_bits) {
+        case 8:
+            fn(uint8_t{});
+            break;
+        case 16:
+            fn(uint16_t{});
+            break;
+        case 32:
+            fn(uint32_t{});
+            break;
+        default:
+            fprintf(stderr, "unsupported elem bits: %d\n", elem_bits);
+    }
+}
 
-    const auto smem_sz = sizeof(int) * batch_size;
+template<typename Tin,
+         typename Tout,
+         typename SrcBlockLen,
+         typename DstBlockLen,
+         typename HeadDim,
+         typename TransformK,
+         typename TransformV>
+__global__ void KvCacheBlocksToLinearKernel2(const Tin** src_k_block_ptrs,
+                                             const Tin** src_v_block_ptrs,
+                                             Tout**      dst_k_ptrs,
+                                             Tout**      dst_v_ptrs,
+                                             const int*  src_cu_block_cnts,
+                                             const int*  seq_lens,
+                                             int         src_offset,
+                                             SrcBlockLen src_block_len,
+                                             DstBlockLen dst_block_len,
+                                             int         head_num,
+                                             HeadDim     head_dim,
+                                             int         batch_size,
+                                             TransformK  transform_k,
+                                             TransformV  transform_v)
+{
+    extern __shared__ int dst_cu_block_cnts[];
 
-    // dbg(src_block_len, dst_block_len, head_num, head_dim, batch_size);
+    for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
+        dst_cu_block_cnts[i] = i;
+    }
 
-    auto fn = [&](auto head_dim) {
-        KvCacheBlocksToLinearKernel<<<blocks, threads, smem_sz, st>>>(src_k_block_ptrs,
-                                                                      src_v_block_ptrs,
-                                                                      dst_k_ptrs,
-                                                                      dst_v_ptrs,
-                                                                      src_cu_block_cnts,
-                                                                      seq_lens,
-                                                                      src_offset,
-                                                                      src_block_len,
-                                                                      dst_block_len,
-                                                                      head_num,
-                                                                      head_dim,
-                                                                      batch_size);
+    __syncthreads();
+
+    ConvertBlockSize(src_k_block_ptrs,
+                     dst_k_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim,
+                     transform_k);
+
+    ConvertBlockSize(src_v_block_ptrs,
+                     dst_v_ptrs,
+                     src_cu_block_cnts,
+                     dst_cu_block_cnts,
+                     seq_lens,
+                     src_offset,
+                     0,
+                     src_block_len,
+                     dst_block_len,
+                     head_dim,
+                     transform_v);
+}
+
+template<typename T>
+void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                   const void** src_v_block_ptrs,
+                                   T**          dst_k_ptrs,
+                                   T**          dst_v_ptrs,
+                                   const int*   src_cu_block_cnts,
+                                   const int*   seq_lens,
+                                   int          src_offset,
+                                   int          src_block_len,
+                                   int          dst_block_len,
+                                   int          head_num,
+                                   int          head_dim,
+                                   int          batch_size,
+                                   int          quant_policy,
+                                   const float* kv_params,
+                                   cudaStream_t st)
+{
+    auto fn = [&](auto tin) {
+        using Tin = decltype(tin);
+
+        constexpr int kVecSize = sizeof(uint4) / sizeof(T);
+        constexpr int kThreads = 256;
+
+        const dim3 blocks((dst_block_len * head_dim / kVecSize + kThreads - 1) / kThreads, head_num, batch_size);
+        const auto smem_sz = sizeof(int) * batch_size;
+
+        KvCacheBlocksToLinearKernel2<<<blocks, kThreads, smem_sz, st>>>(
+            (const Tin**)src_k_block_ptrs,
+            (const Tin**)src_v_block_ptrs,
+            (T**)dst_k_ptrs,
+            (T**)dst_v_ptrs,
+            src_cu_block_cnts,
+            seq_lens,
+            src_offset,
+            src_block_len,
+            dst_block_len,
+            head_num,
+            head_dim,
+            batch_size,
+            ConvertKvCache<Tin, T>{kv_params[0], kv_params[1]},
+            ConvertKvCache<Tin, T>{kv_params[2], kv_params[3]});
     };
 
-    switch (head_dim) {
-        case 128:
-            fn(std::integral_constant<int, 128>{});
-            break;
-        default:
-            fn(head_dim);
-    }
+    (quant_policy & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
 }
 
-template void ConvertKvCacheBlocksToLinear(const half** src_k_block_ptrs,
-                                           const half** src_v_block_ptrs,
-                                           half**       dst_k_ptrs,
-                                           half**       dst_v_ptrs,
-                                           const int*   src_cu_block_cnts,
-                                           const int*   seq_lens,
-                                           int          src_offset,
-                                           int          src_block_len,
-                                           int          dst_block_len,
-                                           int          head_num,
-                                           int          head_dim,
-                                           int          batch_size,
-                                           cudaStream_t st);
-
-template void ConvertKvCacheBlocksToLinear(const float** src_k_block_ptrs,
-                                           const float** src_v_block_ptrs,
-                                           float**       dst_k_ptrs,
-                                           float**       dst_v_ptrs,
-                                           const int*    src_cu_block_cnts,
-                                           const int*    seq_lens,
-                                           int           src_offset,
-                                           int           src_block_len,
-                                           int           dst_block_len,
-                                           int           head_num,
-                                           int           head_dim,
-                                           int           batch_size,
-                                           cudaStream_t  st);
+template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                            const void** src_v_block_ptrs,
+                                            float**      dst_k_ptrs,
+                                            float**      dst_v_ptrs,
+                                            const int*   src_cu_block_cnts,
+                                            const int*   seq_lens,
+                                            int          src_offset,
+                                            int          src_block_len,
+                                            int          dst_block_len,
+                                            int          head_num,
+                                            int          head_dim,
+                                            int          batch_size,
+                                            int          quant_policy,
+                                            const float* kv_params,
+                                            cudaStream_t st);
+
+template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                            const void** src_v_block_ptrs,
+                                            half**       dst_k_ptrs,
+                                            half**       dst_v_ptrs,
+                                            const int*   src_cu_block_cnts,
+                                            const int*   seq_lens,
+                                            int          src_offset,
+                                            int          src_block_len,
+                                            int          dst_block_len,
+                                            int          head_num,
+                                            int          head_dim,
+                                            int          batch_size,
+                                            int          quant_policy,
+                                            const float* kv_params,
+                                            cudaStream_t st);
 
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
index 798851fded..d84c991ac3 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
@@ -30,11 +30,10 @@ void ConvertBlocksToLinear(const T**    src_block_ptrs,
                            int          batch_size,
                            cudaStream_t st);
 
-template<typename T>
-void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
-                                  const T**    src_v_block_ptrs,
-                                  T**          dst_k_ptrs,
-                                  T**          dst_v_ptrs,
+void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
+                                  const void** src_v_block_ptrs,
+                                  void**       dst_k_ptrs,
+                                  void**       dst_v_ptrs,
                                   const int*   src_cu_block_cnts,
                                   const int*   seq_lens,
                                   int          src_offset,
@@ -43,6 +42,24 @@ void ConvertKvCacheBlocksToLinear(const T**    src_k_block_ptrs,
                                   int          head_num,
                                   int          head_dim,
                                   int          batch_size,
+                                  int          elem_bits,
                                   cudaStream_t st);
 
+template<typename T>
+void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
+                                   const void** src_v_block_ptrs,
+                                   T**          dst_k_ptrs,
+                                   T**          dst_v_ptrs,
+                                   const int*   src_cu_block_cnts,
+                                   const int*   seq_lens,
+                                   int          src_offset,
+                                   int          src_block_len,
+                                   int          dst_block_len,
+                                   int          head_num,
+                                   int          head_dim,
+                                   int          batch_size,
+                                   int          quant_policy,
+                                   const float* kv_params,
+                                   cudaStream_t st);
+
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index 67f2984f18..744f2fd342 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -89,10 +89,12 @@ void TestBlocks(thrust::universal_vector<half>&  linear,          // linear data
                               0);
     }
     cudaDeviceSynchronize();
-    std::cout << ">>> Compare\n";
-    Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
-    std::cout << "<<< Compare\n";
-    // std::exit(0);
+
+    if (0) {
+        std::cout << ">>> Compare\n";
+        Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
+        std::cout << "<<< Compare\n";
+    }
 
     _blocks.swap(blocks);
     _ptrs.swap(ptrs);
@@ -101,22 +103,18 @@ void TestBlocks(thrust::universal_vector<half>&  linear,          // linear data
 
 int main(int argc, char* argv[])
 {
+
     DecoderMultiHeadAttentionParams<half> params{};
 
-    // constexpr int kHeadNum = 108 * 4;
-    constexpr int kHeadNum     = 32;
+    constexpr int kHeadNum = 108;
+    // constexpr int kHeadNum     = 32;
     constexpr int kHeadDim     = 128;
-    constexpr int kBatchSize   = 1;
-    constexpr int kContextLen  = 511;
+    constexpr int kBatchSize   = 64;
+    constexpr int kContextLen  = 2047;
     constexpr int kSequenceLen = kContextLen + 1;
     constexpr int kBlockSz     = 128;
     constexpr int kTestIter    = 1;
 
-    // constexpr int kHeadNum     = 3;
-    // constexpr int kHeadDim     = 4;
-    // constexpr int kContextLen  = 7;
-    // constexpr int kSequenceLen = kContextLen + 1;
-
     RNG rng{};
 
     thrust::universal_vector<half>  output(kBatchSize * kHeadNum * kHeadDim);
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index a71d4e7b7c..acc98d942f 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -195,8 +195,8 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
     // [2, L, H, s, D]
     const size_t layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
 
-    auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
-    auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
+    auto k_cache_ptrs = output_tensors->getPtr<void*>("key_cache");
+    auto v_cache_ptrs = output_tensors->getPtr<void*>("value_cache");
 
     auto tmp_k_ptrs = output_tensors->getPtr<T*>("tmp_k");
     auto tmp_v_ptrs = output_tensors->getPtr<T*>("tmp_v");
@@ -225,19 +225,23 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                         stream_);
     sync_check_cuda_error();
 
-    ConvertKvCacheBlocksToLinear((const T**)k_cache_ptrs,
-                                 (const T**)v_cache_ptrs,
-                                 tmp_k_ptrs,
-                                 tmp_v_ptrs,
-                                 cu_block_counts,
-                                 context_length,
-                                 layer_offset,
-                                 kv_cache_block_len_,
-                                 max_seq_len,
-                                 local_kv_head_num_,
-                                 size_per_head_,
-                                 batch_size,
-                                 stream_);
+    const int kv_cache_elem_bits = quant_policy_ & QuantPolicy::kCacheKVInt8 ? 8 : sizeof(T) * 8;
+
+    ConvertKvCacheBlocksToLinear2((const void**)k_cache_ptrs,
+                                  (const void**)v_cache_ptrs,
+                                  (T**)tmp_k_ptrs,
+                                  (T**)tmp_v_ptrs,
+                                  cu_block_counts,
+                                  context_length,
+                                  layer_offset,
+                                  kv_cache_block_len_,
+                                  max_seq_len,
+                                  local_kv_head_num_,
+                                  size_per_head_,
+                                  batch_size,
+                                  quant_policy_,
+                                  weights->past_kv_scale.data(),
+                                  stream_);
     sync_check_cuda_error();
 
     // dbg(kv_cache_block_len_, max_seq_len, local_kv_head_num_, size_per_head_, batch_size);
@@ -378,7 +382,7 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T**          key_c
                            local_head_num_,
                            head_n_rep_,
                            stream_,
-                           quant,
+                           0,  // dequant handled in block->linear conversion
                            kv_scale);
     sync_check_cuda_error();
 
diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
index 0f65eb1e12..275c89ddff 100644
--- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
+++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
@@ -302,7 +302,7 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
         self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path);
     }
     else {
-        self_attn_weights.past_kv_scale = {};
+        self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f};
     }
 }
 
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index ad4d4fafdc..8fd3eb3ede 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -134,6 +134,9 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.rotary_embedding_dim  = size_per_head_;
     params.rotary_embedding_base = 10000.f;
 
+    params.quant_policy = quant_policy_;
+    std::copy(weights->past_kv_scale.begin(), weights->past_kv_scale.end(), std::begin(params.kv_quant_params));
+
     DispatchDecoderMultiheadAttention<T>(params);
 
     linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu
index fe5dc2f44d..1f334b40b3 100644
--- a/src/turbomind/models/llama/llama_kernels.cu
+++ b/src/turbomind/models/llama/llama_kernels.cu
@@ -1,11 +1,15 @@
 // Copyright (c) OpenMMLab. All rights reserved.
 
 #include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
+#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
+#include "src/turbomind/kernels/gemm_s_f16/common.h"
 #include "src/turbomind/kernels/reduce_kernel_utils.cuh"
 #include "src/turbomind/macro.h"
 #include "src/turbomind/models/llama/llama_kernels.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/cuda_type_utils.cuh"
+#include "src/turbomind/utils/logger.h"
+#include <type_traits>
 
 namespace turbomind {
 
@@ -199,62 +203,97 @@ void invokeCreateCausalMasks(
 template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t);
 template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t);
 
-template<typename T>
-__global__ void extend_kv_cache(T**          k_dst_ptrs,
-                                T**          v_dst_ptrs,
-                                const T*     k_src,
-                                const T*     v_src,
-                                const int*   cu_block_counts,
-                                const int*   query_length,
-                                const int*   context_length,
-                                const int    block_length,
-                                const size_t dst_layer_offset,
-                                const int    max_q_len,
-                                const int    head_num,
-                                const int    head_dim)
-{
-    const int batch_id     = blockIdx.y;
-    const int query_len    = query_length[batch_id];
-    const int history_len  = context_length[batch_id] - query_len;
-    const int cu_block_cnt = cu_block_counts[batch_id];
-
-    const int     head_id = blockIdx.z;
-    constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8;
-
-    const int size_per_head_div_x = head_dim / X_ELEMS;
-    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
-    const int head_size_id        = idx % size_per_head_div_x;
-    const int seq_len_id          = idx / size_per_head_div_x;
-
-    const int cache_block_index  = (seq_len_id + history_len) / block_length;
-    const int cache_block_offset = (seq_len_id + history_len) % block_length;
-
-    // x dim is now handled by uint4 type
-    const auto k_val_src = reinterpret_cast<const uint4*>(k_src);
-    const auto v_val_src = reinterpret_cast<const uint4*>(v_src);
-
-    const auto k_val_dst = (uint4*)((k_dst_ptrs + cu_block_cnt)[cache_block_index] + dst_layer_offset);
-    const auto v_val_dst = (uint4*)((v_dst_ptrs + cu_block_cnt)[cache_block_index] + dst_layer_offset);
-
-    if (seq_len_id < query_len) {
-        // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
-        const int64_t dst_idx = head_id * block_length * size_per_head_div_x +  // H
-                                cache_block_offset * size_per_head_div_x +      // s + offset
-                                head_size_id;                                   // D/x
-
-        const int64_t src_idx = batch_id * head_num * max_q_len * size_per_head_div_x +  // B
-                                head_id * max_q_len * size_per_head_div_x +              // H
-                                seq_len_id * size_per_head_div_x +                       // s
-                                head_size_id;                                            // D/x
-
-        k_val_dst[dst_idx] = k_val_src[src_idx];
-        v_val_dst[dst_idx] = v_val_src[src_idx];
+template<typename Ti, typename To>
+struct ExtendKvCache {
+
+    static constexpr int MaxElemSize = std::max(sizeof(Ti), sizeof(To));
+    static constexpr int X_ELEMS     = 16 / MaxElemSize;
+
+    using Vi = Array<Ti, X_ELEMS>;
+    using Vo = Array<To, X_ELEMS>;
+
+    using Transform = ConvertKvCache<Ti, To>;
+
+    struct Params {
+        To**       k_dst_ptrs;
+        To**       v_dst_ptrs;
+        const Ti*  k_src;
+        const Ti*  v_src;
+        const int* cu_block_counts;
+        const int* query_length;
+        const int* context_length;
+        int        block_length;
+        size_t     dst_layer_offset;
+        int        max_q_len;
+        int        head_num;
+        int        head_dim;
+        Transform  transform_k;
+        Transform  transform_v;
+    };
+
+    __device__ void operator()(const Params& params) const
+    {
+        const int batch_id = blockIdx.y;
+
+        const int query_len    = params.query_length[batch_id];
+        const int history_len  = params.context_length[batch_id] - query_len;
+        const int cu_block_cnt = params.cu_block_counts[batch_id];
+
+        const int head_id = blockIdx.z;
+
+        const int size_per_head_div_x = params.head_dim / X_ELEMS;
+        const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
+        const int head_size_id        = idx % size_per_head_div_x;
+        const int seq_len_id          = idx / size_per_head_div_x;
+
+        const int cache_block_index  = (seq_len_id + history_len) / params.block_length;
+        const int cache_block_offset = (seq_len_id + history_len) % params.block_length;
+
+        const auto k_val_src = params.k_src;
+        const auto v_val_src = params.v_src;
+
+        const auto k_val_dst = (params.k_dst_ptrs + cu_block_cnt)[cache_block_index] + params.dst_layer_offset;
+        const auto v_val_dst = (params.v_dst_ptrs + cu_block_cnt)[cache_block_index] + params.dst_layer_offset;
+
+        if (seq_len_id < query_len) {
+            // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
+            const int64_t dst_idx = head_id * params.block_length * size_per_head_div_x +  // H
+                                    cache_block_offset * size_per_head_div_x +             // s + offset
+                                    head_size_id;                                          // D/x
+
+            const int64_t src_idx = batch_id * params.head_num * params.max_q_len * size_per_head_div_x +  // B
+                                    head_id * params.max_q_len * size_per_head_div_x +                     // H
+                                    seq_len_id * size_per_head_div_x +                                     // s
+                                    head_size_id;                                                          // D/x
+
+            Vi k_vi;
+            Vi v_vi;
+
+            Ldg(k_vi, k_val_src + src_idx * X_ELEMS);
+            Ldg(v_vi, v_val_src + src_idx * X_ELEMS);
+
+            Vo k_vo = params.transform_k(k_vi);
+            Vo v_vo = params.transform_v(v_vi);
+
+            Store(k_val_dst + dst_idx * X_ELEMS, k_vo);
+            Store(v_val_dst + dst_idx * X_ELEMS, v_vo);
+        }
     }
-}
+};
+
+namespace {
+
+template<class Kernel, class Params>
+__global__ void KernelWrapper(Params params)
+{
+    Kernel{}(params);
+};
+
+}  // namespace
 
 template<typename T>
-void invokeExtendKVCache(T**          k_dst_ptrs,
-                         T**          v_dst_ptrs,
+void invokeExtendKVCache(void**       k_dst_ptrs,
+                         void**       v_dst_ptrs,
                          const T*     k_src,
                          const T*     v_src,
                          const int*   cu_block_counts,
@@ -267,32 +306,40 @@ void invokeExtendKVCache(T**          k_dst_ptrs,
                          int          head_dim,
                          int          head_num,
                          int          quant,
-                         const float* kv_scale,
+                         const float* kv_params,
                          cudaStream_t stream)
 {
     constexpr int block_sz = 128;
-    constexpr int x        = (sizeof(T) == 4) ? 4 : 8;
-
-    dim3 grid((max_q_len * head_dim / x + block_sz - 1) / block_sz, batch_size, head_num);
-
-    FT_CHECK(quant == 0);
-
-    extend_kv_cache<<<grid, block_sz, 0, stream>>>(k_dst_ptrs,
-                                                   v_dst_ptrs,
-                                                   k_src,
-                                                   v_src,
-                                                   cu_block_counts,
-                                                   query_length,
-                                                   context_length,
-                                                   block_length,
-                                                   dst_layer_offset,
-                                                   max_q_len,
-                                                   head_num,
-                                                   head_dim);
+
+    auto fn = [&](auto value) {
+        using Tout   = decltype(value);
+        using Kernel = ExtendKvCache<T, Tout>;
+
+        dim3 grid((max_q_len * head_dim / Kernel::X_ELEMS + block_sz - 1) / block_sz, batch_size, head_num);
+
+        typename Kernel::Params params{(Tout**)k_dst_ptrs,
+                                       (Tout**)v_dst_ptrs,
+                                       k_src,
+                                       v_src,
+                                       cu_block_counts,
+                                       query_length,
+                                       context_length,
+                                       block_length,
+                                       dst_layer_offset,
+                                       max_q_len,
+                                       head_num,
+                                       head_dim,
+                                       {kv_params[0], kv_params[1]},
+                                       {kv_params[2], kv_params[3]}};
+
+        KernelWrapper<Kernel><<<grid, block_sz, 0, stream>>>(params);
+    };
+
+    (quant & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
 }
 
-template void invokeExtendKVCache(float**      k_dst_ptrs,
-                                  float**      v_dst_ptrs,
+template void invokeExtendKVCache(void**       k_dst_ptrs,
+                                  void**       v_dst_ptrs,
                                   const float* k_src,
                                   const float* v_src,
                                   const int*   cu_block_counts,
@@ -308,8 +355,8 @@ template void invokeExtendKVCache(float**      k_dst_ptrs,
                                   const float* kv_scale,
                                   cudaStream_t stream);
 
-template void invokeExtendKVCache(half**       k_dst_ptrs,
-                                  half**       v_dst_ptrs,
+template void invokeExtendKVCache(void**       k_dst_ptrs,
+                                  void**       v_dst_ptrs,
                                   const half*  k_src,
                                   const half*  v_src,
                                   const int*   cu_block_counts,
@@ -325,96 +372,79 @@ template void invokeExtendKVCache(half**       k_dst_ptrs,
                                   const float* kv_scale,
                                   cudaStream_t stream);
 
-template<typename T>
-__global__ void transpose_value_cache(T*           v_dst,  //
-                                      const T**    v_src,
-                                      const size_t src_offset,
-                                      const int    head_num,
-                                      const int    head_n_rep,
-                                      const int    size_per_head,
-                                      const int*   seq_length,
-                                      const int    max_kv_len,
-                                      const int    max_seq_len)
-{
-    const int     batch_id = blockIdx.y;
-    const int     head_id  = blockIdx.z;
-    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;
-
-    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
-    int       size_per_head_div_x = size_per_head / X_ELEMS;
-
-    // x dim is now handled by uint4 type
-    const auto val_src = reinterpret_cast<const uint4*>(v_src[batch_id] + src_offset);
-    const auto val_dst = reinterpret_cast<uint4*>(v_dst);
-
-    const auto seq_len = seq_length[batch_id];
-
-    const int v_head_size_id = idx % size_per_head_div_x;
-    const int v_seq_len_id   = idx / size_per_head_div_x;
-
-    if (v_seq_len_id < seq_len) {
-        // [B, H, s, D/x] <- [B, H, S[:s], D/x]
-        const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len +  // H
-                                v_seq_len_id * size_per_head_div_x +                        // s
-                                v_head_size_id;                                             // D/x
-
-        const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len +  // B
-                                head_id * size_per_head_div_x * max_kv_len +              // H
-                                v_seq_len_id * size_per_head_div_x +                      // s
-                                v_head_size_id;                                           // D/x
-
-        val_dst[dst_idx] = val_src[src_idx];
-    }
-}
-
-template<typename T>
-__global__ void transpose_value_cache_int8(T*             v_dst,  //
-                                           const int8_t** v_src,
-                                           const size_t   src_offset,
-                                           const int      head_num,
-                                           const int      head_n_rep,
-                                           const int      size_per_head,
-                                           const int*     seq_length,
-                                           const int      max_kv_len,
-                                           const int      max_seq_len,
-                                           const float    v_scale,
-                                           const float    v_zp)
-{
-    const int     batch_id = blockIdx.y;
-    const int     head_id  = blockIdx.z;
-    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;
-
-    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
-    int       size_per_head_div_x = size_per_head / X_ELEMS;
-
-    // x dim is now handled by uint4 type
-    const auto val_src = reinterpret_cast<const uint2*>(v_src[batch_id] + src_offset);
-    const auto val_dst = reinterpret_cast<uint4*>(v_dst);
-
-    const auto seq_len = seq_length[batch_id];
-
-    const int v_head_size_id = idx % size_per_head_div_x;
-    const int v_seq_len_id   = idx / size_per_head_div_x;
-
-    if (v_seq_len_id < seq_len) {
-        // [B, H, s, D/x] <- [B, H, S[:s], D/x]
-        const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len +  // H
-                                v_seq_len_id * size_per_head_div_x +                        // s
-                                v_head_size_id;                                             // D/x
-
-        const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len +  // B
-                                head_id * size_per_head_div_x * max_kv_len +              // H
-                                v_seq_len_id * size_per_head_div_x +                      // s
-                                v_head_size_id;                                           // D/x
-
-        // int8x8 -> fp16x8
-        const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx);
-        auto       to_ptr   = reinterpret_cast<half4*>(val_dst + dst_idx);
-
-        // to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale, v_zp);
-        // to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale, v_zp);
+template<typename Ti, typename To>
+struct TransposeKvCache {
+    static constexpr int MaxElemSize = std::max(sizeof(Ti), sizeof(To));
+    static constexpr int X_ELEMS     = 16 / MaxElemSize;
+
+    using Vi = Array<Ti, X_ELEMS>;
+    using Vo = Array<To, X_ELEMS>;
+
+    using Transform = ConvertKvCache<Ti, To>;
+
+    struct Params {
+        To*        k_dst;
+        To*        v_dst;
+        const Ti** k_src;
+        const Ti** v_src;
+        size_t     src_offset;
+        int        head_num;
+        int        head_n_rep;
+        int        size_per_head;
+        const int* seq_length;
+        int        max_kv_len;
+        int        max_seq_len;
+        Transform  transform_k;
+        Transform  transform_v;
+        // float      k_scale;
+        // float      k_zp;
+        // float      v_scale;
+        // float      v_zp;
+    };
+
+    __device__ void operator()(const Params& params) const
+    {
+        const int batch_id = blockIdx.y;
+        const int head_id  = blockIdx.z;
+
+        const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
+        const int size_per_head_div_x = params.size_per_head / X_ELEMS;
+
+        const auto k_src = params.k_src[batch_id] + params.src_offset;
+        const auto v_src = params.v_src[batch_id] + params.src_offset;
+        const auto k_dst = params.k_dst;
+        const auto v_dst = params.v_dst;
+
+        const auto seq_len = params.seq_length[batch_id];
+
+        const int v_head_size_id = idx % size_per_head_div_x;
+        const int v_seq_len_id   = idx / size_per_head_div_x;
+
+        if (v_seq_len_id < seq_len) {
+            // [B, H, s, D/x] <- [B, H, S[:s], D/x]
+            const int64_t src_idx = head_id / params.head_n_rep * size_per_head_div_x * params.max_seq_len +  // H
+                                    v_seq_len_id * size_per_head_div_x +                                      // s
+                                    v_head_size_id;                                                           // D/x
+
+            const int64_t dst_idx = batch_id * params.head_num * size_per_head_div_x * params.max_kv_len +  // B
+                                    head_id * size_per_head_div_x * params.max_kv_len +                     // H
+                                    v_seq_len_id * size_per_head_div_x +                                    // s
+                                    v_head_size_id;                                                         // D/x
+
+            Vi k_vi;
+            Vi v_vi;
+
+            Ldg(k_vi, k_src + src_idx * X_ELEMS);
+            Ldg(v_vi, v_src + src_idx * X_ELEMS);
+
+            Vo k_vo = params.transform_k(k_vi);
+            Vo v_vo = params.transform_v(v_vi);
+
+            Store(k_dst + dst_idx * X_ELEMS, k_vo);
+            Store(v_dst + dst_idx * X_ELEMS, v_vo);
+        }
     }
-}
+};
 
 template<typename T>
 void invokeTransposeKVCache(T*           key_cache_trans,
@@ -431,59 +461,34 @@ void invokeTransposeKVCache(T*           key_cache_trans,
                             int          head_n_rep,
                             cudaStream_t stream,
                             int          quant,
-                            const float* kv_scale)
+                            const float* kv_params)
 {
     constexpr int block_sz = 128;
-    constexpr int x        = (sizeof(T) == 4) ? 4 : 8;
-
-    dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num);
-
-    if (quant & QuantPolicy::kCacheKVInt8) {
-        transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(key_cache_trans,
-                                                                  reinterpret_cast<const int8_t**>(key_cache),
-                                                                  src_offset,
-                                                                  head_num,
-                                                                  head_n_rep,
-                                                                  size_per_head,
-                                                                  key_length,
-                                                                  max_kv_len,
-                                                                  max_seq_len,
-                                                                  kv_scale[0],
-                                                                  kv_scale[1]);
-
-        transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(val_cache_trans,
-                                                                  reinterpret_cast<const int8_t**>(val_cache),
-                                                                  src_offset,
-                                                                  head_num,
-                                                                  head_n_rep,
-                                                                  size_per_head,
-                                                                  key_length,
-                                                                  max_kv_len,
-                                                                  max_seq_len,
-                                                                  kv_scale[2],
-                                                                  kv_scale[3]);
-    }
-    else {
-        transpose_value_cache<<<grid, block_sz, 0, stream>>>(key_cache_trans,
-                                                             key_cache,
-                                                             src_offset,
-                                                             head_num,
-                                                             head_n_rep,
-                                                             size_per_head,
-                                                             key_length,
-                                                             max_kv_len,
-                                                             max_seq_len);
-
-        transpose_value_cache<<<grid, block_sz, 0, stream>>>(val_cache_trans,
-                                                             val_cache,
-                                                             src_offset,
-                                                             head_num,
-                                                             head_n_rep,
-                                                             size_per_head,
-                                                             key_length,
-                                                             max_kv_len,
-                                                             max_seq_len);
-    }
+
+    auto fn = [&](auto value) {
+        using Tin    = decltype(value);
+        using Kernel = TransposeKvCache<Tin, T>;
+
+        dim3 grid((max_kv_len * size_per_head / Kernel::X_ELEMS + block_sz - 1) / block_sz, batch_size, head_num);
+
+        typename Kernel::Params params{key_cache_trans,
+                                       val_cache_trans,
+                                       (const Tin**)key_cache,
+                                       (const Tin**)val_cache,
+                                       src_offset,
+                                       head_num,
+                                       head_n_rep,
+                                       size_per_head,
+                                       key_length,
+                                       max_kv_len,
+                                       max_seq_len,
+                                       {kv_params[0], kv_params[1]},
+                                       {kv_params[2], kv_params[3]}};
+
+        KernelWrapper<Kernel><<<grid, block_sz, 0, stream>>>(params);
+    };
+
+    (quant & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
 }
 
 template void invokeTransposeKVCache(float*,
diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h
index d6226baf50..7bbcd06376 100644
--- a/src/turbomind/models/llama/llama_kernels.h
+++ b/src/turbomind/models/llama/llama_kernels.h
@@ -34,8 +34,8 @@ void invokeCreateCausalMasks(
     T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream);
 
 template<typename T>
-void invokeExtendKVCache(T**          k_dst_ptrs,
-                         T**          v_dst_ptrs,
+void invokeExtendKVCache(void**       k_dst_ptrs,
+                         void**       v_dst_ptrs,
                          const T*     k_src,
                          const T*     v_src,
                          const int*   cu_block_counts,

From b269d535e1b1983107b847b6ccafef883f98f37a Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Wed, 11 Oct 2023 10:57:21 +0000
Subject: [PATCH 13/56] tune

---
 .../decoder_multihead_attention/array_ops.h    | 18 ++++++++++++------
 .../decoder_multihead_attention.cu             |  8 +++++++-
 2 files changed, 19 insertions(+), 7 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index 99d6135b82..37946d39d9 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -413,7 +413,10 @@ struct ConvertKvCache<int8_t, float> {
     float scale_;
     float zero_;
 
-    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
+    {
+        zero_ = zero_ - 32896.f * scale_;
+    }
 
     template<int N>
     inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<float, N>
@@ -425,8 +428,8 @@ struct ConvertKvCache<int8_t, float> {
             vec       = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
             PRAGMA_UNROLL
             for (int j = 0; j < 4; ++j) {
-                // vec[j] = vec[j] * scale + zero;
-                vec[j] = vec[j] * scale_ + (zero_ - 32896.f * scale_);
+                vec[j] = vec[j] * scale_ + zero_;
+                // vec[j] = vec[j] * scale_ + (zero_ - 32896.f * scale_);
             }
         }
         return vo;
@@ -439,7 +442,10 @@ struct ConvertKvCache<int8_t, half> {
     float scale_;
     float zero_;
 
-    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
+    __device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
+    {
+        zero_ = zero_ - 32896.f * scale_;
+    }
 
     template<int N>
     inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<half, N>
@@ -451,8 +457,8 @@ struct ConvertKvCache<int8_t, half> {
             auto  tmp = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
             PRAGMA_UNROLL
             for (int j = 0; j < 4; ++j) {
-                // vec[j] = half(tmp[j] * scale + zero);
-                vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
+                vec[j] = half(tmp[j] * scale_ + zero_);
+                // vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
             }
         }
         return vo;
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index 91dc207876..2a5f641e20 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -27,9 +27,15 @@ bool Dump()
 template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
 void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
 {
-    // 2048_32x6 ~ 64k smem
+    // cpasync_2048_32x6 ~ 64k smem
     using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;
 
+    // ld_kv16_2048_32x3 ~ 34k smem
+    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>;
+
+    // ld_kv8_2048_64x3 ~ 34k smem
+    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 2048, 3>;
+
     [[maybe_unused]] static const bool init = Dump<MHAType>();
 
     dim3 block(MHAType::kWarpCount * WARP_SIZE);

From d7110e44946cc8f7fb8823e7d095a769cefc118f Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Wed, 11 Oct 2023 13:40:42 +0000
Subject: [PATCH 14/56] async stream cb

---
 src/turbomind/models/llama/LlamaBatch.cc | 76 ++++++++++++++++++++++--
 src/turbomind/models/llama/LlamaBatch.h  | 20 +++++++
 2 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index e828869eae..649c5901c4 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -16,6 +16,7 @@
 #include <cstdint>
 #include <iomanip>
 #include <math.h>
+#include <mutex>
 #include <sstream>
 #include <unordered_map>
 
@@ -137,6 +138,12 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
         }
         // clear output buffers (prevent leaking conversations) if request is successful
         if (ec == 0) {
+
+            if (rank_ == 0) {
+                std::unique_lock lock{output_mutex_};
+                output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
+            }
+
             auto& output_ids      = r->outputs[rank_].at("output_ids");
             auto& sequence_length = r->outputs[rank_].at("sequence_length");
             Clear(output_ids.getPtr<int>(), output_ids.shape.at(2));
@@ -1036,14 +1043,41 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
     Copy(finished_buf_, batch_size, state_->h_finished);
     Copy(sequence_lengths_, batch_size, state_->h_context_length);
 
-    SetOutputTensors(g);
+    if (1) {
+        std::unique_lock<std::mutex> lock;
+        if (rank_ == 0) {
+            // wait for previous output operations
+            lock = std::unique_lock{output_mutex_};
+            output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
+        }
 
-    check_cuda_error(cudaStreamSynchronize(stream_));
+        SetOutputTensors(g);
+        check_cuda_error(cudaStreamSynchronize(stream_));
 
-    for (int i = 0; i < batch_size; ++i) {
-        FT_CHECK(state_->requests[i] != nullptr);
-        if (state_->requests[i]->stream_cb && rank_ == 0) {
-            state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
+        if (rank_ == 0) {
+            // enqueue new output requests
+            for (int i = 0; i < batch_size; ++i) {
+                FT_CHECK(state_->requests[i] != nullptr);
+                if (state_->requests[i]->stream_cb) {
+                    output_reqs_.push_back(state_->requests[i]);
+                }
+            }
+            lock.unlock();
+            // notify output thread when we do have stream cbs to call
+            if (!output_reqs_.empty()) {
+                output_cv_.notify_one();
+            }
+        }
+    }
+    else {
+        SetOutputTensors(g);
+        check_cuda_error(cudaStreamSynchronize(stream_));
+
+        for (int i = 0; i < batch_size; ++i) {
+            FT_CHECK(state_->requests[i] != nullptr);
+            if (state_->requests[i]->stream_cb && rank_ == 0) {
+                state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
+            }
         }
     }
 
@@ -1235,6 +1269,36 @@ void LlamaBatch<T>::Start()
     int device_id = -1;
     check_cuda_error(cudaGetDevice(&device_id));
     internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
+    if (rank_ == 0) {
+        output_thread_ = std::thread(&LlamaBatch::OutputThreadEntry, this);
+    }
+}
+
+template<typename T>
+void LlamaBatch<T>::OutputThreadEntry()
+{
+    while (true) {
+        {
+            // wait for requests with stream cbs
+            std::unique_lock lock(output_mutex_);
+            output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; });
+
+            // invoke stream cbs
+            for (const auto& r : output_reqs_) {
+                r->stream_cb(&r->outputs[rank_].get());
+            }
+            output_reqs_.clear();
+
+            // stop requested
+            if (output_stop_token_) {
+                TM_LOG_INFO("[OutputThreadEntry] stop requested.");
+                break;
+            }
+        }
+        FT_CHECK(output_reqs_.empty());
+        // notify infer thread 0
+        output_cv_.notify_one();
+    }
 }
 
 template class LlamaBatch<half>;
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 8e1c1f5d36..6419683937 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -9,6 +9,8 @@
 #include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/cublasMMWrapper.h"
+#include <condition_variable>
+#include <mutex>
 
 namespace turbomind {
 
@@ -86,6 +88,15 @@ class LlamaBatch {
 
         internal_thread_.join();
 
+        if (output_thread_.joinable()) {
+            {
+                std::lock_guard lock{output_mutex_};
+                output_stop_token_ = true;
+            }
+            output_cv_.notify_one();
+            output_thread_.join();
+        }
+
         FreeBuffer();
     }
 
@@ -94,6 +105,8 @@ class LlamaBatch {
 private:
     void InternalThreadEntry(int device_id);
 
+    void OutputThreadEntry();
+
     void UpdateSequenceStates(BatchState& state, int index);
 
     void CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst);
@@ -213,6 +226,13 @@ class LlamaBatch {
     IAllocator*      allocator_{};
 
     std::thread internal_thread_;
+
+    // async stream callback utils
+    std::thread             output_thread_;
+    std::mutex              output_mutex_;
+    std::condition_variable output_cv_;
+    Requests                output_reqs_;
+    bool                    output_stop_token_{false};
 };
 
 }  // namespace turbomind

From 498e9a3aeacb5dc7f8fccf7501f8d3365c4b7e82 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 13 Oct 2023 04:32:07 +0000
Subject: [PATCH 15/56] nvtx

---
 .../decoder_multihead_attention.cu            |  2 +-
 .../decoder_multihead_attention_params.h      |  3 +
 src/turbomind/models/llama/LlamaBatch.cc      |  4 ++
 src/turbomind/models/llama/LlamaDecoder.cc    | 70 +++++++++++--------
 .../llama/LlamaDecoderSelfAttentionLayer.cc   | 19 +++--
 src/turbomind/models/llama/LlamaFfnLayer.cc   | 26 ++++---
 src/turbomind/models/llama/LlamaV2.cc         |  3 +
 src/turbomind/models/llama/llama_utils.h      | 21 ++++--
 8 files changed, 100 insertions(+), 48 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index 2a5f641e20..cd5b5a908f 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -47,7 +47,7 @@ void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& p
     cudaFuncSetAttribute(
         decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
 
-    decoder_multihead_attention<MHAType><<<grid, block, kDynSmemSize>>>(params);
+    decoder_multihead_attention<MHAType><<<grid, block, kDynSmemSize, params.stream>>>(params);
 }
 
 template<typename T>
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index 199d0f4667..de84526ec0 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -1,4 +1,5 @@
 #pragma once
+#include <cuda_runtime.h>
 
 namespace turbomind {
 
@@ -54,6 +55,8 @@ struct DecoderMultiHeadAttentionParams {
 
     int   quant_policy;
     float kv_quant_params[4];
+
+    cudaStream_t stream;
 };
 
 }  // namespace turbomind
\ No newline at end of file
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 3e25fd9191..64e6557862 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -245,6 +245,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
 template<typename T>
 bool LlamaBatch<T>::Initialize()
 {
+    NvtxScope                                scope("initialize");
     std::vector<const Sequence*>             sequences;
     std::vector<Sequence::Status>            status;
     std::vector<uint64_t>                    priorities;
@@ -755,6 +756,7 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
 template<typename T>
 bool LlamaBatch<T>::Generate(GenerationState& g)
 {
+    NvtxScope scope("Generate");
     const int batch_size = state_->active_size;
 
     constexpr int kLogInterval = 10;
@@ -1216,6 +1218,8 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
             }
         }
 
+        NvtxScope scope("mainloop");
+
         // wait while rank-0 is dequeueing
         shared_state->barrier->wait();
 
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index 926b442429..bb76bd205e 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -118,6 +118,7 @@ void LlamaDecoder<T>::forwardSelfAttn(const LlamaDecoder::Session&
                                       const std::unordered_map<std::string, Tensor>* input_tensors,
                                       size_t                                         layer)
 {
+    NvtxScope scope("self_attn");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     TensorMap self_attention_input_tensors(*input_tensors);
     self_attention_input_tensors.insert("input_query",
@@ -180,6 +181,8 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
 
     // for the shape of key cache, refer to decoder_masked_multihead_attention_template.hpp
 
+    NvtxScope forward_scope("decoder_forward");
+
     Session sess{};
     sess.batch_size = input_tensors->at("decoder_input").shape[0];
     sess.weights    = decoder_layer_weights;
@@ -200,43 +203,54 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
 
     ////////////////////////////////////////////
     /// RMSNorm
-    invokeRootMeanSquareNorm(decoder_output,
-                             decoder_input,
-                             decoder_layer_weights->at(0)->self_attn_norm_weights,
-                             rmsnorm_eps_,
-                             sess.batch_size,
-                             hidden_units_,
-                             stream_);
-    sync_check_cuda_error();
+    {
+        NvtxScope rms_norm_scope("rms_norm_0");
+        invokeRootMeanSquareNorm(decoder_output,
+                                 decoder_input,
+                                 decoder_layer_weights->at(0)->self_attn_norm_weights,
+                                 rmsnorm_eps_,
+                                 sess.batch_size,
+                                 hidden_units_,
+                                 stream_);
+        sync_check_cuda_error();
+    }
 
     for (size_t layer = 0; layer < num_layer_; ++layer) {
+        NvtxScope layer_scope("decode_layer");
+
         // output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
         forwardSelfAttn(sess, decoder_output, input_tensors, layer);
 
-        invokeFusedAddBiasResidualRMSNorm(decoder_input,
-                                          decoder_output,
-                                          decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
-                                          decoder_layer_weights->at(layer)->ffn_norm_weights,
-                                          rmsnorm_eps_,
-                                          sess.batch_size,
-                                          hidden_units_,
-                                          stream_);
-        sync_check_cuda_error();
+        {
+            NvtxScope rms_norm_scope("rms_norm_1");
+            invokeFusedAddBiasResidualRMSNorm(decoder_input,
+                                              decoder_output,
+                                              decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
+                                              decoder_layer_weights->at(layer)->ffn_norm_weights,
+                                              rmsnorm_eps_,
+                                              sess.batch_size,
+                                              hidden_units_,
+                                              stream_);
+            sync_check_cuda_error();
+        }
 
         // decoder_layer_output_ = ffn(decoder_normed_input_)
         forwardFfn(sess, decoder_output, layer);
 
-        auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
-                                                     input_tensors->at("output_norm_weight").getPtr<T>();
-        invokeFusedAddBiasResidualRMSNorm(decoder_input,  //
-                                          decoder_output,
-                                          decoder_layer_weights->at(layer)->ffn_weights.output.bias,
-                                          scale_weight,
-                                          rmsnorm_eps_,
-                                          sess.batch_size,
-                                          hidden_units_,
-                                          stream_);
-        sync_check_cuda_error();
+        {
+            NvtxScope rms_norm_scope("rms_norm_2");
+            auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
+                                                         input_tensors->at("output_norm_weight").getPtr<T>();
+            invokeFusedAddBiasResidualRMSNorm(decoder_input,  //
+                                              decoder_output,
+                                              decoder_layer_weights->at(layer)->ffn_weights.output.bias,
+                                              scale_weight,
+                                              rmsnorm_eps_,
+                                              sess.batch_size,
+                                              hidden_units_,
+                                              stream_);
+            sync_check_cuda_error();
+        }
     }
 
     if (is_free_buffer_after_forward_) {
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index 8fd3eb3ede..fe7fdab7ad 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -99,9 +99,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     allocateBuffer(batch_size);
 
-    PUSH_RANGE("qkv_gemm");
-    linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
-    POP_RANGE;
+    {
+        NvtxScope scope("qkv_gemm");
+        linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
+    }
 
     const auto layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
     // const int  memory_len   = max_seq_len;
@@ -134,12 +135,20 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.rotary_embedding_dim  = size_per_head_;
     params.rotary_embedding_base = 10000.f;
 
+    params.stream = stream_;
+
     params.quant_policy = quant_policy_;
     std::copy(weights->past_kv_scale.begin(), weights->past_kv_scale.end(), std::begin(params.kv_quant_params));
 
-    DispatchDecoderMultiheadAttention<T>(params);
+    {
+        NvtxScope scope("decoder_multihead_attention");
+        DispatchDecoderMultiheadAttention<T>(params);
+    }
 
-    linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
+    {
+        NvtxScope scope("o_gemm");
+        linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
+    }
 
     if (tensor_para_.world_size_ > 1) {
         NcclGuard nccl_guard(tensor_para_, stream_);
diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc
index f605d8f27b..0d78dc4e80 100644
--- a/src/turbomind/models/llama/LlamaFfnLayer.cc
+++ b/src/turbomind/models/llama/LlamaFfnLayer.cc
@@ -20,6 +20,7 @@
 #include "src/turbomind/models/llama/LlamaFfnLayer.h"
 #include "src/turbomind/kernels/activation_kernels.h"
 #include "src/turbomind/models/llama/LlamaNcclGuard.h"
+#include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/nvtx_utils.h"
 // #include <glog/logging.h>
 
@@ -46,6 +47,7 @@ void LlamaFfnLayer<T>::freeBuffer()
 template<typename T>
 void LlamaFfnLayer<T>::activation(int num_token)
 {
+    NvtxScope scope("activation");
     invokeGenericActivation<SiluActivation>(gating_buf_,
                                             (const T*)nullptr,  // bias
                                             inter_buf_,
@@ -76,6 +78,8 @@ void LlamaFfnLayer<T>::forward(TensorMap*               output_tensors,
      *   \param ffn_output [token_num, hidden_dimension]
      */
 
+    NvtxScope scope("ffn");
+
     const size_t num_token = input_tensors->at("ffn_input").shape[0];
     // LOG(WARNING);
 
@@ -84,24 +88,28 @@ void LlamaFfnLayer<T>::forward(TensorMap*               output_tensors,
     const T* ffn_input_data  = input_tensors->at("ffn_input").getPtr<T>();
     T*       ffn_output_data = output_tensors->at("ffn_output").getPtr<T>();
 
-    PUSH_RANGE("ffn");
-
     if (weights->fused_gating_intermediate.kernel) {
+        NvtxScope scope("fused_silu_ffn");
         linear_.forward(
             gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
     }
     else {
-        // w1(x)
-        linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
-        // w3(x)
-        linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
+        {  // w1(x)
+            NvtxScope scope("w1");
+            linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
+        }
+        {  // w3(x)
+            NvtxScope scope("w3");
+            linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
+        }
         // silu(w1(x)) * w3(x)
         activation(num_token);
     }
 
-    // w2(x)
-    linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
-    POP_RANGE;
+    {  // w2(x)
+        NvtxScope scope("w2");
+        linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
+    }
 
     if (tensor_para_.world_size_ > 1) {
         NcclGuard nccl_guard(tensor_para_, stream_);
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index e70e89f378..3e9d34072c 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -187,6 +187,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
 template<typename T>
 void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step)
 {
+    NvtxScope scope("embeddingLookup");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     // ! This kernel can't be used in context decoding
     invokeEmbeddingLookupPosEncodingPadCount(embeddings,
@@ -318,6 +319,7 @@ void LlamaV2<T>::decoderForward(T*          decoder_output,
 template<typename T>
 void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size)
 {
+    NvtxScope scope("postDecodeEmbedding");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     cudaDataType_t data_type = getCudaDataType<T>();
     float          alpha     = 1.f;
@@ -395,6 +397,7 @@ void LlamaV2<T>::dynamicDecode(int*            token_ids,
                                size_t          token_ids_len,
                                size_t          batch_size)
 {
+    NvtxScope scope("dynamicDecode");
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
     int local_batch_size = (int)batch_size;
 
diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h
index 05c10be80b..60942560d3 100644
--- a/src/turbomind/models/llama/llama_utils.h
+++ b/src/turbomind/models/llama/llama_utils.h
@@ -2,6 +2,7 @@
 
 #pragma once
 #include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/nvtx_utils.h"
 #include <cuda_runtime.h>
 #include <sstream>
 #include <string>
@@ -9,8 +10,7 @@
 
 namespace turbomind {
 
-enum QuantPolicy
-{
+enum QuantPolicy {
     kNone = 0x00,
     // reserve 0x01 and 0x02 for backward compatibility
     kReserve1 = 0x01,
@@ -19,8 +19,7 @@ enum QuantPolicy
     kCacheKVInt8 = 0x04,
 };
 
-enum CmpMode
-{
+enum CmpMode {
     kCmpNone,
     kCmpRead,
     kCmpWrite,
@@ -52,7 +51,7 @@ inline std::string to_string(std::string x)
 template<typename... Args>
 std::string Concat(std::string key, Args&&... args)
 {
-    std::vector<std::string> args_str{detail::to_string((Args &&) args)...};
+    std::vector<std::string> args_str{detail::to_string((Args&&)args)...};
     for (const auto& s : args_str) {
         key.append("_");
         key.append(s);
@@ -66,4 +65,16 @@ size_t curandStateGetSize();
 
 bool isDebug();
 
+struct NvtxScope {
+    explicit NvtxScope(const std::string& name)
+    {
+        PUSH_RANGE(name.c_str());
+    }
+
+    ~NvtxScope()
+    {
+        POP_RANGE;
+    }
+};
+
 }  // namespace turbomind

From 6d47a7a1f2702210291dc874b372fe030bc8b3b9 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 13 Oct 2023 07:14:43 +0000
Subject: [PATCH 16/56] config parsing

---
 src/turbomind/models/llama/BlockManager.cc    |  2 +-
 src/turbomind/models/llama/LlamaV2.cc         |  7 +--
 src/turbomind/models/llama/LlamaV2.h          |  3 +-
 src/turbomind/models/llama/SequenceManager.cc | 12 ++---
 src/turbomind/models/llama/SequenceManager.h  |  4 +-
 .../triton_backend/llama/LlamaTritonModel.cc  | 44 +++++++++++--------
 .../triton_backend/llama/LlamaTritonModel.h   |  3 +-
 7 files changed, 43 insertions(+), 32 deletions(-)

diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index bda8cb98cf..37384ea61e 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -82,7 +82,7 @@ size_t BlockManager::GetBlockCount(size_t block_size, double ratio)
     size_t free{};
     size_t total{};
     check_cuda_error(cudaMemGetInfo(&free, &total));
-    return static_cast<size_t>(free * ratio / block_size);
+    return static_cast<size_t>(total * ratio) / block_size;
 }
 
 void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index 3e9d34072c..bb0118ca44 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -55,7 +55,8 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
                     int                          step_length,
                     int                          start_id,
                     int                          end_id,
-                    int                          cache_max_entry_count,
+                    float                        cache_max_block_count,
+                    int                          cache_block_seq_len,
                     int                          cache_chunk_size,
                     int                          quant_policy,
                     bool                         use_context_fmha,
@@ -114,8 +115,8 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
     auto sequence_manager = std::make_unique<SequenceManager>(num_layer,
                                                               local_kv_head_num,
                                                               size_per_head_,
-                                                              128,
-                                                              cache_max_entry_count,
+                                                              cache_block_seq_len,
+                                                              cache_max_block_count,
                                                               cache_chunk_size,
                                                               elem_bits,
                                                               tensor_para_.rank_,
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index b2bb74d9a3..229674f599 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -66,7 +66,8 @@ class LlamaV2 {
             int                          step_length,
             int                          start_id,
             int                          end_id,
-            int                          cache_max_entry_count,
+            float                        cache_max_block_count,
+            int                          cache_block_seq_len,
             int                          cache_chunk_size,
             int                          quant_policy,
             bool                         use_context_fmha,
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 34aa240524..db99ea9d0e 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -10,18 +10,18 @@ namespace turbomind {
 SequenceManager::SequenceManager(size_t      layer_num,
                                  size_t      head_num,
                                  size_t      head_dim,
-                                 size_t      block_len,
+                                 size_t      block_seq_len,
                                  double      block_count,
                                  int         chunk_size,
                                  size_t      elem_bits,
                                  int         rank,
                                  IAllocator* allocator):
-    block_len_(block_len)
+    block_seq_len_(block_seq_len)
 {
     constexpr int kBitsPerByte = 8;
 
-    // [2, L, H, block_len, D]
-    size_t block_size = 2UL * layer_num * head_num * block_len * head_dim * elem_bits / kBitsPerByte;
+    // [2, L, H, block_seq_len, D]
+    size_t block_size = 2UL * layer_num * head_num * block_seq_len * head_dim * elem_bits / kBitsPerByte;
 
     block_manager_ = std::make_unique<BlockManager>(block_size, block_count, chunk_size, allocator);
 
@@ -91,7 +91,7 @@ void SequenceManager::Verify(Sequence& seq, std::vector<const Block*>& retain)
     }
     retain.insert(retain.end(), seq.blocks.begin(), seq.blocks.end());
     seq.status    = Sequence::kLocked;
-    seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_len_);
+    seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);
 }
 
 void SequenceManager::Release(const Sequence& sequence)
@@ -286,7 +286,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
     int              total_required{};
     for (int i = 0; i < sequences.size(); ++i) {
         int seq_len = context_lengths[i] + step_length;
-        int count   = (seq_len + block_len_ - 1) / block_len_ - static_cast<int>(seqs[i]->blocks.size());
+        int count   = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast<int>(seqs[i]->blocks.size());
         required[i] = std::max(0, count);
         total_required += required[i];
     }
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 5fdbf2054d..190b29feaf 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -33,7 +33,7 @@ class SequenceManager {
     explicit SequenceManager(size_t      layer_num,
                              size_t      head_num,
                              size_t      head_dim,
-                             size_t      block_len,
+                             size_t      block_seq_len,
                              double      block_count,
                              int         chunk_size,
                              size_t      elem_bits,
@@ -83,7 +83,7 @@ class SequenceManager {
     void Verify(Sequence& seq, std::vector<const Block*>& retain);
 
 private:
-    int    block_len_;
+    int    block_seq_len_;
     int    rank_;
     size_t val_offset_{};
 
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index 8a7674a2ab..beab5d7d94 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -65,7 +65,7 @@ void LlamaTritonModel<T>::handleMissingParams()
     }
 
     if (!max_batch_size_) {
-        max_batch_size_ = 32;
+        max_batch_size_ = 64;
         TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_);
     }
 
@@ -85,14 +85,18 @@ void LlamaTritonModel<T>::handleMissingParams()
         TM_LOG_WARNING("[LlamaTritonModel] `step_length` is not set, default to %d.", (int)step_length_);
     }
 
-    if (!cache_max_entry_count_) {
-        cache_max_entry_count_ = 32;
-        TM_LOG_WARNING("[LlamaTritonModel] `cache_max_entry_count` is not set, default to %d.",
-                       (int)cache_max_entry_count_);
+    if (!cache_max_block_count_) {
+        cache_max_block_count_ = .95f;
+        TM_LOG_WARNING("[LlamaTritonModel] `cache_max_entry_count` is not set, default to %f.", cache_max_block_count_);
+    }
+
+    if (!cache_block_seq_len_) {
+        cache_block_seq_len_ = 128;
+        TM_LOG_WARNING("[LlamaTritonModel] `cache_block_seq_len` is not set, default to %d.", cache_block_seq_len_);
     }
 
     if (!cache_chunk_size_) {
-        cache_chunk_size_ = cache_max_entry_count_;
+        cache_chunk_size_ = cache_max_block_count_;
         TM_LOG_WARNING("[LlamaTritonModel] `cache_chunk_size` is not set, default to %d.", (int)cache_chunk_size_);
     }
 }
@@ -129,12 +133,14 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t      tensor_para_size,
     max_context_token_num_ = reader.GetInteger("llama", "max_context_token_num", 0);
     session_len_           = reader.GetInteger("llama", "session_len", 0);
     step_length_           = reader.GetInteger("llama", "step_length", 0);
-    cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0);
-    use_context_fmha_      = reader.GetInteger("llama", "use_context_fmha", 1);
+    cache_max_block_count_ = reader.GetFloat("llama", "cache_max_entry_count", 0);
+    cache_block_seq_len_   = reader.GetInteger("llama", "cache_block_seq_len", 0);
     cache_chunk_size_      = reader.GetInteger("llama", "cache_chunk_size", 0);
-    attn_bias_             = reader.GetInteger("llama", "attn_bias", 0);
-    quant_policy_          = reader.GetInteger("llama", "quant_policy", 0);
-    group_size_            = reader.GetInteger("llama", "group_size", 0);
+    use_context_fmha_      = reader.GetInteger("llama", "use_context_fmha", 1);
+
+    attn_bias_    = reader.GetInteger("llama", "attn_bias", 0);
+    quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
+    group_size_   = reader.GetInteger("llama", "group_size", 0);
 
     attn_params_.rotray_embedding_dim    = reader.GetInteger("llama", "rotary_embedding");
     attn_params_.rotary_embedding_base   = reader.GetFloat("llama", "rope_theta", 10000.0f);
@@ -235,7 +241,8 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
                                                   step_length_,
                                                   start_id_,
                                                   end_id_,
-                                                  cache_max_entry_count_,
+                                                  cache_max_block_count_,
+                                                  cache_block_seq_len_,
                                                   cache_chunk_size_,
                                                   quant_policy_,
                                                   use_context_fmha_,
@@ -320,12 +327,13 @@ std::string LlamaTritonModel<T>::toString()
        << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_
        << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << max_batch_size_
        << "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_
-       << "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_
-       << "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_
-       << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_
-       << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_
-       << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_
-       << "\ngroup_size: " << group_size_ << std::endl;
+       << "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_block_count_
+       << "\ncache_block_seq_len: " << cache_block_seq_len_ << "\ncache_chunk_size: " << cache_chunk_size_
+       << "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_
+       << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_
+       << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_
+       << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << "\ngroup_size: " << group_size_
+       << std::endl;
 
     return ss.str();
 }
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h
index b7d8f439ca..0e2b89bff8 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h
@@ -93,7 +93,8 @@ struct LlamaTritonModel: public AbstractTransformerModel {
     int                             step_length_;
     int                             start_id_;
     int                             end_id_;
-    int                             cache_max_entry_count_;
+    float                           cache_max_block_count_;
+    int                             cache_block_seq_len_;
     int                             cache_chunk_size_;
     int                             use_context_fmha_;
     size_t                          tensor_para_size_;

From b49e84ebdbf12dbf4b3f268211170b85eafd0b3c Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 13 Oct 2023 08:25:48 +0000
Subject: [PATCH 17/56] debug

---
 src/turbomind/models/llama/LlamaBatch.cc | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 64e6557862..cea564441c 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -138,12 +138,10 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
         }
         // clear output buffers (prevent leaking conversations) if request is successful
         if (ec == 0) {
-
             if (rank_ == 0) {
                 std::unique_lock lock{output_mutex_};
                 output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
             }
-
             auto& output_ids      = r->outputs[rank_].at("output_ids");
             auto& sequence_length = r->outputs[rank_].at("sequence_length");
             Clear(output_ids.getPtr<int>(), output_ids.shape.at(2));
@@ -1294,23 +1292,27 @@ void LlamaBatch<T>::OutputThreadEntry()
             std::unique_lock lock(output_mutex_);
             output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; });
 
+            // stop requested
+            if (output_stop_token_) {
+                TM_LOG_INFO("[OutputThreadEntry] stop requested.");
+                break;
+            }
+
             if (rank_ == 0 && model_->ffi_lock_) {
+                TM_LOG_INFO("acquire GIL");
                 model_->ffi_lock_(1);
+                TM_LOG_INFO("acquire GIL success");
             }
             // invoke stream cbs
             for (const auto& r : output_reqs_) {
                 r->stream_cb(&r->outputs[rank_].get());
             }
             if (rank_ == 0 && model_->ffi_lock_) {
+                TM_LOG_INFO("release GIL");
                 model_->ffi_lock_(0);
+                TM_LOG_INFO("release GIL success");
             }
             output_reqs_.clear();
-
-            // stop requested
-            if (output_stop_token_) {
-                TM_LOG_INFO("[OutputThreadEntry] stop requested.");
-                break;
-            }
         }
         FT_CHECK(output_reqs_.empty());
         // notify infer thread 0

From b4e8bf115036bff64d3405fc834e62f8b854641b Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 16 Oct 2023 04:35:59 +0000
Subject: [PATCH 18/56] optimize output cost

---
 src/turbomind/models/llama/LlamaBatch.cc      | 123 +++++++++++++-----
 src/turbomind/models/llama/LlamaBatch.h       |   8 ++
 .../llama/LlamaContextAttentionLayer.cc       |   2 +-
 .../models/llama/LlamaContextDecoder.cc       |   6 +-
 src/turbomind/models/llama/SequenceManager.cc |  10 +-
 src/turbomind/models/llama/llama_kernels.cu   |  49 ++++++-
 src/turbomind/models/llama/llama_kernels.h    |  10 ++
 7 files changed, 168 insertions(+), 40 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index cea564441c..33d2ce998e 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -7,6 +7,7 @@
 #include "src/turbomind/models/llama/LlamaV2.h"
 #include "src/turbomind/models/llama/Request.h"
 #include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/llama_kernels.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
@@ -471,6 +472,10 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
     finished_buf_  = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
     seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
 
+    request_output_ids_ptrs_ = (int**)allocator_->reMalloc(request_output_ids_ptrs_, sizeof(int*) * batch_size, true);
+    request_output_ids_lens_ = (int*)allocator_->reMalloc(request_output_ids_lens_, sizeof(int) * batch_size, true);
+    request_seqlen_ptrs_     = (int**)allocator_->reMalloc(request_seqlen_ptrs_, sizeof(int*) * batch_size, true);
+
     is_allocate_buffer_ = true;
 }
 
@@ -530,6 +535,13 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
 
         h_seq_limit_len_ =
             (uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
+
+        h_request_output_ids_ptrs_ =
+            (int**)allocator_->reMalloc(h_request_output_ids_ptrs_, sizeof(int*) * max_batch_size, true, true);
+        h_request_output_ids_lens_ =
+            (int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true);
+        h_request_seqlen_ptrs_ =
+            (int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true);
     }
 
     is_allocate_persistant_buffer_ = true;
@@ -578,6 +590,10 @@ void LlamaBatch<T>::FreeBuffer()
         allocator_->free((void**)&finished_buf_);
         allocator_->free((void**)&seq_limit_len_);
 
+        allocator_->free((void**)&request_output_ids_ptrs_);
+        allocator_->free((void**)&request_output_ids_lens_);
+        allocator_->free((void**)&request_seqlen_ptrs_);
+
         is_allocate_buffer_ = false;
     }
 
@@ -595,6 +611,11 @@ void LlamaBatch<T>::FreeBuffer()
         allocator_->free((void**)&h_input_ids_buf_, true);
         allocator_->free((void**)&h_input_length_buf_, true);
         allocator_->free((void**)&h_seq_limit_len_, true);
+
+        allocator_->free((void**)&h_request_output_ids_ptrs_, true);
+        allocator_->free((void**)&h_request_output_ids_lens_, true);
+        allocator_->free((void**)&h_request_seqlen_ptrs_, true);
+
         is_allocate_persistant_buffer_ = false;
     }
 }
@@ -729,6 +750,23 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
     Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
     Copy(state_->h_finished, batch_size, finished_buf_);
 
+    for (int i = 0; i < batch_size; ++i) {
+        Tensor& output_ids         = state_->requests[i]->outputs[rank_].at("output_ids");
+        int*    req_output_ids_ptr = output_ids.getPtr<int>();
+        int*    req_seqlen_ptr     = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
+
+        h_request_output_ids_ptrs_[i] = req_output_ids_ptr;
+        h_request_output_ids_lens_[i] = output_ids.shape.at(2);
+        h_request_seqlen_ptrs_[i]     = req_seqlen_ptr;
+
+        FT_CHECK(h_request_output_ids_ptrs_[i]);
+        FT_CHECK(h_request_output_ids_lens_[i]);
+        FT_CHECK(h_request_seqlen_ptrs_[i]);
+    }
+    Copy(h_request_output_ids_ptrs_, batch_size, request_output_ids_ptrs_);
+    Copy(h_request_output_ids_lens_, batch_size, request_output_ids_lens_);
+    Copy(h_request_seqlen_ptrs_, batch_size, request_seqlen_ptrs_);
+
     // ! range of step_ [1, 2 * session_len]
     // consider a sequence with context_len == session_len and another sequence with context_len == 1 and
     // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
@@ -849,7 +887,7 @@ void LlamaBatch<T>::ContextDecode()
     for (int i = 0; i < batch_size; ++i) {
         if (state_->is_swap_in[i]) {
             const auto& seq = *state_->sequences[i];
-            dbg(state_->h_context_length[i], seq.cache_len);
+            dbg(std::tuple(i, state_->h_context_length[i], seq.cache_len));
             if (const int missing = state_->h_context_length[i] - seq.cache_len; missing > 1) {
                 base = base < 0 ? i : base;
                 dbg(seq.tokens, seq.cache_len);
@@ -918,7 +956,8 @@ void LlamaBatch<T>::ContextDecode()
         int              max_input_len{};
         auto             input_ids = context_decoder_ids_buf_;
         for (int i = first; i < last; ++i) {
-            input_ids        = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
+            input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
+            dbg(i, h_input_length_buf_[i]);
             h_tmp_k_ptrs_[i] = k_ptr;
             h_tmp_v_ptrs_[i] = v_ptr;
             k_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
@@ -1037,15 +1076,17 @@ void LlamaBatch<T>::OutputContextLogits(T*                      context_decoder_
 template<typename T>
 auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
 {
+    NvtxScope scope("Finish");
     const int batch_size = state_->active_size;
 
     // secure info needed by `Initialize()`
     Copy(finished_buf_, batch_size, state_->h_finished);
     Copy(sequence_lengths_, batch_size, state_->h_context_length);
 
-    if (1) {
+    if constexpr (0) {
         std::unique_lock<std::mutex> lock;
         if (rank_ == 0) {
+            NvtxScope _("acquire_outputs");
             // wait for previous output operations
             lock = std::unique_lock{output_mutex_};
             output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
@@ -1055,6 +1096,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
         check_cuda_error(cudaStreamSynchronize(stream_));
 
         if (rank_ == 0) {
+            NvtxScope _("signal_output_thread");
             // enqueue new output requests
             for (int i = 0; i < batch_size; ++i) {
                 FT_CHECK(state_->requests[i] != nullptr);
@@ -1073,17 +1115,20 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
         SetOutputTensors(g);
         check_cuda_error(cudaStreamSynchronize(stream_));
 
-        if (rank_ == 0 && model_->ffi_lock_) {
-            model_->ffi_lock_(1);
-        }
-        for (int i = 0; i < batch_size; ++i) {
-            FT_CHECK(state_->requests[i] != nullptr);
-            if (state_->requests[i]->stream_cb && rank_ == 0) {
-                state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
+        {
+            NvtxScope _("output_cb");
+            if (rank_ == 0 && model_->ffi_lock_) {
+                model_->ffi_lock_(1);
+            }
+            for (int i = 0; i < batch_size; ++i) {
+                FT_CHECK(state_->requests[i] != nullptr);
+                if (state_->requests[i]->stream_cb && rank_ == 0) {
+                    state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
+                }
+            }
+            if (rank_ == 0 && model_->ffi_lock_) {
+                model_->ffi_lock_(0);
             }
-        }
-        if (rank_ == 0 && model_->ffi_lock_) {
-            model_->ffi_lock_(0);
         }
     }
 
@@ -1096,10 +1141,13 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
     }
 
     std::vector<Signal> signals;
-    for (int i = 0; i < batch_size; ++i) {
-        if (state_->requests[i] && state_->h_finished[i]) {
-            CompleteRequest(i, false, false);
-            signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
+    {
+        NvtxScope _("prepare_completion_signal");
+        for (int i = 0; i < batch_size; ++i) {
+            if (state_->requests[i] && state_->h_finished[i]) {
+                CompleteRequest(i, false, false);
+                signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
+            }
         }
     }
     return signals;
@@ -1108,6 +1156,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
 template<typename T>
 void LlamaBatch<T>::SetOutputTensors(const GenerationState& g)
 {
+    NvtxScope scope("SetOutputTensors");
     // dbg(g.max_init_ctx_len);
     const auto batch_size = state_->active_size;
     // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
@@ -1121,17 +1170,30 @@ void LlamaBatch<T>::SetOutputTensors(const GenerationState& g)
                        stream_);
     sync_check_cuda_error();
 
-    /// TODO: fuse the loop into a single kernel
-    for (int i = 0; i < batch_size; ++i) {
-        if (state_->requests[i]) {
-            auto& output_ids      = state_->requests[i]->outputs[rank_].at("output_ids");
-            auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
-            Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
-            Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
-            if (g.step > g.max_init_ctx_len) {  // +1 for newly generated token
-                invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
-            }
-        }
+    if constexpr (1) {
+        invokeUpdateOutput(request_output_ids_ptrs_,
+                           request_seqlen_ptrs_,
+                           state_->output_ids,
+                           sequence_lengths_,
+                           request_output_ids_lens_,
+                           session_len_,
+                           g.step > g.max_init_ctx_len,
+                           batch_size,
+                           stream_);
+        sync_check_cuda_error();
+    }
+    else {
+        // for (int i = 0; i < batch_size; ++i) {
+        //     if (state_->requests[i]) {
+        //         auto& output_ids      = state_->requests[i]->outputs[rank_].at("output_ids");
+        //         auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
+        //         Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
+        //         Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
+        //         if (g.step > g.max_init_ctx_len) {  // +1 for newly generated token
+        //             invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
+        //         }
+        //     }
+        // }
     }
 }
 
@@ -1275,6 +1337,7 @@ void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Si
 template<typename T>
 void LlamaBatch<T>::Start()
 {
+    TM_LOG_ERROR("LlamaBatch<T>::Start()");
     int device_id = -1;
     check_cuda_error(cudaGetDevice(&device_id));
     internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
@@ -1291,11 +1354,11 @@ void LlamaBatch<T>::OutputThreadEntry()
             // wait for requests with stream cbs
             std::unique_lock lock(output_mutex_);
             output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; });
-
+            // NvtxScope _("output_callback");
             // stop requested
             if (output_stop_token_) {
                 TM_LOG_INFO("[OutputThreadEntry] stop requested.");
-                break;
+                return;
             }
 
             if (rank_ == 0 && model_->ffi_lock_) {
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 7cf408f4d8..f0d3b4d303 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -84,6 +84,7 @@ class LlamaBatch {
 
     ~LlamaBatch()
     {
+        TM_LOG_ERROR("~LlamaBatch()");
         model_->shared_state_->request_queue.close();
 
         internal_thread_.join();
@@ -183,6 +184,13 @@ class LlamaBatch {
     bool*     finished_buf_{};
     uint32_t* seq_limit_len_{};
 
+    int** request_output_ids_ptrs_{};
+    int*  request_output_ids_lens_{};
+    int** request_seqlen_ptrs_{};
+    int** h_request_output_ids_ptrs_{};
+    int*  h_request_output_ids_lens_{};
+    int** h_request_seqlen_ptrs_{};
+
     // pinned buffers
     int* h_input_ids_buf_{};
     int* h_input_length_buf_{};
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 68e743ba8d..65b2be0aac 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -281,7 +281,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                   weights->past_kv_scale.data());
     }
 
-    Compare(qkv_buf_3_, num_token * hidden_units_, Concat("qkv_buf_3", layer_id), kCmpRead, stream_);
+    // Compare(qkv_buf_3_, num_token * hidden_units_, Concat("qkv_buf_3", layer_id), kCmpRead, stream_);
 
     // dbg(max_seq_len);
 
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index c7231c1aa8..20c4437f33 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -209,8 +209,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
 
     allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len);
 
+    // dbg(padding_offset_);
     FT_CHECK(padding_offset_);
-    dbg(padding_offset_);
 
     size_t tmp_token_num{};
     invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_,
@@ -234,8 +234,8 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
                             stream_);
     sync_check_cuda_error();
 
-    Compare(
-        decoder_input_output, sess.token_num * hidden_units_, Concat("context_decoder_input", 0), kCmpRead, stream_);
+    // Compare(
+    //     decoder_input_output, sess.token_num * hidden_units_, Concat("context_decoder_input", 0), kCmpRead, stream_);
 
     /////////////////////////////////////////////
     /// RMSNorm
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index db99ea9d0e..8c9375f76c 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -150,7 +150,7 @@ class Simulator {
                        std::vector<int>&                   ref_count):
         seqs_(seqs), idxs_(idxs), ref_count_(ref_count)
     {
-        dbg(seqs.size());
+        // dbg(seqs.size());
         released_.resize(seqs.size());
         ptr_ = released_.size();
     }
@@ -336,9 +336,9 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
             if (sequences[idxs[v]]->status == Sequence::kCached) {
                 continue;
             }
-            dbg(v, idxs[v]);
+            // dbg(v, idxs[v]);
             int preempt = trans.Preempt(v, idxs[v]);
-            dbg(preempt);
+            // dbg(preempt);
             // Commit only when preemption actually free enough blocks for the sequence to run
             if (block_count <= preempt) {
                 // preempted blocks are in cached state
@@ -348,7 +348,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
             }
         }
 
-        dbg(block_count, trans);
+        // dbg(block_count, trans);
 
         if (block_count == 0) {
             trans.Commit();
@@ -368,7 +368,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
         }
     }
 
-    dbg(schedule);
+    // dbg(schedule);
 
     ////////////////////////////////////////////////////////////////////////////////
     /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)
diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu
index 1f334b40b3..9abcfc9d0d 100644
--- a/src/turbomind/models/llama/llama_kernels.cu
+++ b/src/turbomind/models/llama/llama_kernels.cu
@@ -553,12 +553,59 @@ void invokeGatherOutput(int*         output_ids,
                         int          batch_size,
                         cudaStream_t stream)
 {
-    int block_size = 512;
+    int block_size = 128;
     int grid_size  = batch_size;
     gatherOutput<<<grid_size, block_size, 0, stream>>>(
         output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
 }
 
+__global__ void updateOutput(int**      request_output_ids_ptrs,
+                             int**      request_seqlen_ptrs,
+                             const int* output_ids,
+                             const int* sequence_lengths,
+                             const int* request_output_ids_lens,
+                             int        max_session_len,
+                             bool       token_generated)
+{
+    const int batch_id = blockIdx.x;
+
+    auto request_output_ids = request_output_ids_ptrs[batch_id];
+    auto request_seqlen     = request_seqlen_ptrs[batch_id];
+
+    output_ids += max_session_len * batch_id;
+
+    const int seqlen     = sequence_lengths[batch_id];
+    const int output_len = min(seqlen, request_output_ids_lens[batch_id]);
+
+    for (int i = threadIdx.x; i < output_len; i += blockDim.x) {
+        request_output_ids[i] = output_ids[i];
+    }
+
+    *request_seqlen = seqlen;
+}
+
+void invokeUpdateOutput(int**        request_output_ids_ptrs,
+                        int**        request_seqlen_ptrs,
+                        const int*   output_ids,
+                        const int*   sequence_lengths,
+                        const int*   request_output_ids_lens,
+                        int          max_session_len,
+                        bool         token_generated,
+                        int          batch_size,
+                        cudaStream_t stream)
+{
+    constexpr int block_size = 128;
+    const int     grid_size  = batch_size;
+
+    updateOutput<<<grid_size, block_size, 0, stream>>>(request_output_ids_ptrs,
+                                                       request_seqlen_ptrs,
+                                                       output_ids,
+                                                       sequence_lengths,
+                                                       request_output_ids_lens,
+                                                       max_session_len,
+                                                       token_generated);
+}
+
 #define VERSION_SWITCH(VERSION, CONST_NAME, ...)                                                                       \
     [&] {                                                                                                              \
         if (VERSION == 2) {                                                                                            \
diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h
index 7ce432d204..e629263a59 100644
--- a/src/turbomind/models/llama/llama_kernels.h
+++ b/src/turbomind/models/llama/llama_kernels.h
@@ -77,6 +77,16 @@ void invokeGatherOutput(int*         output_ids,
                         int          batch_size,
                         cudaStream_t stream);
 
+void invokeUpdateOutput(int**        request_output_ids_ptrs,
+                        int**        request_seqlen_ptrs,
+                        const int*   output_ids,
+                        const int*   sequence_lengths,
+                        const int*   request_output_ids_lens,
+                        int          max_session_len,
+                        bool         token_generated,
+                        int          batch_size,
+                        cudaStream_t stream);
+
 void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st);
 
 template<typename T>

From bdf0b41ad78d6d6056a1ff3d3b4fff902cb36afb Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 19 Oct 2023 09:59:53 +0000
Subject: [PATCH 19/56] split-k decoding

---
 .../decoder_multihead_attention.cu            |  63 ++--
 .../decoder_multihead_attention_params.h      |   5 +
 .../decoder_multihead_attention_template.h    | 349 ++++++++++++++++--
 .../test_decoder_multihead_attention.cu       |  61 +--
 src/turbomind/models/llama/LlamaBatch.cc      |  16 +-
 src/turbomind/models/llama/LlamaBatch.h       |   2 +
 src/turbomind/models/llama/LlamaDecoder.cc    |   3 -
 src/turbomind/models/llama/LlamaDecoder.h     |   2 -
 .../llama/LlamaDecoderSelfAttentionLayer.cc   |  28 +-
 .../llama/LlamaDecoderSelfAttentionLayer.h    |   3 +
 src/turbomind/models/llama/LlamaV2.cc         |   6 +-
 src/turbomind/models/llama/LlamaV2.h          |   2 +
 12 files changed, 442 insertions(+), 98 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index cd5b5a908f..abf24ea8f7 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -5,8 +5,10 @@
 
 namespace turbomind {
 
+namespace {
+
 template<typename MHAType>
-bool Dump()
+bool Print(size_t dynamic_smem_size)
 {
     using MapKv = typename MHAType::MapKv;
 
@@ -20,34 +22,50 @@ bool Dump()
     std::cout << "      iter: (" << MapKv::kIterC << ", " << MapKv::kIterS << ")\n";
     std::cout << " footprint: (" << MapKv::kFootprintC << ", " << MapKv::kFootprintS << ")\n";
     std::cout << "     delta: (" << MapKv::kDeltaC << ", " << MapKv::kDeltaS << ")\n";
+    std::cout << "dynamic smem size: " << dynamic_smem_size << "\n";
 
     return true;
 }
 
+}  // namespace
+
 template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
 void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
 {
     // cpasync_2048_32x6 ~ 64k smem
-    using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;
+    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;
+
+    using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
 
     // ld_kv16_2048_32x3 ~ 34k smem
     // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>;
 
     // ld_kv8_2048_64x3 ~ 34k smem
-    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 2048, 3>;
+    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HedDim, 64, HeadDim, 2048, 3>;
+
+    static const size_t kDynSmemSize = MHAType::GetDynamicSmemSize();
+
+    [[maybe_unused]] static const bool _ = Print<MHAType>(kDynSmemSize);
 
-    [[maybe_unused]] static const bool init = Dump<MHAType>();
+    const int slice_count = (params.max_seq_len + MHAType::kSliceLen - 1) / MHAType::kSliceLen;
+    const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
 
     dim3 block(MHAType::kWarpCount * WARP_SIZE);
-    dim3 grid(params.num_heads / HeadPerCta, params.batch_size);
+    dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
 
-    static const size_t kDynSmemSize = MHAType::GetDynamicSmemSize();
-    // std::cout << "dynamic shared memory size: " << kDynamicSmemSize << "\n";
+    // if (params.layer_offset == 0) {
+    //     std::cout << "max_split_k' = " << max_split_k << "\n";
+    // }
 
     cudaFuncSetAttribute(
         decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
 
     decoder_multihead_attention<MHAType><<<grid, block, kDynSmemSize, params.stream>>>(params);
+
+    if (max_split_k > 1) {
+        dim3 grid(params.num_heads, params.batch_size);
+        decoder_multihead_attention_reduce<MHAType><<<grid, block, 0, params.stream>>>(params);
+    }
 }
 
 template<typename T>
@@ -58,25 +76,22 @@ void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>&
     FT_CHECK(params.size_per_head == HeadDim);
 
     if constexpr (std::is_same_v<T, half>) {
-
-        //     int group_size = params.num_heads / params.num_kv_heads;
-
-        //     if (group_size % 8 == 0) {
-        //         invokeDecoderMultiheadAttention<T, HeadDim, 8>(params);
-        //     }
-        //     else if (group_size % 4 == 0) {
-        //         invokeDecoderMultiheadAttention<T, HeadDim, 4>(params);
-        //     }
-        //     else if (group_size % 2 == 0) {
-        //         invokeDecoderMultiheadAttention<T, HeadDim, 2>(params);
-        //     }
-        //     else {
-        //         invokeDecoderMultiheadAttention<T, HeadDim, 1>(params);
-        //     }
-        // }
-        // else {
         if (params.quant_policy & QuantPolicy::kCacheKVInt8) {
             invokeDecoderMultiheadAttention<T, int8_t, HeadDim, 1>(params);
+            return;
+        }
+
+        int group_size = params.num_heads / params.num_kv_heads;
+
+        if (0) {}
+        // else if (group_size % 8 == 0) {
+        //     invokeDecoderMultiheadAttention<T, T, HeadDim, 8>(params);
+        // }
+        else if (group_size % 4 == 0) {
+            invokeDecoderMultiheadAttention<T, T, HeadDim, 4>(params);
+        }
+        else if (group_size % 2 == 0) {
+            invokeDecoderMultiheadAttention<T, T, HeadDim, 2>(params);
         }
         else {
             invokeDecoderMultiheadAttention<T, T, HeadDim, 1>(params);
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index de84526ec0..1e8c556c41 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -56,6 +56,11 @@ struct DecoderMultiHeadAttentionParams {
     int   quant_policy;
     float kv_quant_params[4];
 
+    int    max_split_k;
+    float* partial_O;
+    float* partial_M;
+    float* partial_L;
+
     cudaStream_t stream;
 };
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index a569e95144..e129f260c6 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -4,6 +4,8 @@
 #include "iterator.h"
 #include "src/turbomind/kernels/gemm_s_f16/common.h"
 #include "thread_map.h"
+#include <climits>
+#include <cmath>
 #include <cstdint>
 #include <cuda_pipeline_primitives.h>
 #include <type_traits>
@@ -14,21 +16,23 @@ namespace turbomind {
 
 template<typename T,
          typename Tkv,
-         int HeadPerCta,
-         int MaxHeadDim,
-         int KeyPerIter,
-         int HeadDim,
-         int SliceLen,
-         int Stages>
+         int  HeadPerCta,
+         int  MaxHeadDim,
+         int  KeyPerIter,
+         int  HeadDim,
+         int  SliceLen,
+         int  Stages,
+         bool SplitK>
 struct DecoderMultiHeadAttentionKernel {
     using ParamType = DecoderMultiHeadAttentionParams<T>;
 
-    static constexpr int kWarpCount  = 4;
-    static constexpr int kHeadPerCta = HeadPerCta;
-    static constexpr int kMaxHeadDim = MaxHeadDim;
-    static constexpr int kKeyPerIter = KeyPerIter;
-    static constexpr int kHeadDim    = HeadDim;
-    static constexpr int kStages     = Stages;
+    static constexpr int  kWarpCount  = 4;
+    static constexpr int  kHeadPerCta = HeadPerCta;
+    static constexpr int  kMaxHeadDim = MaxHeadDim;
+    static constexpr int  kKeyPerIter = KeyPerIter;
+    static constexpr int  kHeadDim    = HeadDim;
+    static constexpr int  kStages     = Stages;
+    static constexpr bool kSplitK     = SplitK;
 
     static constexpr int kSliceLen     = SliceLen;
     static constexpr int kIterPerSlice = kSliceLen / kKeyPerIter;
@@ -47,7 +51,8 @@ struct DecoderMultiHeadAttentionKernel {
     static constexpr size_t GetDynamicSmemSize()
     {
         size_t smem_kv_cache = IterKv::kSmemByteSize;
-        size_t smem_kv_align = 128;
+        // size_t smem_kv_align = 128;
+        size_t smem_kv_align = 0;
         size_t smem_qk       = sizeof(float) * kHeadPerCta * kSliceLen;
         size_t smem_pr       = sizeof(float) * kHeadPerCta * kSliceLen;
         return smem_kv_align + smem_kv_cache + std::max(smem_qk, smem_pr);
@@ -78,6 +83,9 @@ struct DecoderMultiHeadAttentionKernel {
     int  kv_head_idx_;
     bool is_gqa_leader_;
 
+    int step_begin_;
+    int step_end_;
+
     int timestep_;
     Tkv* __restrict__ k_cache_;  // [S, D]
     Tkv* __restrict__ v_cache_;  // [S, D]
@@ -137,6 +145,18 @@ struct DecoderMultiHeadAttentionKernel {
 
         timestep_ = params_.per_sample_length[batch_idx_];
 
+        if (kSplitK && params.max_split_k > 1) {
+            const int slice_count     = (timestep_ + kSliceLen - 1) / kSliceLen;
+            const int slice_per_split = (slice_count + params_.max_split_k - 1) / params_.max_split_k;
+
+            step_begin_ = slice_per_split * get_split_k_idx() * kSliceLen;
+            step_end_   = min(timestep_, step_begin_ + slice_per_split * kSliceLen);
+        }
+        else {
+            step_begin_ = 0;
+            step_end_   = timestep_;
+        }
+
         if constexpr (kUseBlockIter) {
             k_cache_ptrs_ = params_.k_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
             v_cache_ptrs_ = params_.v_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
@@ -156,7 +176,8 @@ struct DecoderMultiHeadAttentionKernel {
         static_assert(kMaxHeadDim % WARP_SIZE == 0);
         static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
 
-        using VecQ = Array<T, kVecQSize>;
+        using VecQ      = Array<T, kVecQSize>;
+        using VecQFloat = Array<float, kVecQSize>;
 
         using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
 
@@ -220,13 +241,29 @@ struct DecoderMultiHeadAttentionKernel {
         }
         rotary_emb.apply(frag_K);
 
+        if (kSplitK && step_begin_) {  // Split idx > 0
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                int qi = offset.y + s;
+                if (lane_id_ == 0) {
+                    smem_M_[qi] = -std::numeric_limits<float>::infinity();
+                    smem_L_[qi] = 0.f;
+                }
+                Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
+                Store(&smem_O_[qi * kMaxHeadDim + offset.x], VecQFloat{});
+            }
+            return;
+        }
+
+        ////////////////////////////////////////////////////////
+        // Split 0 computes last step and stores to k/v cache
+
         PRAGMA_UNROLL
         for (int s = 0; s < kQHeadPerThread; ++s) {
             int         qi = offset.y + s;
             QkAccumType qk = qk_dot<QkAccumType, QkComputeType, WARP_SIZE>(frag_Q[s], frag_K);
             if (lane_id_ == 0) {
                 qk *= params_.inv_sqrt_dh;
-                // printf("qk_last[%d]=%f\n", head_idx_, qk);
                 smem_M_[qi] = qk;
                 smem_L_[qi] = 1.f;
             }
@@ -272,8 +309,6 @@ struct DecoderMultiHeadAttentionKernel {
     __device__ void CpAsyncWait()
     {
         __pipeline_wait_prior(kStages - 2);
-        // __syncwarp();
-        // __syncthreads();
     }
 
     __device__ void CpAsyncCommit()
@@ -676,9 +711,9 @@ struct DecoderMultiHeadAttentionKernel {
         State state;
 
         PRAGMA_NO_UNROLL
-        for (int step = 0; step < timestep_; step += kSliceLen) {
-            int iter_length = min(timestep_ - step, kSliceLen);
-            ComputeSlice(frag_Q, state, offset, step, iter_length);
+        for (int step = step_begin_; step < step_end_; step += kSliceLen) {
+            int iter_count = min(step_end_ - step, kSliceLen);
+            ComputeSlice(frag_Q, state, offset, step, iter_count);
         }
     }
 
@@ -691,6 +726,11 @@ struct DecoderMultiHeadAttentionKernel {
             __syncthreads();
         }
 
+        // early exit if split if out of bound
+        if (kSplitK && step_begin_ >= step_end_) {
+            return;
+        }
+
         // early exit if finished flag is set
         if (params_.finished[batch_idx_]) {
             return;
@@ -714,7 +754,6 @@ struct DecoderMultiHeadAttentionKernel {
     {
         static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
 
-        using VecQ      = Array<T, kVecQSize>;
         using VecQFloat = Array<float, kVecQSize>;
 
         using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
@@ -723,24 +762,259 @@ struct DecoderMultiHeadAttentionKernel {
 
         int2 offset = MapQ::get_offset(warp_id_, lane_id_);
 
-        bool is_valid = offset.x < kMaxHeadDim && offset.y < kHeadPerCta;
-        if (!is_valid) {
+        if (offset.x >= kMaxHeadDim || offset.y >= kHeadPerCta) {
+            return;
+        }
+
+        using namespace ops;
+
+        if (!kSplitK || (step_begin_ == 0 && step_end_ == timestep_)) {  // non-split-k
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {
+                const int di = offset.x;
+                const int qi = offset.y + s;
+
+                const float     scale  = __fdividef(1.f, smem_L_[qi] + 1e-8f);
+                const VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
+
+                Store(&params_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
+                      cast<T>(frag_O));
+            }
+        }
+        else {
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {  // split-k
+                const int di = offset.x;
+                const int qi = offset.y + s;
+
+                const VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di];
+
+                // [B, H, k, D]
+                const int index = batch_idx_ * params_.num_heads * params_.max_split_k
+                                  + (head_idx_ + qi) * params_.max_split_k + get_split_k_idx();
+                Store(&params_.partial_O[index * kHeadDim + di], cast<float>(frag_O));
+
+                if (di == 0) {
+                    params_.partial_M[index] = smem_M_[qi];
+                    params_.partial_L[index] = smem_L_[qi];
+                }
+            }
+        }
+    }
+
+    static __device__ void Reduce(const ParamType& params)
+    {
+        static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
+
+        // using VecQ      = Array<T, kVecQSize>;
+        using VecQFloat = Array<float, kVecQSize>;
+        using MapQ      = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
+
+        static constexpr int kQkvHeadPerThread = MapQ::kIterS;
+
+        const int batch_idx = get_batch_idx();
+        const int head_idx  = get_head_idx() * kHeadPerCta;
+        const int warp_id_  = threadIdx.x / WARP_SIZE;
+        const int lane_id_  = threadIdx.x % WARP_SIZE;
+        const int split_k   = (params.per_sample_length[batch_idx] + kSliceLen - 1) / kSliceLen;
+
+        int2 offset = MapQ::get_offset(warp_id_, lane_id_);
+
+        if (offset.x >= kMaxHeadDim || offset.y >= kHeadPerCta) {
             return;
         }
 
+        auto get_index = [&](int k, int qi) {
+            return batch_idx * params.num_heads * split_k + (head_idx + qi) * split_k + k;
+        };
+
+        Array<float, kQkvHeadPerThread> global_M;
+        PRAGMA_UNROLL
+        for (int s = 0; s < kQkvHeadPerThread; ++s) {
+            const int qi = offset.y + s;
+            global_M[s]  = params.partial_M[get_index(0, qi)];
+        }
+        for (int k = 1; k < split_k; ++k) {
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {
+                const int qi = offset.y + s;
+                global_M[s]  = max(global_M[qi], params.partial_M[get_index(k, qi)]);
+            }
+        }
+
+        Array<float, kQkvHeadPerThread> global_L{};
+        for (int k = 0; k < split_k; ++k) {
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {
+                const int qi = offset.y + s;
+                global_L[s] +=
+                    params.partial_L[get_index(k, qi)] * expf(params.partial_M[get_index(k, qi)] - global_M[s]);
+            }
+        }
+
+        VecQFloat frag_O[kQkvHeadPerThread]{};
+        for (int k = 0; k < split_k; ++k) {
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQkvHeadPerThread; ++s) {
+                const int di = offset.x;
+                const int qi = offset.y + s;
+
+                float scale = expf(params.partial_M[get_index(k, qi)] - global_M[s]) / (global_L[s] + 1e-8f);
+
+                VecQFloat partial_O;
+                Ldg(partial_O, &params.partial_O[get_index(k, qi) * kHeadDim + di]);
+
+                using namespace ops;
+                frag_O[s] = frag_O[s] + partial_O * scale;
+            }
+        }
+
         PRAGMA_UNROLL
         for (int s = 0; s < kQkvHeadPerThread; ++s) {
-            int   di    = offset.x;
-            int   qi    = offset.y + s;
-            float scale = __fdividef(1.f, smem_L_[qi] + 1e-6f);
-            // float scale = 1.f;
+            const int di = offset.x;
+            const int qi = offset.y + s;
+            Store(&params.out[batch_idx * params.num_heads * kHeadDim + (head_idx + qi) * kHeadDim + di],
+                  cast<T>(frag_O[s]));
+        }
+    }
+
+    static __device__ void Reduce2(const ParamType& params)
+    {
+        const int batch_idx = get_batch_idx();
+        const int head_idx  = get_head_idx();
+        const int lane_id   = threadIdx.x % WARP_SIZE;
+        const int warp_id   = threadIdx.x / WARP_SIZE;
+
+        const int timestep        = params.per_sample_length[batch_idx];
+        const int max_split_k     = params.max_split_k;
+        const int slice_count     = get_slice_count(timestep);
+        const int slice_per_split = (slice_count + max_split_k - 1) / max_split_k;
+        const int split_k         = (slice_count + slice_per_split - 1) / slice_per_split;
+
+        if (split_k == 1) {
+            return;
+        }
+
+        // [B, H, k, D]
+        const int index = batch_idx * params.num_heads * max_split_k + head_idx * max_split_k + threadIdx.x;
+
+        __shared__ float smem_global_M;
+        __shared__ float smem_global_L;
+        __shared__ __align__(16) float smem_expdiff_M[WARP_SIZE];
+        __shared__ __align__(16) float smem_scale_O[WARP_SIZE];
+
+        {
+            float global_M = threadIdx.x < split_k ? params.partial_M[index] : -std::numeric_limits<float>::infinity();
+            PRAGMA_UNROLL
+            for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+                global_M = fmaxf(global_M, __shfl_xor_sync((uint32_t)-1, global_M, mask));
+            }
+
+            if (threadIdx.x == 0) {
+                smem_global_M = global_M;
+            }
+        }
+
+        __syncthreads();
+
+        {
+            float global_L = threadIdx.x < split_k ? params.partial_L[index] : 0.f;
+
+            if (threadIdx.x < split_k) {
+                auto expdiff_M = expf(params.partial_M[index] - smem_global_M);
+                global_L *= expdiff_M;
+                smem_expdiff_M[threadIdx.x] = expdiff_M;
+            }
+
+            PRAGMA_UNROLL
+            for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+                global_L += __shfl_xor_sync((uint32_t)-1, global_L, mask);
+            }
+
+            if (threadIdx.x == 0) {
+                smem_global_L = global_L;
+            }
+        }
+
+        __syncthreads();
+
+        if (threadIdx.x < split_k) {
+            smem_scale_O[threadIdx.x] = smem_expdiff_M[threadIdx.x] / (smem_global_L + 1e-8f);
+        }
+
+        __syncthreads();
+
+        if constexpr (1) {
+            int   idx = (batch_idx * params.num_heads * max_split_k + head_idx * max_split_k) * kHeadDim + threadIdx.x;
+            float accum_O{};
+            const bool is_valid = threadIdx.x < kHeadDim;
+            for (int k = 0; k < split_k; ++k) {
+                if (is_valid) {
+                    accum_O += smem_scale_O[k] * params.partial_O[idx];
+                }
+                idx += kHeadDim;
+            }
+            if (is_valid) {
+                params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + threadIdx.x] = (T)accum_O;
+            }
+        }
+        else {  // vectorized version, not benificial
             using namespace ops;
-            VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
+            constexpr int kVecSize = 4;
+
+            using VecO = Array<float, kVecSize>;
+            VecO accum_O{};
+
+            int idx =
+                batch_idx * params.num_heads * max_split_k + head_idx * max_split_k + warp_id;  // offset by warp_id
+            idx = idx * kHeadDim + lane_id * kVecSize;                                          // offset by lane_id
+
+            for (int k = warp_id; k < split_k; k += kWarpCount) {
+                VecO frag_O;
+                Ldg(frag_O, &params.partial_O[idx]);
+                accum_O = accum_O + frag_O * smem_scale_O[k];
+                idx += kWarpCount * kHeadDim;
+            }
 
-            Store(&params_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
-                  cast<T>(frag_O));
+            __shared__ __align__(16) VecO reduce_O[kWarpCount][WARP_SIZE];
+
+            reduce_O[warp_id][lane_id] = accum_O;
+
+            PRAGMA_UNROLL
+            for (int mask = kWarpCount / 2; mask >= 1; mask /= 2) {
+                __syncthreads();
+                if (warp_id < mask) {
+                    reduce_O[warp_id][lane_id] = reduce_O[warp_id][lane_id] + reduce_O[warp_id + mask][lane_id];
+                }
+            }
+
+            // no need to sync, last loop ends with warp 0
+            if (warp_id == 0 && lane_id * kVecSize < kHeadDim) {
+                Store(&params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + lane_id * kVecSize],
+                      cast<T>(reduce_O[warp_id][lane_id]));
+            }
         }
     }
+
+    static __device__ int get_slice_count(int timestep)
+    {
+        return (timestep + kSliceLen - 1) / kSliceLen;
+    }
+
+    static __device__ int get_head_idx()
+    {
+        return blockIdx.x;
+    }
+
+    static __device__ int get_batch_idx()
+    {
+        return blockIdx.y;
+    }
+
+    static __device__ int get_split_k_idx()
+    {
+        return blockIdx.z;
+    }
 };
 
 extern __shared__ uint8_t dynamic_smem[];
@@ -752,16 +1026,13 @@ __global__ void decoder_multihead_attention(ParamType params)
 
     uint8_t* smem_ptr = dynamic_smem;
 
-    // Align dynamic smem ptr to 128 byte boundary, this eliminates excessive wavefronts from smem to L1
-    // but it does not improve performance
-    if constexpr (0) {
-        int misalign = (uintptr_t)smem_ptr % 128;
-        if (misalign) {
-            smem_ptr += 128 - misalign;
-        }
-    }
-
     MHAType{params, shared_storage, smem_ptr}.Run();
 }
 
+template<typename MHAType, typename ParamType = typename MHAType::ParamType>
+__global__ void decoder_multihead_attention_reduce(ParamType params)
+{
+    MHAType::Reduce2(params);
+}
+
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index 744f2fd342..6d7a634df5 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -106,44 +106,50 @@ int main(int argc, char* argv[])
 
     DecoderMultiHeadAttentionParams<half> params{};
 
-    constexpr int kHeadNum = 108;
-    // constexpr int kHeadNum     = 32;
-    constexpr int kHeadDim     = 128;
-    constexpr int kBatchSize   = 64;
-    constexpr int kContextLen  = 2047;
+    constexpr int kHeadNum    = 32;
+    constexpr int kHeadDim    = 128;
+    constexpr int KvHeadNum   = 32;
+    constexpr int kBatchSize  = 1;
+    constexpr int kContextLen = 1024;
+    // constexpr int kContextLen  = 1024;
     constexpr int kSequenceLen = kContextLen + 1;
     constexpr int kBlockSz     = 128;
     constexpr int kTestIter    = 1;
+    constexpr int kMaxSplitK   = 4;
 
     RNG rng{};
 
     thrust::universal_vector<half>  output(kBatchSize * kHeadNum * kHeadDim);
-    thrust::universal_vector<half>  qkv(kBatchSize * kHeadNum * 3 * kHeadDim);
+    thrust::universal_vector<half>  qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim);
     thrust::universal_vector<bool>  finished(kBatchSize);
-    thrust::universal_vector<half>  k_cache(kBatchSize * (kContextLen + 1) * kHeadNum * kHeadDim);
-    thrust::universal_vector<half>  v_cache(kBatchSize * (kContextLen + 1) * kHeadNum * kHeadDim);
+    thrust::universal_vector<half>  k_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
+    thrust::universal_vector<half>  v_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
     thrust::universal_vector<int>   sequence_lengths(kBatchSize);
     thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
     thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
 
+    thrust::universal_vector<float> partial_M(kBatchSize * kHeadNum * kMaxSplitK);
+    thrust::universal_vector<float> partial_L(kBatchSize * kHeadNum * kMaxSplitK);
+    thrust::universal_vector<float> partial_O(kBatchSize * kHeadNum * kMaxSplitK * kHeadDim);
+
     rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
 
     if (kContextLen) {
-        rng.GenerateNormal(k_cache.data().get(), kBatchSize * kHeadNum * kSequenceLen * kHeadDim);
-        rng.GenerateNormal(v_cache.data().get(), kBatchSize * kHeadNum * kSequenceLen * kHeadDim);
+        rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
+        rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
 
         cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim,
                           sizeof(half) * kSequenceLen * kHeadDim,
                           0,
                           sizeof(half) * kHeadDim,
-                          kBatchSize * kHeadNum);
+                          kBatchSize * KvHeadNum);
         if constexpr (0) {
             for (int b = 0; b < kBatchSize; ++b) {
-                for (int h = 0; h < kHeadNum; ++h) {
+                for (int h = 0; h < KvHeadNum; ++h) {
                     for (int s = 0; s < kSequenceLen; ++s) {
                         for (int d = 0; d < kHeadDim; ++d) {
                             std::cout << std::setw(7) << std::setprecision(4) << std::fixed
-                                      << (float)k_cache[b * kHeadNum * kSequenceLen * kHeadDim
+                                      << (float)k_cache[b * KvHeadNum * kSequenceLen * kHeadDim
                                                         + h * kSequenceLen * kHeadDim + s * kHeadDim + d]
                                       << " ";
                         }
@@ -160,19 +166,19 @@ int main(int argc, char* argv[])
                           sizeof(half) * kSequenceLen * kHeadDim,
                           0,
                           sizeof(half) * kHeadDim,
-                          kBatchSize * kHeadNum);
+                          kBatchSize * KvHeadNum);
     }
 
     thrust::universal_vector<half>  k_blocks;
     thrust::universal_vector<half*> k_ptrs;
     thrust::universal_vector<int>   cu_block_cnts;
 
-    TestBlocks(k_cache, k_blocks, k_ptrs, cu_block_cnts, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
+    TestBlocks(k_cache, k_blocks, k_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
 
     thrust::universal_vector<half>  v_blocks;
     thrust::universal_vector<half*> v_ptrs;
 
-    TestBlocks(v_cache, v_blocks, v_ptrs, cu_block_cnts, kHeadNum, kHeadDim, kBlockSz, kBatchSize);
+    TestBlocks(v_cache, v_blocks, v_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
 
     thrust::universal_vector<half>  k_cache_ref = k_cache;
     thrust::universal_vector<half>  v_cache_ref = v_cache;
@@ -198,8 +204,8 @@ int main(int argc, char* argv[])
     params.out    = output_ref.data().get();
     params.q      = qkv.data().get();
     params.k      = params.q + kHeadNum * kHeadDim;
-    params.v      = params.k + kHeadNum * kHeadDim;
-    params.stride = 3 * kHeadNum * kHeadDim;
+    params.v      = params.k + KvHeadNum * kHeadDim;
+    params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;
 
     params.batch_size    = kBatchSize;
     params.max_seq_len   = kContextLen + 1;
@@ -217,13 +223,17 @@ int main(int argc, char* argv[])
     params.layer_offset       = 0;
 
     params.num_heads     = kHeadNum;
-    params.num_kv_heads  = kHeadNum;
+    params.num_kv_heads  = KvHeadNum;
     params.size_per_head = kHeadDim;
     params.inv_sqrt_dh   = 1.f / std::sqrt((float)params.size_per_head);
 
     params.rotary_embedding_dim  = kHeadDim;
     params.rotary_embedding_base = 10000.f;
 
+    params.partial_L = partial_L.data().get();
+    params.partial_M = partial_M.data().get();
+    params.partial_O = partial_O.data().get();
+
     for (int i = 0; i < kTestIter; ++i) {
         mmha_ft_reference(params, cudaStream_t{});
     }
@@ -239,6 +249,9 @@ int main(int argc, char* argv[])
     params.per_sample_k_cache = k_cache_ptrs.data().get();
     params.per_sample_v_cache = v_cache_ptrs.data().get();
 
+    params.max_split_k = kMaxSplitK;
+    params.max_seq_len = kContextLen;
+
     std::vector<thrust::universal_vector<half>> outputs;
 
     for (int i = 0; i < std::max(kTestIter, 10); ++i) {
@@ -265,7 +278,7 @@ int main(int argc, char* argv[])
                               0,
                               kBlockSz,
                               kSequenceLen,
-                              kHeadNum,
+                              KvHeadNum,
                               kHeadDim,
                               kBatchSize,
                               0);
@@ -276,7 +289,7 @@ int main(int argc, char* argv[])
                               0,
                               kBlockSz,
                               kSequenceLen,
-                              kHeadNum,
+                              KvHeadNum,
                               kHeadDim,
                               kBatchSize,
                               0);
@@ -293,7 +306,7 @@ int main(int argc, char* argv[])
 
     std::cout << "---------------------------------------------------\n";
 
-    Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadDim, kHeadNum, 0);
+    Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadDim, kHeadNum, false);
 
     // [H, S, D]
 
@@ -301,13 +314,13 @@ int main(int argc, char* argv[])
             k_cache_ref.data().get() + kContextLen * kHeadDim,
             kSequenceLen * kHeadDim,
             kHeadDim,
-            kHeadNum);
+            KvHeadNum);
 
     Compare(v_cache.data().get() + kContextLen * kHeadDim,
             v_cache_ref.data().get() + kContextLen * kHeadDim,
             kSequenceLen * kHeadDim,
             kHeadDim,
-            kHeadNum);
+            KvHeadNum);
 
     return 0;
 }
\ No newline at end of file
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 33d2ce998e..5acafabef5 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -18,6 +18,7 @@
 #include <iomanip>
 #include <math.h>
 #include <mutex>
+#include <numeric>
 #include <sstream>
 #include <unordered_map>
 
@@ -740,6 +741,11 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
     invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
     sync_check_cuda_error();
 
+    // used for dispatching split-k decoding kernels
+    const int sum_seq_len =
+        std::accumulate(state_->h_context_length, state_->h_context_length + batch_size, -batch_size);
+    const int max_seq_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size) - 1;
+
     // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted
     // for
     for (int i = 0; i < batch_size; ++i) {
@@ -786,7 +792,7 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
                         (int)state_->h_finished[i]);
         }
     }
-    return GenerationState{max_context_len, start_step};
+    return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
 }
 
 template<typename T>
@@ -823,6 +829,8 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
                            cu_block_counts_,
                            g.step,
                            0,
+                           g.sum_seq_len,
+                           g.max_seq_len,
                            session_len_,
                            batch_size);
 
@@ -872,8 +880,10 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
     }
 
     ////////////////////////////////////////////////
-    /// ! increase the step counter
-    ++g.step;
+    /// ! increase the counters
+    g.step += 1;
+    g.max_seq_len += 1;
+    g.sum_seq_len += batch_size;
 
     return !should_stop;
 }
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index f0d3b4d303..4c8f8154be 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -60,6 +60,8 @@ class LlamaBatch {
     struct GenerationState {
         int max_init_ctx_len;
         int step;
+        int sum_seq_len;
+        int max_seq_len;
     };
 
     void            InitializeSampling();
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index bb76bd205e..6d633cd829 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -189,12 +189,9 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        ou
 
     allocateBuffer(sess.batch_size);
 
-    sess.ite     = input_tensors->at("ite").getVal<const int>();
     sess.k_cache = &output_tensors->at("key_cache");
     sess.v_cache = &output_tensors->at("value_cache");
 
-    sess.max_memory_len = input_tensors->at("max_seq_len").getVal<int>();
-
     T* decoder_input  = input_tensors->at("decoder_input").getPtr<T>();
     T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
 
diff --git a/src/turbomind/models/llama/LlamaDecoder.h b/src/turbomind/models/llama/LlamaDecoder.h
index 091c2ba55a..e12214617b 100644
--- a/src/turbomind/models/llama/LlamaDecoder.h
+++ b/src/turbomind/models/llama/LlamaDecoder.h
@@ -53,8 +53,6 @@ class LlamaDecoder: public BaseLayer {
 
     struct Session {
         size_t                                          batch_size;
-        int                                             ite;
-        size_t                                          max_memory_len;
         Tensor*                                         k_cache;
         Tensor*                                         v_cache;
         const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index fe7fdab7ad..ba31d90a0d 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -44,6 +44,9 @@ void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
     context_buf_ =
         reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
 
+    workspace_ = (float*)allocator_->reMalloc(
+        workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2));
+
     is_allocate_buffer_ = true;
 }
 
@@ -84,6 +87,9 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     const int*  sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
     const bool* finished_data         = input_tensors->getPtr<bool>("finished");
 
+    const int sum_seq_len = input_tensors->getVal<int>("sum_seq_len");
+    const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
+
     T*  hidden_features_data = output_tensors->getPtr<T>("attention_output");
     T** key_cache_ptrs       = output_tensors->getPtr<T*>("key_cache");
     T** value_cache_ptrs     = output_tensors->getPtr<T*>("value_cache");
@@ -116,7 +122,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
 
     params.batch_size    = batch_size;
-    params.cu_block_cnts = cu_block_counts;  /// TODO
+    params.cu_block_cnts = cu_block_counts;
 
     params.k_cache_block_ptrs  = (void**)key_cache_ptrs;
     params.v_cache_block_ptrs  = (void**)value_cache_ptrs;
@@ -135,6 +141,26 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.rotary_embedding_dim  = size_per_head_;
     params.rotary_embedding_base = 10000.f;
 
+    params.partial_O = workspace_;
+    params.partial_M = params.partial_O + batch_size * local_head_num_ * kMaxSplitK * size_per_head_;
+    params.partial_L = params.partial_M + batch_size * local_head_num_ * kMaxSplitK;
+
+    // avg_batch_size = sum_seq_len / max_seq_len
+    // max_split_k    = kMaxSplitK  / avg_batch_size
+    // max_split_k'   = min(max_split_k, max_seq_lens / kSliceLen)
+
+    const float avg_batch_size = max_seq_len ? (float)sum_seq_len / max_seq_len : 1;
+    FT_CHECK(avg_batch_size >= 1.f);
+    
+    const int max_split_k = std::max(1, (int)std::ceil(kMaxSplitK / avg_batch_size));
+
+    // if (layer_id == 0) {
+    //     TM_LOG_INFO("avg_batch_size = %.1f, max_split_k = %d", avg_batch_size, max_split_k);
+    // }
+
+    params.max_split_k = max_split_k;
+    params.max_seq_len = max_seq_len;
+
     params.stream = stream_;
 
     params.quant_policy = quant_policy_;
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
index 73c9674d23..9f80edc462 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
@@ -92,6 +92,9 @@ class LlamaDecoderSelfAttentionLayer {
     T* qkv_buf_     = nullptr;
     T* context_buf_ = nullptr;
 
+    static constexpr int kMaxSplitK = 16;  // must be <= WARP_SIZE
+    float*               workspace_ = nullptr;
+
     bool is_allocate_buffer_{};
 };
 
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index bb0118ca44..c5ba88e4cf 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -286,13 +286,14 @@ void LlamaV2<T>::decoderForward(T*          decoder_output,
                                 const int*  cu_block_counts,
                                 int         step,
                                 int         ite,
+                                int         sum_seq_len,
+                                int         max_seq_len,
                                 size_t      session_len,
                                 size_t      batch_size)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
-    const int  max_seq_len = session_len;
-    const auto dtype       = getTensorType<T>();
+    const auto dtype = getTensorType<T>();
 
     // max_input_length is not used w/o linear_bias_slopes
     // sequence_lengths_ will be incremented in dynamic decode
@@ -300,6 +301,7 @@ void LlamaV2<T>::decoderForward(T*          decoder_output,
         {"decoder_input", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_input}},
         {"sequence_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
         {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}},
+        {"sum_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &sum_seq_len}},
         {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
         {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
         {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index 229674f599..c7463da652 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -135,6 +135,8 @@ class LlamaV2 {
                         const int*  cu_block_counts,
                         int         step,
                         int         ite,
+                        int         sum_seq_len,
+                        int         max_seq_len,
                         size_t      session_len,
                         size_t      batch_size);
 

From 7a7e7018338550489177fbb03ad37adfe02f75ab Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 19 Oct 2023 10:12:25 +0000
Subject: [PATCH 20/56] minor

---
 src/turbomind/models/llama/LlamaBatch.cc | 3 ++-
 src/turbomind/models/llama/LlamaV2.cc    | 1 -
 src/turbomind/models/llama/LlamaV2.h     | 1 -
 3 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 5acafabef5..dfc12b8a7d 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -304,6 +304,8 @@ bool LlamaBatch<T>::Initialize()
             return sequences[idx]->status == Sequence::kActive;  // present status
         });
 
+        FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
+
         // move swap-ins to the back
         auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) {
             return status[idx] == Sequence::kActive;  // past status
@@ -831,7 +833,6 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
                            0,
                            g.sum_seq_len,
                            g.max_seq_len,
-                           session_len_,
                            batch_size);
 
     model_->postDecodeEmbedding(logits_buf_,  //
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index c5ba88e4cf..ccaf846a07 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -288,7 +288,6 @@ void LlamaV2<T>::decoderForward(T*          decoder_output,
                                 int         ite,
                                 int         sum_seq_len,
                                 int         max_seq_len,
-                                size_t      session_len,
                                 size_t      batch_size)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index c7463da652..ee659b7373 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -137,7 +137,6 @@ class LlamaV2 {
                         int         ite,
                         int         sum_seq_len,
                         int         max_seq_len,
-                        size_t      session_len,
                         size_t      batch_size);
 
     void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);

From 48761d7b804df7cbece3c2e2e06c1f5e9774667d Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 19 Oct 2023 10:24:06 +0000
Subject: [PATCH 21/56] truncate `session_len` by available blocks

---
 src/turbomind/models/llama/LlamaV2.cc | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index ccaf846a07..070fd3fa65 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -33,6 +33,7 @@
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
+#include "src/turbomind/utils/logger.h"
 #include <functional>
 #include <memory>
 #include <sstream>
@@ -121,7 +122,18 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
                                                               elem_bits,
                                                               tensor_para_.rank_,
                                                               allocator);
-    batch_                = std::make_unique<LlamaBatch<T>>(
+
+    const size_t max_session_len = sequence_manager->max_block_count() * cache_block_seq_len;
+    if (max_session_len < session_len) {
+        if (tensor_para.rank_ == 0) {
+            TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
+                           session_len,
+                           max_session_len);
+        }
+        session_len = max_session_len;
+    }
+
+    batch_ = std::make_unique<LlamaBatch<T>>(
         max_batch_size, max_context_token_num, session_len, std::move(sequence_manager), this);
 
     initialize(attn_params, kv_head_num, use_context_fmha, quant_policy);

From f9410a9be2070f118527487f8abd0b2c22f65b01 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 19 Oct 2023 12:58:55 +0000
Subject: [PATCH 22/56] minor

---
 .../decoder_multihead_attention_template.h    | 141 ++----------------
 .../decoder_multihead_attention/iterator.h    |  29 ----
 .../decoder_multihead_attention/kv_cache.cu   |   2 +-
 src/turbomind/models/llama/BlockManager.cc    |   2 +-
 src/turbomind/models/llama/LlamaBatch.cc      |   2 +-
 .../llama/LlamaContextAttentionLayer.cc       |   2 +-
 .../models/llama/LlamaContextDecoder.cc       |   2 +-
 src/turbomind/models/llama/SequenceManager.cc |   2 +-
 .../models/llama/test_cache_manager.cc        |   2 +-
 src/turbomind/utils/debug_utils.h             |   7 +
 10 files changed, 27 insertions(+), 164 deletions(-)
 create mode 100644 src/turbomind/utils/debug_utils.h

diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index e129f260c6..09088843cc 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -804,87 +804,8 @@ struct DecoderMultiHeadAttentionKernel {
 
     static __device__ void Reduce(const ParamType& params)
     {
-        static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
-
-        // using VecQ      = Array<T, kVecQSize>;
-        using VecQFloat = Array<float, kVecQSize>;
-        using MapQ      = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
-
-        static constexpr int kQkvHeadPerThread = MapQ::kIterS;
-
-        const int batch_idx = get_batch_idx();
-        const int head_idx  = get_head_idx() * kHeadPerCta;
-        const int warp_id_  = threadIdx.x / WARP_SIZE;
-        const int lane_id_  = threadIdx.x % WARP_SIZE;
-        const int split_k   = (params.per_sample_length[batch_idx] + kSliceLen - 1) / kSliceLen;
-
-        int2 offset = MapQ::get_offset(warp_id_, lane_id_);
-
-        if (offset.x >= kMaxHeadDim || offset.y >= kHeadPerCta) {
-            return;
-        }
-
-        auto get_index = [&](int k, int qi) {
-            return batch_idx * params.num_heads * split_k + (head_idx + qi) * split_k + k;
-        };
-
-        Array<float, kQkvHeadPerThread> global_M;
-        PRAGMA_UNROLL
-        for (int s = 0; s < kQkvHeadPerThread; ++s) {
-            const int qi = offset.y + s;
-            global_M[s]  = params.partial_M[get_index(0, qi)];
-        }
-        for (int k = 1; k < split_k; ++k) {
-            PRAGMA_UNROLL
-            for (int s = 0; s < kQkvHeadPerThread; ++s) {
-                const int qi = offset.y + s;
-                global_M[s]  = max(global_M[qi], params.partial_M[get_index(k, qi)]);
-            }
-        }
-
-        Array<float, kQkvHeadPerThread> global_L{};
-        for (int k = 0; k < split_k; ++k) {
-            PRAGMA_UNROLL
-            for (int s = 0; s < kQkvHeadPerThread; ++s) {
-                const int qi = offset.y + s;
-                global_L[s] +=
-                    params.partial_L[get_index(k, qi)] * expf(params.partial_M[get_index(k, qi)] - global_M[s]);
-            }
-        }
-
-        VecQFloat frag_O[kQkvHeadPerThread]{};
-        for (int k = 0; k < split_k; ++k) {
-            PRAGMA_UNROLL
-            for (int s = 0; s < kQkvHeadPerThread; ++s) {
-                const int di = offset.x;
-                const int qi = offset.y + s;
-
-                float scale = expf(params.partial_M[get_index(k, qi)] - global_M[s]) / (global_L[s] + 1e-8f);
-
-                VecQFloat partial_O;
-                Ldg(partial_O, &params.partial_O[get_index(k, qi) * kHeadDim + di]);
-
-                using namespace ops;
-                frag_O[s] = frag_O[s] + partial_O * scale;
-            }
-        }
-
-        PRAGMA_UNROLL
-        for (int s = 0; s < kQkvHeadPerThread; ++s) {
-            const int di = offset.x;
-            const int qi = offset.y + s;
-            Store(&params.out[batch_idx * params.num_heads * kHeadDim + (head_idx + qi) * kHeadDim + di],
-                  cast<T>(frag_O[s]));
-        }
-    }
-
-    static __device__ void Reduce2(const ParamType& params)
-    {
-        const int batch_idx = get_batch_idx();
-        const int head_idx  = get_head_idx();
-        const int lane_id   = threadIdx.x % WARP_SIZE;
-        const int warp_id   = threadIdx.x / WARP_SIZE;
-
+        const int batch_idx       = get_batch_idx();
+        const int head_idx        = get_head_idx();
         const int timestep        = params.per_sample_length[batch_idx];
         const int max_split_k     = params.max_split_k;
         const int slice_count     = get_slice_count(timestep);
@@ -944,55 +865,19 @@ struct DecoderMultiHeadAttentionKernel {
 
         __syncthreads();
 
-        if constexpr (1) {
-            int   idx = (batch_idx * params.num_heads * max_split_k + head_idx * max_split_k) * kHeadDim + threadIdx.x;
-            float accum_O{};
-            const bool is_valid = threadIdx.x < kHeadDim;
-            for (int k = 0; k < split_k; ++k) {
-                if (is_valid) {
-                    accum_O += smem_scale_O[k] * params.partial_O[idx];
-                }
-                idx += kHeadDim;
-            }
-            if (is_valid) {
-                params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + threadIdx.x] = (T)accum_O;
-            }
-        }
-        else {  // vectorized version, not benificial
-            using namespace ops;
-            constexpr int kVecSize = 4;
-
-            using VecO = Array<float, kVecSize>;
-            VecO accum_O{};
-
-            int idx =
-                batch_idx * params.num_heads * max_split_k + head_idx * max_split_k + warp_id;  // offset by warp_id
-            idx = idx * kHeadDim + lane_id * kVecSize;                                          // offset by lane_id
-
-            for (int k = warp_id; k < split_k; k += kWarpCount) {
-                VecO frag_O;
-                Ldg(frag_O, &params.partial_O[idx]);
-                accum_O = accum_O + frag_O * smem_scale_O[k];
-                idx += kWarpCount * kHeadDim;
-            }
-
-            __shared__ __align__(16) VecO reduce_O[kWarpCount][WARP_SIZE];
-
-            reduce_O[warp_id][lane_id] = accum_O;
+        int   idx = (batch_idx * params.num_heads * max_split_k + head_idx * max_split_k) * kHeadDim + threadIdx.x;
+        float accum_O{};
 
-            PRAGMA_UNROLL
-            for (int mask = kWarpCount / 2; mask >= 1; mask /= 2) {
-                __syncthreads();
-                if (warp_id < mask) {
-                    reduce_O[warp_id][lane_id] = reduce_O[warp_id][lane_id] + reduce_O[warp_id + mask][lane_id];
-                }
-            }
+        const bool is_valid = threadIdx.x < kHeadDim;
 
-            // no need to sync, last loop ends with warp 0
-            if (warp_id == 0 && lane_id * kVecSize < kHeadDim) {
-                Store(&params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + lane_id * kVecSize],
-                      cast<T>(reduce_O[warp_id][lane_id]));
+        for (int k = 0; k < split_k; ++k) {
+            if (is_valid) {
+                accum_O += smem_scale_O[k] * params.partial_O[idx];
             }
+            idx += kHeadDim;
+        }
+        if (is_valid) {
+            params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + threadIdx.x] = (T)accum_O;
         }
     }
 
@@ -1032,7 +917,7 @@ __global__ void decoder_multihead_attention(ParamType params)
 template<typename MHAType, typename ParamType = typename MHAType::ParamType>
 __global__ void decoder_multihead_attention_reduce(ParamType params)
 {
-    MHAType::Reduce2(params);
+    MHAType::Reduce(params);
 }
 
 }  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
index 0e5158283f..0d1cc433ab 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -274,35 +274,6 @@ struct Iterator {
 
     __device__ void Prefetch(bool mask)
     {
-        // if (blockIdx.x == 0 && threadIdx.x == 0 && mask) {
-        //     int  c    = src_offset_ % ThreadMap::kC;
-        //     int  s    = src_offset_ / ThreadMap::kC;
-        //     bool fuck = src_offset_ >= 128 * 4096;
-        //     printf("%d %d %d %d %s\n", (int)threadIdx.x, c, s, offset_s_, fuck ? "FUCK" : "");
-        // }
-
-        // if (blockIdx.x == 0 && threadIdx.x == 0) {
-        //     int  c    = dst_offset_ % ThreadMap::kC;
-        //     int  s    = dst_offset_ / ThreadMap::kC;
-        //     bool fuck = (dst_offset_ >= Stages * kSizePerTile);
-        //     printf("%d %d %d %s\n", c, s, dst_offset_, fuck ? "FUCK" : "");
-        // }
-
-        // if (init_offset_ / ThreadMap::kC == 0) {
-        //     int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
-        //     int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
-        //     int c = dst_offset_ % ThreadMap::kC;
-        //     printf("tid=%d, k=%d, s=%d, c=%d, offset_s=%d, valid_s=%d, init_s=%d, mask=%d\n",
-        //            threadIdx.x,
-        //            k,
-        //            s,
-        //            c,
-        //            offset_s_,
-        //            (int)is_valid_s_,
-        //            init_offset_ / ThreadMap::kC,
-        //            (int)mask);
-        // }
-
         CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
         // Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
     }
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
index 64fcb26ce4..23382dc2c4 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
@@ -2,7 +2,7 @@
 // #include "cute/tensor.hpp"
 #include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
 #include "src/turbomind/models/llama/llama_utils.h"
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 #include <cuda_fp16.h>
 #include <type_traits>
 
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index 37384ea61e..2c73f68ee1 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -1,6 +1,6 @@
 #include "src/turbomind/models/llama/BlockManager.h"
 #include "src/turbomind/utils/cuda_utils.h"
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 #include "src/turbomind/utils/logger.h"
 #include <algorithm>
 #include <iterator>
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index dfc12b8a7d..3f8593421d 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -11,7 +11,7 @@
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 #include "src/turbomind/utils/logger.h"
 #include <algorithm>
 #include <cstdint>
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 65b2be0aac..1a62e2fb77 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -29,7 +29,7 @@
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 #include "src/turbomind/utils/logger.h"
 
 namespace turbomind {
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index 20c4437f33..fb1e2e8c79 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -27,7 +27,7 @@
 #include "src/turbomind/models/llama/llama_kernels.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/Tensor.h"
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 
 namespace turbomind {
 
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 8c9375f76c..ae0ce626ac 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -1,6 +1,6 @@
 #include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/utils/allocator.h"
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 #include "src/turbomind/utils/logger.h"
 #include <ctime>
 #include <stdexcept>
diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc
index 184f1f2bf8..ffcf5b6b92 100644
--- a/src/turbomind/models/llama/test_cache_manager.cc
+++ b/src/turbomind/models/llama/test_cache_manager.cc
@@ -3,7 +3,7 @@
 
 #include "src/turbomind/utils/allocator.h"
 
-#include "src/turbomind/utils/dbg.h"
+#include "src/turbomind/utils/debug_utils.h"
 #include <catch2/catch_test_macros.hpp>
 #include <iterator>
 
diff --git a/src/turbomind/utils/debug_utils.h b/src/turbomind/utils/debug_utils.h
new file mode 100644
index 0000000000..0e577d5a78
--- /dev/null
+++ b/src/turbomind/utils/debug_utils.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#if __has_include("3rdparty/dbg.h")
+#include "3rdparty/dbg.h"
+#else
+#define dbg(...)
+#endif
\ No newline at end of file

From 96b7f4b1642d58c02c084aae9f44eaeb350b5692 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 19 Oct 2023 13:07:51 +0000
Subject: [PATCH 23/56] license

---
 .../kernels/decoder_multihead_attention/array_ops.h      | 2 ++
 .../decoder_multihead_attention.cu                       | 2 ++
 .../decoder_multihead_attention.h                        | 3 +++
 .../decoder_multihead_attention_params.h                 | 2 ++
 .../decoder_multihead_attention_template.h               | 2 ++
 .../kernels/decoder_multihead_attention/iterator.h       | 2 ++
 .../kernels/decoder_multihead_attention/kv_cache.cu      | 3 ++-
 .../kernels/decoder_multihead_attention/kv_cache.h       | 2 ++
 .../test_decoder_multihead_attention.cu                  | 2 +-
 .../kernels/decoder_multihead_attention/thread_map.h     | 3 +++
 src/turbomind/models/llama/BlockManager.cc               | 2 ++
 src/turbomind/models/llama/CMakeLists.txt                | 9 +++++----
 src/turbomind/models/llama/SequenceManager.cc            | 2 ++
 src/turbomind/models/llama/SequenceManager.h             | 2 ++
 src/turbomind/models/llama/test_cache_manager.cc         | 2 ++
 15 files changed, 34 insertions(+), 6 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index 37946d39d9..eff27b2fda 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
 
 #include "src/turbomind/kernels/gemm_s_f16/common.h"
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index abf24ea8f7..b83de12bd3 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #include "decoder_multihead_attention_template.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
index f4eca0617c..984dde4fe2 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
@@ -1,3 +1,6 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
 
 #include "decoder_multihead_attention_params.h"
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index 1e8c556c41..993001ee89 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
 #include <cuda_runtime.h>
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index 09088843cc..dfeb86e568 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
 
 #include "array_ops.h"
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
index 0d1cc433ab..006d1e5cc6 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
 
 #include "../gemm_s_f16/common.h"
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
index 23382dc2c4..d9a46c40a7 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.cu
@@ -1,5 +1,6 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #include "../gemm_s_f16/common.h"
-// #include "cute/tensor.hpp"
 #include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
 #include "src/turbomind/models/llama/llama_utils.h"
 #include "src/turbomind/utils/debug_utils.h"
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
index d84c991ac3..7ca12db3b5 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
 
 #include <cuda_runtime.h>
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index 6d7a634df5..912bff1ae4 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -1,4 +1,4 @@
-
+// Copyright (c) OpenMMLab. All rights reserved.
 
 #include "decoder_multihead_attention.h"
 #include "kv_cache.h"
diff --git a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
index 0968681c77..47b2636f6d 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
@@ -1,4 +1,7 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
+
 #include "../gemm_s_f16/common.h"
 #include "src/turbomind/kernels/custom_ar_kernels.h"
 
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index 2c73f68ee1..7baec4f378 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #include "src/turbomind/models/llama/BlockManager.h"
 #include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/debug_utils.h"
diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt
index 3519cde42d..0b083ad33e 100644
--- a/src/turbomind/models/llama/CMakeLists.txt
+++ b/src/turbomind/models/llama/CMakeLists.txt
@@ -54,8 +54,9 @@ target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils
 
 install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin)
 
-find_package(Catch2 3 REQUIRED)
-
-add_executable(test_cache_manager test_cache_manager.cc)
-target_link_libraries(test_cache_manager PRIVATE Llama Catch2::Catch2WithMain)
+find_package(Catch2 3 QUIET)
+if (Catch2_FOUND)
+        add_executable(test_cache_manager test_cache_manager.cc)
+        target_link_libraries(test_cache_manager PRIVATE Llama Catch2::Catch2WithMain)
+endif ()
 
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index ae0ce626ac..7e82db822c 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #include "src/turbomind/models/llama/SequenceManager.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/debug_utils.h"
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 190b29feaf..19278a0fbd 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #pragma once
 
 #include "src/turbomind/models/llama/BlockManager.h"
diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc
index ffcf5b6b92..75d9f039dc 100644
--- a/src/turbomind/models/llama/test_cache_manager.cc
+++ b/src/turbomind/models/llama/test_cache_manager.cc
@@ -1,3 +1,5 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
 #include "BlockManager.h"
 #include "SequenceManager.h"
 

From f8020e3b45c8d6f458c3951d4b873a25d380c32f Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 19 Oct 2023 14:15:15 +0000
Subject: [PATCH 24/56] fix

---
 .gitignore                                       |  2 ++
 .../decoder_multihead_attention/CMakeLists.txt   | 16 ++++++++++++++++
 .../decoder_multihead_attention/array_ops.h      |  2 +-
 3 files changed, 19 insertions(+), 1 deletion(-)
 create mode 100644 src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt

diff --git a/.gitignore b/.gitignore
index ccfad036dc..1ed335fe88 100644
--- a/.gitignore
+++ b/.gitignore
@@ -72,3 +72,5 @@ work_dir*/
 *.out
 *.csv
 *.pkl
+
+!CMakeLists.txt
\ No newline at end of file
diff --git a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
new file mode 100644
index 0000000000..7176017671
--- /dev/null
+++ b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu)
+target_compile_options(decoder_multihead_attention PRIVATE
+  --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
+set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
+set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
+target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
+
+add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
+target_compile_options(test_decoder_multihead_attention PRIVATE
+  --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
+target_link_libraries(test_decoder_multihead_attention PRIVATE 
+    decoder_multihead_attention 
+    decoder_masked_multihead_attention
+    cublas)
diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index eff27b2fda..a847ada855 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -22,7 +22,7 @@ template<typename T>
 struct minus {
     __device__ T operator()(T a, T b)
     {
-        return a + b;
+        return a - b;
     }
 };
 

From 90f5b8fef7009a2cccf2993216da5c3d32db3fb1 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 05:07:41 +0000
Subject: [PATCH 25/56] dispatch `cp.async`

---
 .../decoder_multihead_attention.cu            | 52 +++++++++++--------
 .../decoder_multihead_attention_params.h      |  1 +
 .../decoder_multihead_attention/iterator.h    | 23 ++++++--
 .../test_decoder_multihead_attention.cu       |  2 +
 .../llama/LlamaDecoderSelfAttentionLayer.cc   |  3 +-
 .../llama/LlamaDecoderSelfAttentionLayer.h    |  3 ++
 6 files changed, 57 insertions(+), 27 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index b83de12bd3..709db6ebc0 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -2,6 +2,7 @@
 
 #include "decoder_multihead_attention_template.h"
 #include "src/turbomind/models/llama/llama_utils.h"
+#include "src/turbomind/utils/cuda_utils.h"
 
 #include <iostream>
 
@@ -34,39 +35,46 @@ bool Print(size_t dynamic_smem_size)
 template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
 void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
 {
-    // cpasync_2048_32x6 ~ 64k smem
-    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;
+    auto invoke = [&](auto* type) {
+        using Attn = std::remove_reference_t<decltype(*type)>;
 
-    using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
+        static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();
 
-    // ld_kv16_2048_32x3 ~ 34k smem
-    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>;
+        [[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);
 
-    // ld_kv8_2048_64x3 ~ 34k smem
-    // using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HedDim, 64, HeadDim, 2048, 3>;
+        const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
+        const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
 
-    static const size_t kDynSmemSize = MHAType::GetDynamicSmemSize();
+        dim3 block(Attn::kWarpCount * WARP_SIZE);
+        dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
 
-    [[maybe_unused]] static const bool _ = Print<MHAType>(kDynSmemSize);
+        // if (params.layer_offset == 0) {
+        //     std::cout << "max_split_k' = " << max_split_k << ", arch = " << params.arch << "\n";
+        // }
 
-    const int slice_count = (params.max_seq_len + MHAType::kSliceLen - 1) / MHAType::kSliceLen;
-    const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
+        cudaFuncSetAttribute(
+            decoder_multihead_attention<Attn>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
 
-    dim3 block(MHAType::kWarpCount * WARP_SIZE);
-    dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
+        decoder_multihead_attention<Attn><<<grid, block, kDynSmemSize, params.stream>>>(params);
 
-    // if (params.layer_offset == 0) {
-    //     std::cout << "max_split_k' = " << max_split_k << "\n";
-    // }
+        if (max_split_k > 1) {
+            dim3 grid(params.num_heads, params.batch_size);
+            decoder_multihead_attention_reduce<Attn><<<grid, block, 0, params.stream>>>(params);
+        }
+    };
 
-    cudaFuncSetAttribute(
-        decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
+    if (params.arch >= 80) {
+        // DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;  // 64k
 
-    decoder_multihead_attention<MHAType><<<grid, block, kDynSmemSize, params.stream>>>(params);
+        using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
+        invoke((Type*)0);
+    }
+    else {
+        // DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>; // 34k
+        // DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 2048, 3>;  // 34k
 
-    if (max_split_k > 1) {
-        dim3 grid(params.num_heads, params.batch_size);
-        decoder_multihead_attention_reduce<MHAType><<<grid, block, 0, params.stream>>>(params);
+        using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 1024, 3, true>;
+        invoke((Type*)0);
     }
 }
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index 993001ee89..5f18b45216 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -63,6 +63,7 @@ struct DecoderMultiHeadAttentionParams {
     float* partial_M;
     float* partial_L;
 
+    int          arch;
     cudaStream_t stream;
 };
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
index 006d1e5cc6..5e0ba7f885 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -7,6 +7,12 @@
 
 namespace turbomind {
 
+#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
+#define L2_CACHEHINT(size) ".L2::" #size "B"
+#else
+#define L2_CACHEHINT(size)
+#endif
+
 struct BlockIterator {
     const void** ptrs_;
     const void*  prefetch_;
@@ -256,15 +262,20 @@ struct Iterator {
     {
         const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);
         constexpr int cp_size      = sizeof(AccessType);
-        // static_assert(cp_size == 16);
+#if TURBOMIND_ARCH_SM80
+        // clang-format off
         asm volatile("{\n"
                      "  .reg .pred p;\n"
                      "  setp.ne.b32 p, %0, 0;\n"
-                     "  @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
+                     "  @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
                      "}\n" ::"r"((int)mask),
                      "r"(smem_int_ptr),
                      "l"(src),
                      "n"(cp_size));
+        // clang-format on
+#else
+        assert(TURBOMIND_ARCH_SM80);
+#endif
     }
 
     static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
@@ -276,8 +287,12 @@ struct Iterator {
 
     __device__ void Prefetch(bool mask)
     {
-        CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
-        // Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
+        if constexpr (TURBOMIND_ARCH_SM80) {
+            CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
+        }
+        else {
+            Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
+        }
     }
 
     __device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index 912bff1ae4..b5249f31c2 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -252,6 +252,8 @@ int main(int argc, char* argv[])
     params.max_split_k = kMaxSplitK;
     params.max_seq_len = kContextLen;
 
+    params.arch = 80;
+
     std::vector<thrust::universal_vector<half>> outputs;
 
     for (int i = 0; i < std::max(kTestIter, 10); ++i) {
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index ba31d90a0d..78ced5dff8 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -151,7 +151,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     const float avg_batch_size = max_seq_len ? (float)sum_seq_len / max_seq_len : 1;
     FT_CHECK(avg_batch_size >= 1.f);
-    
+
     const int max_split_k = std::max(1, (int)std::ceil(kMaxSplitK / avg_batch_size));
 
     // if (layer_id == 0) {
@@ -161,6 +161,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.max_split_k = max_split_k;
     params.max_seq_len = max_seq_len;
 
+    params.arch   = arch_;
     params.stream = stream_;
 
     params.quant_policy = quant_policy_;
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
index 9f80edc462..95556cd30b 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
@@ -24,6 +24,7 @@
 #include "src/turbomind/models/llama/LlamaLinear.h"
 #include "src/turbomind/models/llama/llama_params.h"
 #include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/nccl_utils.h"
 
 namespace turbomind {
@@ -60,6 +61,7 @@ class LlamaDecoderSelfAttentionLayer {
         is_free_buffer_after_forward_(is_free_buffer_after_forward),
         quant_policy_(quant_policy)
     {
+        arch_ = getSMVersion();
     }
 
     ~LlamaDecoderSelfAttentionLayer()
@@ -96,6 +98,7 @@ class LlamaDecoderSelfAttentionLayer {
     float*               workspace_ = nullptr;
 
     bool is_allocate_buffer_{};
+    int  arch_{};
 };
 
 }  // namespace turbomind

From 0fe3ab9dbb624ee44ed7fdb3998a7cd2fc4970cc Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 06:35:45 +0000
Subject: [PATCH 26/56] fix linking

---
 CMakeLists.txt | 1 +
 1 file changed, 1 insertion(+)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 80c767d1a2..f3f1c7b171 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -343,6 +343,7 @@ add_library(transformer-shared SHARED
   $<TARGET_OBJECTS:cuda_utils>
   $<TARGET_OBJECTS:custom_ar_comm>
   $<TARGET_OBJECTS:custom_ar_kernels>
+  $<TARGET_OBJECTS:decoder_multihead_attention>
   $<TARGET_OBJECTS:decoder_masked_multihead_attention>
   $<TARGET_OBJECTS:decoding_kernels>
   $<TARGET_OBJECTS:gpt_kernels>

From 333ce08477a298308a5bee84697e995ae81d4f75 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 09:07:35 +0000
Subject: [PATCH 27/56] fix

---
 src/turbomind/models/llama/LlamaBatch.cc      | 12 +++++++----
 src/turbomind/models/llama/SequenceManager.cc | 20 +++++++++++++------
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 3f8593421d..35eed403fe 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -265,8 +265,10 @@ bool LlamaBatch<T>::Initialize()
         }
     }
 
+    dbg(holes, active_holes);
+
     auto process = [&](BatchState* state) {
-        // dbg(state->size);
+        dbg(state->size);
         for (int i = 0; i < state->size; ++i) {
             if (auto& r = state->requests[i]) {
                 sequences.push_back(state->sequences[i]);
@@ -289,9 +291,11 @@ bool LlamaBatch<T>::Initialize()
     // dbg(step_length_);
 
     auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);
-    if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
-        dbg(outcome);
-    }
+
+    dbg(outcome);
+    // if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
+    //     dbg(outcome);
+    // }
 
     bool exchange = outcome.swap_in + outcome.swap_out > 0;
 
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 7e82db822c..daff172a15 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -261,6 +261,7 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
                                   const std::vector<uint64_t>&        priorities,
                                   int                                 step_length) -> Outcome
 {
+    dbg(__PRETTY_FUNCTION__);
     ////////////////////////////////////////////////////////////////////////////////
     /// Schedule the assignment of blocks to sequences
     auto    seqs = const_cast<Sequence* const*>(sequences.data());
@@ -285,20 +286,21 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
 
     // count required blocks based on block validity
     std::vector<int> required(sequences.size());
-    int              total_required{};
+    // int              total_required{};
     for (int i = 0; i < sequences.size(); ++i) {
         int seq_len = context_lengths[i] + step_length;
         int count   = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast<int>(seqs[i]->blocks.size());
         required[i] = std::max(0, count);
-        total_required += required[i];
+        // total_required += required[i];
     }
 
     // dbg(required);
 
     // no new blocks required, exit early
-    if (total_required == 0) {
-        return outcome;
-    }
+    // if (total_required == 0) {
+    //     dbg("early exit");
+    //     return outcome;
+    // }
 
     /// TODO: more early exit heuristics
 
@@ -396,7 +398,13 @@ auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
     auto first  = blocks.begin();
 
     for (const auto& idx : schedule.active) {
-        auto& sequence  = *seqs[idx];
+        auto& sequence = *seqs[idx];
+
+        // retain blocks for swap-in sequences
+        if (sequence.status == Sequence::kCached) {
+            block_manager_->Retain(sequence.blocks);
+        }
+
         sequence.status = Sequence::kActive;
 
         auto last = first + required[idx];

From abaca3ec576d6be0d96e0874c74167daec0c2cc3 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 10:21:59 +0000
Subject: [PATCH 28/56] fix deadlock

---
 src/turbomind/models/llama/LlamaV2.h | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index ee659b7373..198ad4b09c 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -202,10 +202,9 @@ class LlamaV2 {
     DynamicDecodeLayer<float>* dynamic_decode_layer_{};
 
     const int                      step_length_;
-    std::unique_ptr<LlamaBatch<T>> batch_;
     std::shared_ptr<SharedState>   shared_state_;
-
-    ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
+    ffi_api_lock_ctrl_t            ffi_lock_;
+    std::unique_ptr<LlamaBatch<T>> batch_;
 };
 
 }  // namespace turbomind

From 290e0871db56e3bc4c6be625ff636d5c44d5b84a Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 11:11:13 +0000
Subject: [PATCH 29/56] guard input length

---
 src/turbomind/models/llama/LlamaBatch.cc | 12 +++++++++++-
 src/turbomind/models/llama/Request.h     |  3 ++-
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 35eed403fe..e6e1994b77 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -47,15 +47,25 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
             if (r) {
                 int ec = 0;
 
+                const int input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
+
                 if (occurrence[r->id] != 1) {
                     ec = Request::kConflict;
                 }
                 else if (r->start_flag && r->stop_flag) {
                     ec = Request::kInvalid;
                 }
-                else if (!r->start_flag && !sequence_manager_->Contains(r->id)) {
+                else if (input_length > session_len_) {
                     ec = Request::kInvalid;
                 }
+                else if (!r->start_flag) {
+                    if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
+                        ec = Request::kTooLong;
+                    }
+                    else if (seq->tokens.size() + input_length > session_len_) {
+                        ec = Request::kTooLong;
+                    }
+                }
 
                 if (ec) {
                     reject(type, r, ec);
diff --git a/src/turbomind/models/llama/Request.h b/src/turbomind/models/llama/Request.h
index 2970ff247c..a33fdd9ca1 100644
--- a/src/turbomind/models/llama/Request.h
+++ b/src/turbomind/models/llama/Request.h
@@ -32,7 +32,8 @@ struct Request {
         kConflict = 2,
         kBusy     = 3,
         kInactive = 4,
-        kFail     = 5
+        kFail     = 5,
+        kTooLong  = 6
     };
     std::promise<int> signal;
 };

From ca700336ff9d285dcd7d95cf0678ba7365cccae3 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 11:26:39 +0000
Subject: [PATCH 30/56] correct start offset

---
 src/turbomind/models/llama/LlamaBatch.cc | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index e6e1994b77..2a4c3f7d94 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -47,7 +47,10 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
             if (r) {
                 int ec = 0;
 
-                const int input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
+                const int  input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
+                const auto get_offset   = [&](int token_count) {
+                    return std::max(0, std::min(token_count, r->inputs[rank_].getVal<int>("step", token_count)));
+                };
 
                 if (occurrence[r->id] != 1) {
                     ec = Request::kConflict;
@@ -62,7 +65,7 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
                     if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
                         ec = Request::kTooLong;
                     }
-                    else if (seq->tokens.size() + input_length > session_len_) {
+                    else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
                         ec = Request::kTooLong;
                     }
                 }

From 32037fd464f89fa10a3ea5c9693e6f49dffb2fb6 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 20 Oct 2023 13:23:55 +0000
Subject: [PATCH 31/56] fix prefill chunking

---
 src/turbomind/models/llama/LlamaBatch.cc | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 2a4c3f7d94..04f6fde775 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -964,9 +964,9 @@ void LlamaBatch<T>::ContextDecode()
         else {
             offsets.push_back(i);
             max_context_cnts.push_back(max_context_count);
-            accum_size        = 0;
-            accum_input_count = 0;
-            max_context_count = 0;
+            accum_size        = 1;
+            accum_input_count = h_input_length_buf_[i];
+            max_context_count = state_->h_context_length[i] - 1;
         }
     }
     offsets.push_back(batch_size);
@@ -983,7 +983,9 @@ void LlamaBatch<T>::ContextDecode()
         std::vector<int> decode_lengths{};
         int              max_input_len{};
         auto             input_ids = context_decoder_ids_buf_;
+        TM_LOG_INFO("first = %d, last = %d", first, last);
         for (int i = first; i < last; ++i) {
+            TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
             input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
             dbg(i, h_input_length_buf_[i]);
             h_tmp_k_ptrs_[i] = k_ptr;

From 03138666f74c3b28682a0ce469ae7cfd3a951172 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Sat, 21 Oct 2023 03:20:52 +0000
Subject: [PATCH 32/56] fix `cache_block_seq_len` param passing

---
 src/turbomind/models/llama/LlamaContextAttentionLayer.h  | 3 ++-
 src/turbomind/models/llama/LlamaContextDecoder.cc        | 5 ++++-
 src/turbomind/models/llama/LlamaContextDecoder.h         | 7 ++++++-
 src/turbomind/models/llama/LlamaDecoder.cc               | 9 +++++++--
 src/turbomind/models/llama/LlamaDecoder.h                | 4 +++-
 .../models/llama/LlamaDecoderSelfAttentionLayer.h        | 3 ++-
 src/turbomind/models/llama/LlamaV2.cc                    | 5 ++++-
 src/turbomind/models/llama/LlamaV2.h                     | 7 +++++--
 8 files changed, 33 insertions(+), 10 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.h b/src/turbomind/models/llama/LlamaContextAttentionLayer.h
index 67cbfa3e30..5e88255579 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.h
@@ -45,6 +45,7 @@ class LlamaContextAttentionLayer {
                                IAllocator*          allocator,
                                bool                 is_free_buffer_after_forward,
                                bool                 use_fmha,
+                               int                  cache_block_seq_len,
                                int                  quant_policy):
         head_num_(head_num),
         size_per_head_(size_per_head),
@@ -58,7 +59,7 @@ class LlamaContextAttentionLayer {
         cublas_wrapper_(cublas_wrapper),
         linear_(cublas_wrapper, stream),
         allocator_(allocator),
-        kv_cache_block_len_(128), /// 
+        kv_cache_block_len_(cache_block_seq_len),
         is_free_buffer_after_forward_(is_free_buffer_after_forward),
         use_fmha_(use_fmha),
         quant_policy_(quant_policy)
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index fb1e2e8c79..2047ffa050 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -66,6 +66,7 @@ template<typename T>
 void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
                                         size_t                      kv_head_num,
                                         bool                        use_fmha,
+                                        int                         cache_block_seq_len,
                                         int                         quant_policy)
 {
     h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
@@ -80,6 +81,7 @@ void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
                                                                  allocator_,
                                                                  is_free_buffer_after_forward_,
                                                                  use_fmha,
+                                                                 cache_block_seq_len,
                                                                  quant_policy);
 
     silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
@@ -140,6 +142,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t                      head_num
                                             IAllocator*                 allocator,
                                             bool                        is_free_buffer_after_forward,
                                             bool                        use_fmha,
+                                            int                         cache_block_seq_len,
                                             int                         quant_policy):
     BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
     head_num_(head_num),
@@ -151,7 +154,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t                      head_num
     tensor_para_(tensor_para),
     data_type_(getTensorType<T>())
 {
-    initialize(attn_params, kv_head_num, use_fmha, quant_policy);
+    initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy);
 }
 
 template<typename T>
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.h b/src/turbomind/models/llama/LlamaContextDecoder.h
index 6750614c5e..9be0ba7025 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.h
+++ b/src/turbomind/models/llama/LlamaContextDecoder.h
@@ -40,7 +40,11 @@ class LlamaContextDecoder: public BaseLayer {
     void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
     void freeBuffer() override;
 
-    void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_fmha, int quant_policy);
+    void initialize(const LlamaAttentionParams& attn_params,
+                    size_t                      kv_head_num,
+                    bool                        use_fmha,
+                    int                         cache_block_seq_len,
+                    int                         quant_policy);
 
     size_t head_num_;
     size_t size_per_head_;
@@ -94,6 +98,7 @@ class LlamaContextDecoder: public BaseLayer {
                         IAllocator*                 allocator,
                         bool                        is_free_buffer_after_forward,
                         bool                        use_fmha,
+                        int                         cache_block_seq_len,
                         int                         quant_policy);
 
     ~LlamaContextDecoder() override;
diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc
index 6d633cd829..071e2e0e0c 100644
--- a/src/turbomind/models/llama/LlamaDecoder.cc
+++ b/src/turbomind/models/llama/LlamaDecoder.cc
@@ -41,6 +41,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t                      head_num,
                               cublasMMWrapper*            cublas_wrapper,
                               IAllocator*                 allocator,
                               bool                        is_free_buffer_after_forward,
+                              int                         cache_block_seq_len,
                               int                         quant_policy):
     BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
     head_num_(head_num),
@@ -53,7 +54,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t                      head_num,
     data_type_(getTensorType<T>())
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
-    initialize(attn_params, kv_head_num, quant_policy);
+    initialize(attn_params, kv_head_num, cache_block_seq_len, quant_policy);
 }
 
 template<typename T>
@@ -65,7 +66,10 @@ LlamaDecoder<T>::~LlamaDecoder()
 }
 
 template<typename T>
-void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy)
+void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
+                                 size_t                      kv_head_num,
+                                 int                         cache_block_seq_len,
+                                 int                         quant_policy)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
@@ -78,6 +82,7 @@ void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params, size_t
                                                                   cublas_wrapper_,
                                                                   allocator_,
                                                                   is_free_buffer_after_forward_,
+                                                                  cache_block_seq_len,
                                                                   quant_policy);
 
     silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
diff --git a/src/turbomind/models/llama/LlamaDecoder.h b/src/turbomind/models/llama/LlamaDecoder.h
index e12214617b..70b1ee2706 100644
--- a/src/turbomind/models/llama/LlamaDecoder.h
+++ b/src/turbomind/models/llama/LlamaDecoder.h
@@ -35,7 +35,8 @@ class LlamaDecoder: public BaseLayer {
     void allocateBuffer() override;  // deprecated
     void allocateBuffer(size_t batch_size);
     void freeBuffer() override;
-    void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy);
+    void
+    initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int cache_block_seq_len, int quant_policy);
 
     size_t head_num_;
     size_t size_per_head_;
@@ -78,6 +79,7 @@ class LlamaDecoder: public BaseLayer {
                  cublasMMWrapper*            cublas_wrapper,
                  IAllocator*                 allocator,
                  bool                        is_free_buffer_after_forward,
+                 int                         cache_block_seq_len,
                  int                         quant_policy);
 
     ~LlamaDecoder() override;
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
index 95556cd30b..bcb538fb85 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
@@ -44,6 +44,7 @@ class LlamaDecoderSelfAttentionLayer {
                                    cublasMMWrapper*            cublas_wrapper,
                                    IAllocator*                 allocator,
                                    bool                        is_free_buffer_after_forward,
+                                   int                         cache_block_seq_len,
                                    int                         quant_policy):
         head_num_(head_num),
         kv_head_num_(kv_head_num),
@@ -57,7 +58,7 @@ class LlamaDecoderSelfAttentionLayer {
         stream_(stream),
         linear_(cublas_wrapper, stream),
         allocator_(allocator),
-        kv_cache_block_len_(128),  ///
+        kv_cache_block_len_(cache_block_seq_len),
         is_free_buffer_after_forward_(is_free_buffer_after_forward),
         quant_policy_(quant_policy)
     {
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index 070fd3fa65..53160c8ede 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -136,7 +136,7 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
     batch_ = std::make_unique<LlamaBatch<T>>(
         max_batch_size, max_context_token_num, session_len, std::move(sequence_manager), this);
 
-    initialize(attn_params, kv_head_num, use_context_fmha, quant_policy);
+    initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy);
 
     /// TODO: decouple Llama model and batch inference
     batch_->Start();
@@ -154,6 +154,7 @@ template<typename T>
 void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
                             size_t                      kv_head_num,
                             bool                        use_context_fmha,
+                            int                         cache_block_seq_len,
                             int                         quant_policy)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
@@ -171,6 +172,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
                                                   allocator_,
                                                   is_free_buffer_after_forward_,
                                                   use_context_fmha,
+                                                  cache_block_seq_len,
                                                   quant_policy);
 
     decoder_ = new LlamaDecoder<T>(head_num_,
@@ -185,6 +187,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
                                    cublas_wrapper_,
                                    allocator_,
                                    is_free_buffer_after_forward_,
+                                   cache_block_seq_len,
                                    quant_policy);
 
     dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index 198ad4b09c..99d5352746 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -104,8 +104,11 @@ class LlamaV2 {
 private:
     friend class Batch;
 
-    void
-    initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_context_fmha, int quant_policy);
+    void initialize(const LlamaAttentionParams& attn_params,
+                    size_t                      kv_head_num,
+                    bool                        use_context_fmha,
+                    int                         cache_block_seq_len,
+                    int                         quant_policy);
 
     void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
 

From b70a4f6f60e8911abbb6b87f3ddb5a31246d0e6c Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Sat, 21 Oct 2023 04:11:25 +0000
Subject: [PATCH 33/56] fix `block_size` fmtstr

---
 src/turbomind/models/llama/BlockManager.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index 7baec4f378..3efa8901dd 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -30,7 +30,7 @@ BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size
         chunk_size_ = chunk_size;
     }
 
-    TM_LOG_INFO("[BlockManager] block_size = %d", block_size_);
+    TM_LOG_INFO("[BlockManager] block_size = %lu MB", (unsigned long)block_size_ >> 20);
     TM_LOG_INFO("[BlockManager] max_block_count = %d", max_block_count_);
     TM_LOG_INFO("[BlockManager] chunk_size = %d", chunk_size_);
 

From 2290461833b88b6b08795250c6a4011be8aa50b0 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Sat, 21 Oct 2023 17:11:26 +0000
Subject: [PATCH 34/56] fix output tokens

---
 src/turbomind/models/llama/llama_kernels.cu | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu
index 9abcfc9d0d..7adc34ff51 100644
--- a/src/turbomind/models/llama/llama_kernels.cu
+++ b/src/turbomind/models/llama/llama_kernels.cu
@@ -574,7 +574,7 @@ __global__ void updateOutput(int**      request_output_ids_ptrs,
 
     output_ids += max_session_len * batch_id;
 
-    const int seqlen     = sequence_lengths[batch_id];
+    const int seqlen     = sequence_lengths[batch_id] + (int)token_generated;
     const int output_len = min(seqlen, request_output_ids_lens[batch_id]);
 
     for (int i = threadIdx.x; i < output_len; i += blockDim.x) {

From 66fa64becb5cb7887ada1959bc58f57dae6ed1e4 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 23 Oct 2023 06:36:36 +0000
Subject: [PATCH 35/56] fix batch resizing

---
 src/turbomind/models/llama/LlamaBatch.cc | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 04f6fde775..2edd23d938 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -321,6 +321,7 @@ bool LlamaBatch<T>::Initialize()
             return sequences[idx]->status == Sequence::kActive;  // present status
         });
 
+        // all blocks are not enough to hold a single sequence
         FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
 
         // move swap-ins to the back
@@ -364,7 +365,7 @@ bool LlamaBatch<T>::Initialize()
 
     const int batch_size = state_->active_size;
 
-    if (exchange || outcome.allocation) {
+    if (exchange || active_holes || outcome.allocation) {
         // Prepare intermediate buffers
         h_cu_block_counts_[0] = 0;
 
@@ -385,10 +386,10 @@ bool LlamaBatch<T>::Initialize()
             });
         }
 
-        if (1) {
-            std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1);
-            dbg(cu_block_cnts);
-        }
+        // if (1) {
+        //     std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1);
+        //     dbg(cu_block_cnts);
+        // }
         // dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size]));
         // dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size]));
         // dbg(h_cu_block_counts_[batch_size]);
@@ -418,7 +419,7 @@ void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std:
     dst->h_finished[j]       = src->h_finished[i];
     dst->seq_len_limit[j]    = src->seq_len_limit[i];
     dst->sequences[j]        = src->sequences[i];
-    dst->is_swap_in[i]       = src->is_swap_in[i];
+    dst->is_swap_in[j]       = src->is_swap_in[i];
     dst->requests[j]         = std::move(src->requests[i]);
 
     Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_);
@@ -972,6 +973,8 @@ void LlamaBatch<T>::ContextDecode()
     offsets.push_back(batch_size);
     max_context_cnts.push_back(max_context_count);
 
+    dbg(offsets, max_context_cnts);
+
     // context decode on sub-batches
     for (int k = 0; k < offsets.size() - 1; ++k) {
         int              first          = offsets[k];
@@ -1111,7 +1114,11 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
 
     // secure info needed by `Initialize()`
     Copy(finished_buf_, batch_size, state_->h_finished);
+
+    // invariant: context_length = sequence_length + 1
+    invokePlusScalar(sequence_lengths_, 1, batch_size, stream_);
     Copy(sequence_lengths_, batch_size, state_->h_context_length);
+    invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
 
     if constexpr (0) {
         std::unique_lock<std::mutex> lock;

From 18001cddcb267347cd7fb7ec7ff669b2612b18b4 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 23 Oct 2023 07:08:22 +0000
Subject: [PATCH 36/56] fix masking of finished sequences

---
 src/turbomind/models/llama/LlamaBatch.cc | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 2edd23d938..fa9a148a0e 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -770,8 +770,9 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
     // for
     for (int i = 0; i < batch_size; ++i) {
         h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
-        // mask finished sequences
-        state_->h_finished[i] = max_context_len >= h_seq_limit_len_[i];
+        if (max_context_len >= h_seq_limit_len_[i]) {  // mask finished sequences
+            state_->h_finished[i] = true;
+        }
     }
     Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
     Copy(state_->h_finished, batch_size, finished_buf_);

From 8705131dfdfbf5627700351a53931ecfacc005d6 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 23 Oct 2023 10:09:34 +0000
Subject: [PATCH 37/56] add debug util

---
 src/turbomind/models/llama/LlamaBatch.cc      |  5 ++++
 .../llama/LlamaDecoderSelfAttentionLayer.cc   | 28 ++++++++++++++++++-
 src/turbomind/models/llama/llama_utils.cu     |  9 ++++++
 src/turbomind/models/llama/llama_utils.h      |  2 ++
 4 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index fa9a148a0e..e1c1d9fcad 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -813,6 +813,11 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
                         (int)state_->h_finished[i]);
         }
     }
+
+    // for (int i = 0; i < batch_size; ++i) {
+    //     gSequenceIds(i) = state_->requests[i]->id;
+    // }
+
     return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
 }
 
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index 78ced5dff8..d411f3f412 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -98,13 +98,29 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     const int layer_id = input_tensors->getVal<int>("layer_id");
 
-    // const int step        = input_tensors->getVal<int>("step");
+    // const int step = input_tensors->getVal<int>("step");
     // const int step_1 = step - 1;
 
     const int batch_size = input_tensors->at("input_query").shape[0];
 
     allocateBuffer(batch_size);
 
+    // std::vector<int> seqlens(batch_size);
+    // check_cuda_error(
+    //     cudaMemcpyAsync(seqlens.data(), sequence_lengths_data, sizeof(int) * batch_size, cudaMemcpyDefault,
+    //     stream_));
+    // check_cuda_error(cudaStreamSynchronize(stream_));
+
+    // for (int i = 0; i < batch_size; ++i) {
+    //     if (gSequenceIds(i) == 1) {
+    //         Compare((T*)input_query_data + hidden_units_ * i,
+    //                 hidden_units_,
+    //                 Concat("query", gSequenceIds(i), seqlens[i], layer_id),
+    //                 compare_mode,
+    //                 stream_);
+    //     }
+    // }
+
     {
         NvtxScope scope("qkv_gemm");
         linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
@@ -172,6 +188,16 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
         DispatchDecoderMultiheadAttention<T>(params);
     }
 
+    // for (int i = 0; i < batch_size; ++i) {
+    //     if (gSequenceIds(i) == 1) {
+    //         Compare((T*)context_buf_ + hidden_units_ * i,
+    //                 hidden_units_,
+    //                 Concat("context_buf", gSequenceIds(i), seqlens[i], layer_id),
+    //                 compare_mode,
+    //                 stream_);
+    //     }
+    // }
+
     {
         NvtxScope scope("o_gemm");
         linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
diff --git a/src/turbomind/models/llama/llama_utils.cu b/src/turbomind/models/llama/llama_utils.cu
index 7050d2d13f..93de6afd58 100644
--- a/src/turbomind/models/llama/llama_utils.cu
+++ b/src/turbomind/models/llama/llama_utils.cu
@@ -157,4 +157,13 @@ bool isDebug()
     return is_debug;
 }
 
+int64_t& gSequenceIds(int batch_idx)
+{
+    thread_local std::vector<int64_t> ids{};
+    if (batch_idx >= ids.size()) {
+        ids.resize(batch_idx + 1, -1);
+    }
+    return ids.at(batch_idx);
+}
+
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h
index 60942560d3..0e31f64c7c 100644
--- a/src/turbomind/models/llama/llama_utils.h
+++ b/src/turbomind/models/llama/llama_utils.h
@@ -77,4 +77,6 @@ struct NvtxScope {
     }
 };
 
+int64_t& gSequenceIds(int batch_idx);
+
 }  // namespace turbomind

From 64de1cd693cd6d24a886b0cac8f5ceab04a97884 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Wed, 25 Oct 2023 12:10:49 +0000
Subject: [PATCH 38/56] free unused block early

---
 src/turbomind/models/llama/BlockManager.cc    |  83 +++-
 src/turbomind/models/llama/BlockManager.h     |  32 +-
 src/turbomind/models/llama/LlamaBatch.cc      |  36 +-
 src/turbomind/models/llama/SequenceManager.cc | 467 +++++++++---------
 src/turbomind/models/llama/SequenceManager.h  |  41 +-
 .../models/llama/test_cache_manager.cc        |   4 +-
 6 files changed, 368 insertions(+), 295 deletions(-)

diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index 3efa8901dd..d04fd604b0 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -42,6 +42,7 @@ BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size
 
     // pre-allocate first chunk
     Malloc();
+    dbg(free_ids_);
 }
 
 BlockManager::~BlockManager()
@@ -68,6 +69,7 @@ bool BlockManager::Malloc()
 
     for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
         auto& block     = blocks_.emplace_back();
+        block.use_count = 0;
         block.ref_count = 0;
         block.id        = (int)blocks_.size() - 1;
         block.timestamp = 0;
@@ -116,15 +118,14 @@ std::vector<const Block*> BlockManager::Allocate(int count)
         auto& block = blocks_[idx];
         FT_CHECK(is_free(block));
         block.ref_count = 1;
+        block.use_count = 1;
         block.unique_id = unique_id_++;
         ret.push_back(&block);
     }
 
     Move(free_ids_, idxs, active_ids_);
 
-    Touch(ret);
-
-    dbg("[Allocate]", free_ids_, active_ids_);
+    dbg(free_ids_, active_ids_);
 
     return ret;
 }
@@ -143,23 +144,49 @@ void BlockManager::Evict(int count)
 
     // set as free
     for (const auto& idx : idxs) {
-        FT_CHECK(is_cached(blocks_[idx]));
-        blocks_[idx].timestamp = 0;
+        auto& b = blocks_[idx];
+        FT_CHECK(is_cached(b));
+        b.ref_count = 0;
+        b.unique_id = 0;
+        b.timestamp = 0;
+    }
+
+    Move(cached_ids_, idxs, free_ids_);
+
+    dbg(cached_ids_, free_ids_);
+}
+
+int BlockManager::Free(const std::vector<const Block*>& bs)
+{
+    std::vector<int> idxs;
+
+    for (const auto& p : bs) {
+        auto& b = blocks_[p->id];
+        FT_CHECK(is_cached(b));
+        if (--b.ref_count == 0) {
+            b.unique_id = 0;
+            b.timestamp = 0;
+            idxs.push_back(b.id);
+        }
     }
 
+    std::sort(idxs.begin(), idxs.end());
+
     Move(cached_ids_, idxs, free_ids_);
 
-    dbg("[Evict]", free_ids_);
+    dbg(cached_ids_, free_ids_);
+
+    return idxs.size();
 }
 
-int BlockManager::Release(const std::vector<const Block*>& bs)
+int BlockManager::Unlock(const std::vector<const Block*>& bs)
 {
     std::vector<int> idxs;
 
     for (const auto& p : bs) {
         auto& block = blocks_[p->id];
         FT_CHECK(is_active(block));
-        if (--block.ref_count == 0) {
+        if (--block.use_count == 0) {
             idxs.push_back(block.id);
         }
     }
@@ -168,19 +195,19 @@ int BlockManager::Release(const std::vector<const Block*>& bs)
 
     Move(active_ids_, idxs, cached_ids_);
 
-    dbg("[Release]", cached_ids_);
+    dbg(active_ids_, cached_ids_);
 
     return idxs.size();
 }
 
-void BlockManager::Retain(const std::vector<const Block*>& bs)
+int BlockManager::Lock(const std::vector<const Block*>& bs)
 {
     std::vector<int> idxs;
 
     for (const auto& p : bs) {
         auto& block = blocks_[p->id];
         FT_CHECK(is_cached(block));
-        if (++block.ref_count == 1) {
+        if (++block.use_count == 1) {
             idxs.push_back(p->id);
         }
     }
@@ -189,7 +216,9 @@ void BlockManager::Retain(const std::vector<const Block*>& bs)
 
     Move(cached_ids_, idxs, active_ids_);
 
-    dbg("[Retain]", active_ids_);
+    dbg(cached_ids_, active_ids_);
+
+    return idxs.size();
 }
 
 void BlockManager::Touch(const std::vector<const Block*>& bs)
@@ -202,32 +231,32 @@ void BlockManager::Touch(const std::vector<const Block*>& bs)
 
 Snapshot BlockManager::TakeSnapshot()
 {
-    std::vector<int> ref_count(blocks_.size());
+    std::vector<int> use_count(blocks_.size());
     for (const auto& idx : active_ids_) {
-        ref_count[idx] = blocks_[idx].ref_count;
+        use_count[idx] = blocks_[idx].use_count;
     }
-    return {active_count(), cached_count(), free_count(), std::move(ref_count)};
+    return {active_count(), cached_count(), free_count(), std::move(use_count)};
 }
 
 std::ostream& operator<<(std::ostream& os, const BlockManager& manager)
 {
-    os << "block_size: " << manager.block_size_ << "\n";
-    os << "max_block_count: " << manager.max_block_count_ << "\n";
-    os << "chunk_size: " << manager.chunk_size_ << "\n";
-    os << "allocator: " << manager.allocator_ << "\n";
-    os << "chunks: " << manager.chunks_.size() << "\n";
-    os << "active_ids: " << manager.active_ids_.size() << "\n";
-    os << "cached_ids: " << manager.cached_ids_.size() << "\n";
-    os << "free_ids: " << manager.free_ids_.size() << "\n";
-    os << "blocks: " << manager.blocks_.size() << "\n";
-    os << "unique_id: " << manager.unique_id_ << "\n";
-    os << "timestamp: " << manager.timestamp_ << "\n";
+    os << "block_size: " << manager.block_size_ << ", ";
+    os << "max_block_count: " << manager.max_block_count_ << ", ";
+    os << "chunk_size: " << manager.chunk_size_ << ", ";
+    os << "chunks: " << manager.chunks_.size() << ", ";
+    os << "active_ids: " << manager.active_ids_.size() << ", ";
+    os << "cached_ids: " << manager.cached_ids_.size() << ", ";
+    os << "free_ids: " << manager.free_ids_.size() << ", ";
+    os << "blocks: " << manager.blocks_.size() << ", ";
+    os << "unique_id: " << manager.unique_id_ << ", ";
+    os << "timestamp: " << manager.timestamp_ << ", ";
+    os << "allocator: " << manager.allocator_;
     return os;
 }
 
 std::ostream& operator<<(std::ostream& os, const Block& block)
 {
-    os << "id=" << block.id << ", ref_count=" << block.ref_count << ", unique_id=" << block.unique_id
+    os << "id=" << block.id << ", use_count=" << block.use_count << ", unique_id=" << block.unique_id
        << ", timestamp=" << block.timestamp << ", data=" << block.data;
     return os;
 }
diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h
index dc14b32a15..984b0446a0 100644
--- a/src/turbomind/models/llama/BlockManager.h
+++ b/src/turbomind/models/llama/BlockManager.h
@@ -21,8 +21,9 @@ namespace turbomind {
 // [L, S/x, H, x, D]
 
 struct Block {
-    int      id;  // fixed linear id in the pool
-    int      ref_count;
+    int      id;         // fixed linear id in the pool
+    int      ref_count;  // all sequences referencing the block
+    int      use_count;  // active sequences using the block
     uint64_t unique_id;  // unique for every block allocation
     uint64_t timestamp;
     void*    data;
@@ -32,24 +33,24 @@ struct Block {
 
 inline bool is_active(const Block& block)
 {
-    return block.ref_count > 0;
+    return block.ref_count > 0 && block.use_count > 0;
 }
 
 inline bool is_cached(const Block& block)
 {
-    return block.ref_count == 0 && block.timestamp > 0;
+    return block.ref_count > 0 && block.use_count == 0;
 }
 
 inline bool is_free(const Block& block)
 {
-    return block.ref_count == 0 && block.timestamp == 0;
+    return block.ref_count == 0 && block.use_count == 0 && block.timestamp == 0;
 }
 
 struct Snapshot {
     int              active;
     int              cached;
     int              free;
-    std::vector<int> ref_count;
+    std::vector<int> use_count;
 };
 
 class BlockManager {
@@ -58,20 +59,21 @@ class BlockManager {
 
     ~BlockManager();
 
-    // free -> active
+    // free -> active (use_count = 1, ref_count = 1)
     [[nodiscard]] std::vector<const Block*> Allocate(int count);
 
-    // decrease ref count
-    // active -> cached
-    [[maybe_unused]] int Release(const std::vector<const Block*>& bs);
+    // cached -> active (use_count += 1)
+    [[maybe_unused]] int Lock(const std::vector<const Block*>& bs);
 
-    // increase ref count
-    // cached -> active
-    void Retain(const std::vector<const Block*>& bs);
+    // active -> cached (use_count -= 1)
+    [[maybe_unused]] int Unlock(const std::vector<const Block*>& bs);
 
-    // cached -> free
+    // cached -> free (ref_count = 0)
     void Evict(int count);
 
+    // cached -> free (ref_count -= 1)
+    [[maybe_unused]] int Free(const std::vector<const Block*>& bs);
+
     // increase timestamp in reversed order
     void Touch(const std::vector<const Block*>& bs);
 
@@ -123,7 +125,7 @@ class BlockManager {
     std::vector<Block> blocks_;  // < 100k
 
     // uint64_t unique_id_{1UL << 63};
-    uint64_t unique_id_{0};
+    uint64_t unique_id_{1};
     uint64_t timestamp_{1};
 };
 
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index e1c1d9fcad..d6d94be43d 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -278,10 +278,9 @@ bool LlamaBatch<T>::Initialize()
         }
     }
 
-    dbg(holes, active_holes);
+    // dbg(holes, active_holes);
 
     auto process = [&](BatchState* state) {
-        dbg(state->size);
         for (int i = 0; i < state->size; ++i) {
             if (auto& r = state->requests[i]) {
                 sequences.push_back(state->sequences[i]);
@@ -305,10 +304,9 @@ bool LlamaBatch<T>::Initialize()
 
     auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);
 
-    dbg(outcome);
-    // if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
-    //     dbg(outcome);
-    // }
+    if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
+        dbg(outcome);
+    }
 
     bool exchange = outcome.swap_in + outcome.swap_out > 0;
 
@@ -1058,6 +1056,14 @@ void LlamaBatch<T>::ContextDecode()
 
     std::fill(h_input_length_buf_ + base, h_input_length_buf_ + batch_size, 0);
 
+    // `SequenceManager` needs real-time value of cache length
+    for (int i = base; i < batch_size; ++i) {
+        if (state_->requests[i]) {
+            FT_CHECK(state_->sequences[i]);
+            state_->sequences[i]->cache_len = state_->h_context_length[i] - 1;  // -1 since we skip last token
+        }
+    }
+
     check_cuda_error(cudaStreamSynchronize(stream_));
     const auto tock = std::chrono::high_resolution_clock::now();
     if (rank_ == 0) {
@@ -1183,6 +1189,14 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
         TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
     }
 
+    // `SequenceManager` needs real-time value of cache length
+    for (int i = 0; i < batch_size; ++i) {
+        if (state_->requests[i]) {
+            FT_CHECK(state_->sequences[i]);
+            state_->sequences[i]->cache_len = state_->h_context_length[i];
+        }
+    }
+
     std::vector<Signal> signals;
     {
         NvtxScope _("prepare_completion_signal");
@@ -1248,7 +1262,7 @@ void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_for
     }
 
     if (debug_ && rank_ == 0) {
-        std::vector<int> tokens(state_->h_context_length[index] + 1);
+        std::vector<int> tokens(state_->h_context_length[index]);
         Copy(state_->output_ids + index * session_len_, tokens.size(), tokens.data());
         cudaStreamSynchronize(stream_);
         std::stringstream ss;
@@ -1262,13 +1276,11 @@ void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_for
         sequence_manager_->Erase(state_->requests[index]->id);
     }
     else {
-        const int cache_len  = state_->h_context_length[index];
-        const int output_len = !is_stop_request ? cache_len + 1 : cache_len;
+        // account for the last generated token if not a stop request (which doesn't generate)
+        const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(is_stop_request);
 
         auto& seq = *state_->sequences[index];
 
-        seq.cache_len = cache_len;
-
         // update token IDs
         seq.tokens.resize(output_len);
 
@@ -1286,7 +1298,7 @@ void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_for
 
         check_cuda_error(cudaStreamSynchronize(stream_));
 
-        sequence_manager_->Release(seq);
+        sequence_manager_->UpdateAndSetUnlock(seq);
     }
 
     state_->sequences[index] = nullptr;
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index daff172a15..12f982be26 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -1,10 +1,14 @@
 // Copyright (c) OpenMMLab. All rights reserved.
 
 #include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/BlockManager.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/debug_utils.h"
 #include "src/turbomind/utils/logger.h"
+#include <cstddef>
+#include <cstdlib>
 #include <ctime>
+#include <numeric>
 #include <stdexcept>
 
 namespace turbomind {
@@ -41,7 +45,7 @@ const Sequence* SequenceManager::Create(uint64_t id)
         }
         auto& seq = it->second;
         if (seq.status != Sequence::kCached) {
-            released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
+            unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
         }
         seq = std::move(sequence);
     }
@@ -71,40 +75,72 @@ bool SequenceManager::Erase(uint64_t id)
     if (auto it = sequences_.find(id); it != sequences_.end()) {
         auto& seq = it->second;
         if (seq.status != Sequence::kCached) {
-            released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
+            unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
+            freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
+        }
+        else {
+            for (int i = 0; i < seq.blocks.size(); ++i) {
+                // filter invalidated blocks
+                if (seq.blocks[i]->unique_id == seq.block_unique_ids[i]) {
+                    freed_.push_back(seq.blocks[i]);
+                }
+            }
         }
         sequences_.erase(it);
     }
     else {
         throw std::out_of_range(std::to_string(id));
     }
+
     return false;
 }
 
-void SequenceManager::Verify(Sequence& seq, std::vector<const Block*>& retain)
+void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
 {
-    FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
-    for (int i = 0; i < seq.blocks.size(); ++i) {
-        if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
-            seq.blocks.resize(i);
-            seq.block_unique_ids.resize(i);
-            break;
+    if (!need_verify_) {
+        return;
+    }
+    std::vector<const Block*> blocks;
+    for (const auto& p : sequences) {
+        auto& seq = const_cast<Sequence&>(*p);
+        if (seq.status != Sequence::kCached) {
+            continue;
+        }
+        FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
+        for (int i = 0; i < seq.blocks.size(); ++i) {
+            if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
+                seq.blocks.resize(i);
+                seq.block_unique_ids.resize(i);
+                break;
+            }
         }
+        blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
+        seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);
+        seq.status    = Sequence::kLocked;
     }
-    retain.insert(retain.end(), seq.blocks.begin(), seq.blocks.end());
-    seq.status    = Sequence::kLocked;
-    seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);
+    block_manager_->Lock(blocks);
+    need_verify_ = false;
 }
 
-void SequenceManager::Release(const Sequence& sequence)
+void SequenceManager::CommitUnlockAndFree()
 {
-    auto& seq = const_cast<Sequence&>(sequence);
-    if (seq.status == Sequence::kActive) {
-        block_manager_->Touch(seq.blocks);
+    if (!unlocked_.empty()) {
+        block_manager_->Unlock(unlocked_);
+        unlocked_.clear();
     }
-    if (seq.status != Sequence::kCached) {
-        released_.insert(released_.end(), seq.blocks.begin(), seq.blocks.end());
+
+    if (!freed_.empty()) {
+        block_manager_->Free(freed_);
+        freed_.clear();
     }
+}
+
+void SequenceManager::UpdateAndSetUnlock(const Sequence& sequence)
+{
+    FT_CHECK(sequence.status != Sequence::kCached);
+    auto& seq = const_cast<Sequence&>(sequence);
+    block_manager_->Touch(seq.blocks);
+    unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
     seq.status = Sequence::kCached;
 }
 
@@ -114,16 +150,44 @@ struct Schedule {
     int free;
     int cached;
 
-    int allocate;
-    int evict;
-    int preempt;
+    int allocate{};
+    int evict{};
+    int preempt{};
 
-    std::vector<int> victims;
+    int last;
 
-    std::vector<int> active;
+    Sequences        active;
     std::vector<int> block_counts;
+    Sequences        inactive;
+    Sequences        victims;
+
+    Schedule(Snapshot snapshot, int size):
+        free(snapshot.free),
+        cached(snapshot.cached),
+        last(size),
+        use_count_(std::move(snapshot.use_count)),
+        unlocked_(size),
+        it_(size)
+    {
+    }
+
+    int Unlock(const Sequences& seqs, int vidx)
+    {
+        while (vidx < it_) {
+            const auto& blocks = seqs[--it_]->blocks;
+            int         count  = 0;
+            for (const auto& p : blocks) {
+                count += static_cast<int>(--use_count_[p->id] == 0);
+            }
+            unlocked_[it_] = count;
+        }
+        return unlocked_[vidx];
+    }
 
-    std::vector<int> inactive;
+private:
+    std::vector<int> use_count_;
+    std::vector<int> unlocked_;
+    int              it_;
 };
 
 template<typename T>
@@ -145,44 +209,6 @@ std::ostream& operator<<(std::ostream& os, const Schedule& s)
     return os;
 }
 
-class Simulator {
-public:
-    explicit Simulator(const std::vector<const Sequence*>& seqs,
-                       const std::vector<int>&             idxs,
-                       std::vector<int>&                   ref_count):
-        seqs_(seqs), idxs_(idxs), ref_count_(ref_count)
-    {
-        // dbg(seqs.size());
-        released_.resize(seqs.size());
-        ptr_ = released_.size();
-    }
-
-    int Release(int order)
-    {
-        while (order < ptr_) {
-            --ptr_;
-            int count = 0;
-            for (const auto& p : seqs_[idxs_[ptr_]]->blocks) {
-                if (--ref_count_[p->id] == 0) {
-                    ++count;
-                }
-            }
-            released_[ptr_] = count;
-        }
-
-        return released_[order];
-    }
-
-private:
-    const std::vector<const Sequence*>& seqs_;
-    const std::vector<int>&             idxs_;
-
-    std::vector<int>& ref_count_;
-
-    std::vector<int> released_;
-    int              ptr_;
-};
-
 struct Transaction {
     int index_;
     int block_count_;
@@ -191,52 +217,69 @@ struct Transaction {
     int evict_{};
     int preempt_{};
 
-    std::vector<int> victims_;
-
-    Schedule&  sched_;
-    Simulator& simulator_;
+    Sequences victims_;
 
-    explicit Transaction(Schedule& sched, int index, int block_count, Simulator& simulator):
-        sched_(sched), index_(index), block_count_(block_count), simulator_(simulator)
-    {
-    }
+    const Sequences& sequences_;
+    Schedule&        schedule_;
 
-    int Allocate(int count)
+    explicit Transaction(const Sequences& sequences, int index, int block_count, Schedule& sched):
+        sequences_(sequences), schedule_(sched), index_(index), block_count_(block_count)
     {
-        allocate_ += count;
-        return count;
     }
 
-    int Evict(int count)
+    void Process()
     {
-        evict_ += count;
-        return count;
-    }
+        int count = block_count_;
 
-    int Preempt(int order, int idx)
-    {
-        victims_.push_back(idx);
-        preempt_ += simulator_.Release(order);
-        return preempt_;
-    }
+        int tmp = std::min(schedule_.free, count);
+        count -= tmp;
+        allocate_ += tmp;
 
-    void Commit()
-    {
-        sched_.free -= allocate_;
-        FT_CHECK(sched_.free >= 0);
+        tmp = std::min(schedule_.cached, count);
+        count -= tmp;
+        evict_ += tmp;
 
-        sched_.cached += preempt_;
-        sched_.cached -= evict_;
-        FT_CHECK(sched_.cached >= 0);
+        for (int vidx = schedule_.last - 1; count && vidx > index_; --vidx) {
+            if (sequences_[vidx]->status == Sequence::kCached) {
+                continue;
+            }
+            victims_.push_back(sequences_[vidx]);
+            preempt_ += schedule_.Unlock(sequences_, vidx);
 
-        sched_.allocate += allocate_;
-        sched_.evict += evict_;
-        sched_.preempt += preempt_;
+            if (count <= preempt_) {
+                evict_ += count;
+                count -= count;
+                schedule_.last = vidx;  // ! modifiying `sched_.last` is part of commit
+                break;
+            }
+        }
 
-        sched_.victims.insert(sched_.victims.end(), victims_.begin(), victims_.end());
+        if (count == 0) {
+            Commit();
+        }
+        else {
+            schedule_.inactive.push_back(sequences_[index_]);
+        }
+    }
 
-        sched_.active.push_back(index_);
-        sched_.block_counts.push_back(block_count_);
+    void Commit()
+    {
+        // update available resources
+        schedule_.free -= allocate_;
+        FT_CHECK(schedule_.free >= 0);
+        schedule_.cached += preempt_;
+        schedule_.cached -= evict_;
+        FT_CHECK(schedule_.cached >= 0);
+
+        // update scheduled operations
+        schedule_.allocate += allocate_;
+        schedule_.evict += evict_;
+        schedule_.preempt += preempt_;
+        schedule_.victims.insert(schedule_.victims.end(), victims_.begin(), victims_.end());
+
+        // update active sequences
+        schedule_.active.push_back(sequences_[index_]);
+        schedule_.block_counts.push_back(block_count_);
     }
 };
 
@@ -249,183 +292,143 @@ std::ostream& operator<<(std::ostream& os, const Transaction& trans)
 
 }  // namespace
 
-std::ostream& operator<<(std::ostream& os, const Sequence& seq)
+void SequenceManager::SortByPriority(Sequences&                   sequences,
+                                     std::vector<int>&            context_lengths,
+                                     const std::vector<uint64_t>& priorities)
 {
-    os << "id=" << seq.id << ", status=" << seq.status << ", size(blocks)=" << seq.blocks.size()
-       << ", cache_len=" << seq.cache_len << ", size(random_state)=" << seq.random_state.size();
-    return os;
+    // sort according to priority
+    std::vector<int> idxs(sequences.size());
+    std::iota(idxs.begin(), idxs.end(), 0);
+    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {
+        return priorities[i] < priorities[j];  //
+    });
+    Sequences        tmp_sequences(sequences.size());
+    std::vector<int> tmp_lengths(context_lengths.size());
+    for (int i = 0; i < sequences.size(); ++i) {
+        tmp_sequences[i] = sequences[idxs[i]];
+        tmp_lengths[i]   = context_lengths[idxs[i]];
+    }
+    sequences.swap(tmp_sequences);
+    context_lengths.swap(tmp_lengths);
 }
 
-auto SequenceManager::Materialize(const std::vector<const Sequence*>& sequences,
-                                  const std::vector<int>&             context_lengths,
-                                  const std::vector<uint64_t>&        priorities,
-                                  int                                 step_length) -> Outcome
+std::vector<int> SequenceManager::CountRequiredBlocks(const Sequences&        sequences,
+                                                      const std::vector<int>& context_lengths,
+                                                      int                     step_length)
 {
-    dbg(__PRETTY_FUNCTION__);
-    ////////////////////////////////////////////////////////////////////////////////
-    /// Schedule the assignment of blocks to sequences
-    auto    seqs = const_cast<Sequence* const*>(sequences.data());
-    Outcome outcome{};
-
-    if (!released_.empty()) {
-        block_manager_->Release(released_);
-        released_.clear();
-    }
-
-    // check validity of of cached blocks (blocks of active & locked seqs are always valid)
-    if (need_verification_) {
-        need_verification_ = false;
-        std::vector<const Block*> retain;
-        for (int i = 0; i < sequences.size(); ++i) {
-            if (seqs[i]->status == Sequence::kCached) {
-                Verify(*seqs[i], retain);
-            }
-        }
-        block_manager_->Retain(retain);
-    }
-
-    // count required blocks based on block validity
     std::vector<int> required(sequences.size());
-    // int              total_required{};
     for (int i = 0; i < sequences.size(); ++i) {
         int seq_len = context_lengths[i] + step_length;
-        int count   = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast<int>(seqs[i]->blocks.size());
+        int count   = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast<int>(sequences[i]->blocks.size());
         required[i] = std::max(0, count);
-        // total_required += required[i];
     }
+    return required;
+}
 
-    // dbg(required);
-
-    // no new blocks required, exit early
-    // if (total_required == 0) {
-    //     dbg("early exit");
-    //     return outcome;
-    // }
-
-    /// TODO: more early exit heuristics
-
-    // sort according to priority
-    std::vector<int> idxs(sequences.size());
-    std::iota(idxs.begin(), idxs.end(), 0);
-    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return priorities[i] < priorities[j]; });
-
-    Snapshot snapshot = block_manager_->TakeSnapshot();
-
-    Schedule schedule{snapshot.free, snapshot.cached};
-    schedule.cached += released_.size();
-
-    Simulator simulator(sequences, idxs, snapshot.ref_count);
+void SequenceManager::AssignAndActivate(const Sequences&                 sequences,  //
+                                        const std::vector<int>&          counts,
+                                        const std::vector<const Block*>& blocks)
+{
+    FT_CHECK(sequences.size() == counts.size());
+    auto first = blocks.begin();
+    for (int i = 0; i < sequences.size(); ++i) {
+        auto& s     = const_cast<Sequence&>(*sequences[i]);
+        auto  count = counts[i];
+        dbg(count);
+        auto last = first + count;
+        std::for_each(first, last, [&](const Block* b) {
+            s.blocks.push_back(b);
+            s.block_unique_ids.push_back(b->unique_id);
+        });
+        s.status = Sequence::kActive;
+        first    = last;
+    }
+}
 
-    std::vector<int> active(idxs.size());
-    std::vector<int> victim(idxs.size());
+auto SequenceManager::Materialize(Sequences                    sequences,
+                                  std::vector<int>             context_lengths,
+                                  const std::vector<uint64_t>& priorities,
+                                  int                          step_length) -> Outcome
+{
+    ////////////////////////////////////////////////////////////////////////////////
+    /// Schedule the assignment of blocks to sequences
 
-    for (int i = 0, j = idxs.size(); i < j; ++i) {
-        const int idx = idxs[i];
+    // process deferred unlock and free operations
+    CommitUnlockAndFree();
 
-        const auto& seq         = *sequences[idx];
-        auto        block_count = required[idx];
+    SortByPriority(sequences, context_lengths, priorities);
 
-        Transaction trans{schedule, idx, block_count, simulator};
+    // Verify and lock cache sequences to avoid their blocks being evicted unnoticed
+    // the blocks can still be preempted later
+    VerifyAndLockCached(sequences);
 
-        // allocate from free blocks
-        if (block_count) {
-            block_count -= trans.Allocate(std::min(block_count, schedule.free));
-        }
-        // evict cached blocks
-        if (block_count) {
-            block_count -= trans.Evict(std::min(block_count, schedule.cached));
-        }
-
-        for (int v = j - 1; block_count && v > i; --v) {
-            if (sequences[idxs[v]]->status == Sequence::kCached) {
-                continue;
-            }
-            // dbg(v, idxs[v]);
-            int preempt = trans.Preempt(v, idxs[v]);
-            // dbg(preempt);
-            // Commit only when preemption actually free enough blocks for the sequence to run
-            if (block_count <= preempt) {
-                // preempted blocks are in cached state
-                block_count -= trans.Evict(block_count);
-                j = v;
-                break;
-            }
-        }
+    std::vector<int> required = CountRequiredBlocks(sequences, context_lengths, step_length);
+    // dbg(required);
 
-        // dbg(block_count, trans);
+    Schedule schedule(block_manager_->TakeSnapshot(), sequences.size());
 
-        if (block_count == 0) {
-            trans.Commit();
-            active[idx] = 1;
-            if (seq.status != Sequence::kActive) {
-                ++outcome.swap_in;
-            }
-        }
+    // `schedule.last` is decreasing in the loop
+    for (int i = 0; i < schedule.last; ++i) {
+        Transaction{sequences, i, required[i], schedule}.Process();
     }
 
-    for (const auto& i : idxs) {
-        if (!active[i]) {
-            schedule.inactive.push_back(i);
-            if (seqs[i]->status == Sequence::kActive) {
-                ++outcome.swap_out;
-            }
-        }
+    // mark remaining sequences invalid
+    for (int i = schedule.last; i < sequences.size(); ++i) {
+        schedule.inactive.push_back(sequences[i]);
     }
 
-    // dbg(schedule);
-
     ////////////////////////////////////////////////////////////////////////////////
     /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)
+
+    // combine allocate and evict since evicted blocks are reused by allocation
     schedule.allocate += schedule.evict;
 
+    if (schedule.allocate) {
+        dbg(*block_manager_);
+    }
+
+    Outcome outcome{};
     outcome.allocation = schedule.allocate;
+    outcome.swap_in    = std::count_if(schedule.active.begin(), schedule.active.end(), [](auto p) {
+        if (p->status != Sequence::kActive) {
+            dbg(*p);
+        }
+        return p->status != Sequence::kActive;  //
+    });
+    outcome.swap_out   = std::count_if(schedule.inactive.begin(), schedule.inactive.end(), [](auto p) {
+        if (p->status == Sequence::kActive) {
+            dbg(*p);
+        }
+        return p->status == Sequence::kActive;  //
+    });
 
     // release preempted blocks -> cached
-    for (const auto& v : schedule.victims) {
-        Release(*sequences[v]);
+    if (!schedule.victims.empty()) {
+        for (const auto& p : schedule.victims) {
+            UpdateAndSetUnlock(*p);
+        }
+        CommitUnlockAndFree();
     }
-    block_manager_->Release(released_);
-    released_.clear();
 
     // evict cached blocks -> free
     if (schedule.evict) {
         block_manager_->Evict(schedule.evict);
-        need_verification_ = true;
+        need_verify_ = true;
     }
 
     // allocate & assign blocks
-    auto blocks = block_manager_->Allocate(schedule.allocate);
-    auto first  = blocks.begin();
-
-    for (const auto& idx : schedule.active) {
-        auto& sequence = *seqs[idx];
-
-        // retain blocks for swap-in sequences
-        if (sequence.status == Sequence::kCached) {
-            block_manager_->Retain(sequence.blocks);
-        }
-
-        sequence.status = Sequence::kActive;
-
-        auto last = first + required[idx];
-        std::for_each(first, last, [&](const Block* b) {
-            sequence.blocks.push_back(b);
-            sequence.block_unique_ids.push_back(b->unique_id);
-        });
-
-        first = last;
+    if (schedule.allocate) {
+        auto blocks = block_manager_->Allocate(schedule.allocate);
+        AssignAndActivate(schedule.active, schedule.block_counts, blocks);
     }
 
-    for (const auto& idx : schedule.inactive) {
-        if (seqs[idx]->status == Sequence::kActive) {
-            seqs[idx]->status = Sequence::kLocked;
+    // active -> locked
+    for (const auto& p : schedule.inactive) {
+        if (p->status == Sequence::kActive) {
+            const_cast<Sequence*>(p)->status = Sequence::kLocked;
         }
     }
 
-    for (const auto& idx : schedule.victims) {
-        seqs[idx]->status = Sequence::kCached;
-    }
-
     return outcome;
 }
 
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 19278a0fbd..8800149bf1 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -30,6 +30,16 @@ struct Sequence {
     friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
 };
 
+using Sequences = std::vector<const Sequence*>;
+
+inline std::ostream& operator<<(std::ostream& os, const Sequence& seq)
+{
+    os << "id=" << seq.id << ", status=" << seq.status << ", token_count=" << seq.tokens.size()
+       << ", block_count=" << seq.blocks.size() << ", cache_len=" << seq.cache_len
+       << ", random_state_size=" << seq.random_state.size();
+    return os;
+}
+
 class SequenceManager {
 public:
     explicit SequenceManager(size_t      layer_num,
@@ -53,7 +63,7 @@ class SequenceManager {
 
     bool Erase(uint64_t id);
 
-    void Release(const Sequence& seq);
+    void UpdateAndSetUnlock(const Sequence& seq);
 
     struct Outcome {
         int allocation;
@@ -61,10 +71,10 @@ class SequenceManager {
         int swap_out;
     };
 
-    Outcome Materialize(const std::vector<const Sequence*>& sequences,
-                        const std::vector<int>&             context_lengths,
-                        const std::vector<uint64_t>&        priorities,
-                        int                                 step_length);
+    Outcome Materialize(Sequences                    sequences,
+                        std::vector<int>             context_lengths,
+                        const std::vector<uint64_t>& priorities,
+                        int                          step_length);
 
     void* OffsetKey(void* block_ptr)
     {
@@ -82,21 +92,36 @@ class SequenceManager {
     }
 
 private:
-    void Verify(Sequence& seq, std::vector<const Block*>& retain);
+    void CommitUnlockAndFree();
+
+    void VerifyAndLockCached(const Sequences& sequences);
+
+    std::vector<int> CountRequiredBlocks(const Sequences&        sequences,  //
+                                         const std::vector<int>& context_lengths,
+                                         int                     step_length);
+
+    static void SortByPriority(Sequences&                   sequences,  //
+                               std::vector<int>&            context_lengths,
+                               const std::vector<uint64_t>& priorities);
+
+    static void AssignAndActivate(const Sequences&                 sequences,  //
+                                  const std::vector<int>&          block_counts,
+                                  const std::vector<const Block*>& blocks);
 
 private:
     int    block_seq_len_;
     int    rank_;
     size_t val_offset_{};
 
-    bool need_verification_{};
+    bool need_verify_{};
 
     // Use `std::map` to avoid reference invalidation
     std::map<uint64_t, Sequence> sequences_;
 
     std::unique_ptr<BlockManager> block_manager_;
 
-    std::vector<const Block*> released_;
+    std::vector<const Block*> unlocked_;
+    std::vector<const Block*> freed_;
 };
 
 inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc)
diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc
index 75d9f039dc..99fbc34c5f 100644
--- a/src/turbomind/models/llama/test_cache_manager.cc
+++ b/src/turbomind/models/llama/test_cache_manager.cc
@@ -46,7 +46,9 @@ TEST_CASE("BlockManager")
     std::copy(blocks3.begin(), blocks3.end(), std::back_inserter(blocks1));
     std::copy(blocks2.begin(), blocks2.end(), std::back_inserter(blocks1));
 
-    REQUIRE(m.Release(blocks1) == 32);
+    m.Touch(blocks1);
+
+    REQUIRE(m.Unlock(blocks1) == 32);
     REQUIRE(m.active_count() == 0);
     REQUIRE(m.free_count() == 0);
     REQUIRE(m.cached_count() == 32);

From 699b0bfee07ec2c32889f304f9f734bbf322f705 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 30 Oct 2023 13:27:54 +0000
Subject: [PATCH 39/56] add ntk scaling and logn scaling

---
 CMakeLists.txt                                | 11 ++-
 .../CMakeLists.txt                            |  8 +-
 .../decoder_multihead_attention/array_ops.h   | 66 +++++++++++------
 .../decoder_multihead_attention.cu            |  2 +-
 .../decoder_multihead_attention_params.h      |  3 +-
 .../decoder_multihead_attention_template.h    | 35 +++++----
 .../test_decoder_multihead_attention.cu       |  6 +-
 .../decoder_multihead_attention/test_utils.cu |  2 +-
 .../decoder_multihead_attention/thread_map.h  |  1 -
 .../kernels/unfused_attention_kernels.cu      | 73 ++++++++++---------
 .../kernels/unfused_attention_kernels.h       |  1 +
 src/turbomind/models/llama/LlamaBatch.cc      | 53 ++++++++++++--
 src/turbomind/models/llama/LlamaBatch.h       |  9 ++-
 .../llama/LlamaContextAttentionLayer.cc       |  7 +-
 .../models/llama/LlamaContextDecoder.cc       |  1 +
 .../llama/LlamaDecoderSelfAttentionLayer.cc   | 30 +++++---
 src/turbomind/models/llama/LlamaV2.cc         | 61 +++++++++-------
 src/turbomind/models/llama/LlamaV2.h          | 63 ++++++++--------
 src/turbomind/models/llama/SequenceManager.cc |  2 +-
 src/turbomind/models/llama/SequenceManager.h  |  2 +
 src/turbomind/models/llama/llama_params.h     |  5 +-
 .../triton_backend/llama/LlamaTritonModel.cc  | 12 ++-
 22 files changed, 287 insertions(+), 166 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index f3f1c7b171..a004d76af5 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -198,10 +198,13 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2F
 
 set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
 # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
-set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
+set(CMAKE_CUDA_FLAGS_RELEASE        "${CMAKE_CUDA_FLAGS_RELEASE}        -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
+set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
+
 if(BUILD_FAST_MATH)
-set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
-message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}")
+    set(CMAKE_CUDA_FLAGS_RELEASE        "${CMAKE_CUDA_FLAGS_RELEASE}        --use_fast_math")
+    set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math")
+    message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}")
 endif()
 
 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
@@ -268,11 +271,13 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');"
                   OUTPUT_VARIABLE USE_CXX11_ABI)
   message("-- USE_CXX11_ABI=${USE_CXX11_ABI}")
   if (USE_CXX11_ABI)
+    set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
   else()
+    set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")
diff --git a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
index 7176017671..fe67d11f0a 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
+++ b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
@@ -1,15 +1,15 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 
 add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu)
-target_compile_options(decoder_multihead_attention PRIVATE
-  --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
+# target_compile_options(decoder_multihead_attention PRIVATE
+#   --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
 set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
 set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
 target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
 
 add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
-target_compile_options(test_decoder_multihead_attention PRIVATE
-  --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
+# target_compile_options(test_decoder_multihead_attention PRIVATE
+#   --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
 target_link_libraries(test_decoder_multihead_attention PRIVATE 
     decoder_multihead_attention 
     decoder_masked_multihead_attention
diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index a847ada855..209da7e71d 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -87,27 +87,40 @@ inline __device__ Array<T, N> operator*(const Array<T, N>& a, const T& b)
 
 }  // namespace ops
 
+template<typename To, typename From, int N>
+inline __device__ Array<To, N> cast(const Array<From, N>& src)
+{
+    Array<To, N> dst;
+    PRAGMA_UNROLL
+    for (int i = 0; i < N; ++i) {
+        dst[i] = (To)src[i];
+    }
+    return dst;
+}
+
 template<int N>
 struct RotaryEmbedding {
 
     static_assert(N % 2 == 0);
 
-    Array<float, N> inv_freqs_;
+    Array<float, N> cs_;
 
     __device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
     {
         PRAGMA_UNROLL
         for (int i = 0; i < N; i += 2) {
-            const float2 tmp  = rotary_embedding_coefficient(offset.x + i, dims, base, timestep);
-            inv_freqs_[i]     = tmp.x;
-            inv_freqs_[i + 1] = tmp.y;
+            const float2 tmp = get_coefficient(offset.x + i, dims, base, timestep);
+            cs_[i]           = tmp.x;
+            cs_[i + 1]       = tmp.y;
         }
     }
 
-    inline __device__ float2 rotary_embedding_coefficient(int idx, int dims, float base, int timestep)
+    static __device__ inline float2 get_coefficient(int idx, int dims, float base, int timestep)
     {
         const float inv_freq = timestep / powf(base, idx / (float)dims);
-        return {cos(inv_freq), sin(inv_freq)};
+        float2      cs;
+        sincosf(inv_freq, &cs.y, &cs.x);
+        return cs;
     }
 
     template<typename T>
@@ -115,35 +128,42 @@ struct RotaryEmbedding {
     {
         PRAGMA_UNROLL
         for (int i = 0; i < N; i += 2) {
-            float tmp0 = inv_freqs_[i] * (float)x[i] - inv_freqs_[i + 1] * (float)x[i + 1];
-            float tmp1 = inv_freqs_[i] * (float)x[i + 1] + inv_freqs_[i + 1] * (float)x[i];
+            float tmp0 = cs_[i] * (float)x[i] - cs_[i + 1] * (float)x[i + 1];
+            float tmp1 = cs_[i] * (float)x[i + 1] + cs_[i + 1] * (float)x[i];
             x[i]       = (T)tmp0;
             x[i + 1]   = (T)tmp1;
         }
     }
 };
 
-template<typename VecQk, typename ThreadMap>
 struct LogNScaling {
-    __device__ void apply(VecQk& x)
+
+    float scale_;
+
+    __device__ static float get_scale(int seq_len, int max_position_embeddings)
     {
-        PRAGMA_UNROLL
-        for (int i = 0; i < VecQk::kSize; ++i) {
-            // TODO:
+        if (seq_len <= max_position_embeddings) {
+            return 1.f;
+        }
+        else {
+            return log2(seq_len) / log2(max_position_embeddings);
         }
     }
-};
 
-template<typename To, typename From, int N>
-inline __device__ Array<To, N> cast(const Array<From, N>& src)
-{
-    Array<To, N> dst;
-    PRAGMA_UNROLL
-    for (int i = 0; i < N; ++i) {
-        dst[i] = (To)src[i];
+    __device__ LogNScaling(int seq_len, int max_position_embeddings)
+    {
+        scale_ = get_scale(seq_len, max_position_embeddings);
     }
-    return dst;
-}
+
+    template<typename T, int N>
+    __device__ void apply(Array<T, N>& x) const
+    {
+        PRAGMA_UNROLL
+        for (int i = 0; i < N; ++i) {
+            x[i] = (T)((float)x[i] * scale_);
+        }
+    }
+};
 
 template<typename T, int N>
 inline __device__ void Store(T* dst, const Array<T, N>& src)
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
index 709db6ebc0..02cc827694 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu
@@ -40,7 +40,7 @@ void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& p
 
         static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();
 
-        [[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);
+        // [[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);
 
         const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
         const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index 5f18b45216..add5a7161c 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -22,6 +22,7 @@ struct DecoderMultiHeadAttentionParams {
     // sequence-level buffers
     const int* __restrict__ per_sample_length;
     const bool* __restrict__ finished;
+    const float* __restrict__ rope_theta;
 
     // kv cache
     void** __restrict__ per_sample_k_cache;  // [H, S, D]
@@ -50,7 +51,7 @@ struct DecoderMultiHeadAttentionParams {
     int   rotary_embedding_dim;
     float rotary_embedding_base;
     int   max_position_embeddings;
-    bool  use_dynamic_ntk;
+    // bool  use_dynamic_ntk;
 
     // log(n) attention
     bool use_logn_attn;
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index dfeb86e568..ae82a8b786 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -233,9 +233,15 @@ struct DecoderMultiHeadAttentionKernel {
             frag_V = frag_V + bias_V;
         }
 
+        // for (int i = 0; i < kVecQSize; ++i) {
+        //     printf("q[%2d][%3d] = %f\n", (int)head_idx_, (int)(offset.x + i), (float)frag_Q[0][i]);
+        // }
+
+        float rotary_embedding_base =
+            params_.rope_theta ? params_.rope_theta[batch_idx_] : params_.rotary_embedding_base;
+
         // Apply rotary embedding
-        RotaryEmbedding<kVecQSize> rotary_emb(
-            params_.rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);
+        RotaryEmbedding<kVecQSize> rotary_emb(rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);
 
         PRAGMA_UNROLL
         for (int s = 0; s < kQHeadPerThread; ++s) {
@@ -243,6 +249,14 @@ struct DecoderMultiHeadAttentionKernel {
         }
         rotary_emb.apply(frag_K);
 
+        if (params_.use_logn_attn) {
+            LogNScaling logn_scaling(timestep_ + 1, params_.max_position_embeddings);
+            PRAGMA_UNROLL
+            for (int s = 0; s < kQHeadPerThread; ++s) {
+                logn_scaling.apply(frag_Q[s]);
+            }
+        }
+
         if (kSplitK && step_begin_) {  // Split idx > 0
             PRAGMA_UNROLL
             for (int s = 0; s < kQHeadPerThread; ++s) {
@@ -268,6 +282,7 @@ struct DecoderMultiHeadAttentionKernel {
                 qk *= params_.inv_sqrt_dh;
                 smem_M_[qi] = qk;
                 smem_L_[qi] = 1.f;
+                // printf("qk[%2d] = %f\n", head_idx_, qk);
             }
             // write Q and O
             Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
@@ -467,10 +482,6 @@ struct DecoderMultiHeadAttentionKernel {
         /// block synchronization
         frag_M = qk_max<MapKv>(frag_M, smem_red_max_, warp_id_, lane_id_);
 
-        if (threadIdx.x == 0 && step == timestep_ - kSliceLen) {
-            // printf("frag_M[%d] = %f\n", head_idx_, (float)frag_M[0]);
-        }
-
         // wait while smem_red_ is being used.
         // __syncthreads();
 
@@ -488,6 +499,10 @@ struct DecoderMultiHeadAttentionKernel {
             }
         }
 
+        // if (threadIdx.x == 0 && step + iter_length == timestep_) {
+        //     printf("frag_M[%2d] = %f\n", head_idx_, (float)frag_M[0]);
+        // }
+
         // __syncthreads();  // DEBUG
 
         /////////////////////////////////////////////////////////////////////////////////////////
@@ -506,17 +521,9 @@ struct DecoderMultiHeadAttentionKernel {
             }
         }
 
-        // if (thread0()) {
-        // printf("frag_L0 = %f\n", (float)frag_L[0]);
-        // }
-
         /// block synchronization
         frag_L = blockSum<kWarpCount>(frag_L, smem_red_sum_, warp_id_, lane_id_);
 
-        if (thread0()) {
-            // printf("frag_L = %f\n", (float)frag_L[0]);
-        }
-
         for (int qi = 0; qi < kHeadPerCta; ++qi) {
             // exp(m1 - m2) * l1
             frag_L[qi] += exp_M_diff[qi] * smem_L_[qi];
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index b5249f31c2..e4636bcea0 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -109,13 +109,13 @@ int main(int argc, char* argv[])
     constexpr int kHeadNum    = 32;
     constexpr int kHeadDim    = 128;
     constexpr int KvHeadNum   = 32;
-    constexpr int kBatchSize  = 1;
-    constexpr int kContextLen = 1024;
+    constexpr int kBatchSize  = 32;
+    constexpr int kContextLen = 7306;
     // constexpr int kContextLen  = 1024;
     constexpr int kSequenceLen = kContextLen + 1;
     constexpr int kBlockSz     = 128;
     constexpr int kTestIter    = 1;
-    constexpr int kMaxSplitK   = 4;
+    constexpr int kMaxSplitK   = 1;
 
     RNG rng{};
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
index 883f0fc3d0..c3fb0d77bc 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
@@ -226,7 +226,7 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
     params.hidden_size_per_head    = p.size_per_head;
     params.rotary_embedding_dim    = p.rotary_embedding_dim;
     params.max_position_embeddings = p.max_position_embeddings;
-    params.use_dynamic_ntk         = p.use_dynamic_ntk;
+    params.use_dynamic_ntk         = false;
     params.use_logn_attn           = p.use_logn_attn;
 
     // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
diff --git a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
index 47b2636f6d..f4c2be1da2 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h
@@ -3,7 +3,6 @@
 #pragma once
 
 #include "../gemm_s_f16/common.h"
-#include "src/turbomind/kernels/custom_ar_kernels.h"
 
 namespace turbomind {
 
diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu
index abbbfd5562..040f7204bf 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.cu
+++ b/src/turbomind/kernels/unfused_attention_kernels.cu
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
+#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
 #include "src/turbomind/kernels/reduce_kernel_utils.cuh"
 #include "src/turbomind/kernels/unfused_attention_kernels.h"
 #include "src/turbomind/utils/cuda_type_utils.cuh"
@@ -854,19 +854,20 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                    T* v_buf,
                                                    T* QKV,
                                                    const T* __restrict qkv_bias,
-                                                   const int* padding_offset,
-                                                   const int* context_length,
-                                                   const int* input_length,
-                                                   int        batch_size,
-                                                   int        seq_len,
-                                                   int        head_num,
-                                                   int        kv_head_num,
-                                                   int        size_per_head,
-                                                   int        rotary_embedding_dim,
-                                                   float      rotary_embedding_base,
-                                                   int        max_position_embeddings,
-                                                   bool       use_dynamic_ntk,
-                                                   bool       use_logn_attn)
+                                                   const int*   padding_offset,
+                                                   const int*   context_length,
+                                                   const int*   input_length,
+                                                   const float* rope_theta,
+                                                   int          batch_size,
+                                                   int          seq_len,
+                                                   int          head_num,
+                                                   int          kv_head_num,
+                                                   int          size_per_head,
+                                                   int          rotary_embedding_dim,
+                                                   float        rotary_embedding_base,
+                                                   int          max_position_embeddings,
+                                                   bool         use_dynamic_ntk,
+                                                   bool         use_logn_attn)
 {
     // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
     // QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
@@ -907,12 +908,18 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
     Vec_t q, k, v;
     Vec_t q_bias, k_bias, v_bias;
 
+    using Vec = Array<T, vec_size>;
+
+    static_assert(sizeof(Vec_t) == sizeof(Vec));
+
+    using namespace ops;
+
     // load Q and apply bias
     if (!is_masked) {
         q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
         if (qkv_bias) {
-            q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
-            q      = mmha::add(q, q_bias);
+            q_bias  = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
+            (Vec&)q = (Vec&)q + (Vec&)q_bias;
         }
     }
 
@@ -921,10 +928,10 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
         k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
         v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
         if (qkv_bias) {
-            k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
-            v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
-            k      = mmha::add(k, k_bias);
-            v      = mmha::add(v, v_bias);
+            k_bias  = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
+            v_bias  = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
+            (Vec&)k = (Vec&)k + (Vec&)k_bias;
+            (Vec&)v = (Vec&)v + (Vec&)v_bias;
         }
     }
 
@@ -932,24 +939,21 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
     const int history_len = context_len - input_length[batch_idx];
     const int timestep    = history_len + seq_idx;
 
-    if (use_dynamic_ntk) {
-        rotary_embedding_base = mmha::rotary_embedding_get_base(
-            context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base);
+    if (rope_theta) {
+        rotary_embedding_base = rope_theta[batch_idx];
     }
 
-    // TODO: unused computation on k if GQA is used
-    mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep);
+    RotaryEmbedding<vec_size> rotary_emb(rotary_embedding_base, rotary_embedding_dim, timestep, {tidx * vec_size, 0});
+    rotary_emb.apply((Array<T, vec_size>&)q);
+
+    if (head_idx < kv_head_num) {
+        rotary_emb.apply((Array<T, vec_size>&)k);
+    }
 
     if (use_logn_attn) {
         // +1 to convert to context length at the timestep
-        float logn_scaling = mmha::logn_attn_get_scaling(timestep + 1, max_position_embeddings);
-        if constexpr (std::is_same_v<T, float>) {
-            q = mmha::mul<Vec_t, float, Vec_t>(logn_scaling, q);
-        }
-        else if constexpr (std::is_same_v<T, half>) {
-            half tmp = __float2half(logn_scaling);
-            q        = mmha::mul<Vec_t, uint16_t, Vec_t>((uint16_t&)tmp, q);
-        }
+        LogNScaling logn_scaling(timestep + 1, max_position_embeddings);
+        logn_scaling.apply((Array<T, vec_size>&)q);
     }
 
     if (!is_masked && !q_buf) {  // also skip modifying QKV if q/k/v_buf are present
@@ -984,6 +988,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                                                              padding_offset,           \
                                                                                              context_length,           \
                                                                                              input_length,             \
+                                                                                             rope_theta,               \
                                                                                              batch_size,               \
                                                                                              seq_len,                  \
                                                                                              head_num,                 \
@@ -1004,6 +1009,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     const int*   padding_offset,
                                     const int*   context_length,
                                     const int*   input_length,
+                                    const float* rope_theta,
                                     const int    batch_size,
                                     const int    seq_len,
                                     const int    token_num,
@@ -1034,6 +1040,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                                  const int*   padding_offset,                                          \
                                                  const int*   history_length,                                          \
                                                  const int*   input_length,                                            \
+                                                 const float* rope_theta,                                              \
                                                  const int    batch_size,                                              \
                                                  const int    seq_len,                                                 \
                                                  const int    token_num,                                               \
diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h
index 846a1b7371..758fe7fba0 100644
--- a/src/turbomind/kernels/unfused_attention_kernels.h
+++ b/src/turbomind/kernels/unfused_attention_kernels.h
@@ -72,6 +72,7 @@ void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                     const int*   padding_offset,
                                     const int*   context_length,
                                     const int*   input_length,
+                                    const float* rope_theta,
                                     const int    batch_size,
                                     const int    seq_len,
                                     const int    token_num,
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index d6d94be43d..f46d7ebe35 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -12,11 +12,12 @@
 #include "src/turbomind/utils/Tensor.h"
 #include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/debug_utils.h"
+#include "src/turbomind/utils/gemm_test/gemm_func.h"
 #include "src/turbomind/utils/logger.h"
 #include <algorithm>
+#include <cmath>
 #include <cstdint>
 #include <iomanip>
-#include <math.h>
 #include <mutex>
 #include <numeric>
 #include <sstream>
@@ -59,11 +60,11 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
                     ec = Request::kInvalid;
                 }
                 else if (input_length > session_len_) {
-                    ec = Request::kInvalid;
+                    ec = Request::kTooLong;
                 }
                 else if (!r->start_flag) {
                     if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
-                        ec = Request::kTooLong;
+                        ec = Request::kInvalid;
                     }
                     else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
                         ec = Request::kTooLong;
@@ -230,7 +231,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
             if (rank_ == 0) {
                 const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i];
                 TM_LOG_WARNING(
-                    "[initialize] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
+                    "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
                     (long)seq.id,
                     state.h_context_length[i],
                     request_output_len,
@@ -239,7 +240,35 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
             }
         }
 
-        // recover random state HtoD if not a new sequence
+        // compute rope scaling factor
+        if (r->start_flag) {
+            seq.rope_theta      = model_->attn_params_.rotary_embedding_base;
+            auto scaling_factor = 1.f;
+            if (r->inputs[rank_].isExist("rope_scaling_factor")) {  // runtime scaling factor
+                scaling_factor = r->inputs[rank_].getVal<float>("rope_scaling_factor");
+            }
+            else if (model_->attn_params_.rope_scaling_factor >= 1.f) {  // infer by `seq_len_limit`
+                scaling_factor   = model_->attn_params_.rope_scaling_factor;
+                auto max_seq_len = state.seq_len_limit[i];
+                auto max_pos_emb = model_->attn_params_.max_position_embeddings;
+                if (max_seq_len > max_pos_emb) {
+                    scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
+                    // scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f))
+                    // - 1.f, 1.f);
+                }
+            }
+            if (scaling_factor != 1.f) {
+                float rope_dim = model_->attn_params_.rotary_embedding_dim;
+                seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
+                TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
+                            (long)seq.id,
+                            scaling_factor,
+                            seq.rope_theta);
+            }
+        }
+        state.h_rope_theta[i] = seq.rope_theta;
+
+        // recover device states if not a new sequence
         if (!r->start_flag) {
             Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state);
             Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state);
@@ -415,6 +444,7 @@ void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std:
 
     dst->h_context_length[j] = src->h_context_length[i];
     dst->h_finished[j]       = src->h_finished[i];
+    dst->h_rope_theta[j]     = src->h_rope_theta[i];
     dst->seq_len_limit[j]    = src->seq_len_limit[i];
     dst->sequences[j]        = src->sequences[i];
     dst->is_swap_in[j]       = src->is_swap_in[i];
@@ -495,6 +525,8 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
     request_output_ids_lens_ = (int*)allocator_->reMalloc(request_output_ids_lens_, sizeof(int) * batch_size, true);
     request_seqlen_ptrs_     = (int**)allocator_->reMalloc(request_seqlen_ptrs_, sizeof(int*) * batch_size, true);
 
+    rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false);
+
     is_allocate_buffer_ = true;
 }
 
@@ -549,7 +581,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
         for (auto& s : states_) {
             s.h_context_length =
                 (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
-            s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
+            s.h_finished   = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
+            s.h_rope_theta = (float*)allocator_->reMalloc(s.h_rope_theta, sizeof(float) * max_batch_size, false, true);
         }
 
         h_seq_limit_len_ =
@@ -613,6 +646,8 @@ void LlamaBatch<T>::FreeBuffer()
         allocator_->free((void**)&request_output_ids_lens_);
         allocator_->free((void**)&request_seqlen_ptrs_);
 
+        allocator_->free((void**)&rope_theta_);
+
         is_allocate_buffer_ = false;
     }
 
@@ -620,6 +655,7 @@ void LlamaBatch<T>::FreeBuffer()
         for (auto& s : states_) {
             allocator_->free((void**)&s.h_context_length, true);
             allocator_->free((void**)&s.h_finished, true);
+            allocator_->free((void**)&s.h_rope_theta, true);
             allocator_->free((void**)&s.output_ids);
         }
         allocator_->free((void**)&h_tmp_k_ptrs_, true);
@@ -792,6 +828,8 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
     Copy(h_request_output_ids_lens_, batch_size, request_output_ids_lens_);
     Copy(h_request_seqlen_ptrs_, batch_size, request_seqlen_ptrs_);
 
+    Copy(state_->h_rope_theta, batch_size, rope_theta_);
+
     // ! range of step_ [1, 2 * session_len]
     // consider a sequence with context_len == session_len and another sequence with context_len == 1 and
     // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
@@ -851,6 +889,7 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
                            sequence_lengths_,
                            finished_buf_,
                            cu_block_counts_,
+                           rope_theta_,
                            g.step,
                            0,
                            g.sum_seq_len,
@@ -938,6 +977,7 @@ void LlamaBatch<T>::ContextDecode()
     const int context_decode_count = batch_size - base;
 
     Copy(state_->h_context_length, batch_size, context_length_buf_);
+    Copy(state_->h_rope_theta, batch_size, rope_theta_);
     Copy(h_input_length_buf_, batch_size, input_length_buf_);
 
     check_cuda_error(cudaStreamSynchronize(stream_));
@@ -1042,6 +1082,7 @@ void LlamaBatch<T>::ContextDecode()
                               input_length_buf_ + first,
                               context_length_buf_ + first,
                               cu_block_counts_ + first,
+                              rope_theta_ + first,
                               token_count,
                               max_input_len,
                               max_context_cnts[k],
diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h
index 4c8f8154be..4e7c2e7b11 100644
--- a/src/turbomind/models/llama/LlamaBatch.h
+++ b/src/turbomind/models/llama/LlamaBatch.h
@@ -22,6 +22,8 @@ struct BatchState {
     void* top_p_curand_state;
     int*  output_ids;  // output ids in [B, S]
 
+    float* h_rope_theta;
+
     std::vector<int> seq_len_limit;
     std::vector<int> is_swap_in;
 
@@ -180,6 +182,8 @@ class LlamaBatch {
     float* context_logits_buf_{};
     float* local_context_logits_buf_{};
 
+    float* rope_theta_{};
+
     // used by dynamic decoder
     int*      token_ids_buf_{};  // all token IDs in [S, B], indexed using `step`
     int*      end_ids_buf_{};
@@ -194,9 +198,8 @@ class LlamaBatch {
     int** h_request_seqlen_ptrs_{};
 
     // pinned buffers
-    int* h_input_ids_buf_{};
-    int* h_input_length_buf_{};
-    // int*       h_sequence_lengths_{};
+    int*       h_input_ids_buf_{};
+    int*       h_input_length_buf_{};
     uint32_t*  h_seq_limit_len_{};
     int*       h_cu_block_counts_{};
     uintptr_t* h_k_block_ptrs_{};
diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
index 1a62e2fb77..92fe00dc56 100644
--- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc
@@ -149,6 +149,8 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
     int*       cu_seqlens      = input_tensors->at("cu_seqlens").getPtr<int>();
     int*       cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
 
+    const float* rope_theta = input_tensors->getPtr<const float>("rope_theta", nullptr);
+
     const auto padding_offset = input_tensors->at("padding_offset").getPtr<int>();
 
     auto Show = [&](const T* x, size_t n) {
@@ -179,16 +181,17 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
                                    padding_offset,  // padding_offset,
                                    context_length,  // used for applying rotary embedding
                                    input_length,
+                                   rope_theta,
                                    batch_size,
                                    max_q_len,  // seq_len
                                    num_token,  // batch_size * seq_len
                                    local_head_num_,
                                    local_kv_head_num_,
                                    size_per_head_,
-                                   params_.rotray_embedding_dim,
+                                   params_.rotary_embedding_dim,
                                    params_.rotary_embedding_base,
                                    params_.max_position_embeddings,
-                                   params_.use_dynamic_ntk,
+                                   false,  // params_.use_dynamic_ntk,
                                    params_.use_logn_attn,
                                    stream_);
     sync_check_cuda_error();
diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc
index 2047ffa050..268ff7ab58 100644
--- a/src/turbomind/models/llama/LlamaContextDecoder.cc
+++ b/src/turbomind/models/llama/LlamaContextDecoder.cc
@@ -114,6 +114,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
         {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
         {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
         {"cu_block_counts", input_tensors->at("cu_block_counts")},
+        {"rope_theta", input_tensors->at("rope_theta")},
         {"max_seq_len", input_tensors->at("max_seq_len")}};
 
     TensorMap self_attention_output_tensors{
diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
index d411f3f412..ecce30072c 100644
--- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
+++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
@@ -98,18 +98,14 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     const int layer_id = input_tensors->getVal<int>("layer_id");
 
-    // const int step = input_tensors->getVal<int>("step");
+    const int step = input_tensors->getVal<int>("step");
     // const int step_1 = step - 1;
 
     const int batch_size = input_tensors->at("input_query").shape[0];
 
-    allocateBuffer(batch_size);
+    const float* rope_theta = input_tensors->getPtr<const float>("rope_theta", nullptr);
 
-    // std::vector<int> seqlens(batch_size);
-    // check_cuda_error(
-    //     cudaMemcpyAsync(seqlens.data(), sequence_lengths_data, sizeof(int) * batch_size, cudaMemcpyDefault,
-    //     stream_));
-    // check_cuda_error(cudaStreamSynchronize(stream_));
+    allocateBuffer(batch_size);
 
     // for (int i = 0; i < batch_size; ++i) {
     //     if (gSequenceIds(i) == 1) {
@@ -126,6 +122,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
         linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
     }
 
+    // if (layer_id == 0) {
+    //     Compare(qkv_buf_, batch_size * 3 * hidden_units_, Concat("qkv_buf", step, layer_id), kCmpRead, stream_);
+    // }
+
     const auto layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
     // const int  memory_len   = max_seq_len;
 
@@ -137,6 +137,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.v      = params.k + local_kv_head_num_ * size_per_head_;
     params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
 
+    params.q_bias = weights->qkv.bias;
+    params.k_bias = params.q_bias + local_head_num_ * size_per_head_;
+    params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_;
+
     params.batch_size    = batch_size;
     params.cu_block_cnts = cu_block_counts;
 
@@ -146,6 +150,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
 
     params.finished          = finished_data;
     params.per_sample_length = sequence_lengths_data;
+    params.rope_theta        = rope_theta;
 
     params.layer_offset = layer_offset;
 
@@ -154,8 +159,11 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     params.size_per_head = size_per_head_;
     params.inv_sqrt_dh   = 1.f / std::sqrt((float)params.size_per_head);
 
-    params.rotary_embedding_dim  = size_per_head_;
-    params.rotary_embedding_base = 10000.f;
+    params.rotary_embedding_dim    = size_per_head_;
+    params.rotary_embedding_base   = params_.rotary_embedding_base;
+    params.max_position_embeddings = params_.max_position_embeddings;
+    // params.use_dynamic_ntk = params_.use_dynamic_ntk;
+    params.use_logn_attn = params_.use_logn_attn;
 
     params.partial_O = workspace_;
     params.partial_M = params.partial_O + batch_size * local_head_num_ * kMaxSplitK * size_per_head_;
@@ -198,6 +206,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     o
     //     }
     // }
 
+    // if (layer_id == 0) {
+    //     Compare(context_buf_, batch_size * hidden_units_, Concat("context_buf", step, layer_id), kCmpRead, stream_);
+    // }
+
     {
         NvtxScope scope("o_gemm");
         linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index 53160c8ede..fca323f0ad 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -74,6 +74,7 @@ LlamaV2<T>::LlamaV2(size_t                       head_num,
     inter_size_(inter_size),
     num_layer_(num_layer),
     vocab_size_(vocab_size),
+    attn_params_(attn_params),
     vocab_size_padded_(vocab_size),
     rmsnorm_eps_(norm_eps),
     start_id_(start_id),
@@ -222,22 +223,23 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
 }
 
 template<typename T>
-void LlamaV2<T>::contextDecode(T*         deocder_output,
-                               uintptr_t* k_cache_ptr,
-                               uintptr_t* v_cache_ptr,
-                               void**     tmp_k_ptrs,
-                               void**     tmp_v_ptrs,
-                               T*         context_decoder_input_buf,
-                               T*         context_decoder_output_buf,
-                               const int* input_ids,
-                               const int* input_length,
-                               const int* context_length,
-                               const int* cu_block_counts,
-                               size_t     token_num,
-                               size_t     max_input_len,
-                               size_t     max_context_len,
-                               size_t     session_len,
-                               size_t     batch_size)
+void LlamaV2<T>::contextDecode(T*           deocder_output,
+                               uintptr_t*   k_cache_ptr,
+                               uintptr_t*   v_cache_ptr,
+                               void**       tmp_k_ptrs,
+                               void**       tmp_v_ptrs,
+                               T*           context_decoder_input_buf,
+                               T*           context_decoder_output_buf,
+                               const int*   input_ids,
+                               const int*   input_length,
+                               const int*   context_length,
+                               const int*   cu_block_counts,
+                               const float* rope_theta,
+                               size_t       token_num,
+                               size_t       max_input_len,
+                               size_t       max_context_len,
+                               size_t       session_len,
+                               size_t       batch_size)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
@@ -274,6 +276,7 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
         {"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}},
         {"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}},
         {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
+        {"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}},
         {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}}};
 
     std::unordered_map<std::string, Tensor> decoder_output_tensors{
@@ -292,18 +295,19 @@ void LlamaV2<T>::contextDecode(T*         deocder_output,
 }
 
 template<typename T>
-void LlamaV2<T>::decoderForward(T*          decoder_output,
-                                uintptr_t*  k_cache_ptr,
-                                uintptr_t*  v_cache_ptr,
-                                T*          decoder_input,
-                                const int*  sequence_length,
-                                const bool* finished,
-                                const int*  cu_block_counts,
-                                int         step,
-                                int         ite,
-                                int         sum_seq_len,
-                                int         max_seq_len,
-                                size_t      batch_size)
+void LlamaV2<T>::decoderForward(T*           decoder_output,
+                                uintptr_t*   k_cache_ptr,
+                                uintptr_t*   v_cache_ptr,
+                                T*           decoder_input,
+                                const int*   sequence_length,
+                                const bool*  finished,
+                                const int*   cu_block_counts,
+                                const float* rope_theta,
+                                int          step,
+                                int          ite,
+                                int          sum_seq_len,
+                                int          max_seq_len,
+                                size_t       batch_size)
 {
     TM_LOG_DEBUG(__PRETTY_FUNCTION__);
 
@@ -319,6 +323,7 @@ void LlamaV2<T>::decoderForward(T*          decoder_output,
         {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
         {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
         {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
+        {"rope_theta", {MEMORY_GPU, TYPE_FP32, {batch_size}, rope_theta}},
         {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
         {"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}},
     };
diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h
index 99d5352746..f26900eaa0 100644
--- a/src/turbomind/models/llama/LlamaV2.h
+++ b/src/turbomind/models/llama/LlamaV2.h
@@ -29,6 +29,7 @@
 #include "src/turbomind/models/llama/LlamaWeight.h"
 #include "src/turbomind/models/llama/Request.h"
 #include "src/turbomind/models/llama/SequenceManager.h"
+#include "src/turbomind/models/llama/llama_params.h"
 #include "src/turbomind/utils/allocator.h"
 #include "src/turbomind/utils/cublasMMWrapper.h"
 #include "src/turbomind/utils/instance_comm.h"
@@ -112,35 +113,37 @@ class LlamaV2 {
 
     void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
 
-    void contextDecode(T*         deocder_output,
-                       uintptr_t* k_block_ptrs,
-                       uintptr_t* v_block_ptrs,
-                       void**     k_tmp_ptrs,
-                       void**     v_tmp_ptrs,
-                       T*         context_decoder_input_buf,
-                       T*         context_decoder_output_buf,
-                       const int* input_ids,
-                       const int* input_length,
-                       const int* context_length,
-                       const int* cu_block_counts,
-                       size_t     token_num,
-                       size_t     max_input_len,
-                       size_t     max_context_len,
-                       size_t     session_len,
-                       size_t     batch_size);
-
-    void decoderForward(T*          decoder_output,
-                        uintptr_t*  k_cache_ptr,
-                        uintptr_t*  v_cache_ptr,
-                        T*          decoder_input,
-                        const int*  sequence_length,
-                        const bool* finished,
-                        const int*  cu_block_counts,
-                        int         step,
-                        int         ite,
-                        int         sum_seq_len,
-                        int         max_seq_len,
-                        size_t      batch_size);
+    void contextDecode(T*           deocder_output,
+                       uintptr_t*   k_block_ptrs,
+                       uintptr_t*   v_block_ptrs,
+                       void**       k_tmp_ptrs,
+                       void**       v_tmp_ptrs,
+                       T*           context_decoder_input_buf,
+                       T*           context_decoder_output_buf,
+                       const int*   input_ids,
+                       const int*   input_length,
+                       const int*   context_length,
+                       const int*   cu_block_counts,
+                       const float* rope_theta,
+                       size_t       token_num,
+                       size_t       max_input_len,
+                       size_t       max_context_len,
+                       size_t       session_len,
+                       size_t       batch_size);
+
+    void decoderForward(T*           decoder_output,
+                        uintptr_t*   k_cache_ptr,
+                        uintptr_t*   v_cache_ptr,
+                        T*           decoder_input,
+                        const int*   sequence_length,
+                        const bool*  finished,
+                        const int*   cu_block_counts,
+                        const float* rope_theta,
+                        int          step,
+                        int          ite,
+                        int          sum_seq_len,
+                        int          max_seq_len,
+                        size_t       batch_size);
 
     void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);
 
@@ -181,6 +184,8 @@ class LlamaV2 {
     size_t       vocab_size_padded_;
     float        rmsnorm_eps_ = 1e-6f;
 
+    const LlamaAttentionParams attn_params_;
+
     static constexpr bool neox_rotary_style_ = false;
 
     const int    start_id_;
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 12f982be26..6c2778daa1 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -36,7 +36,7 @@ SequenceManager::SequenceManager(size_t      layer_num,
 
 const Sequence* SequenceManager::Create(uint64_t id)
 {
-    Sequence sequence{id, {}, {}, {}, {}, {}};
+    Sequence sequence{id, {}, {}, {}, {}, {}, {}, 0.f};
 
     auto it = sequences_.find(id);
     if (it != sequences_.end()) {
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 8800149bf1..be99e120e3 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -27,6 +27,8 @@ struct Sequence {
     // additional data kept round-to-round
     mutable std::vector<std::byte> random_state;  // update by user
 
+    mutable float rope_theta;
+
     friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
 };
 
diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h
index 8f8c96837b..78b1570f02 100644
--- a/src/turbomind/models/llama/llama_params.h
+++ b/src/turbomind/models/llama/llama_params.h
@@ -5,10 +5,11 @@
 namespace turbomind {
 
 struct LlamaAttentionParams {
-    int   rotray_embedding_dim;
+    int   rotary_embedding_dim;
     float rotary_embedding_base;
     int   max_position_embeddings;
-    bool  use_dynamic_ntk;
+    float rope_scaling_factor;
+    // bool  use_dynamic_ntk;
     bool  use_logn_attn;
 };
 
diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index beab5d7d94..3a60896a59 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -74,6 +74,12 @@ void LlamaTritonModel<T>::handleMissingParams()
         TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)session_len_);
     }
 
+    if (!attn_params_.max_position_embeddings) {
+        attn_params_.max_position_embeddings = session_len_;
+        TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to `session_len` (%d).",
+                       (int)attn_params_.max_position_embeddings);
+    }
+
     if (!max_context_token_num_) {
         max_context_token_num_ = (int)std::sqrt(max_batch_size_);
         TM_LOG_WARNING("[LlamaTritonModel] `max_context_token_num` is not set, default to %d.",
@@ -142,10 +148,12 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t      tensor_para_size,
     quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
     group_size_   = reader.GetInteger("llama", "group_size", 0);
 
-    attn_params_.rotray_embedding_dim    = reader.GetInteger("llama", "rotary_embedding");
+    // rotary embedding parameters
+    attn_params_.rotary_embedding_dim    = reader.GetInteger("llama", "rotary_embedding");
     attn_params_.rotary_embedding_base   = reader.GetFloat("llama", "rope_theta", 10000.0f);
+    attn_params_.rope_scaling_factor     = reader.GetFloat("llama", "rope_scaling_factor", 0.f);
     attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0);
-    attn_params_.use_dynamic_ntk         = reader.GetInteger("llama", "use_dynamic_ntk", 0);
+    // attn_params_.use_dynamic_ntk         = reader.GetInteger("llama", "use_dynamic_ntk", 0);
     attn_params_.use_logn_attn           = reader.GetInteger("llama", "use_logn_attn", 0);
 
     handleMissingParams();

From 2e08a0bdb04fadb8de790d26c2e46ad3fa7141e2 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Tue, 31 Oct 2023 05:00:50 +0000
Subject: [PATCH 40/56] cmake flags

---
 CMakeLists.txt | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index a004d76af5..2372875c41 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -196,7 +196,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
 set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
 set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED")
 
-set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
+set(CMAKE_CXX_FLAGS_RELEASE        "${CMAKE_CXX_FLAGS_RELEASE}        -O3")
+set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
 # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
 set(CMAKE_CUDA_FLAGS_RELEASE        "${CMAKE_CUDA_FLAGS_RELEASE}        -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
 set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
@@ -273,12 +274,14 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');"
   if (USE_CXX11_ABI)
     set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
+    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
   else()
     set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
+    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")

From 44782a1c03700eea414b29a13b9a534f88523745 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 04:24:35 +0000
Subject: [PATCH 41/56] fix typo

---
 .../kernels/decoder_multihead_attention/array_ops.h  |  2 +-
 .../decoder_multihead_attention_template.h           | 12 ++++++------
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index 209da7e71d..4885284a54 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -146,7 +146,7 @@ struct LogNScaling {
             return 1.f;
         }
         else {
-            return log2(seq_len) / log2(max_position_embeddings);
+            return log2f(seq_len) / log2f(max_position_embeddings);
         }
     }
 
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index ae82a8b786..7844232045 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -289,8 +289,8 @@ struct DecoderMultiHeadAttentionKernel {
             Store(&smem_O_[qi * kMaxHeadDim + offset.x], cast<float>(frag_V));
         }
 
-        auto farg_K_store = conv_k_store_(frag_K);
-        auto farg_V_store = conv_v_store_(frag_V);
+        auto frag_K_store = conv_k_store_(frag_K);
+        auto frag_V_store = conv_v_store_(frag_V);
 
         // store
         if (warp_id_ == 0 && is_gqa_leader_) {
@@ -304,12 +304,12 @@ struct DecoderMultiHeadAttentionKernel {
                            + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
                 v_cache_ = (Tkv*)v_cache_ptrs_[block_index] + params_.layer_offset
                            + kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
-                Store(&k_cache_[block_offset * kHeadDim + offset.x], farg_K_store);
-                Store(&v_cache_[block_offset * kHeadDim + offset.x], farg_V_store);
+                Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K_store);
+                Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V_store);
             }
             else {
-                Store(&k_cache_[timestep_ * kHeadDim + offset.x], farg_K_store);
-                Store(&v_cache_[timestep_ * kHeadDim + offset.x], farg_V_store);
+                Store(&k_cache_[timestep_ * kHeadDim + offset.x], frag_K_store);
+                Store(&v_cache_[timestep_ * kHeadDim + offset.x], frag_V_store);
             }
         }
     }

From 39c1a87a8dd8fb672c8537f20459e4b13ae21301 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 05:49:06 +0000
Subject: [PATCH 42/56] w4a16 for sm75

---
 .../kernels/gemm_s_f16/cta_iterator.h         | 19 ++++++++++++++++
 .../kernels/gemm_s_f16/gemm_template.h        | 22 ++++++++++++++++++-
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/src/turbomind/kernels/gemm_s_f16/cta_iterator.h b/src/turbomind/kernels/gemm_s_f16/cta_iterator.h
index 48cf9ace2c..0c13ae3116 100644
--- a/src/turbomind/kernels/gemm_s_f16/cta_iterator.h
+++ b/src/turbomind/kernels/gemm_s_f16/cta_iterator.h
@@ -3,6 +3,7 @@
 #pragma once
 
 #include "common.h"
+#include <cstddef>
 #include <cstdint>
 
 namespace turbomind {
@@ -236,7 +237,13 @@ struct IteratorA {
 
     __device__ void prefetch(bool mask)
     {
+#if TURBOMIND_ARCH_SM80
         cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
+#else
+        if (mask) {
+            *(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
+        }
+#endif
     }
 };
 
@@ -417,7 +424,13 @@ struct IteratorQ {
 
     __device__ void prefetch(bool mask)
     {
+#if TURBOMIND_ARCH_SM80
         cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
+#else
+        if (mask) {
+            *(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
+        }
+#endif
     }
 };
 
@@ -613,8 +626,14 @@ struct IteratorB {
 
     __device__ void prefetch(bool mask)
     {
+#if TURBOMIND_ARCH_SM80
         cp_async_cg_B(
             smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
+#else
+        if (is_valid_n_ && mask) {
+            *(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
+        }
+#endif
     }
 };
 
diff --git a/src/turbomind/kernels/gemm_s_f16/gemm_template.h b/src/turbomind/kernels/gemm_s_f16/gemm_template.h
index a429ba8536..0e3e9bca9d 100644
--- a/src/turbomind/kernels/gemm_s_f16/gemm_template.h
+++ b/src/turbomind/kernels/gemm_s_f16/gemm_template.h
@@ -9,6 +9,23 @@
 
 namespace turbomind {
 
+__inline__ __device__ void
+mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)
+{
+#if TURBOMIND_ARCH_SM75
+    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
+    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
+    float const*    C = reinterpret_cast<float const*>(&c);
+    float*          D = reinterpret_cast<float*>(&d);
+    asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, "
+        "{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
+        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
+        : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
+#else
+    assert(TURBOMIND_ARCH_SM75);
+#endif
+}
+
 __inline__ __device__ void
 mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
 {
@@ -22,7 +39,10 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
         : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
         : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
 #else
-    assert(TURBOMIND_ARCH_SM80);
+    const Array<half, 4>* _a = (const Array<half, 4>*)&a;
+    const Array<half, 2>* _b = (const Array<half, 2>*)&b;
+    mma_m16n8k8_row_col(d, _a[0], _b[0], c);
+    mma_m16n8k8_row_col(d, _a[1], _b[1], d);
 #endif
 }
 

From c8eedeff92a13bd910a14e5348ff2e59a3a8c9c1 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 05:52:28 +0000
Subject: [PATCH 43/56] fix msvc build

---
 src/turbomind/models/llama/Barrier.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/turbomind/models/llama/Barrier.h b/src/turbomind/models/llama/Barrier.h
index 6eb0df9585..e34c42e6ce 100644
--- a/src/turbomind/models/llama/Barrier.h
+++ b/src/turbomind/models/llama/Barrier.h
@@ -2,6 +2,7 @@
 
 #pragma once
 
+#include "src/turbomind/utils/cuda_utils.h"
 #include "src/turbomind/utils/logger.h"
 #ifndef _MSC_VER
 #include <pthread.h>

From 6de4a376e1be46837d1533f2bf785a7e1de7e7f9 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 06:48:29 +0000
Subject: [PATCH 44/56] fix msvc build

---
 .../test_decoder_multihead_attention.cu         | 17 +++++++++--------
 .../decoder_multihead_attention/test_utils.h    |  1 +
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index e4636bcea0..f93f8ce466 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -10,6 +10,7 @@
 
 #include <iomanip>
 #include <numeric>
+#include <random>
 
 using namespace turbomind;
 
@@ -106,15 +107,15 @@ int main(int argc, char* argv[])
 
     DecoderMultiHeadAttentionParams<half> params{};
 
-    constexpr int kHeadNum    = 32;
-    constexpr int kHeadDim    = 128;
-    constexpr int KvHeadNum   = 32;
-    constexpr int kBatchSize  = 32;
-    constexpr int kContextLen = 7306;
-    // constexpr int kContextLen  = 1024;
+    constexpr int kHeadNum   = 32;
+    constexpr int kHeadDim   = 128;
+    constexpr int KvHeadNum  = 32;
+    constexpr int kBatchSize = 1;
+    // constexpr int kContextLen = 7306;
+    constexpr int kContextLen  = 1024;
     constexpr int kSequenceLen = kContextLen + 1;
     constexpr int kBlockSz     = 128;
-    constexpr int kTestIter    = 1;
+    constexpr int kTestIter    = 10;
     constexpr int kMaxSplitK   = 1;
 
     RNG rng{};
@@ -256,7 +257,7 @@ int main(int argc, char* argv[])
 
     std::vector<thrust::universal_vector<half>> outputs;
 
-    for (int i = 0; i < std::max(kTestIter, 10); ++i) {
+    for (int i = 0; i < std::max(kTestIter, 1); ++i) {
         DispatchDecoderMultiheadAttention<half>(params);
         if (auto err = cudaGetLastError(); err != cudaSuccess) {
             std::cout << cudaGetErrorString(err) << "\n";
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_utils.h b/src/turbomind/kernels/decoder_multihead_attention/test_utils.h
index ecfedcb53f..35caf5f036 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_utils.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_utils.h
@@ -3,6 +3,7 @@
 #pragma once
 
 #include "decoder_multihead_attention.h"
+#include "src/turbomind/macro.h"
 #include <cuda_fp16.h>
 #include <memory>
 

From 86f60c345be3efafbe9dbeb39008cdf6caac11e3 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 07:44:18 +0000
Subject: [PATCH 45/56] fix block verification

---
 src/turbomind/models/llama/BlockManager.cc    |  2 +-
 src/turbomind/models/llama/SequenceManager.cc | 24 ++++++++++---------
 2 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index d04fd604b0..6ce8497eae 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -216,7 +216,7 @@ int BlockManager::Lock(const std::vector<const Block*>& bs)
 
     Move(cached_ids_, idxs, active_ids_);
 
-    dbg(cached_ids_, active_ids_);
+    // dbg(cached_ids_, active_ids_);
 
     return idxs.size();
 }
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index 6c2778daa1..c4d97ff539 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -97,9 +97,6 @@ bool SequenceManager::Erase(uint64_t id)
 
 void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
 {
-    if (!need_verify_) {
-        return;
-    }
     std::vector<const Block*> blocks;
     for (const auto& p : sequences) {
         auto& seq = const_cast<Sequence&>(*p);
@@ -107,11 +104,13 @@ void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
             continue;
         }
         FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
-        for (int i = 0; i < seq.blocks.size(); ++i) {
-            if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
-                seq.blocks.resize(i);
-                seq.block_unique_ids.resize(i);
-                break;
+        if (need_verify_) {
+            for (int i = 0; i < seq.blocks.size(); ++i) {
+                if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
+                    seq.blocks.resize(i);
+                    seq.block_unique_ids.resize(i);
+                    break;
+                }
             }
         }
         blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
@@ -334,7 +333,7 @@ void SequenceManager::AssignAndActivate(const Sequences&                 sequenc
     for (int i = 0; i < sequences.size(); ++i) {
         auto& s     = const_cast<Sequence&>(*sequences[i]);
         auto  count = counts[i];
-        dbg(count);
+        // dbg(count);
         auto last = first + count;
         std::for_each(first, last, [&](const Block* b) {
             s.blocks.push_back(b);
@@ -417,8 +416,11 @@ auto SequenceManager::Materialize(Sequences                    sequences,
     }
 
     // allocate & assign blocks
-    if (schedule.allocate) {
-        auto blocks = block_manager_->Allocate(schedule.allocate);
+    {
+        std::vector<const Block*> blocks;
+        if (schedule.allocate) {
+            blocks = block_manager_->Allocate(schedule.allocate);
+        }
         AssignAndActivate(schedule.active, schedule.block_counts, blocks);
     }
 

From bce90b398dbac0a45fda3d74a7046087729a1efe Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 08:41:49 +0000
Subject: [PATCH 46/56] fix msvc build

---
 .../test_decoder_multihead_attention.cu                         | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index f93f8ce466..14c64dba57 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -8,9 +8,9 @@
 #include <iostream>
 #include <thrust/universal_vector.h>
 
+#include <algorithm>
 #include <iomanip>
 #include <numeric>
-#include <random>
 
 using namespace turbomind;
 

From 683b1b9157a45061fcc094008af4beee72837175 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 10:24:23 +0000
Subject: [PATCH 47/56] use `std::shuffle`

---
 .../test_decoder_multihead_attention.cu                      | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index 14c64dba57..b8148b325d 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -11,6 +11,7 @@
 #include <algorithm>
 #include <iomanip>
 #include <numeric>
+#include <random>
 
 using namespace turbomind;
 
@@ -48,7 +49,9 @@ void TestBlocks(thrust::universal_vector<half>&  linear,          // linear data
     std::vector<size_t> idxs(batch_size * n_blocks);
     std::iota(idxs.begin(), idxs.end(), 0);
 
-    std::random_shuffle(idxs.begin(), idxs.end());
+    std::random_device rd;
+    std::mt19937       g(rd());
+    std::shuffle(idxs.begin(), idxs.end(), g);
 
     for (int i = 0; i < idxs.size(); ++i) {
         ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;

From 5563b2644be6b0f3db256687488f56f8395af922 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 11:32:57 +0000
Subject: [PATCH 48/56] fix lint

---
 .gitignore                                                    | 2 +-
 CMakeLists.txt                                                | 2 +-
 .../kernels/decoder_multihead_attention/CMakeLists.txt        | 4 ++--
 src/turbomind/kernels/decoder_multihead_attention/array_ops.h | 4 ++--
 .../decoder_multihead_attention/decoder_multihead_attention.h | 2 +-
 .../decoder_multihead_attention_params.h                      | 2 +-
 .../decoder_multihead_attention_template.h                    | 2 +-
 src/turbomind/kernels/decoder_multihead_attention/iterator.h  | 2 +-
 src/turbomind/kernels/decoder_multihead_attention/kv_cache.h  | 2 +-
 .../test_decoder_multihead_attention.cu                       | 2 +-
 .../kernels/decoder_multihead_attention/test_utils.cu         | 2 +-
 src/turbomind/models/llama/BlockManager.cc                    | 2 +-
 src/turbomind/models/llama/BlockManager.h                     | 2 +-
 src/turbomind/models/llama/CMakeLists.txt                     | 1 -
 src/turbomind/models/llama/SequenceManager.cc                 | 2 +-
 src/turbomind/models/llama/SequenceManager.h                  | 2 +-
 src/turbomind/models/llama/test_cache_manager.cc              | 2 +-
 src/turbomind/utils/allocator.h                               | 2 +-
 18 files changed, 19 insertions(+), 20 deletions(-)

diff --git a/.gitignore b/.gitignore
index 1ed335fe88..79a716bd9d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -73,4 +73,4 @@ work_dir*/
 *.csv
 *.pkl
 
-!CMakeLists.txt
\ No newline at end of file
+!CMakeLists.txt
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2372875c41..5a074ecf22 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -61,7 +61,7 @@ option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF
 
 option(BUILD_FAST_MATH "Build in fast math mode" ON)
 
-# the environment variable 
+# the environment variable
 #   ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
 # must be set at runtime
 # https://github.com/google/sanitizers/issues/1322
diff --git a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
index fe67d11f0a..61e5245ffc 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
+++ b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt
@@ -10,7 +10,7 @@ target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutla
 add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
 # target_compile_options(test_decoder_multihead_attention PRIVATE
 #   --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
-target_link_libraries(test_decoder_multihead_attention PRIVATE 
-    decoder_multihead_attention 
+target_link_libraries(test_decoder_multihead_attention PRIVATE
+    decoder_multihead_attention
     decoder_masked_multihead_attention
     cublas)
diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
index 4885284a54..5a1300ff2d 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h
@@ -386,7 +386,7 @@ struct ConvertKvCache<Ti, int8_t> {
         Array<int8_t, N> vo;
         PRAGMA_UNROLL
         for (int i = 0; i < N; ++i) {
-            // convert to unsigned int by offseting +128
+            // convert to unsigned int by offsetting +128
             (uint8_t&)vo[i] = round(((float)vi[i] - zero_) / scale_ + 128.f);
         }
         return vo;
@@ -487,4 +487,4 @@ struct ConvertKvCache<int8_t, half> {
     }
 };
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
index 984dde4fe2..5f7024c49c 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h
@@ -9,4 +9,4 @@ namespace turbomind {
 template<typename T>
 void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params);
 
-}
\ No newline at end of file
+}
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
index add5a7161c..e77d38d759 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h
@@ -68,4 +68,4 @@ struct DecoderMultiHeadAttentionParams {
     cudaStream_t stream;
 };
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
index 7844232045..8977cc07a1 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h
@@ -116,7 +116,7 @@ struct DecoderMultiHeadAttentionKernel {
 
     __device__ bool thread0()
     {
-        return blockIdx.x == 0 && threadIdx.x == 0;
+        return blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0;
     }
 
     __device__ DecoderMultiHeadAttentionKernel(const ParamType& params, SharedStorage& smem, uint8_t* dsmem):
diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
index 5e0ba7f885..683d95b589 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h
@@ -330,4 +330,4 @@ struct Iterator {
     }
 };
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
index 7ca12db3b5..f72d58c135 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
+++ b/src/turbomind/kernels/decoder_multihead_attention/kv_cache.h
@@ -64,4 +64,4 @@ void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
                                    const float* kv_params,
                                    cudaStream_t st);
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
index b8148b325d..b9ba215864 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu
@@ -329,4 +329,4 @@ int main(int argc, char* argv[])
             KvHeadNum);
 
     return 0;
-}
\ No newline at end of file
+}
diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
index c3fb0d77bc..7660d3860b 100644
--- a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
+++ b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
@@ -239,4 +239,4 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
 
 template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params, cudaStream_t st);
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/BlockManager.cc b/src/turbomind/models/llama/BlockManager.cc
index 6ce8497eae..2738e674e3 100644
--- a/src/turbomind/models/llama/BlockManager.cc
+++ b/src/turbomind/models/llama/BlockManager.cc
@@ -261,4 +261,4 @@ std::ostream& operator<<(std::ostream& os, const Block& block)
     return os;
 }
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/BlockManager.h b/src/turbomind/models/llama/BlockManager.h
index 984b0446a0..da3e53ee54 100644
--- a/src/turbomind/models/llama/BlockManager.h
+++ b/src/turbomind/models/llama/BlockManager.h
@@ -129,4 +129,4 @@ class BlockManager {
     uint64_t timestamp_{1};
 };
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt
index 0b083ad33e..a8058dce85 100644
--- a/src/turbomind/models/llama/CMakeLists.txt
+++ b/src/turbomind/models/llama/CMakeLists.txt
@@ -59,4 +59,3 @@ if (Catch2_FOUND)
         add_executable(test_cache_manager test_cache_manager.cc)
         target_link_libraries(test_cache_manager PRIVATE Llama Catch2::Catch2WithMain)
 endif ()
-
diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc
index c4d97ff539..1c7a221b04 100644
--- a/src/turbomind/models/llama/SequenceManager.cc
+++ b/src/turbomind/models/llama/SequenceManager.cc
@@ -434,4 +434,4 @@ auto SequenceManager::Materialize(Sequences                    sequences,
     return outcome;
 }
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index be99e120e3..8095cd7ae2 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -132,4 +132,4 @@ inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome
     return os;
 }
 
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind
diff --git a/src/turbomind/models/llama/test_cache_manager.cc b/src/turbomind/models/llama/test_cache_manager.cc
index 99fbc34c5f..16629565f1 100644
--- a/src/turbomind/models/llama/test_cache_manager.cc
+++ b/src/turbomind/models/llama/test_cache_manager.cc
@@ -113,4 +113,4 @@ TEST_CASE("SequenceManager functional test")
             dbg(i, outcome);
         }
     }
-}
\ No newline at end of file
+}
diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h
index 1cebb33a00..1ba191d211 100644
--- a/src/turbomind/utils/allocator.h
+++ b/src/turbomind/utils/allocator.h
@@ -475,4 +475,4 @@ class Allocator<AllocatorType::TH>: public IAllocator {
     }
 };
 #endif
-}  // namespace turbomind
\ No newline at end of file
+}  // namespace turbomind

From 8936413491457d69f2b2792288dcf318a7b9da31 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 11:36:42 +0000
Subject: [PATCH 49/56] fix lint

---
 src/turbomind/kernels/CMakeLists.txt | 2 +-
 src/turbomind/utils/debug_utils.h    | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt
index f96da72200..a7593e3de9 100644
--- a/src/turbomind/kernels/CMakeLists.txt
+++ b/src/turbomind/kernels/CMakeLists.txt
@@ -71,4 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)
 set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)
 
 add_subdirectory(gemm_s_f16)
-add_subdirectory(decoder_multihead_attention)
\ No newline at end of file
+add_subdirectory(decoder_multihead_attention)
diff --git a/src/turbomind/utils/debug_utils.h b/src/turbomind/utils/debug_utils.h
index 0e577d5a78..f07af38db2 100644
--- a/src/turbomind/utils/debug_utils.h
+++ b/src/turbomind/utils/debug_utils.h
@@ -4,4 +4,4 @@
 #include "3rdparty/dbg.h"
 #else
 #define dbg(...)
-#endif
\ No newline at end of file
+#endif

From bd6b89c38e1f22ef48e920c52c85ed18798f8579 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 2 Nov 2023 11:41:24 +0000
Subject: [PATCH 50/56] fix lint

---
 ...der_masked_multihead_attention_template.cuh | 18 ++++++++++++------
 src/turbomind/models/llama/Request.h           |  3 ++-
 src/turbomind/models/llama/SequenceManager.h   |  3 ++-
 src/turbomind/models/llama/llama_params.h      |  2 +-
 src/turbomind/models/llama/llama_utils.h       |  8 +++++---
 5 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
index f1f83e2add..85ece1fa99 100644
--- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
+++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
@@ -79,7 +79,8 @@ namespace mmha {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int Dh>
-struct Qk_vec_m_ {};
+struct Qk_vec_m_ {
+};
 
 template<>
 struct Qk_vec_m_<float, 32> {
@@ -179,7 +180,8 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int THREADS_PER_KEY>
-struct K_vec_m_ {};
+struct K_vec_m_ {
+};
 
 template<>
 struct K_vec_m_<float, 4> {
@@ -260,7 +262,8 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int V_VEC_SIZE>
-struct V_vec_m_ {};
+struct V_vec_m_ {
+};
 
 template<>
 struct V_vec_m_<float, 1> {
@@ -340,7 +343,8 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> {
 
 #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
 template<typename T>
-struct Qk_vec_acum_fp32_ {};
+struct Qk_vec_acum_fp32_ {
+};
 
 template<>
 struct Qk_vec_acum_fp32_<float> {
@@ -422,7 +426,8 @@ struct Qk_vec_acum_fp32_<fp8_4_t> {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T>
-struct K_vec_acum_fp32_ {};
+struct K_vec_acum_fp32_ {
+};
 
 template<>
 struct K_vec_acum_fp32_<float> {
@@ -484,7 +489,8 @@ struct K_vec_acum_fp32_<fp8_4_t> {
 
 #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
 template<typename T>
-struct V_vec_acum_fp32_ {};
+struct V_vec_acum_fp32_ {
+};
 
 template<>
 struct V_vec_acum_fp32_<float> {
diff --git a/src/turbomind/models/llama/Request.h b/src/turbomind/models/llama/Request.h
index a33fdd9ca1..ebce5c9ce9 100644
--- a/src/turbomind/models/llama/Request.h
+++ b/src/turbomind/models/llama/Request.h
@@ -27,7 +27,8 @@ struct Request {
     using Callback = std::function<void(std::unordered_map<std::string, Tensor>*)>;
     Callback stream_cb;
 
-    enum {
+    enum
+    {
         kInvalid  = 1,
         kConflict = 2,
         kBusy     = 3,
diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h
index 8095cd7ae2..b4a5f3ba44 100644
--- a/src/turbomind/models/llama/SequenceManager.h
+++ b/src/turbomind/models/llama/SequenceManager.h
@@ -8,7 +8,8 @@ namespace turbomind {
 
 struct Sequence {
 
-    enum Status {
+    enum Status
+    {
         kCached = 0,
         kLocked,
         kActive
diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h
index 78b1570f02..ecb611a7d5 100644
--- a/src/turbomind/models/llama/llama_params.h
+++ b/src/turbomind/models/llama/llama_params.h
@@ -10,7 +10,7 @@ struct LlamaAttentionParams {
     int   max_position_embeddings;
     float rope_scaling_factor;
     // bool  use_dynamic_ntk;
-    bool  use_logn_attn;
+    bool use_logn_attn;
 };
 
 }  // namespace turbomind
diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h
index 0e31f64c7c..acfe5054ca 100644
--- a/src/turbomind/models/llama/llama_utils.h
+++ b/src/turbomind/models/llama/llama_utils.h
@@ -10,7 +10,8 @@
 
 namespace turbomind {
 
-enum QuantPolicy {
+enum QuantPolicy
+{
     kNone = 0x00,
     // reserve 0x01 and 0x02 for backward compatibility
     kReserve1 = 0x01,
@@ -19,7 +20,8 @@ enum QuantPolicy {
     kCacheKVInt8 = 0x04,
 };
 
-enum CmpMode {
+enum CmpMode
+{
     kCmpNone,
     kCmpRead,
     kCmpWrite,
@@ -51,7 +53,7 @@ inline std::string to_string(std::string x)
 template<typename... Args>
 std::string Concat(std::string key, Args&&... args)
 {
-    std::vector<std::string> args_str{detail::to_string((Args&&)args)...};
+    std::vector<std::string> args_str{detail::to_string((Args &&) args)...};
     for (const auto& s : args_str) {
         key.append("_");
         key.append(s);

From 8c8d8bfc45aac49f44bc9dbddaaff4cf6adae737 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Mon, 6 Nov 2023 14:29:08 +0000
Subject: [PATCH 51/56] clear incoming buffer

---
 src/turbomind/models/llama/LlamaBatch.cc | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index f46d7ebe35..1afa591828 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -428,6 +428,10 @@ bool LlamaBatch<T>::Initialize()
         static_assert(sizeof(uintptr_t) == sizeof(void*));
     }
 
+    // clear incoming buffer
+    std::fill(incoming_->requests.begin(), incoming_->requests.end(), nullptr);
+    std::fill(incoming_->sequences.begin(), incoming_->sequences.end(), nullptr);
+
     // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
     // generation & sampling need to be re-initialized for correctness
     return exchange || active_holes;

From d3a1356c2de3a61d07a716afdd5ff03fb3aba2b4 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Tue, 7 Nov 2023 04:14:01 +0000
Subject: [PATCH 52/56] clear finished requests

---
 src/turbomind/models/llama/LlamaBatch.cc | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 1afa591828..9ca3fdc761 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -144,6 +144,7 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
             if (state_->requests[i] && state_->requests[i]->id == r->id) {
                 ec = 0;
                 CompleteRequest(i, true, r->end_flag);
+                state_->requests[i].reset();
                 break;
             }
         }
@@ -174,7 +175,8 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
 {
     auto& state = *incoming_;
 
-    state.size = state.active_size = 0;
+    FT_CHECK(state.size == 0);
+    FT_CHECK(state.active_size == 0);
 
     int i = 0;
     for (const auto& r : requests) {
@@ -429,8 +431,9 @@ bool LlamaBatch<T>::Initialize()
     }
 
     // clear incoming buffer
-    std::fill(incoming_->requests.begin(), incoming_->requests.end(), nullptr);
-    std::fill(incoming_->sequences.begin(), incoming_->sequences.end(), nullptr);
+    std::fill_n(incoming_->requests.begin(), incoming_->size, nullptr);
+    std::fill_n(incoming_->sequences.begin(), incoming_->size, nullptr);
+    incoming_->size = 0;
 
     // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
     // generation & sampling need to be re-initialized for correctness

From 55dcb8bd7d3d37c4011cbef3663dd9a6a47cece8 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Tue, 7 Nov 2023 12:24:14 +0000
Subject: [PATCH 53/56] fix batch initialization

---
 src/turbomind/models/llama/LlamaBatch.cc | 55 ++++++++++++------------
 src/turbomind/models/llama/LlamaV2.cc    | 15 +++++--
 2 files changed, 39 insertions(+), 31 deletions(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index 9ca3fdc761..a0109b91e4 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -25,6 +25,13 @@
 
 namespace turbomind {
 
+void ClearState(BatchState& s)
+{
+    std::fill_n(s.requests.begin(), s.size, nullptr);
+    std::fill_n(s.sequences.begin(), s.size, nullptr);
+    s.size = s.active_size = 0;
+}
+
 template<typename T>
 void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
 {
@@ -184,6 +191,8 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
         // sanity check, incoming request in previous iter should have been moved to `state_`
         FT_CHECK(!state.requests[i]);
 
+        TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
+
         state.requests[i] = r;
 
         // get sequence for the request
@@ -328,11 +337,6 @@ bool LlamaBatch<T>::Initialize()
     process(state_);
     process(incoming_);
 
-    // dbg(sequences);
-    // dbg(context_lengths);
-    // dbg(priorities);
-    // dbg(step_length_);
-
     auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);
 
     if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
@@ -344,7 +348,7 @@ bool LlamaBatch<T>::Initialize()
     std::vector<int> idxs(sequences.size());
     std::iota(idxs.begin(), idxs.end(), 0);
 
-    if (exchange) {
+    if (exchange || holes || incoming_->size) {
         // put active ones first
         auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
             return sequences[idx]->status == Sequence::kActive;  // present status
@@ -366,11 +370,9 @@ bool LlamaBatch<T>::Initialize()
             }
             std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; });
         }
-    }
 
-    if (exchange || holes) {
-        // Copy sequence states to the back state buffer
-        back_->size = back_->active_size = 0;
+        // Copy sequence states to back buffer
+        FT_CHECK(back_->size == 0 && back_->active_size == 0);
         for (const auto& i : idxs) {
             auto& s = *sequences[i];
             if (exchange) {
@@ -379,6 +381,7 @@ bool LlamaBatch<T>::Initialize()
                 if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
                     SaveRandomState(*state, idx);
                 }
+                // mark swap-ins
                 if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
                     state->is_swap_in[idx] = 1;
                 }
@@ -390,10 +393,15 @@ bool LlamaBatch<T>::Initialize()
         }
         // Swap the buffers
         std::swap(state_, back_);
-    }
 
-    const int batch_size = state_->active_size;
+        ClearState(*back_);
+        ClearState(*incoming_);
+    }
 
+    /// Update block ptrs when there were
+    //  1. swap-in or swap-out
+    //  2. holes in the active buffer
+    //  3. new allocations (for exsiting active sequences)
     if (exchange || active_holes || outcome.allocation) {
         // Prepare intermediate buffers
         h_cu_block_counts_[0] = 0;
@@ -401,6 +409,8 @@ bool LlamaBatch<T>::Initialize()
         auto k_ptrs = h_k_block_ptrs_;
         auto v_ptrs = h_v_block_ptrs_;
 
+        const int batch_size = state_->active_size;
+
         for (int i = 0; i < batch_size; ++i) {
             const auto& seq = *state_->sequences[i];
 
@@ -415,28 +425,17 @@ bool LlamaBatch<T>::Initialize()
             });
         }
 
-        // if (1) {
-        //     std::vector cu_block_cnts(h_cu_block_counts_, h_cu_block_counts_ + batch_size + 1);
-        //     dbg(cu_block_cnts);
-        // }
-        // dbg(std::vector(h_k_block_ptrs_, h_k_block_ptrs_ + h_cu_block_counts_[batch_size]));
-        // dbg(std::vector(h_v_block_ptrs_, h_v_block_ptrs_ + h_cu_block_counts_[batch_size]));
-        // dbg(h_cu_block_counts_[batch_size]);
+        static_assert(sizeof(uintptr_t) == sizeof(void*));
 
         Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
         Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
         Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
-
-        static_assert(sizeof(uintptr_t) == sizeof(void*));
     }
 
-    // clear incoming buffer
-    std::fill_n(incoming_->requests.begin(), incoming_->size, nullptr);
-    std::fill_n(incoming_->sequences.begin(), incoming_->size, nullptr);
-    incoming_->size = 0;
-
-    // in case of swap-in/swap-out or there are holes in active buffer, layout of the buffers is changed
-    // generation & sampling need to be re-initialized for correctness
+    /// Layout of the buffers is changed, generation & sampling need to be re-initialized for correctness when there
+    /// were
+    //  1. swap-in or swap-out
+    //  2. holes in the active buffer
     return exchange || active_holes;
 }
 
diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc
index fca323f0ad..f7a6e784f4 100644
--- a/src/turbomind/models/llama/LlamaV2.cc
+++ b/src/turbomind/models/llama/LlamaV2.cc
@@ -37,7 +37,6 @@
 #include <functional>
 #include <memory>
 #include <sstream>
-#include <stdexcept>
 
 namespace turbomind {
 
@@ -545,15 +544,25 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>*       outputs,
     bool             has_error = 0;
     if (rank == 0) {
         TM_LOG_INFO("[forward] Enqueue requests");
+
+        std::vector<uint64_t> ids;
+        for (const auto& r : requests) {
+            ids.push_back(r->id);
+        }
+
         auto futures = shared_state_->request_queue.enqueue(std::move(requests));
 
+        FT_CHECK_WITH_INFO(ids.size() == futures.size(), "check failed");
+
         TM_LOG_INFO("[forward] Wait for requests to complete ...");
-        for (auto& f : futures) {
-            auto ec = f.get();
+
+        for (int i = 0; i < futures.size(); ++i) {
+            auto ec = futures[i].get();
             error_codes.push_back(ec);
             if (ec) {
                 has_error = true;
             }
+            TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec);
         }
     }
 

From b7bf3d76875f80aef0e8d814028a68089c1136e5 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Tue, 7 Nov 2023 12:26:39 +0000
Subject: [PATCH 54/56] fix typo

---
 src/turbomind/models/llama/LlamaBatch.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index a0109b91e4..d1c096b1aa 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -401,7 +401,7 @@ bool LlamaBatch<T>::Initialize()
     /// Update block ptrs when there were
     //  1. swap-in or swap-out
     //  2. holes in the active buffer
-    //  3. new allocations (for exsiting active sequences)
+    //  3. new allocations (for existing active sequences)
     if (exchange || active_holes || outcome.allocation) {
         // Prepare intermediate buffers
         h_cu_block_counts_[0] = 0;

From 6b1c38b61aa22e903323fe84be79a65a92c35fb6 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 9 Nov 2023 04:29:59 +0000
Subject: [PATCH 55/56] fix typo

---
 CMakeLists.txt | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5a074ecf22..53e0eb2471 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -273,15 +273,15 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');"
   message("-- USE_CXX11_ABI=${USE_CXX11_ABI}")
   if (USE_CXX11_ABI)
     set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
+    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
-    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
     set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
   else()
     set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
+    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
-    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")
     set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")

From 15b49219274618db674c61e0a9274d0cfef99607 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Thu, 9 Nov 2023 04:33:22 +0000
Subject: [PATCH 56/56] fix comparison

---
 src/turbomind/triton_backend/llama/LlamaTritonModel.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
index 3a60896a59..fb54346ac0 100644
--- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
+++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc
@@ -233,7 +233,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
     ft::NcclParam pipeline_para = nccl_params.second[comms_rank];
 
     ft::FT_CHECK(tensor_para.world_size_ == tensor_para_size_);
-    ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_);
+    ft::FT_CHECK(pipeline_para.world_size_ == pipeline_para_size_);
 
     auto llama = std::make_unique<ft::LlamaV2<T>>(head_num_,
                                                   kv_head_num_,