Skip to content

Conversation

sanchitintel
Copy link

@sanchitintel sanchitintel commented Sep 29, 2025

Problem

The Existing MoEGEMM implementation is using duplicated headers from cutlass-sycl to use cutlass Group GEMM (except for commenting out a line in a duplicated xe_array_epilogue.hpp file, which pertains to not using the C matrix, but then that change destroys the generality of groupgemm, so might as well just use a separate MoEGEMM implementation).

Performance issues with the MoEGEMM implementation:

  1. B matrix was being transposed before calling the cutlass kernel
  2. There were multiple arrays being transferred from Host to Device
  3. This implementation doesn't require transferring num_tokens_per_expert from GPU to CPU, so this implementation will be useful towards developing a fully fused implementation in the future. The xetla implementation in IPEX also already has num_tokens_per_expert on the GPU.

Integration issues:

  1. cutlass headers had to be unnecessarily duplicated. Code maintenance was problematic because more than 5k lines were copied from the cutlass-sycl repo, and any changes made to them were not readily discernible. Most files were changed only to modify include statatments.
  2. More cutlass headers will have to be duplicated in the future, so this approach is not scalable

Solution

  1. Do not transpose B matrix
  2. Only pass the base A, B, C, D matrix pointers to the GPU. For C , it's nullptr.
  3. Prevent D2H & H2D transfers (except for passing a GPU pointer of pointer which points to a base matrix).
  4. Compute tensors' pointer offsets in cutlass kernels
  5. Instead of copy-pasting cutlass headers, use a separate cutlass branch, and regularly merge main branch commits to it.
    Reference for cutlass changes: https://github.com/intel/cutlass-sycl/tree/vllm_xpu_cutlass
  6. The interface (e.g. passing cutlass kernel arguments) looks cleaner now, at least to me.

Follow-up

  1. Improve performance further
  2. Consider renaming GroupGEMM in the API to MoEGEMM (perhaps in this PR)?

Test Plan

UTs

  • Measure E2E performance
  • Measure kernel performance (can't be compared to the previous implementation directly since it pre-transposed B on CPU, and then used RowMajor B in the GEMM computation, whereas B is ColumnMajor in this PR. 32x32 transpose copy atom has not been added in cutlass-sycl yet, and performance will be even better once it's added).

cc @pengzhao-intel @Liangliang-Ma @jikunshang @YizhouZ @baodii @rogerxfeng8

Please review this PR. I'll revise it as per review suggestions.
Please don't create a separate PR that uses code from this PR (including changes in cutlass) without attribution. Thank you!

CMakeLists.txt Outdated

# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "2eeb05da5b801b34114b6b394dcef836fc9a7cc9" CACHE STRING "CUTLASS revision to use")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit is from a new branch vllm_xpu_cutlass of https://github.com/intel/cutlass-sycl.

@pengzhao-intel
Copy link
Collaborator

Any performance data of this PR?

@pengzhao-intel
Copy link
Collaborator

Prevent D2H & H2D transfers (except for passing a GPU pointer of pointer which points to a base matrix).
how do you avoid these D2H and H2D?

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

Any performance data of this PR?

Hi @pengzhao-intel, the cutlass kernel performance in this PR can't be compared directly to the previous implementation because:

  1. The main branch transposed B on CPU, but not in this PR, so we save up on that overhead outside the cutlass kernel.
  2. The main branch uses RowMajor B in the GEMM computation, whereas B is ColumnMajor in this PR. The GEMM kernel performance differs for RowMajor B & ColumnMajor B, especially since 32x32 transpose copy atom for 16-bit dtypes doesn't currently exist.
    I can measure the standalone cutlass kernel performance with RowMajor B with MoEGEMM as an extension of GroupGEMM intel/sycl-tla#520. Then the comparison will be apple-to-apple, because there's no point in comparing performance of RowMajor B in main branch vs. ColumnMajor B in this branch.

FWIW, the standalone cutlass performance with RowMajor B using the standalone kernel was 12.8% better with num_experts=16, M_per_expert=256, N=16384, K=5120 on B580.

Thanks!

@jikunshang
Copy link
Collaborator

cc @mayuyuace Please take a look for interface.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

@pengzhao-intel,

