diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 7a300aaee1a16a..811d148608483c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)( #endif // SLM for intermediate QK results - __local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE]; + __local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE]; // SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs __local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG]; @@ -166,7 +166,7 @@ KERNEL(pa_sdpa_opt)( #endif const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE; - INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO; #define KEY_VEC_SIZE SUBGROUP_SIZE unroll_for (uint qk_idx = 0; qk_idx < HEAD_SIZE / KEY_VEC_SIZE; qk_idx++) { @@ -181,9 +181,9 @@ KERNEL(pa_sdpa_opt)( unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) { #if STORE_QUERY_TO_SLM - qk_acc = mad(sub_group_broadcast(q_val, i), k_vals[i], qk_acc); + qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc); #else - qk_acc = mad(sub_group_broadcast(q_val[qk_idx], i), k_vals[i], qk_acc); + qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc); #endif } } @@ -196,7 +196,7 @@ KERNEL(pa_sdpa_opt)( #endif if (token_idx >= seq_len) - qk_acc = INPUT0_VAL_MIN; + qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN; qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc)); @@ -235,7 +235,7 @@ KERNEL(pa_sdpa_opt)( if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) { #endif SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max); - slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new); + slm_qk_vals[local_data_idx] = qk_new; exp_sum += qk_new; } @@ -266,7 +266,7 @@ KERNEL(pa_sdpa_opt)( if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) { #endif SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum; - slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new); + slm_qk_vals[local_data_idx] = qk_new; } }