-
Notifications
You must be signed in to change notification settings - Fork 12.7k
CUDA: Optimize reduce_rows_f32
kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n
#15132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ORippler
wants to merge
15
commits into
ggml-org:master
Choose a base branch
from
ORippler:osimons/optimize_reduce_rows_f32
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+155
−36
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
3deb3b1
Factor out `reduce_rows_f32` from common.cuh
ORippler c270ffe
Hide memory-latency by loop unrolling in reduce_rows_f32
ORippler ece608a
Further optimizations to `reduce_rows_f32`
ORippler 9070af8
Add perf tests for `reduce_rows_f32` kernel
ORippler 80de672
Add heuristic to toggle 128/512 threads based on sm count
ORippler 8e04242
Ensure perf gains also for small ncols and large nrows
ORippler 8fc2c03
Modify perf and unit-tests
ORippler 9296d1f
Apply auto-formatting by clang
ORippler a6fe4dd
Fix CI build failure
ORippler 4a1c5bc
Remove sm_count property from `ggml_backend_cuda_context`
ORippler 7c7413e
Add CUB-based implementation for GGML_OP_MEAN
ORippler 48cf9e4
Add heuristics to execute CUB branch only when it brings perf
ORippler e8373bf
Add unit-test for CUB-based mean
ORippler 0e9a5d8
Rename `USE_CUB` to `GGML_CUDA_USE_CUB`
ORippler d647028
Unindent Preprocessor directives
ORippler File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ORippler marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#include "common.cuh" | ||
|
||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) | ||
template <bool norm> | ||
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { | ||
const int row = blockIdx.x; | ||
const int col = threadIdx.x; | ||
|
||
float sum = 0.0f; | ||
const int num_unroll = 8; | ||
float temp[num_unroll]; | ||
float sum_temp[num_unroll] = { 0.0f }; | ||
for (int i = col; i < ncols;) { | ||
for (int j = 0; j < num_unroll; ++j) { | ||
if (i < ncols) { | ||
ORippler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
temp[j] = x[row * ncols + i]; | ||
} else { | ||
temp[j] = 0; | ||
} | ||
i += blockDim.x; | ||
} | ||
for (int j = 0; j < num_unroll; ++j) { | ||
sum_temp[j] += temp[j]; | ||
} | ||
} | ||
for (int j = 0; j < num_unroll; ++j) { | ||
sum += sum_temp[j]; | ||
} | ||
|
||
// sum up partial sums | ||
sum = warp_reduce_sum(sum); | ||
if (blockDim.x > WARP_SIZE) { | ||
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); | ||
__shared__ float s_sum[32]; | ||
const int warp_id = threadIdx.x / WARP_SIZE; | ||
const int lane_id = threadIdx.x % WARP_SIZE; | ||
if (lane_id == 0) { | ||
s_sum[warp_id] = sum; | ||
} | ||
__syncthreads(); | ||
sum = 0.0f; | ||
if (lane_id < (blockDim.x / WARP_SIZE)) { | ||
sum = s_sum[lane_id]; | ||
} | ||
sum = warp_reduce_sum(sum); | ||
} | ||
|
||
if (col != 0) { | ||
return; | ||
} | ||
|
||
dst[row] = norm ? sum / ncols : sum; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.