how do you avoid these D2H and H2D?

  1. In the main branch, strides of A, B, and D matrices of each expert were being computed on CPU & were then being passed to GPU.
  2. Besides that, pointer computation of each expert's A, B, and D matrix was being done on CPU, and then those pointers were being transferred to GPU.
  3. alpha & beta arrays were unnecessarily being transferred from H2D.
  4. num_tokens_per_expert need not be sent to the CPU now.

Now that computation happens on the GPU, with just A, B, C, and D pointers being provided.
Technically, this implementation is still doing an H2D copy of 4 pointers of pointer to tensor, but I can remove that one as well, if you'd like. I didn't do it in this PR because this requires further duplication of cutlass headers - existing GroupGEMM MMA & epilogue collectives accept a pointer of pointer as argument for A/B/C/D.

Thanks!

@mayuyuace
Copy link
Collaborator

@Liangliang-Ma Please review this.

@Liangliang-Ma
Copy link
Collaborator

Hi Sanchit, thanks a lot for your contribution. Could you please add your optimization in the cutlass-sycl example first? We can then take care of the framework integration part on our side, since it involves broader validation and alignment with the frontend.

@sanchitintel
Copy link
Author

Hi @Liangliang-Ma, thanks for taking a look!

Could you please add your optimization in the cutlass-sycl example first?

Please advise as to which optimization are you referring to? That example is also not using D2H or H2D copies except for 4 pointers to pointers.
It has the same optimizations as listed in #48 (comment).
The example has some dead code (unused code is optimized out by the compiler), though, but it should be okay, as it's not an integration example (but I'll probably have to revise it before its landing in cutlass-sycl).

As for this PR, please compare E2E performance or the kernel performance (but with RowMajor B).
This PR's implementation is faster than the main branch on an Arc B580.
My second PR will improve the performance further.

Thanks!

We can then take care of the framework integration part on our side, since it involves broader validation and alignment with the frontend

Does the framework provide alpha & beta arrays with values for each group/expert?
That's not a requirement, and yet you're creating dummy tensors.

TBH, one of my concerns is your reuse of code without attribution.

@Liangliang-Ma
Copy link
Collaborator

Liangliang-Ma commented Sep 29, 2025

Which optimization are you referring to?

If there is any optimization in mainloop or tilescheduler, that will be cool :)
These H2D/D2H, stride/pointer calculation and num_tokens_per_expert will be refactored in a pre-gemm kernel later, as what flashinfer do. Currently we do it in python side and integrate grouped_gemm first.
Alpha/beta can be removed indeed.

@sanchitintel
Copy link
Author

If there is any optimization in mainloop or tilescheduler

After eliminating D2H & H2D transfers, pointer computation is done on the GPU. But that mostly affects the latency outside the cutlass kernel.
Can you please run the standalone kernel with RowWise B from intel/sycl-tla#520 (comment)? Please modify it a bit to compare against the cutlass kernel in this repo.
It performs better than the cutlass kernel in the main branch of this repo.

@Liangliang-Ma
Copy link
Collaborator

TBH, one of my concerns is your reuse of code without attribution.

