From 5bdc4d3ae253850fb11bf7f1c095008672ecb5f4 Mon Sep 17 00:00:00 2001
From: Antoni Baum <antoni.baum@protonmail.com>
Date: Wed, 24 Jul 2024 11:36:52 -0700
Subject: [PATCH] Add fp8 support to `reshape_and_cache_flash` (#6667)

---
 csrc/cache.h                          |  3 +-
 csrc/cache_kernels.cu                 | 75 ++++++++++++++++-----------
 csrc/torch_bindings.cpp               |  3 +-
 tests/kernels/test_cache.py           | 42 ++++++++++++---
 vllm/_custom_ops.py                   |  5 +-
 vllm/attention/backends/flash_attn.py |  2 +
 vllm/attention/backends/flashinfer.py |  2 +
 vllm/utils.py                         |  9 +++-
 8 files changed, 98 insertions(+), 43 deletions(-)

diff --git a/csrc/cache.h b/csrc/cache.h
index 52177e8901a89..11c4c5001daaa 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
                              torch::Tensor& key_cache,
                              torch::Tensor& value_cache,
                              torch::Tensor& slot_mapping,
-                             const std::string& kv_cache_dtype);
+                             const std::string& kv_cache_dtype,
+                             const double k_scale, const double v_scale);
 
 // Just for unittest
 void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index caef7f5e18630..1be806bbfa43c 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
   }
 }
 
-template <typename scalar_t>
+template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
 __global__ void reshape_and_cache_flash_kernel(
     const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
     const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
-    scalar_t* __restrict__ k_cache,      // [num_blocks, block_size, num_heads,
+    cache_t* __restrict__ key_cache,     // [num_blocks, block_size, num_heads,
                                          // head_size]
-    scalar_t* __restrict__ v_cache,      // [num_blocks, block_size, num_heads,
+    cache_t* __restrict__ value_cache,   // [num_blocks, block_size, num_heads,
                                          // head_size]
     const int64_t* __restrict__ slot_mapping,  // [num_tokens]
     const int block_stride, const int key_stride, const int value_stride,
-    const int num_heads, const int head_size, const int block_size) {
+    const int num_heads, const int head_size, const int block_size,
+    const float k_scale, const float v_scale) {
   const int64_t token_idx = blockIdx.x;
   const int64_t slot_idx = slot_mapping[token_idx];
   // NOTE: slot_idx can be -1 if the token is padded
@@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
     const int64_t src_value_idx = token_idx * value_stride + i;
     const int head_idx = i / head_size;
     const int head_offset = i % head_size;
-    const int64_t tgt_value_idx = block_idx * block_stride +
-                                  block_offset * num_heads * head_size +
-                                  head_idx * head_size + head_offset;
-    k_cache[tgt_value_idx] = key[src_key_idx];
-    v_cache[tgt_value_idx] = value[src_value_idx];
+    const int64_t tgt_key_value_idx = block_idx * block_stride +
+                                      block_offset * num_heads * head_size +
+                                      head_idx * head_size + head_offset;
+    scalar_t tgt_key = key[src_key_idx];
+    scalar_t tgt_value = value[src_value_idx];
+    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+      key_cache[tgt_key_value_idx] = tgt_key;
+      value_cache[tgt_key_value_idx] = tgt_value;
+    } else {
+      key_cache[tgt_key_value_idx] =
+          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
+      value_cache[tgt_key_value_idx] =
+          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
+    }
   }
 }
 }  // namespace vllm
@@ -278,40 +288,45 @@ void reshape_and_cache(
                              CALL_RESHAPE_AND_CACHE)
 }
 
+// KV_T is the stored data type of kv-cache.
+// CACHE_T is the data type of key and value tensors.
+// KV_DTYPE is the real data type of kv-cache.
+#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)         \
+  vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>       \
+      <<<grid, block, 0, stream>>>(                                   \
+          reinterpret_cast<KV_T*>(key.data_ptr()),                    \
+          reinterpret_cast<KV_T*>(value.data_ptr()),                  \
+          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \
+          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \
+          slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
+          value_stride, num_heads, head_size, block_size, k_scale, v_scale);
+
 void reshape_and_cache_flash(
-    torch::Tensor& key,      // [num_tokens, num_heads, head_size]
-    torch::Tensor& value,    // [num_tokens, num_heads, head_size]
-    torch::Tensor& k_cache,  // [num_blocks, block_size, num_heads, head_size]
-    torch::Tensor& v_cache,  // [num_blocks, block_size, num_heads, head_size]
+    torch::Tensor& key,        // [num_tokens, num_heads, head_size]
+    torch::Tensor& value,      // [num_tokens, num_heads, head_size]
+    torch::Tensor& key_cache,  // [num_blocks, block_size, num_heads, head_size]
+    torch::Tensor&
+        value_cache,  // [num_blocks, block_size, num_heads, head_size]
     torch::Tensor& slot_mapping,  // [num_tokens]
-    const std::string& kv_cache_dtype) {
-  // FIXME: only support auto datatype, does not support fp8
-  if (kv_cache_dtype != "auto") {
-    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
-  }
+    const std::string& kv_cache_dtype, const double k_scale,
+    const double v_scale) {
   int num_tokens = key.size(0);
   int num_heads = key.size(1);
   int head_size = key.size(2);
-  int block_size = k_cache.size(1);
+  int block_size = key_cache.size(1);
 
   int key_stride = key.stride(0);
   int value_stride = value.stride(0);
-  int block_stride = k_cache.stride(0);
-  TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
+  int block_stride = key_cache.stride(0);
+  TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
 
   dim3 grid(num_tokens);
   dim3 block(std::min(num_heads * head_size, 512));
   const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  VLLM_DISPATCH_FLOATING_TYPES(
-      key.scalar_type(), "reshape_and_cache_flash", [&] {
-        vllm::reshape_and_cache_flash_kernel<scalar_t>
-            <<<grid, block, 0, stream>>>(
-                key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
-                k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
-                slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
-                value_stride, num_heads, head_size, block_size);
-      });
+
+  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
+                             CALL_RESHAPE_AND_CACHE_FLASH);
 }
 
 namespace vllm {
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 0df9bdb75018f..3027b63ba2b33 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
       "                        Tensor! key_cache,"
       "                        Tensor! value_cache,"
       "                        Tensor slot_mapping,"
-      "                        str kv_cache_dtype) -> ()");
+      "                        str kv_cache_dtype,"
+      "                        float k_scale, float v_scale) -> ()");
   cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
                  &reshape_and_cache_flash);
 
diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py
index 70ae3d0c6e0c3..f9a609464abfc 100644
--- a/tests/kernels/test_cache.py
+++ b/tests/kernels/test_cache.py
@@ -215,8 +215,6 @@ def test_reshape_and_cache_flash(
     device: str,
     kv_cache_dtype: str,
 ) -> None:
-    if kv_cache_dtype == "fp8":
-        pytest.skip()
     random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
@@ -248,15 +246,33 @@ def test_reshape_and_cache_flash(
         dtype,
         device=device,
     )
-    key_cache, value_cache = key_caches[0], value_caches[0]
+    key_cache, value_cache = key_caches[0].contiguous(
+    ), value_caches[0].contiguous()
+    del key_caches
+    del value_caches
 
     # Clone the KV caches.
-    cloned_key_cache = key_cache.clone()
-    cloned_value_cache = value_cache.clone()
+    if kv_cache_dtype == "fp8":
+        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
+        ops.convert_fp8(cloned_key_cache, key_cache)
+        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
+        ops.convert_fp8(cloned_value_cache, value_cache)
+    else:
+        cloned_key_cache = key_cache.clone()
+        cloned_value_cache = value_cache.clone()
+
+    # Using default kv_scale
+    k_scale = v_scale = 1.0
 
     # Call the reshape_and_cache kernel.
     ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
-                                slot_mapping, kv_cache_dtype)
+                                slot_mapping, kv_cache_dtype, k_scale, v_scale)
+
+    if kv_cache_dtype == "fp8":
+        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
+        ops.convert_fp8(result_key_cache, key_cache)
+        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
+        ops.convert_fp8(result_value_cache, value_cache)
 
     # Run the reference implementation.
     block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
@@ -269,8 +285,18 @@ def test_reshape_and_cache_flash(
         cloned_key_cache[block_idx, block_offset, :, :] = key[i]
         cloned_value_cache[block_idx, block_offset, :, :] = value[i]
 
-    assert torch.allclose(key_cache, cloned_key_cache)
-    assert torch.allclose(value_cache, cloned_value_cache)
+    if kv_cache_dtype == "fp8":
+        assert torch.allclose(result_key_cache,
+                              cloned_key_cache,
+                              atol=0.001,
+                              rtol=0.1)
+        assert torch.allclose(result_value_cache,
+                              cloned_value_cache,
+                              atol=0.001,
+                              rtol=0.1)
+    else:
+        assert torch.allclose(key_cache, cloned_key_cache)
+        assert torch.allclose(value_cache, cloned_value_cache)
 
 
 @pytest.mark.parametrize("direction", COPYING_DIRECTION)
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index e5151c070f2f7..0186594656cc1 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -426,10 +426,13 @@ def reshape_and_cache_flash(
     value_cache: torch.Tensor,
     slot_mapping: torch.Tensor,
     kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
 ) -> None:
     torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                    value_cache, slot_mapping,
-                                                   kv_cache_dtype)
+                                                   kv_cache_dtype, k_scale,
+                                                   v_scale)
 
 
 def copy_blocks(key_caches: List[torch.Tensor],
diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py
index b16a204c8f44e..949bd973cf3c4 100644
--- a/vllm/attention/backends/flash_attn.py
+++ b/vllm/attention/backends/flash_attn.py
@@ -478,6 +478,8 @@ def forward(
                 value_cache,
                 attn_metadata.slot_mapping.flatten(),
                 self.kv_cache_dtype,
+                k_scale,
+                v_scale,
             )
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py
index 9dac12d3b906d..2a4900489df35 100644
--- a/vllm/attention/backends/flashinfer.py
+++ b/vllm/attention/backends/flashinfer.py
@@ -489,6 +489,8 @@ def forward(
                 kv_cache[:, 1],
                 attn_metadata.slot_mapping.flatten(),
                 self.kv_cache_dtype,
+                k_scale,
+                v_scale,
             )
 
         query = query.contiguous(
diff --git a/vllm/utils.py b/vllm/utils.py
index 83605631b5bd6..876c3bf90b02c 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash(
     seed: int = 0,
     device: Optional[str] = "cuda",
 ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
-    assert cache_dtype != "fp8"
     torch.random.manual_seed(seed)
     if torch.cuda.is_available():
         torch.cuda.manual_seed(seed)
@@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash(
         key_value_cache = torch.empty(size=key_value_cache_shape,
                                       dtype=torch_dtype,
                                       device=device)
-        key_value_cache.uniform_(-scale, scale)
+        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
+            key_value_cache.uniform_(-scale, scale)
+        elif cache_dtype == 'fp8':
+            _generate_random_fp8(key_value_cache, -scale, scale)
+        else:
+            raise ValueError(
+                f"Does not support key cache of type {cache_dtype}")
         key_caches.append(key_value_cache[:, 0])
         value_caches.append(key_value_cache[:, 1])
     return key_caches, value_caches