Skip to content

Commit c85ea56

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
iRoPE varseq flag for pre-calculated kv qparams (#4160)
Summary: Pull Request resolved: #4160 X-link: facebookresearch/FBGEMM#1240 as title. This is needed to handle this case: https://www.internalfb.com/diff/D73833204?dst_version_fbid=9500286030082255&transaction_fbid=676020828512263 This will help avoid amax calc in rope for decode and partial prefill batch lanes. Also, we can rely on it in Kernel2, to return back and avoid unneccessary quantization. Reviewed By: y-sq Differential Revision: D73478483 fbshipit-source-id: 4b549046f12671420631d3bc90bbe51fb3f5d023
1 parent 91c5a79 commit c85ea56

File tree

3 files changed

+56
-36
lines changed

3 files changed

+56
-36
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ namespace fbgemm_gpu {
3131
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
3232
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!)? XK, Tensor? XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
3333
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
34-
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False,bool update_kv=True, Tensor?amax_qkv=None) -> Tensor");
34+
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False,bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor");
3535
m.def("rope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
3636
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None) -> Tensor");
3737
m.def("nope_qkv_varseq_prefill(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
38-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None) -> Tensor");
38+
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor");
3939
m.def("nope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
4040
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None) -> Tensor");
4141
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
@@ -48,7 +48,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
4848
"dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING(
4949
DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)");
5050
m.def(
51-
"quantize_qkv_per_head(Tensor amax, Tensor XQKV, Tensor varseq_seqpos, Tensor? varseq_batch, Tensor q_seqstarts, Tensor cache_K, Tensor cache_V, Tensor XQ_O, int max_seq_len, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
51+
"quantize_qkv_per_head(Tensor amax, Tensor XQKV, Tensor varseq_seqpos, Tensor? varseq_batch, Tensor? is_precalculated_qparam, Tensor cache_K, Tensor cache_V, Tensor XQ_O, int B, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
5252
m.def(
5353
"convert_e4m3fn_kv_cache_to_e4m3fnuz_inplace(Tensor cache_K, Tensor cache_V, Tensor qparam_K, Tensor qparam_V) -> ()");
5454
}
@@ -104,7 +104,8 @@ at::Tensor rope_qkv_varseq_prefill_meta(
104104
bool /* write_k_back */,
105105
bool /* k_norm */,
106106
bool /* update_kv */,
107-
std::optional<at::Tensor> /* amax_qkv */
107+
std::optional<at::Tensor> /* amax_qkv */,
108+
std::optional<at::Tensor> /* kv_quant_scale_precomputed */
108109
) {
109110
return at::empty_like(XQ);
110111
}
@@ -155,7 +156,8 @@ at::Tensor nope_qkv_varseq_prefill_meta(
155156
std::optional<at::Tensor> /* qparam_v */,
156157
bool /* k_norm */,
157158
bool /* update_kv */,
158-
std::optional<at::Tensor> /* amax_qkv */
159+
std::optional<at::Tensor> /* amax_qkv */,
160+
std::optional<at::Tensor> /* kv_quant_scale_precomputed */
159161
) {
160162
return at::empty_like(XQ);
161163
}
@@ -287,11 +289,11 @@ at::Tensor quantize_qkv_per_head_meta(
287289
at::Tensor XQKV,
288290
at::Tensor /* varseq_seqpos */,
289291
std::optional<at::Tensor> /* varseq_batch */,
290-
at::Tensor /* q_seqstarts */,
292+
std::optional<at::Tensor> /* is_precalculated_qparam */,
291293
at::Tensor cache_K /* cache_K */,
292294
at::Tensor /* cache_V */,
293295
at::Tensor /* XQ_O */,
294-
int64_t /* max_seq_len */,
296+
int64_t /* B */,
295297
std::optional<at::Tensor> /* qparam_k */,
296298
std::optional<at::Tensor> /* qparam_v */) {
297299
const at::SymInt B_KV = cache_K.sym_size(0);

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_fp8(
10641064
double hi_freq_factor = 32,
10651065
bool write_k_back = false,
10661066
bool k_norm = false,
1067-
float* amax = nullptr) {
1067+
float* amax = nullptr,
1068+
bool* is_precalculated_qparam = nullptr) {
10681069
// Launch b_t_(sum(h)) warps.
10691070
auto b_t_hh = blockIdx.x * blockDim.y +
10701071
threadIdx.y; // Block = [kThreadsPerWarp, kWarpsPerBlock]
@@ -1167,7 +1168,10 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_fp8(
11671168
*reinterpret_cast<uint2*>(&dst_bf16);
11681169
}
11691170
// This kernel does not write to xq_o
1170-
if (qkv != QKV::Q) {
1171+
bool is_precalculated_qparam_b_t = is_precalculated_qparam
1172+
? is_precalculated_qparam[b_t]
1173+
: true; // for decode it is true
1174+
if (qkv != QKV::Q && is_precalculated_qparam_b_t) {
11711175
// only write to cache if batch lane has a pre-calculated qparam
11721176
// quantize and write to dst_row
11731177
CUDA_KERNEL_ASSERT(qparam_k_ptr != nullptr)
@@ -1182,13 +1186,9 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_fp8(
11821186
}
11831187
quantize_fp8_kv<at::Float8_e4m3fn, KVQuantRecipe::perHeadScaling>(
11841188
dst, dst_row_q, qparam_row);
1185-
}
1186-
// amax calculation for QKV
1187-
// TODO: avoid amax calculation for KV if batch lane has a pre-calculated
1188-
// qparam if (qkv != QKV::Q && is_precalculated_qparam[b_t] ){
1189-
// return;
1190-
// }
1191-
if (amax != nullptr) {
1189+
} else {
1190+
// qkv == Q or qparam is not precalculated
1191+
CUDA_KERNEL_ASSERT(amax != nullptr);
11921192
// per_row_amax(dst, &amax[b_t * HH + hh]);
11931193
per_head_amax(dst, &amax[b * HH + hh]);
11941194
}
@@ -1211,7 +1211,8 @@ at::Tensor nope_qkv_varseq_prefill(
12111211
std::optional<at::Tensor> qparam_v = std::nullopt,
12121212
bool k_norm = false,
12131213
bool update_kv = true,
1214-
std::optional<at::Tensor> amax_qkv = std::nullopt) {
1214+
std::optional<at::Tensor> amax_qkv = std::nullopt,
1215+
std::optional<at::Tensor> kv_quant_scale_precomputed = std::nullopt) {
12151216
auto B_T = XQ.size(0);
12161217
auto N_H = XQ.size(1);
12171218

@@ -1292,9 +1293,14 @@ at::Tensor nope_qkv_varseq_prefill(
12921293
CUDA_KERNEL_ASSERT(num_groups_ == 1);
12931294
if (cache_K.dtype() == at::kFloat8_e4m3fn) {
12941295
float* amax_ptr = nullptr;
1296+
bool* is_precalculated_qparam = nullptr;
12951297
if (amax_qkv.has_value()) {
12961298
amax_ptr = static_cast<float*>(amax_qkv.value().data_ptr());
12971299
}
1300+
if (kv_quant_scale_precomputed.has_value()) {
1301+
is_precalculated_qparam =
1302+
static_cast<bool*>(kv_quant_scale_precomputed.value().data_ptr());
1303+
}
12981304
rope_xpos_qkv_varseq_prefill_kernel_fp8<
12991305
PositionEmbeddingMode::NOPE,
13001306
CacheLogicalDtype::FP8,
@@ -1331,7 +1337,8 @@ at::Tensor nope_qkv_varseq_prefill(
13311337
0,
13321338
true, // write_k_back and q too if we are doing norm.
13331339
k_norm,
1334-
amax_ptr);
1340+
amax_ptr,
1341+
is_precalculated_qparam);
13351342
C10_CUDA_KERNEL_LAUNCH_CHECK();
13361343
} else {
13371344
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
@@ -1530,7 +1537,8 @@ at::Tensor nope_qkv_decoding(
15301537
0,
15311538
true, // write_k_back and q too if we are doing norm.
15321539
k_norm,
1533-
amax_ptr);
1540+
amax_ptr,
1541+
nullptr);
15341542
15351543
C10_CUDA_KERNEL_LAUNCH_CHECK();
15361544
@@ -1618,7 +1626,8 @@ at::Tensor rope_qkv_varseq_prefill(
16181626
bool write_k_back = false,
16191627
bool k_norm = false,
16201628
bool update_kv = true,
1621-
std::optional<at::Tensor> amax_qkv = std::nullopt) {
1629+
std::optional<at::Tensor> amax_qkv = std::nullopt,
1630+
std::optional<at::Tensor> kv_quant_scale_precomputed = std::nullopt) {
16221631
auto B_T = XQ.size(0);
16231632
auto N_H = XQ.size(1);
16241633
auto N_KVH = 0;
@@ -1707,9 +1716,14 @@ at::Tensor rope_qkv_varseq_prefill(
17071716
CUDA_KERNEL_ASSERT(num_groups_ == 1);
17081717
if (cache_K.dtype() == at::kFloat8_e4m3fn) {
17091718
float* amax_ptr = nullptr;
1719+
bool* is_precalculated_qparam = nullptr;
17101720
if (amax_qkv.has_value()) {
17111721
amax_ptr = static_cast<float*>(amax_qkv.value().data_ptr());
17121722
}
1723+
if (kv_quant_scale_precomputed.has_value()) {
1724+
is_precalculated_qparam =
1725+
static_cast<bool*>(kv_quant_scale_precomputed.value().data_ptr());
1726+
}
17131727
rope_xpos_qkv_varseq_prefill_kernel_fp8<
17141728
PositionEmbeddingMode::ROPE,
17151729
CacheLogicalDtype::FP8,
@@ -1746,7 +1760,8 @@ at::Tensor rope_qkv_varseq_prefill(
17461760
hi_freq_factor,
17471761
true,
17481762
k_norm,
1749-
amax_ptr);
1763+
amax_ptr,
1764+
is_precalculated_qparam);
17501765
C10_CUDA_KERNEL_LAUNCH_CHECK();
17511766
17521767
} else {
@@ -2113,7 +2128,8 @@ at::Tensor rope_qkv_decoding(
21132128
hi_freq_factor,
21142129
true,
21152130
k_norm,
2116-
amax_ptr);
2131+
amax_ptr,
2132+
nullptr);
21172133
21182134
C10_CUDA_KERNEL_LAUNCH_CHECK();
21192135
} else {
@@ -2849,8 +2865,7 @@ __global__ void quantizeQKVPerHead(
28492865
at::BFloat16* xqkv, // [B_T, HH, D_H]
28502866
const int32_t* varseq_seqpos, // [B_T]
28512867
const int32_t* varseq_batch, // [B_T]
2852-
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
2853-
q_seqstarts, // [B+1]
2868+
const bool* is_precalculated_qparam, // [B_T]
28542869
at::PackedTensorAccessor64<at::Float8_e4m3fn, 3, at::RestrictPtrTraits>
28552870
XQ_O, // [B_T][N_H][D]
28562871
at::PackedTensorAccessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>
@@ -2913,8 +2928,9 @@ __global__ void quantizeQKVPerHead(
29132928
}
29142929
// Skip quantization if scale is pre-calculated for K/V
29152930
// as in decode and partial prefill cases
2916-
int start = q_seqstarts[b];
2917-
if (qkv != QKV::Q && varseq_seqpos[start] > 0)
2931+
bool is_precalculated_qparam_b_t =
2932+
is_precalculated_qparam ? is_precalculated_qparam[b_t] : true;
2933+
if (qkv != QKV::Q && is_precalculated_qparam_b_t)
29182934
return;
29192935
29202936
CUDA_KERNEL_ASSERT(uintptr_t(qparam) % 4 == 0);
@@ -2965,11 +2981,11 @@ at::Tensor quantize_qkv_per_head(
29652981
at::Tensor xqkv, // [B_T, HH, D_H]
29662982
at::Tensor varseq_seqpos, // [B_T]
29672983
std::optional<at::Tensor> varseq_batch, // [B_T]
2968-
at::Tensor q_seqstarts, // [B+1]
2984+
std::optional<at::Tensor> is_precalculated_qparam, // [B_T]
29692985
at::Tensor cache_K, // [B][MAX_T][N_KVH][D_H]
29702986
at::Tensor cache_V, // [B][MAX_T][N_KVH][D_H]
29712987
at::Tensor XQ_O, // [B_T][N_H][D]
2972-
int64_t max_seq_length, // Length of the sequence
2988+
int64_t B, // Batch size
29732989
std::optional<at::Tensor> qparam_k,
29742990
std::optional<at::Tensor> qparam_v) {
29752991
auto B_T = XQ_O.size(0);
@@ -2980,16 +2996,14 @@ at::Tensor quantize_qkv_per_head(
29802996
float* qparam_v_ptr = nullptr;
29812997
if (qparam_k.has_value()) {
29822998
// prefill case
2983-
// HH += N_KVH_L * 2;
29842999
qparam_k_ptr = qparam_k.value().data_ptr<float>();
29853000
qparam_v_ptr = qparam_v.value().data_ptr<float>();
29863001
}
29873002
auto num_warps = B_T * HH;
29883003
dim3 block_size(kThreadsPerWarp, kWarpsPerBlock);
29893004
dim3 grid_size(cuda_calc_xblock_count(num_warps, kWarpsPerBlock));
29903005
2991-
auto scale_q = at::zeros(
2992-
{q_seqstarts.size(0) - 1, N_KVH_L}, XQ_O.options().dtype(at::kFloat));
3006+
auto scale_q = at::zeros({B, N_KVH_L}, XQ_O.options().dtype(at::kFloat));
29933007
float* const scale_q_ptr = scale_q.data_ptr<float>();
29943008
// Launch the kernel
29953009
// TODO: Launch the kernel with B_T * N_H_L blocks only in case of decode.
@@ -3005,7 +3019,9 @@ at::Tensor quantize_qkv_per_head(
30053019
varseq_seqpos.data_ptr<int32_t>(),
30063020
varseq_batch.has_value() ? varseq_batch.value().data_ptr<int32_t>()
30073021
: nullptr, // not needed for decode
3008-
q_seqstarts.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
3022+
is_precalculated_qparam.has_value()
3023+
? is_precalculated_qparam.value().data_ptr<bool>()
3024+
: nullptr,
30093025
XQ_O.packed_accessor64<at::Float8_e4m3fn, 3, at::RestrictPtrTraits>(),
30103026
cache_K.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),
30113027
cache_V.packed_accessor64<at::Float8_e4m3fn, 4, at::RestrictPtrTraits>(),

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ at::Tensor nope_qkv_varseq_prefill(
2727
std::optional<at::Tensor> qparam_v,
2828
bool k_norm,
2929
bool update_kv,
30-
std::optional<at::Tensor> amax_qkv);
30+
std::optional<at::Tensor> amax_qkv,
31+
std::optional<at::Tensor> kv_quant_scale_precomputed);
3132

3233
at::Tensor nope_qkv_decoding(
3334
at::Tensor XQ,
@@ -73,7 +74,8 @@ at::Tensor rope_qkv_varseq_prefill(
7374
bool write_k_back,
7475
bool k_norm,
7576
bool update_kv,
76-
std::optional<at::Tensor> amax_qkv);
77+
std::optional<at::Tensor> amax_qkv,
78+
std::optional<at::Tensor> kv_quant_scale_precomputed);
7779

7880
at::Tensor rope_qkv_decoding(
7981
at::Tensor XQ,
@@ -174,11 +176,11 @@ at::Tensor quantize_qkv_per_head(
174176
at::Tensor XQKV,
175177
at::Tensor varseq_seqpos,
176178
std::optional<at::Tensor> varseq_batch,
177-
at::Tensor q_seqstarts,
179+
std::optional<at::Tensor> is_precalculated_qparam,
178180
at::Tensor cache_K,
179181
at::Tensor cache_V,
180182
at::Tensor XQ_O,
181-
int64_t max_seq_len,
183+
int64_t B,
182184
std::optional<at::Tensor> qparam_k,
183185
std::optional<at::Tensor> qparam_v);
184186

0 commit comments

Comments
 (0)