diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index 90b135743..3a7716404 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -31,7 +31,7 @@ def _flatten_kv_cache( stride_vos: tl.constexpr, stride_vod: tl.constexpr, stride_boff, - OUT_SIZE: tl.constexpr, + OUT_SIZE, HEAD_DIM_K: tl.constexpr, HEAD_DIM_V: tl.constexpr, BLOCK_BS: tl.constexpr, @@ -124,7 +124,7 @@ def _flatten_kv_cache_quant( stride_vod: tl.constexpr, stride_boff, quant_policy: tl.constexpr, - OUT_SIZE: tl.constexpr, + OUT_SIZE, HEAD_DIM_K: tl.constexpr, HEAD_DIM_V: tl.constexpr, BLOCK_BS: tl.constexpr,