Thanks for your note. However, my code is derived independently from the original open-source PR ([https://github.com/intel/sycl-tla/pull/252]), the same source you also used :) Please check the commit time if you have any further concern.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

num_tokens_per_expert will be refactored in a pre-gemm kernel later, as what flashinfer do. Currently we do it in python side and integrate grouped_gemm first

Can you please let me know why this specific change has to be delayed? It's independent to changing the cutlass implementation.

Please take a look at the diff. With these changes, you need not send tokens_per_expert to CPU - it'd stay on the GPU (unlike the main branch, which first sends it to CPU, and then re-creates a GPU tensor).

idxs = topk_ids.argsort()
counts = topk_ids.to(torch.int).bincount()
tokens_per_expert = counts.cumsum()
num_per_tok = n_experts_per_token
token_idxs = idxs // num_per_tok
########### gemm1 ##################
input_B = w13
assert (list(input_A.shape)[0] == total_input_size)
gemm_args = prepare_gemm_args(2 * intermediate_size, hidden_size,
input_A, input_B, gemm1_output,
num_experts)
torch.ops._xpu_C.cutlass_grouped_gemm(tokens_per_expert=tokens_per_expert,

@sanchitintel
Copy link
Author

Thanks for your note. However, my code is derived independently from the original open-source PR ([https://github.com/intel/cutlass-sycl/pull/252]), the same source you also used :) Please check the commit time if you have any further concern.

We can agree to disagree :)
Besides, #22 (comment).

I attributed whatever code I reused

@Liangliang-Ma
Copy link
Collaborator

After eliminating D2H & H2D transfers, pointer computation is done on the GPU. But that mostly affects the latency outside the cutlass kernel.

Yes, Like I said D2H/H2D will be handled along with another kernel to eliminate. We would like a overall optimization on fused moe not only in groupedgemm.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

After eliminating D2H & H2D transfers, pointer computation is done on the GPU. But that mostly affects the latency outside the cutlass kernel.

Yes, Like I said D2H/H2D will be handled along with another kernel to eliminate. We would like a overall optimization on fused moe not only in groupedgemm.

As I mentioned in some earlier comments, the standalone cutlass kernel (not referring to E2E performance) in this PR already performs better than the standalone cutlass kernel in the main branch.

Is performance not the yardstick to determine what implementation to retain?

Can you please modify & run the standalone kernel with RowMajor B from intel/cutlass-sycl#520 (comment) with LLaMA 4 input shapes?

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

Like I said D2H/H2D will be handled along with another kernel to eliminate.
Currently we do it in python side

Can you please explain what's a technical reason for not making this change now?
Is it because it's done that way in the model?

The technical reason you provided for NOT sending tokens_per_expert to CPU and then back to GPU (thereby keeping it on GPU) does not seem reasonable, as it's not even related to the cutlass implementation.

@Liangliang-Ma
Copy link
Collaborator

I attributed whatever code I reused

We can agree to disagree :) like in #intel/sycl-tla#520 (comment)
What's point when I finished most of code before you provided anything and we still talk about attribution right now?

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

What's point when I finished most of code before you provided anything and we still talk about attribution right now?

I was never supposed to provide a vLLM integration example.

I had provided you a cutlass example more than 2 weeks ago, and had requested you to modify it for integration.
I informed you of having resolved the FP32 -> BF16 epilogue issue on Sep 11 (Sep 12 at your end).
You applied it in your PR on Sep 13.
You had been working on your own without informing us, keeping us in the dark, and that wasted our time as well.

Even my branch that computes pointers on GPU (intel/sycl-tla#520) is older than your vLLM group GEMM PR.

My concern about attributions was just that you did not add attributions at #22 (comment), and neither did you bring it up later.

We can agree to disagree :)

Sure! Replied at intel/sycl-tla#520 (comment)

@Liangliang-Ma
Copy link
Collaborator

I had provided you a cutlass example more than 2 weeks ago, and had requested you to modify it for integration.
Even my branch that computes pointers on GPU (intel/sycl-tla#520) is older than your vLLM group GEMM PR.

You can check my vllm Grouped Gemm PR, which was started in Aug and when you provide it 2 weeks ago I begun fixing e2e issues. Why you can not just open the commit list and see the time stamp?

You had been working on your own without informing us, keeping us in the dark, and that wasted our time as well

Why dont you give me the link of origin grouped gemm example while you just change the frontend? When I saw your code and found what you did was basically same as reusing example, I was shocked.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

You can check my vllm Grouped Gemm PR, which was started in Aug and when you provide it 2 weeks ago I begun fixing e2e issues.

intel/sycl-tla#520 does pointer computation on GPU

Why you can not just open the commit list and see the time stamp?

I informed you of the FP32 -> BF16 conversion in epilogue solution on Sep 11 (Sep 12 at your end). You used it on Sep 13 but didn't inform me.

Why dont you give me the link of origin grouped gemm example while you just change the frontend?

When Patric told me that num_tokens_per_expert MUST be on the GPU, I added the changes to the same branch as intel/sycl-tla#520, but with git reset --soft. Like your implementation, my previous implementation was a wrapper on cutlass GroupGEMM (yours also supported nullptr C but that one didn't).

When I saw your code and found what you did was basically same as reusing example, I was shocked

See the header of your group gemm kernel in vLLM main branch, which is actually for a cutlass-sycl example (so I don't see why you're shocked), and explicitly states it has been copy-pasted:

This file is almost a complete copy of 04_bmg_grouped_gemm,
except that it's used for FP8 (E5M2 & E4M3) datatype inputs.

I had added that specific header in cutlass.
All the examples in cutlass-sycl/examples are copy-pasted to create more examples.
In fact, nvidia cutlass repo follows the same convention.

@Liangliang-Ma
Copy link
Collaborator

so I don't see why you're shocked

Because I worked for vLLM and copied example as a base. And you as a kernel developer also copied example.
My major work lands on integration and I dont think reuse example with moving offset calc from some place to another place would waste you too much time.

We need performant kernel not this case of most file change in frontend.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

We need performant kernel not this case of most file change in frontend.

Again, please refer to #48 (comment).
The standalone cutlass kernel (I'm NOT referring to E2E performance, which is better with this PR, anyway) performs better than the one in the main branch. Is that not what you mean by backend changes?
Please specify what you mean by backend changes?

so I don't see why you're shocked

Because I worked for vLLM and copied example as a base. And you as a kernel developer also copied example.

All examples in cutlass-sycl/examples are reused as subsequent examples.

My major work lands on integration and I dont think reuse example with moving offset calc from some place to another place would waste you too much time.

That's not the point. You didn't inform us that you were also working on it. We also spent time on discussing & developing some other optimizations I'll share later.

Even when I told you that I resolved the FP32 -> BF16 conversion in epilogue issue, you still didn't inform us.

@Liangliang-Ma
Copy link
Collaborator

Just profiled with llama4-scout moe config with 8192 tokens evenly distributed to experts. This PR takes 21.6ms on first gemm and 11.4ms on second gemm. While main's 16.9/8.6ms.
The profile code is from #https://github.com/Liangliang-Ma/vllm-xpu-kernels/blob/dev/tests/cutlass/profile_moe.py
@sanchitintel

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls   Total FLOPs
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                           _xpu_C::cutlass_grouped_gemm         5.54%      54.833ms         6.08%      60.162ms      30.081ms      32.862ms        48.35%      32.862ms      16.431ms             2            --
_ZTSN10syclcompat12experimental6detail13KernelFuncto...         0.00%       0.000us         0.00%       0.000us       0.000us      32.834ms        48.31%      32.834ms      16.417ms             2            --
                                   xetla_fused_moe_gemm         0.20%       1.946ms         0.31%       3.041ms       1.521ms      32.812ms        48.28%      32.812ms      16.406ms             2            --
gpu::xetla::MoEGEMM<sycl::_V1::ext::oneapi::bfloat16...         0.00%       0.000us         0.00%       0.000us       0.000us      32.812ms        48.28%      32.812ms      16.406ms             2            --
                               torch_ipex::silu_and_mul         6.23%      61.708ms         6.31%      62.484ms      31.242ms       2.024ms         2.98%       2.024ms       1.012ms             2            --
at::AtenIpexTypeXPU::impl::op_and_mul_functor<c10::B...         0.00%       0.000us         0.00%       0.000us       0.000us       2.024ms         2.98%       2.024ms       1.012ms             2            --
                 Memcpy D2M (DEVICE -> MEMORY(Unknown))         0.00%       0.000us         0.00%       0.000us       0.000us     109.877us         0.16%     109.877us       2.073us            53            --
                              aten::_local_scalar_dense         1.32%      13.088ms         2.02%      20.019ms     488.278us      82.278us         0.12%      82.278us       2.007us            41            --
                                              aten::cat         1.61%      15.893ms         1.90%      18.796ms       2.685ms      48.015us         0.07%      48.015us       6.859us             7            --
                                            aten::copy_         4.39%      43.469ms         4.88%      48.272ms       3.448ms      38.640us         0.06%      38.640us       2.760us            14            --
                 Memcpy M2D (MEMORY(Unknown) -> DEVICE)         0.00%       0.000us         0.00%       0.000us       0.000us      32.704us         0.05%      32.704us       3.634us             9            --
at::native::xpu::CatArrayBatchedCopyKernelFunctor<c1...         0.00%       0.000us         0.00%       0.000us       0.000us      24.894us         0.04%      24.894us       3.556us             7            --
                            Memcpy H2D (HOST -> DEVICE)         0.00%       0.000us         0.00%       0.000us       0.000us      23.121us         0.03%      23.121us       3.303us             7            --
                                              aten::min         1.78%      17.609ms         1.81%      17.968ms      17.968ms      10.312us         0.02%      10.312us      10.312us             1            --
at::native::xpu::ReduceKernel<1, at::native::xpu::Re...         0.00%       0.000us         0.00%       0.000us       0.000us      10.312us         0.02%      10.312us      10.312us             1            --
                                              aten::max         1.77%      17.491ms         1.80%      17.843ms      17.843ms      10.208us         0.02%      10.208us      10.208us             1            --
at::native::xpu::ReduceKernel<1, at::native::xpu::Re...         0.00%       0.000us         0.00%       0.000us       0.000us      10.208us         0.02%      10.208us      10.208us             1            --
                                               aten::ne        17.99%     178.078ms        18.06%     178.787ms      59.596ms       9.791us         0.01%       9.791us       3.264us             3            --
                                            aten::index        23.36%     231.314ms        25.29%     250.346ms     250.346ms       7.083us         0.01%      25.935us      25.935us             1            --
at::native::xpu::IndexKernelFunctor<at::native::xpu:...         0.00%       0.000us         0.00%       0.000us       0.000us       7.083us         0.01%       7.083us       7.083us             1            --
                                      aten::bitwise_and        10.17%     100.647ms        10.23%     101.257ms     101.257ms       6.979us         0.01%       6.979us       6.979us             1            --
at::native::xpu::VectorizedElementwiseKernel<16, at:...         0.00%       0.000us         0.00%       0.000us       0.000us       6.979us         0.01%       6.979us       6.979us             1            --
at::native::xpu::VectorizedElementwiseKernel<8, at::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.875us         0.01%       6.875us       3.438us             2            --
at::native::xpu::VectorizedElementwiseKernel<2, at::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.769us         0.01%       6.769us       2.256us             3            --
                                            copy_if_xpu         0.35%       3.477ms         1.75%      17.302ms      17.302ms       6.666us         0.01%      14.477us      14.477us             1            --
at::native::xpu::UnrolledElementwiseKernel<at::nativ...         0.00%       0.000us         0.00%       0.000us       0.000us       6.458us         0.01%       6.458us       6.458us             1            --
                                              aten::abs         1.96%      19.379ms         4.00%      39.611ms       9.903ms       6.457us         0.01%      12.914us       3.229us             4            --
at::native::xpu::VectorizedElementwiseKernel<8, at::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.457us         0.01%       6.457us       3.229us             2            --
                                              aten::mul        12.31%     121.909ms        12.39%     122.655ms     122.655ms       6.145us         0.01%       6.145us       6.145us             1        36.000
at::native::xpu::VectorizedElementwiseKernel<16, at:...         0.00%       0.000us         0.00%       0.000us       0.000us       6.145us         0.01%       6.145us       6.145us             1            --
                                              aten::div         0.59%       5.851ms         0.69%       6.859ms       6.859ms       5.937us         0.01%       5.937us       5.937us             1            --
at::native::xpu::VectorizedElementwiseKernel<2, at::...         0.00%       0.000us         0.00%       0.000us       0.000us       5.937us         0.01%       5.937us       5.937us             1            --
at::native::xpu::VectorizedElementwiseKernel<8, at::...         0.00%       0.000us         0.00%       0.000us       0.000us       5.624us         0.01%       5.624us       2.812us             2            --
                                     inclusive_scan_xpu         1.20%      11.897ms         1.24%      12.234ms      12.234ms       5.416us         0.01%       5.416us       5.416us             1            --
at::native::xpu::pstl::KSScanKernelFunctor<1, long*,...         0.00%       0.000us         0.00%       0.000us       0.000us       5.416us         0.01%       5.416us       5.416us             1            --
                                             aten::ceil         2.09%      20.685ms         2.12%      20.992ms      20.992ms       5.000us         0.01%       5.000us       5.000us             1            --
at::native::xpu::VectorizedElementwiseKernel<8, at::...         0.00%       0.000us         0.00%       0.000us       0.000us       5.000us         0.01%       5.000us       5.000us             1            --
                                               aten::gt         0.10%       1.025ms         0.12%       1.151ms     575.524us       4.478us         0.01%       4.478us       2.239us             2            --
                                          aten::nonzero         0.04%     375.541us         1.85%      18.281ms      18.281ms       4.375us         0.01%      18.852us      18.852us             1            --
      at::native::xpu::FlattenIdxtoRealIdxKernelFunctor         0.00%       0.000us         0.00%       0.000us       0.000us       4.375us         0.01%       4.375us       4.375us             1            --
at::native::xpu::pstl::PredictKernelFunctor<long, lo...         0.00%       0.000us         0.00%       0.000us       0.000us       4.062us         0.01%       4.062us       4.062us             1            --
                                               aten::eq         0.03%     256.898us         0.04%     391.620us     391.620us       2.708us         0.00%       2.708us       2.708us             1            --
at::native::xpu::pstl::CopyIfKernelFunctor<long, lon...         0.00%       0.000us         0.00%       0.000us       0.000us       2.604us         0.00%       2.604us       2.604us             1            --
                                               aten::lt         0.01%      77.085us         0.01%     122.278us     122.278us       2.291us         0.00%       2.291us       2.291us             1            --
                                    #####ipex_moegemm_1         1.92%      19.007ms         2.20%      21.809ms      21.809ms       0.000us         0.00%      18.902ms      18.902ms             1            --
                             torch_ipex::fused_moe_gemm         0.00%      33.904us         0.31%       3.075ms       1.538ms       0.000us         0.00%      32.812ms      16.406ms             2            --
                                            aten::empty         0.02%     189.241us         0.02%     189.241us       8.228us       0.000us         0.00%       0.000us       0.000us            23            --
                                  urEnqueueKernelLaunch         1.10%      10.924ms         1.10%      10.924ms     321.301us       0.000us         0.00%       0.000us       0.000us            34            --
                                    #####ipex_moegemm_2         1.43%      14.169ms         1.46%      14.443ms      14.443ms       0.000us         0.00%      13.910ms      13.910ms             1            --
                                        aten::transpose         0.00%      38.969us         0.01%      50.401us      16.800us       0.000us         0.00%       0.000us       0.000us             3            --
                                       aten::as_strided         0.01%     117.465us         0.01%     117.465us       1.013us       0.000us         0.00%       0.000us       0.000us           116            --
                                 @@@@@cutlass_moegemm_1         0.10%       1.008ms         5.91%      58.530ms      58.530ms       0.000us         0.00%      21.642ms      21.642ms             1            --
                                               aten::to         0.00%      45.513us         4.89%      48.440ms       3.027ms       0.000us         0.00%      38.640us       2.415us            16            --
                                         aten::_to_copy         0.01%      87.602us         4.89%      48.394ms       3.457ms       0.000us         0.00%      38.640us       2.760us            14            --
                                    aten::empty_strided         0.00%      34.699us         0.00%      34.699us       2.479us       0.000us         0.00%       0.000us       0.000us            14            --
                                     urEnqueueUSMMemcpy         1.77%      17.559ms         1.77%      17.559ms     254.484us       0.000us         0.00%       0.000us       0.000us            69            --
                                       aten::lift_fresh         0.00%       3.791us         0.00%       3.791us       3.791us       0.000us         0.00%       0.000us       0.000us             1            --
                                          aten::detach_         0.00%       3.507us         0.00%       6.465us       6.465us       0.000us         0.00%       0.000us       0.000us             1            --
                                                detach_         0.00%       2.958us         0.00%       2.958us       2.958us       0.000us         0.00%       0.000us       0.000us             1            --
                                       urUSMDeviceAlloc         0.03%     271.095us         0.03%     271.095us      33.887us       0.000us         0.00%       0.000us       0.000us             8            --
                                 @@@@@cutlass_moegemm_2         0.44%       4.362ms        84.10%     832.643ms     832.643ms       0.000us         0.00%      11.484ms      11.484ms             1            --
                                           aten::select         0.02%     213.297us         0.03%     280.960us       3.267us       0.000us         0.00%       0.000us       0.000us            86            --
                                            aten::slice         0.02%     169.948us         0.02%     201.047us       8.377us       0.000us         0.00%       0.000us       0.000us            24            --
                                         urUSMHostAlloc         0.01%     142.862us         0.01%     142.862us     142.862us       0.000us         0.00%       0.000us       0.000us             1            --
                                            aten::stack         0.00%      32.971us         0.05%     449.183us     449.183us       0.000us         0.00%       9.270us       9.270us             1            --
                                             aten::view         0.00%      17.023us         0.00%      17.023us       5.674us       0.000us         0.00%       0.000us       0.000us             3            --
                                          aten::reshape         0.00%      10.771us         0.00%      21.457us      10.728us       0.000us         0.00%       0.000us       0.000us             2            --
                                         aten::isfinite         0.01%     141.338us        32.44%     321.204ms     321.204ms       0.000us         0.00%      17.290us      17.290us             1            --
                                          aten::resize_         0.01%      60.071us         0.01%      60.071us      15.018us       0.000us         0.00%       0.000us       0.000us             4            --
                                          aten::__and__         0.00%      31.858us        10.23%     101.289ms     101.289ms       0.000us         0.00%       6.979us       6.979us             1            --
                                    aten::masked_select         0.01%     116.955us        25.30%     250.476ms     250.476ms       0.000us         0.00%      25.935us      25.935us             1            --
                                             aten::item         0.01%      90.860us         2.03%      20.110ms     490.494us       0.000us         0.00%      82.278us       2.007us            41            --
                                                aten::t         0.00%       5.865us         0.00%      20.422us      20.422us       0.000us         0.00%       0.000us       0.000us             1            --
                                             aten::set_         0.00%      22.405us         0.00%      22.405us      22.405us       0.000us         0.00%       0.000us       0.000us             1            --
                                           aten::unbind         0.01%      66.861us         0.02%     197.674us      98.837us       0.000us         0.00%       0.000us       0.000us             2            --
                                       aten::is_nonzero         0.00%      13.968us         0.22%       2.141ms     535.170us       0.000us         0.00%      10.206us       2.552us             4            --
                                     aten::resolve_conj         0.00%       5.925us         0.00%       5.925us       0.494us       0.000us         0.00%       0.000us       0.000us            12            --
                                      aten::resolve_neg         0.00%       4.806us         0.00%       4.806us       0.401us       0.000us         0.00%       0.000us       0.000us            12            --
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

@Liangliang-Ma, for the first GEMM, I get 15.04 ms.

I have commented several times already that an apple-to-apple comparison would be with RowMajor B or ColumnMajor Bfor both implementations, and yet you're displaying data for ColumnMajor B for this implementation, and benchmarked with RowMajor B for the current implementation.

Even after repeatedly mentioning this fact, you still didn't do an apple-to-apple comparison. :(

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

@Liangliang-Ma, if you really want an apple-to-apple comparison, transpose B on CPU, like you already do on the main branch,
and then use RowMajor B for this PR with a 32x32 copy atom ending in _N _V (again, just as in the main branch).

cc @pengzhao-intel

@Liangliang-Ma
Copy link
Collaborator

@sanchitintel pls provide code for comparison.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

@Liangliang-Ma,

Sorry, but you seem to be deliberately stalling an apple-to-apple comparison, which needs 3 changes on top of this PR:

  1. Transpose B matrices, just as your python code in the main branch does
  2. Use the same copy atom for B as you use in main-branch,
  3. use RowMajor B, just as you do in the main branch

@Liangliang-Ma
Copy link
Collaborator

Liangliang-Ma commented Sep 29, 2025

@sanchitintel what do you mean by deliberately stalling it? why you not just provide the code to let me profile? Give me a branch or something. It's simple.

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

@Liangliang-Ma, here's the diff (try editing this comment without actually modifying anything. Then you can copy-paste the diff and apply it)-

diff --git a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp
index 507a31b..a94954d 100644
--- a/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp
+++ b/csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp
@@ -223,19 +223,17 @@ void kernel_functor(sycl::queue& stream, void* ptr_A, void* ptr_B, void* ptr_D,
using ElementOutput = bfloat16_t;

using LayoutA = cutlass::layout::RowMajor;

  • using LayoutB = cutlass::layout::ColumnMajor;
  • using LayoutB = cutlass::layout::RowMajor;
    using LayoutC = cutlass::layout::RowMajor;
    using LayoutD = cutlass::layout::RowMajor;

    using TileShape = Shape<_256, _256, _32>;
    using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;

  • using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
  • using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
    // This TiledMMA is the default one in intel/cutlass-sycl examples
    using TiledMma =
  •  TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
    
  •           Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
    
  •           Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
    
  •                Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;
    
  •  typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
    
  •                                Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
    

    constexpr int PipelineStages = 2;
    using GEMMDispatchPolicy =
    diff --git a/tests/fused_moe/test_fused_moe.py b/tests/fused_moe/test_fused_moe.py
    index 9211bdc..fbeab3e 100644
    --- a/tests/fused_moe/test_fused_moe.py
    +++ b/tests/fused_moe/test_fused_moe.py
    @@ -56,6 +56,7 @@ def test_grouped_gemm(m, n, k, e, topk, dtype):
    ref_A = input_A.clone()

    weight

    input_B = torch.randn((num_experts, n, k), dtype=dtype, device=DEVICE)

  • input_B = input_B.transpose(-1, -2).contiguous().transpose(-1, -2)

    output offset

    output = torch.empty((num_tokens_after_duplication, n), dtype=dtype, device=DEVICE)
    diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py
    index 4cdabcd..a599433 100644
    --- a/vllm_xpu_kernels/fused_moe_interface.py
    +++ b/vllm_xpu_kernels/fused_moe_interface.py
    @@ -67,7 +67,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids,
    token_idxs = idxs // num_per_tok

    ########### gemm1 ##################

  • input_B = w13
  • input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2)
    assert (list(input_A.shape)[0] == total_input_size)
    gemm_args = prepare_gemm_args(2 * intermediate_size, hidden_size,
    input_A, input_B, gemm1_output,
    @@ -83,7 +83,7 @@ def cutlass_fused_moe(hidden_states, w13, w2, topk_weights, topk_ids,

    ########### gemm2 ##################
    input_A = act_output.contiguous()

  • input_B = w2
  • input_B = w2.transpose(-1, -2).contiguous().transpose(-1, -2)
    gemm_args = prepare_gemm_args(hidden_size, intermediate_size,
    input_A, input_B, gemm2_output,
    num_experts)

With 100 iterations of your script, I'm getting unexpectedly good numbers with this PR - 9.26 ms & 2.59 ms 😮

image

Can you please take a look? Thanks!

@Liangliang-Ma
Copy link
Collaborator

I'm getting unexpectedly good numbers with this PR - 9.26 ms & 2.59 ms

Are we talking about CPU times when we compare kernels now? That's new to me.
Look at the kernel time:

This is this PR time in 100 iter aver for two gemms:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                           _xpu_C::cutlass_grouped_gemm        76.20%        1.149s        92.63%        1.397s       6.984ms        2.319s        95.81%        2.319s      11.593ms           200

And this is main:

   7 -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
   8                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls
   9 -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
  10                            _xpu_C::cutlass_grouped_gemm         2.17%     365.146ms         3.25%     546.100ms       2.731ms        2.343s        89.98%        2.343s      11.717ms           200

The XPU time avg has no much difference. The CPU time before the grouped gemm is our next step. Let's focus on XPU time now. Correct me if you have any concern about this data on xpu time. @sanchitintel

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

@Liangliang-Ma,

thanks for the data!

Let's focus on XPU time now

I'm consistently getting 11.33 ms for that number when the number of iterations are 200.
e.g. 11..332, 11.334, 11.337

Are we talking about CPU times when we compare kernels now? That's new to me.

Sorry, my bad! the output got truncated on my VSCode, and the table headers are in two lines :(

@Liangliang-Ma
Copy link
Collaborator

I'm consistently getting 11.33 ms for that number when the number of iterations are 200.
e.g. 11..332, 11.334, 11.337

So we can conclude that with your PR we can get 1% to 3% (somewhere btw it) optimization on grouped gemm kernel compared to cutlass example, right? We will discuss about it after our holidays. Thanks!

@sanchitintel
Copy link
Author

sanchitintel commented Sep 29, 2025

For the standalone cutlass kernel, I always see more than 3%, based on your baseline data.
As mentioned in the description, E2E performance will see more gains. In fact, the description doesn't even mention the fact that the standalone kernel is a bit faster.
Pointer offset computation on CPU is not necessary. It's not happening in some other kernel but python code written by you.
So, I'm not sure why you say that it happens in another kernel.

FWIW, the standalone kernel performance can also be improved even in this PR, since the absolute performance difference between just the cutlass implementations is higher (it was ~12% for these input shapes).

I'll make more changes.

Happy holidays!

@sanchitintel
Copy link
Author

sanchitintel commented Sep 30, 2025

@Liangliang-Ma,

Just FYI, if you'd pass static_cast<const ElementC**>((void*)ptr_D.get()) as the value of ptr_c argument for the epilogue, instead of nullptr (it's okay to do so, since C tiles don't actually get used when beta is 0), then it'd only be ~0.1 TFLOPs slower, or maybe even less.

Then you would be able to use the cutlass-sycl main branch (not referring to the cutlass branch used by this PR), and all the headers you duplicated from cutlass can be removed (provided you're okay with using the default sycl queue, or with passing a queue to a suitable GemmUniversalAdapter::run() API), which would make maintaining the code easier.

Thanks

@pengzhao-intel
Copy link
Collaborator

Closing this PR. Sanchit can focus on MoE kernel and make a PR in cutlass-sycl and pytorch and liangliang can do what vllm needs by themselves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants