-
Notifications
You must be signed in to change notification settings - Fork 12.6k
CUDA/HIP: ssm-scan: switch from shared memory to reisters, fixes indexing problem on warp64 devices #15101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…xing problem on warp64 devices
Sorry, there is a concurrent PR that I forgot about: #13291 . Can you check whether that one would also fix the issue? |
Both prs fix the correctness issue, performance is the same on CDNA, so either pr is fine with me. |
Im not sure the complexity of the other pr is worth it given the same performance, but perhaps the cub path makes it worth it for nv devices. for the record i am testing with this added to test-backend-ops:
|
For RDNA2 the other pr is more favorable, winning in the small n_seqs case while loosing in the large n_seqs case:
|
After upgrading llvm, ssm_scan_f32 is failing again on warp64 devices.
Taking another closer look at the code, i dont get how this could have possibly ever worked. I also dont get what
llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu
Line 52 in fd1234c
Further, given the kernel is restricted to N == 16 anyhow (but if you where to invoke it with N != 16 it simply do sent load the data from global memory into shared memory, and works on it uninitialized, huh?) i dont understand why we are bothering with loading from global memory into shared memory at all here as loading directly into registers makes the kernel use just 63 vector registers at N == 16 and we are not shearing any data.
As far as i know GCN/CDNA is the most register starved of all modern architectures and even there we have 256 64 wide vector registers resulting in an occupancy of 4 there.
So thats what this pr dose.