Skip to content

CUDA/HIP: ssm-scan: switch from shared memory to reisters, fixes indexing problem on warp64 devices #15101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

IMbackK
Copy link
Collaborator

@IMbackK IMbackK commented Aug 5, 2025

After upgrading llvm, ssm_scan_f32 is failing again on warp64 devices.

Taking another closer look at the code, i dont get how this could have possibly ever worked. I also dont get what

smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
the ternary is trying to achieve here at all as its not avoiding a bank conflict, like the comment seams to suggest might have been the intent.

Further, given the kernel is restricted to N == 16 anyhow (but if you where to invoke it with N != 16 it simply do sent load the data from global memory into shared memory, and works on it uninitialized, huh?) i dont understand why we are bothering with loading from global memory into shared memory at all here as loading directly into registers makes the kernel use just 63 vector registers at N == 16 and we are not shearing any data.

As far as i know GCN/CDNA is the most register starved of all modern architectures and even there we have 256 64 wide vector registers resulting in an occupancy of 4 there.

So thats what this pr dose.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 5, 2025
@JohannesGaessler
Copy link
Collaborator

Sorry, there is a concurrent PR that I forgot about: #13291 . Can you check whether that one would also fix the issue?

@IMbackK
Copy link
Collaborator Author

IMbackK commented Aug 6, 2025

Both prs fix the correctness issue, performance is the same on CDNA, so either pr is fine with me.

@IMbackK
Copy link
Collaborator Author

IMbackK commented Aug 6, 2025

Im not sure the complexity of the other pr is worth it given the same performance, but perhaps the cub path makes it worth it for nv devices.

for the record i am testing with this added to test-backend-ops:

for (int seq_len : {1, 32, 128, 512, 1024}) {
        for (int batch_size : { 1, 2, 8, }) {
            test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, seq_len, batch_size));
        }
    }

@IMbackK
Copy link
Collaborator Author

IMbackK commented Aug 6, 2025

For RDNA2 the other pr is more favorable, winning in the small n_seqs case while loosing in the large n_seqs case:

#13291:
Backend 1/2: ROCm0
  Device description: AMD Radeon RX 6800 XT
  Device memory: 16368 MB (16304 MB free)

  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1,n_seqs=1):             229376 runs -     4.43 us/run -      204 kB/run -   43.95 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1,n_seqs=2):             221184 runs -     4.55 us/run -      344 kB/run -   72.12 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1,n_seqs=8):             188416 runs -     5.46 us/run -     1185 kB/run -  206.89 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=1):             32768 runs -    32.88 us/run -      580 kB/run -   16.82 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=2):             32768 runs -    32.87 us/run -     1096 kB/run -   31.80 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=8):             32020 runs -    35.79 us/run -     4192 kB/run -  111.71 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=128,n_seqs=1):                    16384 runs -    97.19 us/run -     1744 kB/run -   17.11 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=128,n_seqs=2):                    16384 runs -    95.75 us/run -     3424 kB/run -   34.10 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=128,n_seqs=8):                     9940 runs -   120.74 us/run -    13504 kB/run -  106.66 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=512,n_seqs=1):                     5243 runs -   440.24 us/run -     6400 kB/run -   13.86 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=512,n_seqs=2):                     2635 runs -   439.45 us/run -    12736 kB/run -   27.64 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=512,n_seqs=8):                     2648 runs -   458.07 us/run -    50752 kB/run -  105.66 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1024,n_seqs=1):                    2662 runs -   870.29 us/run -    12608 kB/run -   13.82 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1024,n_seqs=2):                    1335 runs -   870.47 us/run -    25152 kB/run -   27.56 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1024,n_seqs=8):                    1340 runs -   914.04 us/run -   100416 kB/run -  104.77 GB/s
#15101:
Backend 1/2: ROCm0
  Device description: AMD Radeon RX 6800 XT
  Device memory: 16368 MB (16304 MB free)

  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1,n_seqs=1):             188416 runs -     5.35 us/run -      204 kB/run -   36.36 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1,n_seqs=2):             188416 runs -     5.49 us/run -      344 kB/run -   59.82 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1,n_seqs=8):             163840 runs -     6.30 us/run -     1185 kB/run -  179.51 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=1):             49152 runs -    21.07 us/run -      580 kB/run -   26.25 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=2):             49152 runs -    21.41 us/run -     1096 kB/run -   48.82 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=8):             40025 runs -    30.84 us/run -     4192 kB/run -  129.61 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=128,n_seqs=1):                    16384 runs -    74.84 us/run -     1744 kB/run -   22.22 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=128,n_seqs=2):                    16384 runs -    75.17 us/run -     3424 kB/run -   43.44 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=128,n_seqs=8):                    12425 runs -    99.43 us/run -    13504 kB/run -  129.52 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=512,n_seqs=1):                     5243 runs -   354.38 us/run -     6400 kB/run -   17.22 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=512,n_seqs=2):                     5270 runs -   341.87 us/run -    12736 kB/run -   35.53 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=512,n_seqs=8):                     3310 runs -   370.42 us/run -    50752 kB/run -  130.66 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1024,n_seqs=1):                    2662 runs -   699.56 us/run -    12608 kB/run -   17.19 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1024,n_seqs=2):                    2670 runs -   674.34 us/run -    25152 kB/run -   35.57 GB/s
  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=1024,n_seqs=8):                    1675 runs -   732.02 us/run -   100416 kB/run -  130.82 GB/s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants