Skip to content

Commit

Permalink
TEST: [GPU] Use FP32 accumulator for QK multiplication for 2nd+ token…
Browse files Browse the repository at this point in the history
… calculation in PagedAttention
  • Loading branch information
sshlyapn committed Jan 24, 2025
1 parent cfbc998 commit af77957
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
#endif

// SLM for intermediate QK results
__local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];

// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG];
Expand Down Expand Up @@ -166,7 +166,7 @@ KERNEL(pa_sdpa_opt)(
#endif
const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE;

INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;

#define KEY_VEC_SIZE SUBGROUP_SIZE
unroll_for (uint qk_idx = 0; qk_idx < HEAD_SIZE / KEY_VEC_SIZE; qk_idx++) {
Expand All @@ -181,9 +181,9 @@ KERNEL(pa_sdpa_opt)(

unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
#if STORE_QUERY_TO_SLM
qk_acc = mad(sub_group_broadcast(q_val, i), k_vals[i], qk_acc);
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
#else
qk_acc = mad(sub_group_broadcast(q_val[qk_idx], i), k_vals[i], qk_acc);
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
#endif
}
}
Expand All @@ -196,7 +196,7 @@ KERNEL(pa_sdpa_opt)(
#endif

if (token_idx >= seq_len)
qk_acc = INPUT0_VAL_MIN;
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;

qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));

Expand Down Expand Up @@ -235,7 +235,7 @@ KERNEL(pa_sdpa_opt)(
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
#endif
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
slm_qk_vals[local_data_idx] = qk_new;

exp_sum += qk_new;
}
Expand Down Expand Up @@ -266,7 +266,7 @@ KERNEL(pa_sdpa_opt)(
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
#endif
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
slm_qk_vals[local_data_idx] = qk_new;
}
}

Expand Down

0 comments on commit af77957

Please sign in to comment.