Skip to content

Commit

Permalink
[GPU] Fix accuracy issue in PagedAttention kernel for large prompts (…
Browse files Browse the repository at this point in the history
…4K/8K tokens) by correcting index calculation in sub_group_broadcast function to ensure accurate data broadcasting within the subgroup
  • Loading branch information
sshlyapn committed Nov 6, 2024
1 parent 5833781 commit f471153
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ KERNEL(pa_sdpa_finalization_stage)(
partition_num * HEAD_SIZE +
head_size_idx;
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_num / SUBGROUP_SIZE], partition_num)) / TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_num / SUBGROUP_SIZE], partition_num % SUBGROUP_SIZE)) / TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}
const uint out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
head_num_idx * HEAD_SIZE +
Expand Down

0 comments on commit f471153

Please sign in to comment.